# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import gc
import itertools
import logging
import os
import sys
import time
import numpy as np
import onnx
import psutil
import torch
from benchmark_helper import measure_memory, setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings_as_ortvalues,
get_merged_sample_with_past_kv_inputs,
get_msft_sample_inputs,
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
)
from optimum.onnxruntime import ORTModelForCausalLM
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import trange
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import onnxruntime as ort
logger = logging.getLogger(__name__)
# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
def get_ort_model_inputs_len(args, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
return 0
if args.benchmark_type == "hf-ort":
try:
# New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
return len(model.inputs_names)
except Exception:
# Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
return len(model.decoder.input_names)
return len(model.get_inputs())
def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
init_inputs, iter_inputs = None, None
# For past_present_share_buffer:
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
# Set max_seq_len to config value for other models
max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
init_inputs = get_sample_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
return_dict=True,
)
iter_inputs = get_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
elif args.benchmark_type in {"hf-ort"}:
if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
# Using split models in Optimum (e.g. created by Optimum export)
init_inputs = get_sample_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
return_dict=True,
)
iter_inputs = get_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
else:
# Using merged model in Optimum (e.g. created by convert_to_onnx export)
init_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
elif args.benchmark_type == "ort-convert-to-onnx":
# Microsoft export from convert_to_onnx
init_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
elif args.benchmark_type == "ort-msft":
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
init_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)
iter_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=args.sequence_length,
seq_len=1,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)
else:
raise Exception("Unable to auto-detect inputs for provided model")
return init_inputs, iter_inputs
def get_model(args: argparse.Namespace):
model, sess_options = None, None
start_time, end_time = None, None
# There are multiple sources that the model could come from:
# 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
# 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
# 3) Benchmark LLaMA-2 from local download of model
# 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
# 5) Benchmark LLaMA-2 from convert_to_onnx
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
source,
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_cache=True,
cache_dir=args.cache_dir,
).to(args.target_device)
end_time = time.time()
if args.benchmark_type == "hf-pt-compile":
model = torch.compile(model)
elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
sess_options = ort.SessionOptions()
sess_options.enable_profiling = args.profile
if args.verbose:
sess_options.log_verbosity_level = 1
sess_options.log_severity_level = 1
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
if args.benchmark_type == "hf-ort":
# Optimum export or convert_to_onnx.py export
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
decoder_file_name = None
decoder_with_past_file_name = None
for filename in os.listdir(args.hf_ort_dir_path):
if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
continue
if "decoder_model" in filename or filename == "model.onnx":
decoder_file_name = filename
if "decoder_with_past_model" in filename:
decoder_with_past_file_name = filename
if "decoder_merged_model" in filename:
decoder_file_name = filename
decoder_with_past_file_name = filename
start_time = time.time()
model = ORTModelForCausalLM.from_pretrained(
args.hf_ort_dir_path,
decoder_file_name=decoder_file_name,
decoder_with_past_file_name=decoder_with_past_file_name,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
use_merged=(True if decoder_file_name == "model.onnx" else None),
provider=provider,
provider_options=provider_options,
session_options=sess_options,
)
end_time = time.time()
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
# Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
start_time = time.time()
model = ort.InferenceSession(
args.ort_model_path.format(args.rank),
sess_options,
providers=[args.execution_provider],
)
end_time = time.time()
logger.info(f"Loaded model in {end_time - start_time} s")
return model
def time_fn(args, fn, inputs):
# Warm up
warmup_range = (
range(args.warmup_runs)
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
)
if args.verbose:
outputs = fn(inputs)
logger.info(outputs)
input_sync = lambda *kwargs: ( # noqa: E731
args.io_binding.synchronize_inputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: (
torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None
)
) # no-op function
output_sync = lambda *kwargs: ( # noqa: E731
args.io_binding.synchronize_outputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: (
torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None
)
) # no-op function
for _ in warmup_range:
input_sync()
fn(inputs)
output_sync()
# Benchmark
total_time = 0
bench_range = (
range(args.num_runs)
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
input_sync()
start_time = time.time()
fn(inputs)
output_sync()
end_time = time.time()
total_time += end_time - start_time
# Newline print after trange in order to print metrics on new lines without progress bar on same line
if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
logger.info("")
latency = total_time / args.num_runs
throughput = args.batch_size / latency
if args.rank == 0:
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Sequence Length: {args.sequence_length}")
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} tps")
return
def profile_fn(args, fn, inputs, inputs_type):
# Filename prefix format:
# "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
) as prof:
with record_function("model_inference"):
fn(inputs)
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
filename = os.path.join(args.log_folder, f"{prefix}.log")
with open(filename, "w") as f:
f.write(prof_data)
else:
# Profile ORT kernels
fn(inputs)
# Set new log name for ORT profile log generated
filename = f"{prefix}.json"
return filename
def measure_fn(args, fn, inputs):
# Measure CPU usage
pid = os.getpid()
process = psutil.Process(pid)
process.cpu_percent(interval=0.1)
fn(inputs)
if args.rank == 0:
logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
# Measure memory usage
gc.collect()
torch.cuda.empty_cache()
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
# Flush output so memory usage is printed
sys.stdout.flush()
def run_hf_inference(args, init_inputs, iter_inputs, model):
# Inference steps to measure
def get_logits(inputs):
# Inference pass without decoding
outputs = model(**inputs)
return outputs
# Examples of other inference steps that can be measured:
# To use, uncomment the function and assign it to `generate_fn`
# def get_pred_ids(inputs):
# # Inference pass with predicted token ids generation
# predicted_ids = model.generate(**inputs)
# return predicted_ids
# def gen_and_dec(inputs):
# # Inference pass with generation and decoding
# predicted_ids = get_pred_ids(inputs)
# transcription = []
# for bs in range(args.batch_size):
# for rs in range(args.num_return_sequences):
# transcription.append(
# args.tokenizer.batch_decode(
# predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
# )[0]
# )
# return transcription
generate_fn = get_logits
if args.benchmark_type == "hf-pt-compile":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(init_inputs)
generate_fn(iter_inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder_with_past.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# PyTorch evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
time_fn(args, generate_fn, init_inputs)
measure_fn(args, generate_fn, init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
time_fn(args, generate_fn, iter_inputs)
measure_fn(args, generate_fn, iter_inputs)
def run_ort_inference(args, init_inputs, iter_inputs, model):
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
# Verify model inputs
inputs = verify_ort_inputs(model, inputs)
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
)
setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding, kv_cache_ortvalues
return inputs, kv_cache_ortvalues
def with_io_binding(io_binding):
# Inference pass with IO binding
model.run_with_iobinding(io_binding)
def without_io_binding(inputs):
# Inference pass without IO binding
outputs = model.run(None, inputs)
return outputs
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
kv_cache_ortvalues = {}
if args.profile:
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
# Turn profiling off to stop appending to log file
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
# Re-initialize model for new log file instead of appending to old log file
model = get_model(args)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
# Turn profiling off to stop appending to log
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# ORT evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_init_inputs)
measure_fn(args, generate_fn, ort_init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_iter_inputs)
measure_fn(args, generate_fn, ort_iter_inputs)
def run_inference(args, init_inputs, iter_inputs, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
run_hf_inference(args, init_inputs, iter_inputs, model)
elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
run_ort_inference(args, init_inputs, iter_inputs, model)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
def get_args(rank=0):
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
"--benchmark-type",
type=str,
required=True,
choices=[
"hf-pt-eager",
"hf-pt-compile",
"hf-ort",
"ort-msft",
"ort-convert-to-onnx",
],
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=True,
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
)
parser.add_argument(
"-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
)
# Args for choosing the model
parser.add_argument(
"-p",
"--precision",
required=True,
type=str,
default="fp32",
choices=["int4", "int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"--hf-pt-dir-path",
type=str,
default="",
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
default="",
help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
)
parser.add_argument(
"--ort-model-path",
type=str,
default="",
help="Path to ONNX model",
)
# Args for running and evaluating the model
parser.add_argument(
"-b",
"--batch-sizes",
default="1 2",
)
parser.add_argument(
"-s",
"--sequence-lengths",
default="32 64 128 256 512",
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda", "rocm"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=10)
parser.add_argument("--seed", type=int, default=2)
# Args for decoding logic
parser.add_argument("--max-length", type=int, default=32)
parser.add_argument("--num-return-sequences", type=int, default=1)
# Args for accessing detailed info
parser.add_argument("--profile", default=False, action="store_true")
parser.add_argument(
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
)
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
parser.add_argument(
"--cache-dir",
type=str,
required=True,
default="./model_cache",
help="Cache dir where Hugging Face files are stored",
)
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Set runtime properties
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
args.device = "cuda"
# Check that paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
args.batch_sizes = args.batch_sizes.split(" ")
args.sequence_lengths = args.sequence_lengths.split(" ")
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)
# Check that only one (batch_size, sequence_length) combination is set for profiling
if args.profile:
assert (
len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1
), "Please provide only one (batch_size, sequence_length) combination for profiling"
return args
def main():
rank = get_rank()
world_size = get_size()
args = get_args(rank)
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
args.rank = rank
args.world_size = world_size
tokenizer = AutoTokenizer.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
)
config = AutoConfig.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
)
target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
setattr(args, "tokenizer", tokenizer) # noqa: B010
setattr(args, "config", config) # noqa: B010
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
# Get model and model info
model = get_model(args)
ort_model_inputs_len = get_ort_model_inputs_len(args, model)
# Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
else:
setattr(args, "use_buffer_share", False) # noqa: B010
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
if args.rank == 0:
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
setattr(args, "batch_size", int(batch_size)) # noqa: B010
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
run_inference(args, init_inputs, iter_inputs, model)
if __name__ == "__main__":
main()