Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
from .hf_api_endpoint import HFAPIEndpoint
from .hf_transformers import HFTransformers
from .llm import LLM
from .minimax import MiniMax
from .mistral_ai import MistralAI
from .openai import OpenAI
from .openai_assistant import OpenAIAssistant
Expand All @@ -105,6 +106,7 @@
"HFAPIEndpoint",
"Together",
"MistralAI",
"MiniMax",
"Anthropic",
"Cohere",
"AI21",
Expand Down
149 changes: 149 additions & 0 deletions src/llms/minimax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
from functools import cached_property
from typing import Any, Callable

import openai

from ..utils import ring_utils as ring
from .llm import (
DEFAULT_BATCH_SIZE,
_check_max_new_tokens_possible,
_check_temperature_and_top_p,
)
from .openai import OpenAI

# MiniMax model context lengths and max output lengths
_MINIMAX_CONTEXT_LENGTHS: dict[str, int] = {
"MiniMax-M2.7": 1000000,
"MiniMax-M2.7-highspeed": 1000000,
"MiniMax-M2.5": 204800,
"MiniMax-M2.5-highspeed": 204800,
}

_MINIMAX_MAX_OUTPUT_LENGTHS: dict[str, int] = {
"MiniMax-M2.7": 16384,
"MiniMax-M2.7-highspeed": 16384,
"MiniMax-M2.5": 8192,
"MiniMax-M2.5-highspeed": 8192,
}

_MINIMAX_BASE_URL = "https://api.minimax.io/v1"


class MiniMax(OpenAI):
"""A MiniMax LLM provider that uses MiniMax's OpenAI-compatible API.

MiniMax provides large language models accessible via an OpenAI-compatible
API endpoint. Supported models include ``MiniMax-M2.7``,
``MiniMax-M2.7-highspeed``, ``MiniMax-M2.5``, and
``MiniMax-M2.5-highspeed``.

Args:
model_name: The name of the MiniMax model to use.
system_prompt: An optional system prompt to use.
api_key: The MiniMax API key. If ``None``, the ``MINIMAX_API_KEY``
environment variable will be used.
retry_on_fail: Whether to retry on failure.
cache_folder_path: The path to the cache folder.
**kwargs: Additional keyword arguments passed to the OpenAI client.
"""

def __init__(
self,
model_name: str,
system_prompt: None | str = None,
api_key: None | str = None,
retry_on_fail: bool = True,
cache_folder_path: None | str = None,
**kwargs,
):
super().__init__(
model_name=model_name,
system_prompt=system_prompt or "You are a helpful assistant.",
api_key=api_key or os.environ.get("MINIMAX_API_KEY"),
base_url=_MINIMAX_BASE_URL,
retry_on_fail=retry_on_fail,
cache_folder_path=cache_folder_path,
**kwargs,
)

@cached_property
def client(self) -> openai.OpenAI:
other_kwargs: dict[str, Any] = {}
if self.api_key:
other_kwargs["api_key"] = self.api_key
return openai.OpenAI(
base_url=_MINIMAX_BASE_URL,
**other_kwargs,
**self.kwargs,
)

@ring.lru(maxsize=128)
def get_max_context_length(self, max_new_tokens: int) -> int:
"""Gets the maximum context length for the model.

Args:
max_new_tokens: The maximum number of tokens that can be generated.

Returns:
The maximum context length.
"""
# Use known context lengths for MiniMax models
max_context_length = _MINIMAX_CONTEXT_LENGTHS.get(self.model_name, 204800)
# Account for chat format tokens (system prompt + message framing)
format_tokens = 4 * 3 + self.count_tokens(self.system_prompt or "")
return max_context_length - max_new_tokens - format_tokens

def _get_max_output_length(self) -> None | int:
return _MINIMAX_MAX_OUTPUT_LENGTHS.get(self.model_name, 8192)

def _run_batch(
self,
max_length_func: Callable[[list[str]], int],
inputs: list[str],
max_new_tokens: None | int = None,
temperature: float = 1.0,
top_p: float = 0.0,
n: int = 1,
stop: None | str | list[str] = None,
repetition_penalty: None | float = None,
logit_bias: None | dict[int, float] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
seed: None | int = None,
**kwargs,
) -> list[str] | list[list[str]]:
# MiniMax requires temperature in (0.0, 1.0]
if temperature == 0.0:
temperature = 0.01
elif temperature > 1.0:
temperature = 1.0

return super()._run_batch(
max_length_func=max_length_func,
inputs=inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
repetition_penalty=repetition_penalty,
logit_bias=logit_bias,
batch_size=batch_size,
seed=seed,
**kwargs,
)

@cached_property
def model_card(self) -> None | str:
return "https://platform.minimaxi.com/document/Models"

@cached_property
def license(self) -> None | str:
return "https://platform.minimaxi.com/document/Terms%20of%20service"

@cached_property
def citation(self) -> None | list[str]:
return None


__all__ = ["MiniMax"]
216 changes: 216 additions & 0 deletions src/tests/llms/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
GoogleAIStudio,
HFAPIEndpoint,
HFTransformers,
MiniMax,
MistralAI,
OpenAI,
OpenAIAssistant,
Expand Down Expand Up @@ -3162,6 +3163,221 @@ def chat_mocked(**kwargs):
assert "client" not in llm.__dict__ and "tokenizer" not in llm.__dict__


class TestMiniMax:
def test_init(self, create_datadreamer):
with create_datadreamer():
llm = MiniMax("MiniMax-M2.7", api_key="fake-key")
assert llm.model_name == "MiniMax-M2.7"
assert llm.base_url == "https://api.minimax.io/v1"
assert llm.system_prompt == "You are a helpful assistant."

def test_metadata(self, create_datadreamer):
llm = MiniMax("MiniMax-M2.7", api_key="fake-key")
assert llm.model_card == "https://platform.minimaxi.com/document/Models"
assert (
llm.license
== "https://platform.minimaxi.com/document/Terms%20of%20service"
)
assert llm.citation is None

def test_count_tokens(self, create_datadreamer):
with create_datadreamer():
llm = MiniMax("MiniMax-M2.7", api_key="fake-key")
token_count = llm.count_tokens("This is a test.")
assert isinstance(token_count, int)
assert token_count > 0

def test_get_max_context_length(self, create_datadreamer):
with create_datadreamer():
llm = MiniMax("MiniMax-M2.7", api_key="fake-key")
ctx = llm.get_max_context_length(max_new_tokens=0)
# M2.7 has 1M context, minus format tokens
assert ctx > 999000
assert ctx < 1000000

llm2 = MiniMax("MiniMax-M2.5", api_key="fake-key")
ctx2 = llm2.get_max_context_length(max_new_tokens=0)
assert ctx2 > 204000
assert ctx2 < 204800

def test_get_max_output_length(self, create_datadreamer):
with create_datadreamer():
llm = MiniMax("MiniMax-M2.7", api_key="fake-key")
assert llm._get_max_output_length() == 16384

llm2 = MiniMax("MiniMax-M2.5", api_key="fake-key")
assert llm2._get_max_output_length() == 8192

llm3 = MiniMax("MiniMax-M2.5-highspeed", api_key="fake-key")
assert llm3._get_max_output_length() == 8192

llm4 = MiniMax("MiniMax-M2.7-highspeed", api_key="fake-key")
assert llm4._get_max_output_length() == 16384

def test_temperature_clamping(self, create_datadreamer, mocker):
from unittest.mock import MagicMock

with create_datadreamer():
llm = MiniMax("MiniMax-M2.7", api_key="fake-key")

# Create a mock response
mock_choice = MagicMock()
mock_choice.message.content = "Test response"
mock_response = MagicMock()
mock_response.choices = [mock_choice]

mocker.patch.object(
llm.client.chat.completions,
"create",
return_value=mock_response,
)

# Run with temperature=0.0 (should be clamped to 0.01)
llm.run(
["Test prompt"],
max_new_tokens=10,
temperature=0.0,
top_p=1.0,
n=1,
batch_size=1,
)
call_kwargs = llm.client.chat.completions.create.call_args_list[
0
].kwargs
assert call_kwargs["temperature"] == 0.01

@typing.no_type_check
def test_run(self, create_datadreamer, mocker):
from unittest.mock import MagicMock

with create_datadreamer():
llm = MiniMax("MiniMax-M2.7", api_key="fake-key")

# Create a mock response factory
def chat_mocked(**kwargs):
prompt = kwargs["messages"][-1]["content"]
mock_choice = MagicMock()
mock_choice.message.content = f"Response to: {prompt}"
mock_response = MagicMock()
mock_response.choices = [mock_choice]
return mock_response

mocker.patch.object(
llm.client.chat.completions,
"create",
side_effect=chat_mocked,
)

# Simple run
generated_texts = llm.run(
["What color is the sky?", "What color are trees?"],
max_new_tokens=25,
temperature=0.3,
top_p=1.0,
n=1,
stop=None,
repetition_penalty=None,
logit_bias=None,
batch_size=2,
)
assert generated_texts == [
"Response to: What color is the sky?",
"Response to: What color are trees?",
]

# Test return_generator
generated_texts_generator = llm.run(
["What color is the sky?", "What color are trees?"],
max_new_tokens=25,
temperature=0.3,
top_p=1.0,
n=1,
stop=None,
repetition_penalty=None,
logit_bias=None,
batch_size=2,
return_generator=True,
)
assert isinstance(generated_texts_generator, GeneratorType)
assert list(generated_texts_generator) == generated_texts

# Test unload model
assert "client" in llm.__dict__
llm.unload_model()
assert "client" not in llm.__dict__

def test_custom_system_prompt(self, create_datadreamer):
with create_datadreamer():
llm = MiniMax(
"MiniMax-M2.7",
system_prompt="You are a coding assistant.",
api_key="fake-key",
)
assert llm.system_prompt == "You are a coding assistant."

def test_display_name(self, create_datadreamer):
with create_datadreamer():
llm = MiniMax("MiniMax-M2.7", api_key="fake-key")
assert "MiniMax-M2.7" in llm.display_name

@pytest.mark.skipif(
"MINIMAX_API_KEY" not in os.environ, reason="requires MiniMax API key"
)
def test_integration_run(self, create_datadreamer):
"""Integration test that runs against the real MiniMax API."""
with create_datadreamer():
llm = MiniMax("MiniMax-M2.7")
generated_texts = llm.run(
["Say hello in one word."],
max_new_tokens=10,
temperature=0.01,
top_p=1.0,
n=1,
batch_size=1,
)
assert len(generated_texts) == 1
assert isinstance(generated_texts[0], str)
assert len(generated_texts[0]) > 0

@pytest.mark.skipif(
"MINIMAX_API_KEY" not in os.environ, reason="requires MiniMax API key"
)
def test_integration_streaming(self, create_datadreamer):
"""Integration test for generator-based output."""
with create_datadreamer():
llm = MiniMax("MiniMax-M2.7")
results = llm.run(
["What is 2+2? Answer with just the number."],
max_new_tokens=5,
temperature=0.01,
top_p=1.0,
n=1,
batch_size=1,
return_generator=True,
)
results_list = list(results)
assert len(results_list) == 1
assert "4" in results_list[0]

@pytest.mark.skipif(
"MINIMAX_API_KEY" not in os.environ, reason="requires MiniMax API key"
)
def test_integration_m27_highspeed(self, create_datadreamer):
"""Integration test for the M2.7-highspeed model."""
with create_datadreamer():
llm = MiniMax("MiniMax-M2.7-highspeed")
generated_texts = llm.run(
["Say 'yes' or 'no'."],
max_new_tokens=5,
temperature=0.01,
top_p=1.0,
n=1,
batch_size=1,
)
assert len(generated_texts) == 1
assert isinstance(generated_texts[0], str)


class TestPetals:
pydantic_version = None
bitsandbytes_version = None
Expand Down
Loading
Loading