Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions scripts/throughput_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,34 @@ def generate_output_token_counts(mean, std, num, input_token_count):
return output


def generate_output_token_counts_from_existing(
distribution: List[int], num: int, input_token_count: int
):
assert len(distribution) > 0, "Can't have a distribution with 0 tokens"
output = []
# Sample without replacement so that we don't have as much variance
for _ in range(num // len(distribution)):
random.shuffle(distribution)
output.extend(distribution)
random.shuffle(distribution)
output.extend(distribution[: num % len(distribution)])
assert len(output) == num

for i in range(len(output)):
output[i] = min(output[i], MAX_CONTEXT_WINDOW - input_token_count)
return output


def read_distribution_from_file(fpath: str):
# Assumes the distribution is some json-formatted string that represents a list
try:
with open(fpath, "r") as fin:
return json.load(fin)
except FileNotFoundError:
print("File not found. Exiting.")
raise


def run_benchmark(
model: str,
framework: InferenceFramework,
Expand All @@ -231,17 +259,23 @@ def run_benchmark(
concurrency: int,
verbose: bool,
local_port: int,
response_token_count_distribution: Optional[List] = None,
):
prompt = generate_prompt(config.input_token_count, hf_model)

prompt_num_tokens = config.input_token_count

output_token_counts = generate_output_token_counts(
config.output_token_count_mean,
config.output_token_count_std,
num_trials,
config.input_token_count,
)
if response_token_count_distribution is not None:
output_token_counts = generate_output_token_counts_from_existing(
response_token_count_distribution, num_trials, config.input_token_count
)
else:
output_token_counts = generate_output_token_counts(
config.output_token_count_mean,
config.output_token_count_std,
num_trials,
config.input_token_count,
)

start = time.time()
results = send_requests(
Expand Down Expand Up @@ -352,10 +386,18 @@ def run_benchmarks(
verbose: bool = False,
hf_model: Optional[str] = None,
local_port: int = 5005,
response_token_count_distribution_file: Optional[str] = None,
):
"""Run benchmarks."""
all_statistics = []
config = BenchmarkConfig(input_token_count, output_token_count_mean)

response_token_count_distribution = None
if response_token_count_distribution_file is not None:
response_token_count_distribution = read_distribution_from_file(
response_token_count_distribution_file
)

try:
if verbose:
print(f"Running benchmark for config {config}")
Expand All @@ -375,6 +417,7 @@ def run_benchmarks(
concurrency,
verbose,
local_port,
response_token_count_distribution,
)
all_statistics.append(statistics)
except Exception:
Expand Down Expand Up @@ -404,6 +447,7 @@ def run_benchmarks_concurrency_range(
verbose: bool = False,
hf_model: Optional[str] = None,
local_port: int = 5005,
response_token_count_distribution_file: Optional[str] = None,
):
if output_file is not None:
# Create empty file
Expand All @@ -422,6 +466,7 @@ def run_benchmarks_concurrency_range(
verbose,
hf_model,
local_port,
response_token_count_distribution_file,
)


Expand Down