Skip to content

Commit 8fea3de

Browse files
authored
llms/openai: add web search tool support
Fixes #1454.
1 parent 509308f commit 8fea3de

File tree

5 files changed

+326
-0
lines changed

5 files changed

+326
-0
lines changed

llms/openai/internal/openaiclient/chat.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ type ChatRequest struct {
8383

8484
// Metadata allows you to specify additional information that will be passed to the model.
8585
Metadata map[string]any `json:"metadata,omitempty"`
86+
87+
// WebSearchOptions configures web search behavior for search-enabled models
88+
// like gpt-4o-search-preview and gpt-4o-mini-search-preview.
89+
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
8690
}
8791

8892
// MarshalJSON ensures that only one of MaxTokens or MaxCompletionTokens is sent.
@@ -151,6 +155,39 @@ const (
151155
ToolTypeFunction ToolType = "function"
152156
)
153157

158+
// WebSearchOptions configures web search behavior for OpenAI models.
159+
// This is used with search-enabled models like gpt-4o-search-preview.
160+
type WebSearchOptions struct {
161+
// SearchContextSize controls how much context is gathered from web search.
162+
// Valid values: "low", "medium", "high". Higher values provide more context
163+
// but increase latency and cost.
164+
SearchContextSize string `json:"search_context_size,omitempty"`
165+
166+
// UserLocation provides approximate user location for localized search results.
167+
UserLocation *UserLocation `json:"user_location,omitempty"`
168+
}
169+
170+
// UserLocation represents the user's approximate location for web search.
171+
type UserLocation struct {
172+
// Type must be "approximate" for user-provided location.
173+
Type string `json:"type"`
174+
175+
// Approximate contains the approximate location details.
176+
Approximate *ApproximateLocation `json:"approximate,omitempty"`
177+
}
178+
179+
// ApproximateLocation contains approximate location information.
180+
type ApproximateLocation struct {
181+
// Country is the two-letter ISO country code (e.g., "US", "GB").
182+
Country string `json:"country,omitempty"`
183+
184+
// City is the city name (e.g., "San Francisco", "London").
185+
City string `json:"city,omitempty"`
186+
187+
// Region is the region or state (e.g., "California", "London").
188+
Region string `json:"region,omitempty"`
189+
}
190+
154191
// Tool is a tool to use in a chat request.
155192
type Tool struct {
156193
Type ToolType `json:"type"`

llms/openai/internal/openaiclient/marshal_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,116 @@ func TestChatRequest_TemperatureMarshalJSON(t *testing.T) {
167167
}
168168
}
169169

170+
func TestChatRequest_WebSearchOptionsMarshalJSON(t *testing.T) {
171+
tests := []struct {
172+
name string
173+
request ChatRequest
174+
want map[string]interface{}
175+
}{
176+
{
177+
name: "no web search options",
178+
request: ChatRequest{
179+
Model: "gpt-4o-search-preview",
180+
},
181+
want: nil,
182+
},
183+
{
184+
name: "empty web search options",
185+
request: ChatRequest{
186+
Model: "gpt-4o-search-preview",
187+
WebSearchOptions: &WebSearchOptions{},
188+
},
189+
want: map[string]interface{}{},
190+
},
191+
{
192+
name: "web search with search context size",
193+
request: ChatRequest{
194+
Model: "gpt-4o-search-preview",
195+
WebSearchOptions: &WebSearchOptions{
196+
SearchContextSize: "high",
197+
},
198+
},
199+
want: map[string]interface{}{
200+
"search_context_size": "high",
201+
},
202+
},
203+
{
204+
name: "web search with user location",
205+
request: ChatRequest{
206+
Model: "gpt-4o-search-preview",
207+
WebSearchOptions: &WebSearchOptions{
208+
SearchContextSize: "medium",
209+
UserLocation: &UserLocation{
210+
Type: "approximate",
211+
Approximate: &ApproximateLocation{
212+
Country: "US",
213+
City: "San Francisco",
214+
Region: "California",
215+
},
216+
},
217+
},
218+
},
219+
want: map[string]interface{}{
220+
"search_context_size": "medium",
221+
"user_location": map[string]interface{}{
222+
"type": "approximate",
223+
"approximate": map[string]interface{}{
224+
"country": "US",
225+
"city": "San Francisco",
226+
"region": "California",
227+
},
228+
},
229+
},
230+
},
231+
}
232+
233+
for _, tt := range tests {
234+
t.Run(tt.name, func(t *testing.T) {
235+
data, err := json.Marshal(tt.request)
236+
if err != nil {
237+
t.Fatalf("failed to marshal: %v", err)
238+
}
239+
240+
var result map[string]interface{}
241+
if err := json.Unmarshal(data, &result); err != nil {
242+
t.Fatalf("failed to unmarshal: %v", err)
243+
}
244+
245+
webSearchOpts, hasWebSearch := result["web_search_options"]
246+
if tt.want == nil {
247+
if hasWebSearch {
248+
t.Errorf("expected no web_search_options, got %v", webSearchOpts)
249+
}
250+
} else {
251+
if !hasWebSearch {
252+
t.Fatal("expected web_search_options to be present")
253+
}
254+
// Check that it's properly serialized
255+
webSearchMap, ok := webSearchOpts.(map[string]interface{})
256+
if !ok {
257+
t.Fatalf("web_search_options is not a map: %T", webSearchOpts)
258+
}
259+
if tt.want["search_context_size"] != nil {
260+
if webSearchMap["search_context_size"] != tt.want["search_context_size"] {
261+
t.Errorf("search_context_size: got %v, want %v",
262+
webSearchMap["search_context_size"], tt.want["search_context_size"])
263+
}
264+
}
265+
if tt.want["user_location"] != nil {
266+
userLoc, ok := webSearchMap["user_location"].(map[string]interface{})
267+
if !ok {
268+
t.Fatalf("user_location is not a map: %T", webSearchMap["user_location"])
269+
}
270+
wantUserLoc := tt.want["user_location"].(map[string]interface{})
271+
if userLoc["type"] != wantUserLoc["type"] {
272+
t.Errorf("user_location.type: got %v, want %v", userLoc["type"], wantUserLoc["type"])
273+
}
274+
}
275+
}
276+
})
277+
}
278+
}
279+
170280
func TestIsReasoningModel(t *testing.T) {
171281
tests := []struct {
172282
model string

llms/openai/openaillm.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
293293
FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior),
294294
Seed: opts.Seed,
295295
Metadata: apiMetadata,
296+
WebSearchOptions: webSearchOptionsFromCallOptions(opts.WebSearchOptions),
296297
}
297298
if opts.JSONMode {
298299
req.ResponseFormat = ResponseFormatJSON
@@ -498,3 +499,26 @@ func toolCallFromToolCall(tc llms.ToolCall) openaiclient.ToolCall {
498499
},
499500
}
500501
}
502+
503+
// webSearchOptionsFromCallOptions converts llms.WebSearchOptions to openaiclient.WebSearchOptions.
504+
func webSearchOptionsFromCallOptions(opts *llms.WebSearchOptions) *openaiclient.WebSearchOptions {
505+
if opts == nil {
506+
return nil
507+
}
508+
result := &openaiclient.WebSearchOptions{
509+
SearchContextSize: opts.SearchContextSize,
510+
}
511+
if opts.UserLocation != nil {
512+
result.UserLocation = &openaiclient.UserLocation{
513+
Type: opts.UserLocation.Type,
514+
}
515+
if opts.UserLocation.Approximate != nil {
516+
result.UserLocation.Approximate = &openaiclient.ApproximateLocation{
517+
Country: opts.UserLocation.Approximate.Country,
518+
City: opts.UserLocation.Approximate.City,
519+
Region: opts.UserLocation.Approximate.Region,
520+
}
521+
}
522+
}
523+
return result
524+
}

llms/openai/options_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,108 @@ func TestWithLegacyMaxTokensField(t *testing.T) {
7373
t.Error("expected openai:use_legacy_max_tokens to be true")
7474
}
7575
}
76+
77+
func TestWithWebSearch(t *testing.T) {
78+
// Test with nil options (default behavior)
79+
opts := &llms.CallOptions{}
80+
llms.WithWebSearch(nil)(opts)
81+
if opts.WebSearchOptions == nil {
82+
t.Fatal("expected WebSearchOptions to be initialized")
83+
}
84+
85+
// Test with custom search context size
86+
opts2 := &llms.CallOptions{}
87+
llms.WithWebSearch(&llms.WebSearchOptions{
88+
SearchContextSize: "high",
89+
})(opts2)
90+
if opts2.WebSearchOptions == nil {
91+
t.Fatal("expected WebSearchOptions to be set")
92+
}
93+
if opts2.WebSearchOptions.SearchContextSize != "high" {
94+
t.Errorf("expected SearchContextSize=high, got %s", opts2.WebSearchOptions.SearchContextSize)
95+
}
96+
97+
// Test with user location
98+
opts3 := &llms.CallOptions{}
99+
llms.WithWebSearch(&llms.WebSearchOptions{
100+
SearchContextSize: "medium",
101+
UserLocation: &llms.UserLocation{
102+
Type: "approximate",
103+
Approximate: &llms.ApproximateLocation{
104+
Country: "US",
105+
City: "San Francisco",
106+
Region: "California",
107+
},
108+
},
109+
})(opts3)
110+
if opts3.WebSearchOptions == nil {
111+
t.Fatal("expected WebSearchOptions to be set")
112+
}
113+
if opts3.WebSearchOptions.UserLocation == nil {
114+
t.Fatal("expected UserLocation to be set")
115+
}
116+
if opts3.WebSearchOptions.UserLocation.Type != "approximate" {
117+
t.Errorf("expected Type=approximate, got %s", opts3.WebSearchOptions.UserLocation.Type)
118+
}
119+
if opts3.WebSearchOptions.UserLocation.Approximate == nil {
120+
t.Fatal("expected Approximate to be set")
121+
}
122+
if opts3.WebSearchOptions.UserLocation.Approximate.Country != "US" {
123+
t.Errorf("expected Country=US, got %s", opts3.WebSearchOptions.UserLocation.Approximate.Country)
124+
}
125+
if opts3.WebSearchOptions.UserLocation.Approximate.City != "San Francisco" {
126+
t.Errorf("expected City=San Francisco, got %s", opts3.WebSearchOptions.UserLocation.Approximate.City)
127+
}
128+
if opts3.WebSearchOptions.UserLocation.Approximate.Region != "California" {
129+
t.Errorf("expected Region=California, got %s", opts3.WebSearchOptions.UserLocation.Approximate.Region)
130+
}
131+
}
132+
133+
func TestWebSearchOptionsConversion(t *testing.T) {
134+
// Test nil conversion
135+
result := webSearchOptionsFromCallOptions(nil)
136+
if result != nil {
137+
t.Error("expected nil result for nil input")
138+
}
139+
140+
// Test basic conversion
141+
opts := &llms.WebSearchOptions{
142+
SearchContextSize: "high",
143+
}
144+
result = webSearchOptionsFromCallOptions(opts)
145+
if result == nil {
146+
t.Fatal("expected non-nil result")
147+
}
148+
if result.SearchContextSize != "high" {
149+
t.Errorf("expected SearchContextSize=high, got %s", result.SearchContextSize)
150+
}
151+
152+
// Test full conversion with user location
153+
opts2 := &llms.WebSearchOptions{
154+
SearchContextSize: "medium",
155+
UserLocation: &llms.UserLocation{
156+
Type: "approximate",
157+
Approximate: &llms.ApproximateLocation{
158+
Country: "GB",
159+
City: "London",
160+
Region: "London",
161+
},
162+
},
163+
}
164+
result2 := webSearchOptionsFromCallOptions(opts2)
165+
if result2 == nil {
166+
t.Fatal("expected non-nil result")
167+
}
168+
if result2.UserLocation == nil {
169+
t.Fatal("expected UserLocation to be set")
170+
}
171+
if result2.UserLocation.Type != "approximate" {
172+
t.Errorf("expected Type=approximate, got %s", result2.UserLocation.Type)
173+
}
174+
if result2.UserLocation.Approximate == nil {
175+
t.Fatal("expected Approximate to be set")
176+
}
177+
if result2.UserLocation.Approximate.Country != "GB" {
178+
t.Errorf("expected Country=GB, got %s", result2.UserLocation.Approximate.Country)
179+
}
180+
}

llms/options.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ type CallOptions struct {
6969
// Supported MIME types are: text/plain: (default) Text output.
7070
// application/json: JSON response in the response candidates.
7171
ResponseMIMEType string `json:"response_mime_type,omitempty"`
72+
73+
// WebSearchOptions configures web search behavior for models that support it.
74+
// Currently supported by OpenAI models like gpt-4o-search-preview.
75+
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
7276
}
7377

7478
// Tool is a tool that can be used by the model.
@@ -109,6 +113,39 @@ type FunctionReference struct {
109113
// FunctionCallBehavior is the behavior to use when calling functions.
110114
type FunctionCallBehavior string
111115

116+
// WebSearchOptions configures web search behavior for models that support web search.
117+
// This is currently supported by OpenAI models like gpt-4o-search-preview.
118+
type WebSearchOptions struct {
119+
// SearchContextSize controls how much context is gathered from web search.
120+
// Valid values: "low", "medium", "high". Higher values provide more context
121+
// but increase latency and cost.
122+
SearchContextSize string `json:"search_context_size,omitempty"`
123+
124+
// UserLocation provides approximate user location for localized search results.
125+
UserLocation *UserLocation `json:"user_location,omitempty"`
126+
}
127+
128+
// UserLocation represents the user's approximate location for web search.
129+
type UserLocation struct {
130+
// Type must be "approximate" for user-provided location.
131+
Type string `json:"type"`
132+
133+
// Approximate contains the approximate location details.
134+
Approximate *ApproximateLocation `json:"approximate,omitempty"`
135+
}
136+
137+
// ApproximateLocation contains approximate location information.
138+
type ApproximateLocation struct {
139+
// Country is the two-letter ISO country code (e.g., "US", "GB").
140+
Country string `json:"country,omitempty"`
141+
142+
// City is the city name (e.g., "San Francisco", "London").
143+
City string `json:"city,omitempty"`
144+
145+
// Region is the region or state (e.g., "California", "London").
146+
Region string `json:"region,omitempty"`
147+
}
148+
112149
const (
113150
// FunctionCallBehaviorNone will not call any functions.
114151
FunctionCallBehaviorNone FunctionCallBehavior = "none"
@@ -291,3 +328,16 @@ func WithResponseMIMEType(responseMIMEType string) CallOption {
291328
o.ResponseMIMEType = responseMIMEType
292329
}
293330
}
331+
332+
// WithWebSearch enables web search for models that support it.
333+
// Use with OpenAI models like gpt-4o-search-preview and gpt-4o-mini-search-preview.
334+
// Pass nil for default web search behavior, or provide WebSearchOptions to customize.
335+
func WithWebSearch(options *WebSearchOptions) CallOption {
336+
return func(o *CallOptions) {
337+
if options == nil {
338+
o.WebSearchOptions = &WebSearchOptions{}
339+
} else {
340+
o.WebSearchOptions = options
341+
}
342+
}
343+
}

0 commit comments

Comments
 (0)