diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt index 3c1cf8512..d0e331f4c 100644 --- a/model-engine/model_engine_server/inference/vllm/requirements.txt +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -1,2 +1,2 @@ -vllm==0.3.3 +vllm==0.4.0.post1 pydantic>=2.0 diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py index d9b502efc..c7ef4b434 100644 --- a/model-engine/model_engine_server/inference/vllm/vllm_server.py +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -4,7 +4,7 @@ import signal import subprocess import traceback -from typing import AsyncGenerator +from typing import AsyncGenerator, Dict, List, Optional import uvicorn from fastapi import BackgroundTasks, FastAPI, HTTPException, Request @@ -13,7 +13,9 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor +from vllm.outputs import CompletionOutput from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob from vllm.utils import random_uuid TIMEOUT_KEEP_ALIVE = 5 # seconds. @@ -76,13 +78,12 @@ async def generate(request: Request) -> Response: async def stream_results() -> AsyncGenerator[str, None]: last_output_text = "" async for request_output in results_generator: + log_probs = format_logprobs(request_output) ret = { "text": request_output.outputs[-1].text[len(last_output_text) :], "count_prompt_tokens": len(request_output.prompt_token_ids), "count_output_tokens": len(request_output.outputs[0].token_ids), - "log_probs": ( - request_output.outputs[0].logprobs[-1] if sampling_params.logprobs else None - ), + "log_probs": log_probs[-1] if log_probs and sampling_params.logprobs else None, "finished": request_output.finished, } last_output_text = request_output.outputs[-1].text @@ -116,7 +117,7 @@ async def abort_request() -> None: "text": final_output.outputs[0].text, "count_prompt_tokens": len(final_output.prompt_token_ids), "count_output_tokens": len(final_output.outputs[0].token_ids), - "log_probs": final_output.outputs[0].logprobs, + "log_probs": format_logprobs(final_output), "tokens": tokens, } return Response(content=json.dumps(ret)) @@ -166,6 +167,18 @@ def debug(sig, frame): i.interact(message) +def format_logprobs(request_output: CompletionOutput) -> Optional[List[Dict[int, float]]]: + """Given a request output, format the logprobs if they exist.""" + output_logprobs = request_output.outputs[0].logprobs + if output_logprobs is None: + return None + + def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: + return {k: v.logprob for k, v in logprobs.items()} + + return [extract_logprobs(logprobs) for logprobs in output_logprobs] + + if __name__ == "__main__": check_unknown_startup_memory_usage() parser = argparse.ArgumentParser() diff --git a/model-engine/tests/unit/inference/conftest.py b/model-engine/tests/unit/inference/conftest.py index 26a3a0a3a..20c4aae85 100644 --- a/model-engine/tests/unit/inference/conftest.py +++ b/model-engine/tests/unit/inference/conftest.py @@ -58,13 +58,19 @@ def create_batch_completions_request_content(): @pytest.fixture def create_vllm_request_outputs(): + class Logprob: + """mock, from https://github.com/vllm-project/vllm/blob/v0.4.1/vllm/sequence.py#L18""" + + def __init__(self, logprob: float): + self.logprob = logprob + mock_vllm_request_output1 = MagicMock() mock_vllm_request_output1.outputs = [ MagicMock(text="text1"), ] mock_vllm_request_output1.prompt_token_ids = [1, 2, 3] mock_vllm_request_output1.outputs[0].token_ids = [4] - mock_vllm_request_output1.outputs[0].logprobs = [{4: 0.1}] + mock_vllm_request_output1.outputs[0].logprobs = [{4: Logprob(0.1)}] mock_vllm_request_output2 = MagicMock() mock_vllm_request_output2.outputs = [ @@ -72,7 +78,7 @@ def create_vllm_request_outputs(): ] mock_vllm_request_output2.prompt_token_ids = [1, 2, 3] mock_vllm_request_output2.outputs[0].token_ids = [4, 5] - mock_vllm_request_output2.outputs[0].logprobs = [{4: 0.1, 5: 0.2}] + mock_vllm_request_output2.outputs[0].logprobs = [{4: Logprob(0.1), 5: Logprob(0.2)}] mock_vllm_request_output3 = MagicMock() mock_vllm_request_output3.outputs = [ @@ -80,7 +86,9 @@ def create_vllm_request_outputs(): ] mock_vllm_request_output3.prompt_token_ids = [1, 2, 3] mock_vllm_request_output3.outputs[0].token_ids = [4, 5, 6] - mock_vllm_request_output3.outputs[0].logprobs = [{4: 0.1, 5: 0.2, 6: 0.3}] + mock_vllm_request_output3.outputs[0].logprobs = [ + {4: Logprob(0.1), 5: Logprob(0.2), 6: Logprob(0.3)} + ] return [mock_vllm_request_output1, mock_vllm_request_output2, mock_vllm_request_output3]