Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
vllm==0.3.3
vllm==0.4.0.post1
pydantic>=2.0
23 changes: 18 additions & 5 deletions model-engine/model_engine_server/inference/vllm/vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import signal
import subprocess
import traceback
from typing import AsyncGenerator
from typing import AsyncGenerator, Dict, List, Optional

import uvicorn
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
Expand All @@ -13,7 +13,9 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.utils import random_uuid

TIMEOUT_KEEP_ALIVE = 5 # seconds.
Expand Down Expand Up @@ -76,13 +78,12 @@ async def generate(request: Request) -> Response:
async def stream_results() -> AsyncGenerator[str, None]:
last_output_text = ""
async for request_output in results_generator:
log_probs = format_logprobs(request_output)
ret = {
"text": request_output.outputs[-1].text[len(last_output_text) :],
"count_prompt_tokens": len(request_output.prompt_token_ids),
"count_output_tokens": len(request_output.outputs[0].token_ids),
"log_probs": (
request_output.outputs[0].logprobs[-1] if sampling_params.logprobs else None
),
"log_probs": log_probs[-1] if log_probs and sampling_params.logprobs else None,
"finished": request_output.finished,
}
last_output_text = request_output.outputs[-1].text
Expand Down Expand Up @@ -116,7 +117,7 @@ async def abort_request() -> None:
"text": final_output.outputs[0].text,
"count_prompt_tokens": len(final_output.prompt_token_ids),
"count_output_tokens": len(final_output.outputs[0].token_ids),
"log_probs": final_output.outputs[0].logprobs,
"log_probs": format_logprobs(final_output),
"tokens": tokens,
}
return Response(content=json.dumps(ret))
Expand Down Expand Up @@ -166,6 +167,18 @@ def debug(sig, frame):
i.interact(message)


def format_logprobs(request_output: CompletionOutput) -> Optional[List[Dict[int, float]]]:
"""Given a request output, format the logprobs if they exist."""
output_logprobs = request_output.outputs[0].logprobs
if output_logprobs is None:
return None

def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]:
return {k: v.logprob for k, v in logprobs.items()}

return [extract_logprobs(logprobs) for logprobs in output_logprobs]


if __name__ == "__main__":
check_unknown_startup_memory_usage()
parser = argparse.ArgumentParser()
Expand Down
14 changes: 11 additions & 3 deletions model-engine/tests/unit/inference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,29 +58,37 @@ def create_batch_completions_request_content():

@pytest.fixture
def create_vllm_request_outputs():
class Logprob:
"""mock, from https://github.com/vllm-project/vllm/blob/v0.4.1/vllm/sequence.py#L18"""

def __init__(self, logprob: float):
self.logprob = logprob

mock_vllm_request_output1 = MagicMock()
mock_vllm_request_output1.outputs = [
MagicMock(text="text1"),
]
mock_vllm_request_output1.prompt_token_ids = [1, 2, 3]
mock_vllm_request_output1.outputs[0].token_ids = [4]
mock_vllm_request_output1.outputs[0].logprobs = [{4: 0.1}]
mock_vllm_request_output1.outputs[0].logprobs = [{4: Logprob(0.1)}]

mock_vllm_request_output2 = MagicMock()
mock_vllm_request_output2.outputs = [
MagicMock(text="text1 text2"),
]
mock_vllm_request_output2.prompt_token_ids = [1, 2, 3]
mock_vllm_request_output2.outputs[0].token_ids = [4, 5]
mock_vllm_request_output2.outputs[0].logprobs = [{4: 0.1, 5: 0.2}]
mock_vllm_request_output2.outputs[0].logprobs = [{4: Logprob(0.1), 5: Logprob(0.2)}]

mock_vllm_request_output3 = MagicMock()
mock_vllm_request_output3.outputs = [
MagicMock(text="text1 text2 text3"),
]
mock_vllm_request_output3.prompt_token_ids = [1, 2, 3]
mock_vllm_request_output3.outputs[0].token_ids = [4, 5, 6]
mock_vllm_request_output3.outputs[0].logprobs = [{4: 0.1, 5: 0.2, 6: 0.3}]
mock_vllm_request_output3.outputs[0].logprobs = [
{4: Logprob(0.1), 5: Logprob(0.2), 6: Logprob(0.3)}
]
return [mock_vllm_request_output1, mock_vllm_request_output2, mock_vllm_request_output3]


Expand Down