Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ccc7cf7
fix up some issues I encountered while trying Mixtral
seanshi-scale Apr 10, 2024
2bd356e
add changes to tput bm script to get trt llm to work locally
seanshi-scale Apr 10, 2024
5ff1693
add num completion tokens for trt llm
seanshi-scale Apr 10, 2024
a1a1d16
try updating the trt triton_model_repo code to tag: v0.8.0 from tenso…
seanshi-scale Apr 26, 2024
258bc99
change some outputs
seanshi-scale Apr 27, 2024
28fc2b5
throw in random values for tensorrt_llm_bls
seanshi-scale Apr 27, 2024
eb173f8
at this point, the built triton image respects stop sequences, but no…
seanshi-scale Apr 27, 2024
99795d5
add in a hack to get streaming requests to not drop spaces hopefully
seanshi-scale Apr 29, 2024
563b128
try out another somewhat-hack to get the tokenizer to respect stop se…
seanshi-scale Apr 30, 2024
0121939
add supported model
seanshi-scale May 2, 2024
0bc9376
fix bug with new trt version
seanshi-scale May 3, 2024
3ba38d0
new test + add a new case for TRT response handling
seanshi-scale May 3, 2024
f5bcd4c
comment + handle float output_log_probs
seanshi-scale May 3, 2024
2b2075f
add test
seanshi-scale May 3, 2024
902a2c9
black
seanshi-scale May 9, 2024
c5b0bd1
readme
seanshi-scale May 9, 2024
fd0130f
ruff
seanshi-scale May 9, 2024
62b19d4
add some comments
seanshi-scale May 9, 2024
10f6e4e
Merge branch 'main' into seanshi/20240409-tensorrtllm-improvements
seanshi-scale May 9, 2024
498fca7
don't need to decode the entire thing twice
seanshi-scale May 14, 2024
3ef8fca
Merge branch 'main' into seanshi/20240409-tensorrtllm-improvements
seanshi-scale May 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@
"llama-2-70b-chat",
]
),
LLMInferenceFramework.TENSORRT_LLM: set(["llama-2-7b"]),
LLMInferenceFramework.TENSORRT_LLM: set(
["llama-2-7b", "mixtral-8x7b", "mixtral-8x7b-instruct"]
),
}

_SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = {
Expand Down Expand Up @@ -1467,11 +1469,28 @@ def model_output_to_completion_output(
num_prompt_tokens = count_tokens(
prompt, model_content.model_name, self.tokenizer_repository
)
return CompletionOutput(
if "token_ids" in model_output:
# TensorRT 23.10 has this field, TensorRT 24.03 does not
# For backwards compatibility with pre-2024/05/02
num_completion_tokens = len(model_output["token_ids"]) - num_prompt_tokens
# Output is "<s> prompt output"
text=model_output["text_output"][(len(prompt) + 4) :],
text = model_output["text_output"][(len(prompt) + 4) :]
elif "output_log_probs" in model_output:
# TensorRT 24.01 + surrounding code.
# For some reason TRT returns output_log_probs as either a list or a float
# Also the log probs don't look right, so returning log-probs is still broken
num_completion_tokens = (
len(model_output["output_log_probs"])
if type(model_output["output_log_probs"]) == list
else 1
)
# Output is just "output". See `exclude_input_in_output` inside of
# inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt
text = model_output["text_output"]
return CompletionOutput(
text=text,
num_prompt_tokens=num_prompt_tokens,
num_completion_tokens=len(model_output["token_ids"]) - num_prompt_tokens,
num_completion_tokens=num_completion_tokens,
)
else:
raise EndpointUnsupportedInferenceTypeException(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3
FROM nvcr.io/nvidia/tritonserver:24.03-trtllm-python-py3

COPY requirements.txt /workspace/requirements.txt
WORKDIR /workspace
Expand Down
14 changes: 14 additions & 0 deletions model-engine/model_engine_server/inference/tensorrt-llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Preparing the model weights/tokenizers

Our TensorRT-LLM docker image expects weights to live in s3/other blob store with the following directory structure:

root/
model_tokenizer/
<everything in a HF directory other than the weights themselves>
model_weights/
config.json
rank<i>.engine

You can obtain `model_weights` by building a TRT-LLM engine via the directions found on Nvidia's site (e.g. https://github.com/NVIDIA/TensorRT-LLM/blob/main/README.md#installation, https://github.com/NVIDIA/TensorRT-LLM/blob/v0.8.0/examples/llama/convert_checkpoint.py)

The inference image is built via the Dockerfile in the same directory as this readme.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def parse_arguments():
"--world_size", type=int, default=1, help="world size, only support tensor parallelism now"
)
parser.add_argument("--tritonserver", type=str, default="/opt/tritonserver/bin/tritonserver")
parser.add_argument(
"--http-address",
type=str,
default="ipv6:[::1]",
help="Default HTTP address to ipv6:[::1].",
)
parser.add_argument(
"--http-port",
type=int,
Expand All @@ -20,14 +26,16 @@ def parse_arguments():
return parser.parse_args()


def get_cmd(world_size, tritonserver, model_repo, http_port):
def get_cmd(world_size, tritonserver, model_repo, http_address, http_port):
cmd = "mpirun --allow-run-as-root "
for i in range(world_size):
cmd += f" -n 1 {tritonserver} --model-repository={model_repo} --http-address ipv6:[::1] --http-port {http_port} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{i}_ : "
cmd += f" -n 1 {tritonserver} --model-repository={model_repo} --http-address {http_address} --http-port {http_port} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{i}_ : "
return cmd


if __name__ == "__main__":
args = parse_arguments()
cmd = get_cmd(int(args.world_size), args.tritonserver, args.model_repo, args.http_port)
cmd = get_cmd(
int(args.world_size), args.tritonserver, args.model_repo, args.http_address, args.http_port
)
subprocess.call(cmd, shell=True)
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
sentencepiece==0.1.99
protobuf==4.24.4
protobuf==4.24.4
torch==2.2.2
Loading