-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcommands_do.go
More file actions
211 lines (178 loc) · 6.14 KB
/
commands_do.go
File metadata and controls
211 lines (178 loc) · 6.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
sdk "github.com/modelrelay/modelrelay/sdk/go"
"github.com/modelrelay/modelrelay/sdk/go/llm"
"github.com/spf13/cobra"
)
func newDoCmd() *cobra.Command {
var model string
var system string
var allowAll bool
var allow []string
var maxTurns int
var trace bool
cmd := &cobra.Command{
Use: "do <task>",
Short: "Execute a task using AI with bash tools",
Long: `Execute a task using AI that can run bash commands.
Examples:
mrl do "commit my changes"
mrl do "list all TODO comments in this repo"
mrl do "run tests and fix any failures" --allow-all
mrl do "show git status" --allow "git "
By default, no commands are allowed. Use --allow to whitelist
command prefixes, or --allow-all to permit any command.
Permissions can also be set in config:
mrl config set --allow-all
mrl config set --allow "git " --allow "npm "`,
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
return runDo(cmd, args, model, system, allow, allowAll, maxTurns, trace)
},
}
cmd.Flags().StringVar(&model, "model", "", "Model ID (overrides profile default)")
cmd.Flags().StringVar(&system, "system", "", "System prompt")
cmd.Flags().StringSliceVar(&allow, "allow", nil, "Allow bash command prefix (repeatable)")
cmd.Flags().BoolVar(&allowAll, "allow-all", false, "Allow all bash commands (use with care)")
cmd.Flags().IntVar(&maxTurns, "max-turns", 50, "Max tool loop turns")
cmd.Flags().BoolVar(&trace, "trace", false, "Print tool calls as they execute")
return cmd
}
func runDo(cmd *cobra.Command, args []string, modelFlag, system string, allowFlag []string, allowAllFlag bool, maxTurns int, traceFlag bool) error {
cfg, err := runtimeConfigFrom(cmd)
if err != nil {
return err
}
model := resolveModel(modelFlag, cfg)
if model == "" {
return errors.New("model is required (set via --model, MODELRELAY_MODEL, or mrl config set --model)")
}
// Merge CLI flags with config (CLI takes precedence)
allowAll := allowAllFlag || cfg.AllowAll
allow := allowFlag
if len(allow) == 0 {
allow = cfg.Allow
}
trace := traceFlag || cfg.Trace
if !allowAll && len(allow) == 0 {
return errors.New("bash permissions required: use --allow <prefix>, --allow-all, or set allow_all in config")
}
client, err := newPromptClient(cfg)
if err != nil {
return err
}
prompt := strings.Join(args, " ")
ctx, cancel := contextWithTimeout(cfg.Timeout)
defer cancel()
return runDoLoop(ctx, client, model, system, prompt, allow, allowAll, maxTurns, trace)
}
func runDoLoop(ctx context.Context, client *sdk.Client, model, system, prompt string, allow []string, allowAll bool, maxTurns int, trace bool) error {
// Build bash tool options
bashOpts := []sdk.LocalBashOption{
sdk.WithLocalBashTimeout(30 * time.Second),
sdk.WithLocalBashMaxOutputBytes(64_000),
sdk.WithLocalBashInheritEnv(),
}
if allowAll {
bashOpts = append(bashOpts, sdk.WithLocalBashAllowAllCommands())
}
if len(allow) > 0 {
rules := make([]sdk.BashCommandRule, len(allow))
for i, prefix := range allow {
rules[i] = sdk.BashCommandPrefix(prefix)
}
bashOpts = append(bashOpts, sdk.WithLocalBashAllowRules(rules...))
}
// Create tool registry and definitions
registry := sdk.NewToolRegistry()
sdk.NewLocalBashToolPack(".", bashOpts...).RegisterInto(registry)
bashTool := sdk.MustFunctionToolFromType[bashToolArgs](sdk.ToolNameBash, "Execute a shell command")
tools := []llm.Tool{bashTool}
// Build initial messages
var messages []llm.InputItem
sysPrompt := system
if sysPrompt == "" {
sysPrompt = `You are an agent that completes tasks by executing shell commands. Use the bash tool to run commands. Do not explain how to do things - actually do them. Be concise - when done, just say what you did in one short sentence.
When writing git commits:
- Write clear, descriptive commit messages that explain what changed and why
- Use conventional commit format when appropriate (feat:, fix:, docs:, refactor:, etc.)
- Look at the actual diff to understand what changed before writing the message`
}
messages = append(messages, llm.NewSystemText(sysPrompt))
messages = append(messages, llm.NewUserText(prompt))
modelID := sdk.NewModelID(model)
var usage sdk.AgentUsage
for range maxTurns {
req, callOpts, err := client.Responses.New().
Model(modelID).
Input(messages).
Tools(tools).
Build()
if err != nil {
return err
}
resp, err := client.Responses.Create(ctx, req, callOpts...)
if err != nil {
return err
}
usage.LLMCalls++
usage.InputTokens += resp.Usage.InputTokens
usage.OutputTokens += resp.Usage.OutputTokens
usage.TotalTokens += resp.Usage.TotalTokens
toolCalls := resp.ToolCalls()
if len(toolCalls) == 0 {
// Done - print final response and exit
if text := resp.AssistantText(); text != "" {
fmt.Println(text)
}
return nil
}
usage.ToolCalls += len(toolCalls)
// Add assistant message with tool calls
messages = append(messages, sdk.AssistantMessageWithToolCalls(resp.AssistantText(), toolCalls))
// Print tool calls before execution
if trace {
for _, tc := range toolCalls {
if tc.Function != nil {
var args bashToolArgs
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err == nil && args.Command != "" {
fmt.Printf("\033[1;36m→ %s\033[0m\n", args.Command)
} else {
fmt.Printf("\033[1;36m→ %s %s\033[0m\n", tc.Function.Name, tc.Function.Arguments)
}
}
}
}
// Execute tools and add results
results := registry.ExecuteAll(toolCalls)
messages = append(messages, registry.ResultsToMessages(results)...)
// Print tool execution output
for _, result := range results {
if result.Result != nil {
switch r := result.Result.(type) {
case sdk.BashResult:
if r.Output != "" {
fmt.Printf("\033[2m%s\033[0m\n", r.Output)
}
if r.Error != "" {
fmt.Printf("\033[31merror: %s\033[0m\n", r.Error)
}
case string:
if r != "" {
fmt.Println(r)
}
}
}
if result.Error != nil {
fmt.Printf("\033[31merror: %s\033[0m\n", result.Error)
}
}
}
return fmt.Errorf("max turns (%d) reached without completion", maxTurns)
}