Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,18 @@ func NewMessageContentSource(
}

type MessageContentToolUse struct {
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input json.RawMessage `json:"input,omitempty"`
ID string `json:"id"`
Name string `json:"name"`
Input json.RawMessage `json:"input"`
}

func NewMessageContentToolUse(
toolUseId, name string,
input json.RawMessage,
) *MessageContentToolUse {
if input == nil {
input = json.RawMessage("{}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: prefer backticks ` to a quotes " for strings in go.

}
return &MessageContentToolUse{
ID: toolUseId,
Name: name,
Expand Down
3 changes: 3 additions & 0 deletions message_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ func (c *Client) CreateMessagesStream(
if len(response.Content) > d.Index {
stopContent = response.Content[d.Index]
if stopContent.Type == MessagesContentTypeToolUse {
if stopContent.Input == nil {
stopContent.Input = json.RawMessage("{}")
}
stopContent.Input = json.RawMessage(*stopContent.PartialJson)
stopContent.PartialJson = nil
response.Content[d.Index] = stopContent
Expand Down
171 changes: 157 additions & 14 deletions message_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
emptyMessagesLimit := 100
server := test.NewTestServer()
server.RegisterHandler("/v1/messages",
handlerMessagesStreamEmptyMessages(emptyMessagesLimit-1, "fake: {}"),

Check failure on line 126 in message_stream_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

undefined: handlerMessagesStreamEmptyMessages
)

ts := server.AnthropicTestServer()
Expand Down Expand Up @@ -153,7 +153,7 @@
server := test.NewTestServer()
server.RegisterHandler(
"/v1/messages",
handlerMessagesStreamEmptyMessages(emptyMessagesLimit, "fake: {}"),

Check failure on line 156 in message_stream_test.go

View workflow job for this annotation

GitHub Actions / Sanity check

undefined: handlerMessagesStreamEmptyMessages
)

ts := server.AnthropicTestServer()
Expand Down Expand Up @@ -288,6 +288,104 @@
}
}

func TestMessagesStreamToolUseWithoutParameters(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/messages", handlerMessagesStreamToolUseWithoutParameters)

ts := server.AnthropicTestServer()
ts.Start()
defer ts.Close()

baseUrl := ts.URL + "/v1"
cli := anthropic.NewClient(
test.GetTestToken(),
anthropic.WithBaseURL(baseUrl),
)

request := anthropic.MessagesStreamRequest{
MessagesRequest: anthropic.MessagesRequest{
Model: anthropic.ModelClaude3Opus20240229,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage("What is the weather like in San Francisco?"),
},
MaxTokens: 1000,
Tools: []anthropic.ToolDefinition{
{
Name: "get_weather",
Description: "Get the current weather in a given location",
InputSchema: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"location": {
Type: jsonschema.String,
Description: "The city and state, e.g. San Francisco, CA",
},
},
Required: []string{"location"},
},
},
},
},
OnContentBlockStop: func(data anthropic.MessagesEventContentBlockStopData, content anthropic.MessageContent) {
t.Logf("content block stop, index: %d", data.Index)
switch content.Type {
case anthropic.MessagesContentTypeText:
t.Logf("content block stop, text: %s", content.GetText())
case anthropic.MessagesContentTypeToolUse:
t.Logf("content blog stop, tool_use: %+v, input: %s",
*content.MessageContentToolUse,
content.MessageContentToolUse.Input,
)
}
},
}

resp, err := cli.CreateMessagesStream(context.Background(), request)
if err != nil {
t.Fatal(err)
}

request.Messages = append(request.Messages, anthropic.Message{
Role: anthropic.RoleAssistant,
Content: resp.Content,
})

var toolUse *anthropic.MessageContentToolUse

for _, m := range resp.Content {
if m.Type == anthropic.MessagesContentTypeToolUse {
toolUse = m.MessageContentToolUse
}
}

if toolUse == nil {
t.Fatalf("tool use not found")
}

request.Messages = append(
request.Messages,
anthropic.NewToolResultsMessage(toolUse.ID, "65 degrees", false),
)

resp, err = cli.CreateMessagesStream(context.Background(), request)
if err != nil {
t.Fatal(err)
}

var hasDegrees bool
for _, m := range resp.Content {
if m.Type == anthropic.MessagesContentTypeText {
if strings.Contains(m.GetText(), "65 degrees") {
hasDegrees = true
break
}
}
}
if !hasDegrees {
t.Fatalf("Expected response to contain '65 degrees', got: %+v", resp.Content)
}
}

func handlerMessagesStream(w http.ResponseWriter, r *http.Request) {
request, err := getRequest[anthropic.MessagesRequest](r)
if err != nil {
Expand Down Expand Up @@ -438,31 +536,76 @@
_, _ = w.Write(dataBytes)
}

func handlerMessagesStreamEmptyMessages(numEmptyMessages int, payload string) test.Handler {
return func(w http.ResponseWriter, r *http.Request) {
_, err := getRequest[anthropic.MessagesRequest](r)
if err != nil {
http.Error(w, "request error", http.StatusBadRequest)
return
func handlerMessagesStreamToolUseWithoutParameters(w http.ResponseWriter, r *http.Request) {
messagesReq, err := getRequest[anthropic.MessagesRequest](r)
if err != nil {
http.Error(w, "request error", http.StatusBadRequest)
return
}

var hasToolResult bool

for _, m := range messagesReq.Messages {
for _, c := range m.Content {
if c.Type == anthropic.MessagesContentTypeToolResult {
hasToolResult = true
break
}
}
}

w.Header().Set("Content-Type", "text/event-stream")

var dataBytes []byte

dataBytes = append(dataBytes, []byte("event: message_start\n")...)
dataBytes = append(
dataBytes,
[]byte(
`data: {"type":"message_start","message":{"id":"123333","type":"message","role":"assistant","model":"claude-3-opus-20240229","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":844,"output_tokens":2}}}`+"\n\n",
)...)

w.Header().Set("Content-Type", "text/event-stream")
if hasToolResult {
dataBytes = append(dataBytes, []byte("event: content_block_start\n")...)
dataBytes = append(
dataBytes,
[]byte(
`data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`+"\n\n",
)...)

var dataBytes []byte
dataBytes = append(dataBytes, []byte("event: content_block_delta\n")...)
dataBytes = append(
dataBytes,
[]byte(
`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"The current weather in San Francisco is 65 degrees Fahrenheit. It's a nice, moderate temperature typical of the San Francisco Bay Area climate."}}`+"\n\n",
)...)

dataBytes = append(dataBytes, []byte("event: message_start\n")...)
dataBytes = append(dataBytes, []byte("event: content_block_stop\n")...)
dataBytes = append(
dataBytes,
[]byte(`data: {"type":"content_block_stop","index":0}`+"\n\n")...)

dataBytes = append(dataBytes, []byte("event: message_delta\n")...)
dataBytes = append(
dataBytes,
[]byte(
`data: {"type":"message_start","message":{"id":"123333","type":"message","role":"assistant","model":"claude-3-opus-20240229","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":844,"output_tokens":2}}}`+"\n\n",
`data: {"type":"message_delta","delta":{"stop_reason":"end_return","stop_sequence":null},"usage":{"output_tokens":9}}`+"\n\n",
)...)
} else {
dataBytes = append(dataBytes, []byte("event: content_block_start\n")...)
dataBytes = append(dataBytes, []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_019ktsPEWabjtYw1iGdjT2Qy","name":"get_weather","input":{}}}`+"\n\n")...)

for i := 0; i < numEmptyMessages; i++ {
dataBytes = append(dataBytes, []byte(payload+"\n")...)
}
dataBytes = append(dataBytes, []byte("event: content_block_stop\n")...)
dataBytes = append(dataBytes, []byte(`data: {"type":"content_block_stop","index":0}`+"\n\n")...)

_, _ = w.Write(dataBytes)
dataBytes = append(dataBytes, []byte("event: message_delta\n")...)
dataBytes = append(dataBytes, []byte(`data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":9}}`+"\n\n")...)
}

dataBytes = append(dataBytes, []byte("event: message_stop\n")...)
dataBytes = append(dataBytes, []byte(`data: {"type":"message_stop"}`+"\n\n")...)

_, _ = w.Write(dataBytes)
}

func TestVertexMessagesStream(t *testing.T) {
Expand Down
Loading