Skip to content
Merged
29 changes: 29 additions & 0 deletions model-engine/model_engine_server/common/dtos/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,30 @@ class CreateBatchCompletionsModelConfig(BaseModel):
"""


class ToolConfig(BaseModel):
"""
Configuration for tool use.
NOTE: this config is highly experimental and signature will change significantly in future iterations.
"""

name: str
"""
Name of the tool to use for the batch inference.
"""
max_iterations: Optional[int] = 10
"""
Maximum number of iterations to run the tool.
"""
execution_timeout_seconds: Optional[int] = 60
"""
Maximum runtime of the tool in seconds.
"""
should_retry_on_error: Optional[bool] = True
"""
Whether to retry the tool on error.
"""


class CreateBatchCompletionsRequest(BaseModel):
"""
Request object for batch completions.
Expand Down Expand Up @@ -456,6 +480,11 @@ class CreateBatchCompletionsRequest(BaseModel):
"""
Maximum runtime of the batch inference in seconds. Default to one day.
"""
tool_config: Optional[ToolConfig] = None
"""
Configuration for tool use.
NOTE: this config is highly experimental and signature will change significantly in future iterations.
"""


class CreateBatchCompletionsResponse(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2275,6 +2275,11 @@ async def execute(
hardware.gpus = max(hardware.gpus, request.model_config.num_shards)
request.model_config.num_shards = hardware.gpus

if request.tool_config and request.tool_config.name != "code_evaluator":
raise ObjectHasInvalidValueException(
"Only code_evaluator tool is supported for batch completions."
)

batch_bundle = await self.create_batch_job_bundle(user, request, hardware)

validate_resource_requests(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json

COMPLETION_PROMPT1 = """\
FYI: you can write code like this:
```python
import math
print(math.sqrt(2))
```
1.41...
>>>

For reference, the third digit of 4.32 is 2. Also, use "Final Answer: X" to indicate your final answer.

### Problem:

What is the 4th digit of pi?

### Answer:
```python
import math
print(math.pi)
```
3.141592653589793
>>>

Final Answer: 1

### Problem:

What is the 4th digit of the square root of 2?

### Answer:
"""

COMPLETION_PROMPT2 = """\
FYI: you can write code like this:
```python
import math
print(math.sqrt(2))
```
1.41...
>>>

For reference, the third digit of 4.32 is 2. Also, use "Final Answer: X" to indicate your final answer.

### Problem:

What is the 4th digit of pi?

### Answer:
```python
import math
print(math.pi)
```
3.141592653589793
>>>

Final Answer: 1

### Problem:

What is the 5th digit of the square root of 2?

### Answer:
"""

data = {
"prompts": [
COMPLETION_PROMPT1,
COMPLETION_PROMPT2,
"what is deep learning",
],
"max_new_tokens": 100,
"temperature": 0.0,
"return_token_log_probs": True,
"stop_sequences": ["</s>", "\n### Problem:\n", ">>>\n"],
}

json.dump(data, open("sample_data_tool.json", "w"))
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ vllm==0.2.5
pydantic==1.10.13
boto3==1.34.15
smart-open==6.4.0
ddtrace==2.4.0
ddtrace==2.4.0
docker==7.0.0
func-timeout==4.3.5
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"input_data_path":"./sample_data_tool.json",
"output_data_path":"./sample_output_tool.json",
"model_config":{
"model":"mistral-7b",
"checkpoint_path":"s3://scale-ml/models/mistral-7b",
"num_shards": 1,
"labels": {"team": "my_team"}
},
"data_parallelism":2,
"tool_config": {
"name": "code_evaluator"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"prompts": [
"FYI: you can write code like this: \n```python\nimport math\nprint(math.sqrt(2))\n```\n1.41...\n>>>\n\nFor reference, the third digit of 4.32 is 2. Also, use \"Final Answer: X\" to indicate your final answer.\n\n### Problem:\n\nWhat is the 4th digit of pi?\n\n### Answer:\n```python\nimport math\nprint(math.pi)\n```\n3.141592653589793\n>>>\n\nFinal Answer: 1\n\n### Problem:\n\nWhat is the 4th digit of the square root of 2?\n\n### Answer: \n",
"FYI: you can write code like this: \n```python\nimport math\nprint(math.sqrt(2))\n```\n1.41...\n>>>\n\nFor reference, the third digit of 4.32 is 2. Also, use \"Final Answer: X\" to indicate your final answer.\n\n### Problem:\n\nWhat is the 4th digit of pi?\n\n### Answer:\n```python\nimport math\nprint(math.pi)\n```\n3.141592653589793\n>>>\n\nFinal Answer: 1\n\n### Problem:\n\nWhat is the 5th digit of the square root of 2?\n\n### Answer: \n",
"what is deep learning"
],
"max_new_tokens": 100,
"temperature": 0.0,
"return_token_log_probs": true,
"stop_sequences": [
"</s>",
"\n### Problem:\n",
">>>\n"
]
}
Loading