Skip to content

Commit a6f47c3

Browse files
authored
chore: use toolcalls (#116)
Signed-off-by: Sertac Ozercan <[email protected]>
1 parent 0043908 commit a6f47c3

File tree

3 files changed

+39
-32
lines changed

3 files changed

+39
-32
lines changed

cmd/cli/completion.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ import (
1212

1313
openai "github.com/sashabaranov/go-openai"
1414
"github.com/sethvargo/go-retry"
15+
log "github.com/sirupsen/logrus"
1516
)
1617

18+
const maxRetries = 10
19+
1720
type oaiClients struct {
1821
openAIClient openai.Client
1922
}
@@ -72,20 +75,18 @@ func gptCompletion(ctx context.Context, client oaiClients, prompts []string) (st
7275

7376
var resp string
7477
var err error
75-
r := retry.WithMaxRetries(10, retry.NewExponential(1*time.Second))
78+
r := retry.WithMaxRetries(maxRetries, retry.NewExponential(1*time.Second))
7679
if err := retry.Do(ctx, r, func(ctx context.Context) error {
7780
resp, err = client.openaiGptChatCompletion(ctx, &prompt, temp)
7881

79-
requestErr := &openai.RequestError{}
82+
requestErr := &openai.APIError{}
8083
if errors.As(err, &requestErr) {
8184
switch requestErr.HTTPStatusCode {
82-
case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusServiceUnavailable:
85+
case http.StatusTooManyRequests, http.StatusRequestTimeout, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
86+
log.Debugf("retrying due to status code %d: %s", requestErr.HTTPStatusCode, requestErr.Message)
8387
return retry.RetryableError(err)
8488
}
8589
}
86-
if err != nil {
87-
return err
88-
}
8990
return nil
9091
}); err != nil {
9192
return "", err

cmd/cli/functions.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,17 @@ func (s *schema) Run() (content string, err error) {
6969
return string(schemaBytes), nil
7070
}
7171

72-
func funcCall(call *openai.FunctionCall) (string, error) {
73-
switch call.Name {
72+
func callTool(toolCall openai.ToolCall) (string, error) {
73+
switch toolCall.Function.Name {
7474
case findSchemaNames.Name:
7575
var f schemaNames
76-
if err := json.Unmarshal([]byte(call.Arguments), &f); err != nil {
76+
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &f); err != nil {
7777
return "", err
7878
}
7979
return f.Run()
8080
case getSchema.Name:
8181
var f schema
82-
if err := json.Unmarshal([]byte(call.Arguments), &f); err != nil {
82+
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &f); err != nil {
8383
return "", err
8484
}
8585
return f.Run()

cmd/cli/openai.go

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,24 @@ import (
99
log "github.com/sirupsen/logrus"
1010
)
1111

12-
type functionCallType string
12+
type toolChoiceType string
1313

1414
const (
15-
fnCallAuto functionCallType = "auto"
16-
fnCallNone functionCallType = "none"
15+
toolChoiceAuto toolChoiceType = "auto"
16+
toolChoiceNone toolChoiceType = "none"
1717
)
1818

1919
func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt *strings.Builder, temp float32) (string, error) {
2020
var (
21-
resp openai.ChatCompletionResponse
22-
req openai.ChatCompletionRequest
23-
funcName *openai.FunctionCall
24-
content string
25-
err error
21+
resp openai.ChatCompletionResponse
22+
req openai.ChatCompletionRequest
23+
content string
24+
err error
2625
)
2726

2827
// if we are using the k8s API, we need to call the functions
29-
fnCallType := fnCallAuto
28+
toolChoiseType := toolChoiceAuto
29+
3030
for {
3131
prompt.WriteString(content)
3232
log.Debugf("prompt: %s", prompt.String())
@@ -44,30 +44,36 @@ func (c *oaiClients) openaiGptChatCompletion(ctx context.Context, prompt *string
4444
}
4545

4646
if *usek8sAPI {
47-
// TODO: migrate to tools api
48-
req.Functions = []openai.FunctionDefinition{ // nolint:staticcheck
49-
findSchemaNames,
50-
getSchema,
47+
req.Tools = []openai.Tool{
48+
{
49+
Type: "function",
50+
Function: &findSchemaNames,
51+
},
52+
{
53+
Type: "function",
54+
Function: &getSchema,
55+
},
5156
}
52-
req.FunctionCall = fnCallType // nolint:staticcheck
57+
req.ToolChoice = toolChoiseType
5358
}
5459

5560
resp, err = c.openAIClient.CreateChatCompletion(ctx, req)
5661
if err != nil {
5762
return "", err
5863
}
5964

60-
funcName = resp.Choices[0].Message.FunctionCall
61-
// if there is no function call, we are done
62-
if funcName == nil {
65+
if len(resp.Choices[0].Message.ToolCalls) == 0 {
6366
break
6467
}
65-
log.Debugf("calling function: %s", funcName.Name)
6668

67-
// if there is a function call, we need to call it and get the result
68-
content, err = funcCall(funcName)
69-
if err != nil {
70-
return "", err
69+
for _, tool := range resp.Choices[0].Message.ToolCalls {
70+
log.Debugf("calling tool: %s", tool.Function.Name)
71+
72+
// if there is a tool call, we need to call it and get the result
73+
content, err = callTool(tool)
74+
if err != nil {
75+
return "", err
76+
}
7177
}
7278
}
7379

0 commit comments

Comments
 (0)