diff --git a/scripts/throughput_benchmarks.py b/scripts/throughput_benchmarks.py index c689d8cc5..d67614a5f 100644 --- a/scripts/throughput_benchmarks.py +++ b/scripts/throughput_benchmarks.py @@ -221,6 +221,34 @@ def generate_output_token_counts(mean, std, num, input_token_count): return output +def generate_output_token_counts_from_existing( + distribution: List[int], num: int, input_token_count: int +): + assert len(distribution) > 0, "Can't have a distribution with 0 tokens" + output = [] + # Sample without replacement so that we don't have as much variance + for _ in range(num // len(distribution)): + random.shuffle(distribution) + output.extend(distribution) + random.shuffle(distribution) + output.extend(distribution[: num % len(distribution)]) + assert len(output) == num + + for i in range(len(output)): + output[i] = min(output[i], MAX_CONTEXT_WINDOW - input_token_count) + return output + + +def read_distribution_from_file(fpath: str): + # Assumes the distribution is some json-formatted string that represents a list + try: + with open(fpath, "r") as fin: + return json.load(fin) + except FileNotFoundError: + print("File not found. Exiting.") + raise + + def run_benchmark( model: str, framework: InferenceFramework, @@ -231,17 +259,23 @@ def run_benchmark( concurrency: int, verbose: bool, local_port: int, + response_token_count_distribution: Optional[List] = None, ): prompt = generate_prompt(config.input_token_count, hf_model) prompt_num_tokens = config.input_token_count - output_token_counts = generate_output_token_counts( - config.output_token_count_mean, - config.output_token_count_std, - num_trials, - config.input_token_count, - ) + if response_token_count_distribution is not None: + output_token_counts = generate_output_token_counts_from_existing( + response_token_count_distribution, num_trials, config.input_token_count + ) + else: + output_token_counts = generate_output_token_counts( + config.output_token_count_mean, + config.output_token_count_std, + num_trials, + config.input_token_count, + ) start = time.time() results = send_requests( @@ -352,10 +386,18 @@ def run_benchmarks( verbose: bool = False, hf_model: Optional[str] = None, local_port: int = 5005, + response_token_count_distribution_file: Optional[str] = None, ): """Run benchmarks.""" all_statistics = [] config = BenchmarkConfig(input_token_count, output_token_count_mean) + + response_token_count_distribution = None + if response_token_count_distribution_file is not None: + response_token_count_distribution = read_distribution_from_file( + response_token_count_distribution_file + ) + try: if verbose: print(f"Running benchmark for config {config}") @@ -375,6 +417,7 @@ def run_benchmarks( concurrency, verbose, local_port, + response_token_count_distribution, ) all_statistics.append(statistics) except Exception: @@ -404,6 +447,7 @@ def run_benchmarks_concurrency_range( verbose: bool = False, hf_model: Optional[str] = None, local_port: int = 5005, + response_token_count_distribution_file: Optional[str] = None, ): if output_file is not None: # Create empty file @@ -422,6 +466,7 @@ def run_benchmarks_concurrency_range( verbose, hf_model, local_port, + response_token_count_distribution_file, )