Skip to content

Commit 601d753

Browse files
committed
introduce pkg/ai as core LLM completion layer
Add new pkg/ai package that extracts and centralizes model interaction logic from runtime. The package reuses existing types from chat, tools, and provider packages without moving them.
1 parent e20da46 commit 601d753

35 files changed

+2153
-473
lines changed

pkg/ai/completion.go

Lines changed: 474 additions & 0 deletions
Large diffs are not rendered by default.

pkg/ai/completion_test.go

Lines changed: 486 additions & 0 deletions
Large diffs are not rendered by default.

pkg/ai/generate.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package ai
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"io"
8+
"iter"
9+
10+
"github.com/docker/docker-agent/pkg/chat"
11+
)
12+
13+
// StreamValue represents a single value yielded during streaming.
14+
type StreamValue[Out, Stream any] struct {
15+
Done bool
16+
Chunk Stream // valid if Done is false
17+
Value Out // valid if Done is true
18+
Response *ModelResponse // valid if Done is true
19+
}
20+
21+
// ModelStreamValue is a stream value for a model response.
22+
// Out is never set because the value is already available in the Response field.
23+
type ModelStreamValue = StreamValue[struct{}, chat.MessageStreamResponse]
24+
25+
// GenerateStream generates a model response and streams the output.
26+
// It returns an iterator that yields streaming results.
27+
func GenerateStream(ctx context.Context, opts ...Option) iter.Seq2[*ModelStreamValue, error] {
28+
return func(yield func(*ModelStreamValue, error) bool) {
29+
c := &completion{
30+
yield: func(resp chat.MessageStreamResponse) bool {
31+
return yield(&ModelStreamValue{
32+
Done: false,
33+
Chunk: resp,
34+
}, nil)
35+
},
36+
}
37+
38+
c = c.applyOptions(opts...)
39+
40+
res, err := c.generate(ctx)
41+
if errors.Is(err, io.EOF) {
42+
return
43+
}
44+
45+
if err != nil {
46+
yield(nil, err)
47+
return
48+
}
49+
50+
yield(&ModelStreamValue{
51+
Done: true,
52+
Response: res,
53+
}, nil)
54+
}
55+
}
56+
57+
// Generate runs a completion and returns the final model response.
58+
// It handles retry, fallback, tool execution, and streaming internally.
59+
func Generate(ctx context.Context, opts ...Option) (*ModelResponse, error) {
60+
return new(completion).applyOptions(opts...).generate(ctx)
61+
}
62+
63+
// GenerateText is a convenience wrapper around Generate that returns
64+
// only the text content from the model response.
65+
func GenerateText(ctx context.Context, opts ...Option) (string, error) {
66+
res, err := Generate(ctx, opts...)
67+
if err != nil {
68+
return "", err
69+
}
70+
71+
return res.Content, nil
72+
}
73+
74+
// GenerateValue runs a completion and unmarshals the model's response
75+
// content into the provided type. Use with structured output to get
76+
// typed responses from the model.
77+
func GenerateValue[Out any](ctx context.Context, opts ...Option) (*Out, error) {
78+
res, err := Generate(ctx, opts...)
79+
if err != nil {
80+
return nil, err
81+
}
82+
83+
var out Out
84+
if err := json.Unmarshal([]byte(res.Content), &out); err != nil {
85+
return nil, err
86+
}
87+
88+
return &out, nil
89+
}

pkg/ai/generate_test.go

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
package ai
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/docker/docker-agent/pkg/chat"
10+
)
11+
12+
func TestGenerateStream(t *testing.T) {
13+
t.Parallel()
14+
15+
tests := []struct {
16+
name string
17+
p *mockProvider
18+
err string
19+
expContent string
20+
}{
21+
{
22+
name: "happy path yields chunks then done",
23+
p: &mockProvider{
24+
id: "test",
25+
msgs: []chat.MessageStreamResponse{
26+
{
27+
Choices: []chat.MessageStreamChoice{
28+
{Delta: chat.MessageDelta{Content: "hello"}},
29+
},
30+
},
31+
{
32+
Choices: []chat.MessageStreamChoice{
33+
{Delta: chat.MessageDelta{Content: " world"}},
34+
},
35+
},
36+
{
37+
Choices: []chat.MessageStreamChoice{
38+
{FinishReason: chat.FinishReasonStop},
39+
},
40+
Usage: &chat.Usage{InputTokens: 10},
41+
},
42+
},
43+
},
44+
expContent: "hello world",
45+
},
46+
{
47+
name: "error yields error",
48+
p: &mockProvider{
49+
id: "test",
50+
err: errors.New("model failed"),
51+
},
52+
err: "model failed",
53+
},
54+
}
55+
56+
for _, tt := range tests {
57+
t.Run(tt.name, func(t *testing.T) {
58+
opts := []Option{
59+
WithModels(tt.p),
60+
WithMessages(chat.Message{Role: "user", Content: "test"}),
61+
}
62+
63+
var (
64+
chunks int
65+
res *ModelResponse
66+
)
67+
68+
for sv, err := range GenerateStream(t.Context(), opts...) {
69+
if err != nil {
70+
require.ErrorContains(t, err, tt.err)
71+
return
72+
}
73+
74+
if sv.Done {
75+
res = sv.Response
76+
break
77+
}
78+
79+
chunks++
80+
}
81+
82+
if tt.err != "" {
83+
t.Fatal("expected error but got none")
84+
}
85+
86+
require.NotNil(t, res)
87+
require.Equal(t, tt.expContent, res.Content)
88+
require.Positive(t, chunks)
89+
})
90+
}
91+
}
92+
93+
func TestGenerateText(t *testing.T) {
94+
t.Parallel()
95+
96+
tests := []struct {
97+
name string
98+
p *mockProvider
99+
err string
100+
expContent string
101+
}{
102+
{
103+
name: "returns text content",
104+
p: &mockProvider{
105+
id: "test",
106+
msgs: []chat.MessageStreamResponse{
107+
{
108+
Choices: []chat.MessageStreamChoice{
109+
{Delta: chat.MessageDelta{Content: "hello"}},
110+
},
111+
},
112+
{
113+
Choices: []chat.MessageStreamChoice{
114+
{FinishReason: chat.FinishReasonStop},
115+
},
116+
},
117+
},
118+
},
119+
expContent: "hello",
120+
},
121+
{
122+
name: "error returns empty string",
123+
p: &mockProvider{
124+
id: "test",
125+
err: errors.New("model failed"),
126+
},
127+
err: "model failed",
128+
},
129+
}
130+
131+
for _, tt := range tests {
132+
t.Run(tt.name, func(t *testing.T) {
133+
text, err := GenerateText(t.Context(),
134+
WithModels(tt.p),
135+
WithMessages(chat.Message{Role: "user", Content: "test"}),
136+
)
137+
138+
if tt.err != "" {
139+
require.ErrorContains(t, err, tt.err)
140+
require.Empty(t, text)
141+
return
142+
}
143+
144+
require.NoError(t, err)
145+
require.Equal(t, tt.expContent, text)
146+
})
147+
}
148+
}
149+
150+
func TestGenerateValue(t *testing.T) {
151+
t.Parallel()
152+
153+
type Person struct {
154+
Name string `json:"name"`
155+
Age int `json:"age"`
156+
}
157+
158+
tests := []struct {
159+
name string
160+
p *mockProvider
161+
err string
162+
exp *Person
163+
}{
164+
{
165+
name: "unmarshals json response",
166+
p: &mockProvider{
167+
id: "test",
168+
msgs: []chat.MessageStreamResponse{
169+
{
170+
Choices: []chat.MessageStreamChoice{
171+
{Delta: chat.MessageDelta{Content: `{"name":"Alice","age":30}`}},
172+
},
173+
},
174+
{
175+
Choices: []chat.MessageStreamChoice{
176+
{FinishReason: chat.FinishReasonStop},
177+
},
178+
},
179+
},
180+
},
181+
exp: &Person{Name: "Alice", Age: 30},
182+
},
183+
{
184+
name: "invalid json returns error",
185+
p: &mockProvider{
186+
id: "test",
187+
msgs: []chat.MessageStreamResponse{
188+
{
189+
Choices: []chat.MessageStreamChoice{
190+
{Delta: chat.MessageDelta{Content: "not json"}},
191+
},
192+
},
193+
{
194+
Choices: []chat.MessageStreamChoice{
195+
{FinishReason: chat.FinishReasonStop},
196+
},
197+
},
198+
},
199+
},
200+
err: "invalid character",
201+
},
202+
{
203+
name: "model error returns error",
204+
p: &mockProvider{
205+
id: "test",
206+
err: errors.New("model failed"),
207+
},
208+
err: "model failed",
209+
},
210+
}
211+
212+
for _, tt := range tests {
213+
t.Run(tt.name, func(t *testing.T) {
214+
result, err := GenerateValue[Person](t.Context(),
215+
WithModels(tt.p),
216+
WithMessages(chat.Message{Role: "user", Content: "test"}),
217+
)
218+
219+
if tt.err != "" {
220+
require.ErrorContains(t, err, tt.err)
221+
require.Nil(t, result)
222+
return
223+
}
224+
225+
require.NoError(t, err)
226+
require.Equal(t, tt.exp, result)
227+
})
228+
}
229+
}

pkg/ai/interceptor.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package ai
2+
3+
import (
4+
"context"
5+
6+
"github.com/docker/docker-agent/pkg/chat"
7+
"github.com/docker/docker-agent/pkg/model/provider"
8+
"github.com/docker/docker-agent/pkg/tools"
9+
)
10+
11+
// StreamRequest holds the parameters for a single model stream call.
12+
// It is passed through the interceptor chain and can be inspected or
13+
// modified by interceptors before reaching the actual model call.
14+
type StreamRequest struct {
15+
Model provider.Provider
16+
Messages []chat.Message
17+
Tools []tools.Tool
18+
}
19+
20+
// StreamInterceptor wraps a stream call, allowing callers to observe,
21+
// modify, or short-circuit the request before and after it reaches the
22+
// model. The interceptor receives the request and a handler to call the
23+
// next step in the chain — either another interceptor or the actual
24+
// model call. Returning without calling the handler skips the model call.
25+
//
26+
// Example:
27+
//
28+
// func logInterceptor(ctx context.Context, r *StreamRequest, h StreamHandler) (*ModelResponse, error) {
29+
// // before: inspect or modify request
30+
// res, err := h(ctx, r)
31+
// // after: inspect response, record telemetry, etc.
32+
// return res, err
33+
// }
34+
type StreamInterceptor func(context.Context, *StreamRequest, StreamHandler) (*ModelResponse, error)
35+
36+
// StreamHandler is the function signature for the next step in the
37+
// interceptor chain. Call it to proceed with the stream request.
38+
type StreamHandler func(context.Context, *StreamRequest) (*ModelResponse, error)
39+
40+
// ToolCallInterceptor wraps an individual tool call execution.
41+
// The interceptor is responsible for calling tool.Handler and can
42+
// add behavior before and after (permissions, logging, telemetry).
43+
type ToolCallInterceptor func(context.Context, *ModelResponse, tools.ToolCall, tools.Tool) (*tools.ToolCallResult, error)

0 commit comments

Comments
 (0)