Skip to content

Commit c8ecaf4

Browse files
qiaodevcopybara-github
authored andcommitted
feat!: Support SendClientContent/SendRealtimeInput/SendToolResponse methods in Session struct and remove Send method
PiperOrigin-RevId: 744855891
1 parent aebbdaa commit c8ecaf4

File tree

2 files changed

+204
-8
lines changed

2 files changed

+204
-8
lines changed

live.go

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23+
"log"
2324
"net/http"
2425
"net/url"
26+
"sync"
2527

2628
"github.com/gorilla/websocket"
2729
)
@@ -43,10 +45,18 @@ type Session struct {
4345
apiClient *apiClient
4446
}
4547

48+
var (
49+
experimentalWarningLiveConnect sync.Once
50+
)
51+
4652
// Connect establishes a realtime connection to the specified model with given configuration.
4753
// It returns a Session object representing the connection or an error if the connection fails.
4854
// The live module is experimental.
4955
func (r *Live) Connect(context context.Context, model string, config *LiveConnectConfig) (*Session, error) {
56+
experimentalWarningLiveConnect.Do(func() {
57+
log.Println("Warning: The Live API is experimental and may change in future versions.")
58+
})
59+
5060
httpOptions := r.apiClient.clientConfig.HTTPOptions
5161
if httpOptions.APIVersion == "" {
5262
return nil, fmt.Errorf("live module requires APIVersion to be set. You can set APIVersion to v1beta1 for BackendVertexAI or v1apha for BackendGeminiAPI")
@@ -130,10 +140,67 @@ func (r *Live) Connect(context context.Context, model string, config *LiveConnec
130140
return s, nil
131141
}
132142

143+
// LiveClientContentInput is the input for [SendClientContent].
144+
type LiveClientContentInput struct {
145+
// The content appended to the current conversation with the model.
146+
// For single-turn queries, this is a single instance. For multi-turn
147+
// queries, this is a repeated field that contains conversation history and
148+
// latest request.
149+
turns []*Content
150+
// TurnComplete is default to true, indicating that the server content generation should
151+
// start with the currently accumulated prompt. If set to false, the server will await
152+
// additional messages, accumulating the prompt, and start generation until received a
153+
// TurnComplete true message.
154+
TurnComplete *bool `json:"turnComplete,omitempty"`
155+
}
156+
157+
// SendClientContent transmits a [LiveClientContent] over the established connection.
158+
// It returns an error if sending the message fails.
159+
// The live module is experimental.
160+
func (s *Session) SendClientContent(input LiveClientContentInput) error {
161+
if input.TurnComplete == nil {
162+
input.TurnComplete = Ptr(true)
163+
}
164+
clientMessage := &LiveClientMessage{
165+
ClientContent: &LiveClientContent{Turns: input.turns, TurnComplete: *input.TurnComplete},
166+
}
167+
return s.send(clientMessage)
168+
}
169+
170+
// LiveRealtimeInput is the input for [SendRealtimeInput].
171+
type LiveRealtimeInput struct {
172+
media *Blob
173+
}
174+
175+
// SendRealtimeInput transmits a [LiveClientRealtimeInput] over the established connection.
176+
// It returns an error if sending the message fails.
177+
// The live module is experimental.
178+
func (s *Session) SendRealtimeInput(input LiveRealtimeInput) error {
179+
clientMessage := &LiveClientMessage{
180+
RealtimeInput: &LiveClientRealtimeInput{MediaChunks: []*Blob{input.media}},
181+
}
182+
return s.send(clientMessage)
183+
}
184+
185+
// LiveToolResponseInput is the input for [SendToolResponse].
186+
type LiveToolResponseInput struct {
187+
FunctionResponses []*FunctionResponse
188+
}
189+
190+
// SendToolResponse transmits a [LiveClientToolResponse] over the established connection.
191+
// It returns an error if sending the message fails.
192+
// The live module is experimental.
193+
func (s *Session) SendToolResponse(input LiveToolResponseInput) error {
194+
clientMessage := &LiveClientMessage{
195+
ToolResponse: &LiveClientToolResponse{FunctionResponses: input.FunctionResponses},
196+
}
197+
return s.send(clientMessage)
198+
}
199+
133200
// Send transmits a LiveClientMessage over the established connection.
134201
// It returns an error if sending the message fails.
135202
// The live module is experimental.
136-
func (s *Session) Send(input *LiveClientMessage) error {
203+
func (s *Session) send(input *LiveClientMessage) error {
137204
if input.Setup != nil {
138205
return fmt.Errorf("message SetUp is not supported in Send(). Use Connect() instead")
139206
}

live_test.go

Lines changed: 136 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ func TestLiveConnect(t *testing.T) {
195195
})
196196
}
197197

198-
t.Run("Send and Receive", func(t *testing.T) {
198+
t.Run("SendClientContent and Receive", func(t *testing.T) {
199199
sendReceiveTests := []struct {
200200
desc string
201201
client *Client
@@ -206,19 +206,19 @@ func TestLiveConnect(t *testing.T) {
206206
{
207207
desc: "send clientContent to Google AI",
208208
client: mldevClient,
209-
wantRequestBodySlice: []string{`{"setup":{"model":"models/test-model"}}`, `{"clientContent":{"turns":[{"parts":[{"text":"client test message"}],"role":"user"}]}}`},
209+
wantRequestBodySlice: []string{`{"setup":{"model":"models/test-model"}}`, `{"clientContent":{"turnComplete":true,"turns":[{"parts":[{"text":"client test message"}],"role":"user"}]}}`},
210210
fakeResponseBodySlice: []string{`{"setupComplete":{}}`, `{"serverContent":{"modelTurn":{"parts":[{"text":"server test message"}],"role":"user"}}}`},
211211
},
212212
{
213213
desc: "send clientContent to Vertex AI",
214214
client: vertexClient,
215-
wantRequestBodySlice: []string{`{"setup":{"model":"projects/test-project/locations/test-location/publishers/google/models/test-model"}}`, `{"clientContent":{"turns":[{"parts":[{"text":"client test message"}],"role":"user"}]}}`},
215+
wantRequestBodySlice: []string{`{"setup":{"model":"projects/test-project/locations/test-location/publishers/google/models/test-model"}}`, `{"clientContent":{"turnComplete":true,"turns":[{"parts":[{"text":"client test message"}],"role":"user"}]}}`},
216216
fakeResponseBodySlice: []string{`{"setupComplete":{}}`, `{"serverContent":{"modelTurn":{"parts":[{"text":"server test message"}],"role":"user"}}}`},
217217
},
218218
{
219219
desc: "received error in response",
220220
client: mldevClient,
221-
wantRequestBodySlice: []string{`{"setup":{"model":"models/test-model"}}`, `{"clientContent":{"turns":[{"parts":[{"text":"client test message"}],"role":"user"}]}}`},
221+
wantRequestBodySlice: []string{`{"setup":{"model":"models/test-model"}}`, `{"clientContent":{"turnComplete":true,"turns":[{"parts":[{"text":"client test message"}],"role":"user"}]}}`},
222222
fakeResponseBodySlice: []string{`{"setupComplete":{}}`, `{"error":{"code":400,"message":"test error message","status":"INVALID_ARGUMENT"}}`},
223223
wantErr: true,
224224
},
@@ -239,12 +239,141 @@ func TestLiveConnect(t *testing.T) {
239239
defer session.Close()
240240

241241
// Construct a test message
242-
clientMessage := &LiveClientMessage{
243-
ClientContent: &LiveClientContent{Turns: Text("client test message")},
242+
243+
// Test sending the message
244+
err = session.SendClientContent(LiveClientContentInput{turns: Text("client test message")})
245+
if err != nil {
246+
t.Errorf("Send failed : %v", err)
247+
}
248+
249+
// Construct the expected response
250+
serverMessage := &LiveServerMessage{ServerContent: &LiveServerContent{ModelTurn: Text("server test message")[0]}}
251+
// Test receiving the response
252+
gotMessage, err := session.Receive()
253+
if err != nil {
254+
if tt.wantErr {
255+
return
256+
}
257+
t.Errorf("Receive failed: %v", err)
244258
}
259+
if diff := cmp.Diff(gotMessage, serverMessage); diff != "" {
260+
t.Errorf("Response message mismatch (-want +got):\n%s", diff)
261+
}
262+
})
263+
}
264+
})
265+
266+
t.Run("SendRealtimeInput and Receive", func(t *testing.T) {
267+
sendReceiveTests := []struct {
268+
desc string
269+
client *Client
270+
wantRequestBodySlice []string
271+
fakeResponseBodySlice []string
272+
wantErr bool
273+
}{
274+
{
275+
desc: "send realtimeInput to Google AI",
276+
client: mldevClient,
277+
wantRequestBodySlice: []string{`{"setup":{"model":"models/test-model"}}`, `{"realtimeInput":{"mediaChunks":[{"data":"dGVzdCBkYXRh","mimeType":"image/png"}]}}`},
278+
fakeResponseBodySlice: []string{`{"setupComplete":{}}`, `{"serverContent":{"modelTurn":{"parts":[{"text":"server test message"}],"role":"user"}}}`},
279+
},
280+
{
281+
desc: "send realtimeInput to Vertex AI",
282+
client: vertexClient,
283+
wantRequestBodySlice: []string{`{"setup":{"model":"projects/test-project/locations/test-location/publishers/google/models/test-model"}}`, `{"realtimeInput":{"mediaChunks":[{"data":"dGVzdCBkYXRh","mimeType":"image/png"}]}}`},
284+
fakeResponseBodySlice: []string{`{"setupComplete":{}}`, `{"serverContent":{"modelTurn":{"parts":[{"text":"server test message"}],"role":"user"}}}`},
285+
},
286+
{
287+
desc: "received error in response",
288+
client: mldevClient,
289+
wantRequestBodySlice: []string{`{"setup":{"model":"models/test-model"}}`, `{"realtimeInput":{"mediaChunks":[{"data":"dGVzdCBkYXRh","mimeType":"image/png"}]}}`},
290+
fakeResponseBodySlice: []string{`{"setupComplete":{}}`, `{"error":{"code":400,"message":"test error message","status":"INVALID_ARGUMENT"}}`},
291+
wantErr: true,
292+
},
293+
}
294+
295+
for _, tt := range sendReceiveTests {
296+
t.Run(tt.desc, func(t *testing.T) {
297+
ts := setupTestWebsocketServer(t, tt.wantRequestBodySlice, tt.fakeResponseBodySlice)
298+
defer ts.Close()
299+
300+
tt.client.Live.apiClient.clientConfig.HTTPOptions.BaseURL = strings.Replace(ts.URL, "http", "ws", 1)
301+
tt.client.Live.apiClient.clientConfig.HTTPClient = ts.Client()
302+
303+
session, err := tt.client.Live.Connect(ctx, "test-model", &LiveConnectConfig{})
304+
if err != nil {
305+
t.Fatalf("Connect failed: %v", err)
306+
}
307+
defer session.Close()
308+
309+
// Test sending the message
310+
err = session.SendRealtimeInput(LiveRealtimeInput{&Blob{Data: []byte("test data"), MIMEType: "image/png"}})
311+
if err != nil {
312+
t.Errorf("Send failed : %v", err)
313+
}
314+
315+
// Construct the expected response
316+
serverMessage := &LiveServerMessage{ServerContent: &LiveServerContent{ModelTurn: Text("server test message")[0]}}
317+
// Test receiving the response
318+
gotMessage, err := session.Receive()
319+
if err != nil {
320+
if tt.wantErr {
321+
return
322+
}
323+
t.Errorf("Receive failed: %v", err)
324+
}
325+
if diff := cmp.Diff(gotMessage, serverMessage); diff != "" {
326+
t.Errorf("Response message mismatch (-want +got):\n%s", diff)
327+
}
328+
})
329+
}
330+
})
331+
332+
t.Run("SendToolResponse and Receive", func(t *testing.T) {
333+
sendReceiveTests := []struct {
334+
desc string
335+
client *Client
336+
wantRequestBodySlice []string
337+
fakeResponseBodySlice []string
338+
wantErr bool
339+
}{
340+
{
341+
desc: "send realtimeInput to Google AI",
342+
client: mldevClient,
343+
wantRequestBodySlice: []string{`{"setup":{"model":"models/test-model"}}`, `{"toolResponse":{"functionResponses":[{"name":"test-function"}]}}`},
344+
fakeResponseBodySlice: []string{`{"setupComplete":{}}`, `{"serverContent":{"modelTurn":{"parts":[{"text":"server test message"}],"role":"user"}}}`},
345+
},
346+
{
347+
desc: "send realtimeInput to Vertex AI",
348+
client: vertexClient,
349+
wantRequestBodySlice: []string{`{"setup":{"model":"projects/test-project/locations/test-location/publishers/google/models/test-model"}}`, `{"toolResponse":{"functionResponses":[{"name":"test-function"}]}}`},
350+
fakeResponseBodySlice: []string{`{"setupComplete":{}}`, `{"serverContent":{"modelTurn":{"parts":[{"text":"server test message"}],"role":"user"}}}`},
351+
},
352+
{
353+
desc: "received error in response",
354+
client: mldevClient,
355+
wantRequestBodySlice: []string{`{"setup":{"model":"models/test-model"}}`, `{"toolResponse":{"functionResponses":[{"name":"test-function"}]}}`},
356+
fakeResponseBodySlice: []string{`{"setupComplete":{}}`, `{"error":{"code":400,"message":"test error message","status":"INVALID_ARGUMENT"}}`},
357+
wantErr: true,
358+
},
359+
}
360+
361+
for _, tt := range sendReceiveTests {
362+
t.Run(tt.desc, func(t *testing.T) {
363+
ts := setupTestWebsocketServer(t, tt.wantRequestBodySlice, tt.fakeResponseBodySlice)
364+
defer ts.Close()
365+
366+
tt.client.Live.apiClient.clientConfig.HTTPOptions.BaseURL = strings.Replace(ts.URL, "http", "ws", 1)
367+
tt.client.Live.apiClient.clientConfig.HTTPClient = ts.Client()
368+
369+
session, err := tt.client.Live.Connect(ctx, "test-model", &LiveConnectConfig{})
370+
if err != nil {
371+
t.Fatalf("Connect failed: %v", err)
372+
}
373+
defer session.Close()
245374

246375
// Test sending the message
247-
err = session.Send(clientMessage)
376+
err = session.SendToolResponse(LiveToolResponseInput{FunctionResponses: []*FunctionResponse{{Name: "test-function"}}})
248377
if err != nil {
249378
t.Errorf("Send failed : %v", err)
250379
}

0 commit comments

Comments
 (0)