# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- import argparse import datetime import gc import itertools import logging import os import sys import time import numpy as np import onnx import psutil import torch from benchmark_helper import measure_memory, setup_logger from dist_settings import get_rank, get_size from llama_inputs import ( add_io_bindings_as_ortvalues, get_merged_sample_with_past_kv_inputs, get_msft_sample_inputs, get_sample_inputs, get_sample_with_past_kv_inputs, verify_ort_inputs, ) from optimum.onnxruntime import ORTModelForCausalLM from torch.profiler import ProfilerActivity, profile, record_function from tqdm import trange from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import onnxruntime as ort logger = logging.getLogger(__name__) # For determining whether the ONNX model can do both prompt generation and token generation or only one of the two def get_ort_model_inputs_len(args, model): if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: return 0 if args.benchmark_type == "hf-ort": try: # New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268) return len(model.inputs_names) except Exception: # Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54) return len(model.decoder.input_names) return len(model.get_inputs()) def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): init_inputs, iter_inputs = None, None # For past_present_share_buffer: # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported # Set max_seq_len to config value for other models max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: init_inputs = get_sample_inputs( args.config, args.target_device, args.batch_size, args.sequence_length, return_dict=True, ) iter_inputs = get_sample_with_past_kv_inputs( args.config, args.target_device, args.batch_size, args.sequence_length, use_fp16=args.use_fp16, return_dict=True, ) elif args.benchmark_type in {"hf-ort"}: if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids] # Using split models in Optimum (e.g. created by Optimum export) init_inputs = get_sample_inputs( args.config, args.target_device, args.batch_size, args.sequence_length, return_dict=True, ) iter_inputs = get_sample_with_past_kv_inputs( args.config, args.target_device, args.batch_size, args.sequence_length, use_fp16=args.use_fp16, return_dict=True, ) else: # Using merged model in Optimum (e.g. created by convert_to_onnx export) init_inputs = get_merged_sample_with_past_kv_inputs( args.config, args.target_device, args.batch_size, seq_len=args.sequence_length, past_seq_len=0, max_seq_len=max_seq_len, use_fp16=args.use_fp16, use_buffer_share=args.use_buffer_share, engine="pt", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( args.config, args.target_device, args.batch_size, seq_len=1, past_seq_len=args.sequence_length, max_seq_len=max_seq_len, use_fp16=args.use_fp16, use_buffer_share=args.use_buffer_share, engine="pt", return_dict=True, ) elif args.benchmark_type == "ort-convert-to-onnx": # Microsoft export from convert_to_onnx init_inputs = get_merged_sample_with_past_kv_inputs( args.config, args.target_device, args.batch_size, seq_len=args.sequence_length, past_seq_len=0, max_seq_len=max_seq_len, use_fp16=args.use_fp16, use_buffer_share=args.use_buffer_share, engine="ort", return_dict=True, world_size=args.world_size, ) iter_inputs = get_merged_sample_with_past_kv_inputs( args.config, args.target_device, args.batch_size, seq_len=1, past_seq_len=args.sequence_length, max_seq_len=max_seq_len, use_fp16=args.use_fp16, use_buffer_share=args.use_buffer_share, engine="ort", return_dict=True, world_size=args.world_size, ) elif args.benchmark_type == "ort-msft": # Microsoft export from https://github.com/microsoft/Llama-2-Onnx split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos] init_inputs = get_msft_sample_inputs( args.config, args.batch_size, past_seq_len=0, seq_len=args.sequence_length, max_seq_len=max_seq_len, use_fp16=args.use_fp16, use_buffer_share=args.use_buffer_share, split_kv=split_kv, ) iter_inputs = get_msft_sample_inputs( args.config, args.batch_size, past_seq_len=args.sequence_length, seq_len=1, max_seq_len=max_seq_len, use_fp16=args.use_fp16, use_buffer_share=args.use_buffer_share, split_kv=split_kv, ) else: raise Exception("Unable to auto-detect inputs for provided model") return init_inputs, iter_inputs def get_model(args: argparse.Namespace): model, sess_options = None, None start_time, end_time = None, None # There are multiple sources that the model could come from: # 1) Benchmark LLaMA-2 from unofficial source on Hugging Face # 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token # 3) Benchmark LLaMA-2 from local download of model # 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx) # 5) Benchmark LLaMA-2 from convert_to_onnx if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name start_time = time.time() model = AutoModelForCausalLM.from_pretrained( source, torch_dtype=torch.float16 if args.use_fp16 else torch.float32, use_auth_token=args.auth, trust_remote_code=args.auth, use_cache=True, cache_dir=args.cache_dir, ).to(args.target_device) end_time = time.time() if args.benchmark_type == "hf-pt-compile": model = torch.compile(model) elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}: sess_options = ort.SessionOptions() sess_options.enable_profiling = args.profile if args.verbose: sess_options.log_verbosity_level = 1 sess_options.log_severity_level = 1 else: raise Exception(f"Cannot recognize {args.benchmark_type}") if args.benchmark_type == "hf-ort": # Optimum export or convert_to_onnx.py export provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None decoder_file_name = None decoder_with_past_file_name = None for filename in os.listdir(args.hf_ort_dir_path): if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename: continue if "decoder_model" in filename or filename == "model.onnx": decoder_file_name = filename if "decoder_with_past_model" in filename: decoder_with_past_file_name = filename if "decoder_merged_model" in filename: decoder_file_name = filename decoder_with_past_file_name = filename start_time = time.time() model = ORTModelForCausalLM.from_pretrained( args.hf_ort_dir_path, decoder_file_name=decoder_file_name, decoder_with_past_file_name=decoder_with_past_file_name, use_auth_token=args.auth, trust_remote_code=args.auth, use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy. use_merged=(True if decoder_file_name == "model.onnx" else None), provider=provider, provider_options=provider_options, session_options=sess_options, ) end_time = time.time() if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}") start_time = time.time() model = ort.InferenceSession( args.ort_model_path.format(args.rank), sess_options, providers=[args.execution_provider], ) end_time = time.time() logger.info(f"Loaded model in {end_time - start_time} s") return model def time_fn(args, fn, inputs): # Warm up warmup_range = ( range(args.warmup_runs) if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.warmup_runs, file=sys.stdout, desc="Warm up") ) if args.verbose: outputs = fn(inputs) logger.info(outputs) input_sync = lambda *kwargs: ( # noqa: E731 args.io_binding.synchronize_inputs() if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize else lambda *kwargs: ( torch.cuda.synchronize() if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize else lambda *kwargs: None ) ) # no-op function output_sync = lambda *kwargs: ( # noqa: E731 args.io_binding.synchronize_outputs() if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize else lambda *kwargs: ( torch.cuda.synchronize() if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize else lambda *kwargs: None ) ) # no-op function for _ in warmup_range: input_sync() fn(inputs) output_sync() # Benchmark total_time = 0 bench_range = ( range(args.num_runs) if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: input_sync() start_time = time.time() fn(inputs) output_sync() end_time = time.time() total_time += end_time - start_time # Newline print after trange in order to print metrics on new lines without progress bar on same line if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}: logger.info("") latency = total_time / args.num_runs throughput = args.batch_size / latency if args.rank == 0: logger.info(f"Batch Size: {args.batch_size}") logger.info(f"Sequence Length: {args.sequence_length}") logger.info(f"Latency: {latency} s") logger.info(f"Throughput: {throughput} tps") return def profile_fn(args, fn, inputs, inputs_type): # Filename prefix format: # "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>" prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}" filename = None if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: # Profile PyTorch kernels with profile( # noqa: SIM117 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True ) as prof: with record_function("model_inference"): fn(inputs) prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows) filename = os.path.join(args.log_folder, f"{prefix}.log") with open(filename, "w") as f: f.write(prof_data) else: # Profile ORT kernels fn(inputs) # Set new log name for ORT profile log generated filename = f"{prefix}.json" return filename def measure_fn(args, fn, inputs): # Measure CPU usage pid = os.getpid() process = psutil.Process(pid) process.cpu_percent(interval=0.1) fn(inputs) if args.rank == 0: logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%") # Measure memory usage gc.collect() torch.cuda.empty_cache() measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs)) # Flush output so memory usage is printed sys.stdout.flush() def run_hf_inference(args, init_inputs, iter_inputs, model): # Inference steps to measure def get_logits(inputs): # Inference pass without decoding outputs = model(**inputs) return outputs # Examples of other inference steps that can be measured: # To use, uncomment the function and assign it to `generate_fn` # def get_pred_ids(inputs): # # Inference pass with predicted token ids generation # predicted_ids = model.generate(**inputs) # return predicted_ids # def gen_and_dec(inputs): # # Inference pass with generation and decoding # predicted_ids = get_pred_ids(inputs) # transcription = [] # for bs in range(args.batch_size): # for rs in range(args.num_return_sequences): # transcription.append( # args.tokenizer.batch_decode( # predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True # )[0] # ) # return transcription generate_fn = get_logits if args.benchmark_type == "hf-pt-compile": # Run forward pass once with each set of inputs to process through Dynamo generate_fn(init_inputs) generate_fn(iter_inputs) if args.profile: new_logname = profile_fn(args, generate_fn, init_inputs, "prompt") if args.benchmark_type == "hf-ort": # Turn profiling off to stop appending to log old_logname = model.decoder.session.end_profiling() logger.warning(f"Renaming {old_logname} to {new_logname}") os.rename(old_logname, os.path.join(args.log_folder, new_logname)) new_logname = profile_fn(args, generate_fn, iter_inputs, "token") if args.benchmark_type == "hf-ort": # Turn profiling off to stop appending to log old_logname = model.decoder_with_past.session.end_profiling() logger.warning(f"Renaming {old_logname} to {new_logname}") os.rename(old_logname, os.path.join(args.log_folder, new_logname)) return # PyTorch evaluations logger.info("\nEvaluating `model(inputs)` step to get past_key_values") time_fn(args, generate_fn, init_inputs) measure_fn(args, generate_fn, init_inputs) logger.info("\nEvaluating `model(inputs)` step with past_key_values") time_fn(args, generate_fn, iter_inputs) measure_fn(args, generate_fn, iter_inputs) def run_ort_inference(args, init_inputs, iter_inputs, model): def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Verify model inputs inputs = verify_ort_inputs(model, inputs) # Add IO bindings for non-CPU execution providers if args.device != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues( model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues ) setattr(args, "io_binding", io_binding) # noqa: B010 return io_binding, kv_cache_ortvalues return inputs, kv_cache_ortvalues def with_io_binding(io_binding): # Inference pass with IO binding model.run_with_iobinding(io_binding) def without_io_binding(inputs): # Inference pass without IO binding outputs = model.run(None, inputs) return outputs generate_fn = with_io_binding if args.device != "cpu" else without_io_binding kv_cache_ortvalues = {} if args.profile: ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt") # Turn profiling off to stop appending to log file old_logname = model.end_profiling() logger.warning(f"Renaming {old_logname} to {new_logname}") os.rename(old_logname, os.path.join(args.log_folder, new_logname)) # Re-initialize model for new log file instead of appending to old log file model = get_model(args) ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token") # Turn profiling off to stop appending to log old_logname = model.end_profiling() logger.warning(f"Renaming {old_logname} to {new_logname}") os.rename(old_logname, os.path.join(args.log_folder, new_logname)) return # ORT evaluations logger.info("\nEvaluating `model(inputs)` step to get past_key_values") ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_init_inputs) measure_fn(args, generate_fn, ort_init_inputs) logger.info("\nEvaluating `model(inputs)` step with past_key_values") ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_iter_inputs) measure_fn(args, generate_fn, ort_iter_inputs) def run_inference(args, init_inputs, iter_inputs, model): if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}: run_hf_inference(args, init_inputs, iter_inputs, model) elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: run_ort_inference(args, init_inputs, iter_inputs, model) else: raise Exception(f"Cannot recognize {args.benchmark_type}") def get_args(rank=0): parser = argparse.ArgumentParser() parser.add_argument( "-bt", "--benchmark-type", type=str, required=True, choices=[ "hf-pt-eager", "hf-pt-compile", "hf-ort", "ort-msft", "ort-convert-to-onnx", ], ) parser.add_argument( "-m", "--model-name", type=str, required=True, help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')", ) parser.add_argument( "-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model" ) # Args for choosing the model parser.add_argument( "-p", "--precision", required=True, type=str, default="fp32", choices=["int4", "int8", "fp16", "fp32"], help="Precision for model. For ONNX models, the model's precision should be set before running this script.", ) parser.add_argument( "--hf-pt-dir-path", type=str, default="", help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)", ) parser.add_argument( "--hf-ort-dir-path", type=str, default="", help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)", ) parser.add_argument( "--ort-model-path", type=str, default="", help="Path to ONNX model", ) # Args for running and evaluating the model parser.add_argument( "-b", "--batch-sizes", default="1 2", ) parser.add_argument( "-s", "--sequence-lengths", default="32 64 128 256 512", ) parser.add_argument( "-d", "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", choices=["cpu", "cuda", "rocm"], ) parser.add_argument("-id", "--device-id", type=int, default=0) parser.add_argument("-w", "--warmup-runs", type=int, default=5) parser.add_argument("-n", "--num-runs", type=int, default=10) parser.add_argument("--seed", type=int, default=2) # Args for decoding logic parser.add_argument("--max-length", type=int, default=32) parser.add_argument("--num-return-sequences", type=int, default=1) # Args for accessing detailed info parser.add_argument("--profile", default=False, action="store_true") parser.add_argument( "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by" ) parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display") parser.add_argument("--verbose", default=False, action="store_true") parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files") parser.add_argument( "--cache-dir", type=str, required=True, default="./model_cache", help="Cache dir where Hugging Face files are stored", ) args = parser.parse_args() # Set seed properties np.random.seed(args.seed) torch.manual_seed(args.seed) # Set runtime properties if "ort" in args.benchmark_type: setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010 if args.execution_provider == "CUDAExecutionProvider": args.execution_provider = (args.execution_provider, {"device_id": rank}) elif args.execution_provider == "ROCMExecutionProvider": args.execution_provider = (args.execution_provider, {"device_id": rank}) args.device = "cuda" # Check that paths have been specified for any benchmarking with ORT if args.benchmark_type == "hf-ort": assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`" if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: assert args.ort_model_path, "Please specify a path to `--ort-model-path`" args.batch_sizes = args.batch_sizes.split(" ") args.sequence_lengths = args.sequence_lengths.split(" ") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models args.precision = ( "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16" ) # Check that only one (batch_size, sequence_length) combination is set for profiling if args.profile: assert ( len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1 ), "Please provide only one (batch_size, sequence_length) combination for profiling" return args def main(): rank = get_rank() world_size = get_size() args = get_args(rank) setup_logger(args.verbose) logger.info(args.__dict__) torch.backends.cudnn.benchmark = True args.rank = rank args.world_size = world_size tokenizer = AutoTokenizer.from_pretrained( args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth ) config = AutoConfig.from_pretrained( args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth ) target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device use_fp16 = args.precision == "fp16" setattr(args, "tokenizer", tokenizer) # noqa: B010 setattr(args, "config", config) # noqa: B010 setattr(args, "target_device", target_device) # noqa: B010 setattr(args, "use_fp16", use_fp16) # noqa: B010 # Get model and model info model = get_model(args) ort_model_inputs_len = get_ort_model_inputs_len(args, model) # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA) if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}: onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False) gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010 else: setattr(args, "use_buffer_share", False) # noqa: B010 # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): if args.rank == 0: logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") setattr(args, "batch_size", int(batch_size)) # noqa: B010 setattr(args, "sequence_length", int(sequence_length)) # noqa: B010 init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len) run_inference(args, init_inputs, iter_inputs, model) if __name__ == "__main__": main()
Memory