diff --git a/docs/providers/dmr/index.md b/docs/providers/dmr/index.md index 71f6e46e4..e24902e2a 100644 --- a/docs/providers/dmr/index.md +++ b/docs/providers/dmr/index.md @@ -64,29 +64,111 @@ models: model: ai/qwen3 max_tokens: 8192 provider_opts: - runtime_flags: ["--ngl=33", "--top-p=0.9"] + runtime_flags: ["--threads", "8"] ``` Runtime flags also accept a single string: ```yaml provider_opts: - runtime_flags: "--ngl=33 --top-p=0.9" + runtime_flags: "--threads 8" ``` -## Parameter Mapping +Use only flags your Model Runner backend allows (see `docker model configure --help` and backend docs). **Do not** put sampling parameters (`temperature`, `top_p`, penalties) in `runtime_flags` — set them on the model (`temperature`, `top_p`, etc.); they are sent **per request** via the OpenAI-compatible chat API. -docker-agent model config fields map to llama.cpp flags automatically: +## Context size -| Config | llama.cpp Flag | -| ------------------- | --------------------- | -| `temperature` | `--temp` | -| `top_p` | `--top-p` | -| `frequency_penalty` | `--frequency-penalty` | -| `presence_penalty` | `--presence-penalty` | -| `max_tokens` | `--context-size` | +`max_tokens` controls the **maximum output tokens** per chat completion request. To set the engine's **total context window**, use `provider_opts.context_size`: -`runtime_flags` always take priority over derived flags on conflict. +```yaml +models: + local: + provider: dmr + model: ai/qwen3 + max_tokens: 4096 # max output tokens (per-request) + provider_opts: + context_size: 32768 # total context window (sent via _configure) +``` + +If `context_size` is omitted, Model Runner uses its default. `max_tokens` is **not** used as the context window. + +## Thinking / reasoning budget + +When using the **llama.cpp** backend, `thinking_budget` is sent as structured `llamacpp.reasoning-budget` on `_configure` (maps to `--reasoning-budget`). String efforts use the same token mapping as other providers; `adaptive` maps to unlimited (`-1`). + +When using the **vLLM** backend, `thinking_budget` is sent as `thinking_token_budget` in each chat completion request. Effort levels map to token counts using the same scale as other providers; `adaptive` maps to unlimited (`-1`). + +```yaml +models: + local: + provider: dmr + model: ai/qwen3 + thinking_budget: medium # llama.cpp: reasoning-budget=8192; vLLM: thinking_token_budget=8192 +``` + +On **MLX** and **SGLang** backends, `thinking_budget` is silently ignored — those engines do not currently expose a per-request reasoning token budget knob. + +## vLLM-specific configuration + +When running a model on the **vLLM** backend, additional engine-level settings can be passed via `provider_opts` and are forwarded to model-runner's `_configure` endpoint: + +- `gpu_memory_utilization` — fraction of GPU memory (0.0–1.0) vLLM may use. Values outside this range are rejected. +- `hf_overrides` — map of Hugging Face config overrides applied when vLLM loads the model. + +```yaml +models: + vllm-local: + provider: dmr + model: ai/some-model-safetensors + provider_opts: + gpu_memory_utilization: 0.9 + hf_overrides: + max_model_len: 8192 + dtype: bfloat16 +``` + +`hf_overrides` keys (including nested ones) must match `^[a-zA-Z_][a-zA-Z0-9_]*$` — the same rule model-runner enforces server-side to block injection via flags. Invalid keys are rejected at client creation time so you fail fast instead of after a round-trip. + +These options are ignored on non-vLLM backends. + +## Keeping models resident in memory (`keep_alive`) + +By default model-runner unloads idle models after a few minutes. Override the idle timeout via `provider_opts.keep_alive`: + +```yaml +models: + sticky: + provider: dmr + model: ai/qwen3 + provider_opts: + keep_alive: "30m" # duration string + # keep_alive: "0" # unload immediately after each request + # keep_alive: "-1" # keep loaded forever +``` + +Accepted values: any Go duration string (`"30s"`, `"5m"`, `"1h"`, `"2h30m"`), `"0"` (immediate unload), or `"-1"` (never unload). Invalid values are rejected before the configure request is sent. + +## Operating mode (`mode`) + +Model-runner normally infers the backend mode from the request path. You can pin it explicitly via `provider_opts.mode`: + +```yaml +provider_opts: + mode: embedding # one of: completion, embedding, reranking, image-generation +``` + +Most agents don't need this — leave it unset unless you know you need it. + +## Raw runtime flags (`raw_runtime_flags`) + +`runtime_flags` (a list) is the preferred way to pass flags. If you have a pre-built command-line string you'd rather ship verbatim, use `raw_runtime_flags` instead: + +```yaml +provider_opts: + raw_runtime_flags: "--threads 8 --batch-size 512" +``` + +Model-runner parses the string with shell-style word splitting. `runtime_flags` and `raw_runtime_flags` are mutually exclusive — setting both is an error. ## Speculative Decoding diff --git a/pkg/model/provider/dmr/client.go b/pkg/model/provider/dmr/client.go index 439e31918..bfbe52988 100644 --- a/pkg/model/provider/dmr/client.go +++ b/pkg/model/provider/dmr/client.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log/slog" + "maps" "net/http" "os" "time" @@ -54,6 +55,7 @@ type Client struct { client openai.Client baseURL string httpClient *http.Client + engine string } // NewClient creates a new DMR client from the provided configuration @@ -103,18 +105,28 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt clientOptions = append(clientOptions, option.WithBaseURL(baseURL), option.WithAPIKey("")) // DMR doesn't need auth - // Build runtime flags from ModelConfig and engine - contextSize, providerRuntimeFlags, specOpts := parseDMRProviderOpts(cfg) - configFlags := buildRuntimeFlagsFromModelConfig(engine, cfg) - finalFlags, warnings := mergeRuntimeFlagsPreferUser(configFlags, providerRuntimeFlags) - for _, w := range warnings { - slog.Warn(w) + parsed, err := parseDMRProviderOpts(engine, cfg) + if err != nil { + slog.Error("DMR provider_opts invalid", "error", err, "model", cfg.Model) + return nil, err } - slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", contextSize, "runtime_flags", finalFlags, "speculative_opts", specOpts, "engine", engine) + backendCfg := buildConfigureBackendConfig(parsed.contextSize, parsed.runtimeFlags, parsed.specOpts, parsed.llamaCpp, parsed.vllm, parsed.keepAlive) + slog.Debug("DMR provider_opts parsed", + "model", cfg.Model, + "engine", engine, + "context_size", derefInt64(parsed.contextSize), + "runtime_flags", parsed.runtimeFlags, + "raw_runtime_flags", parsed.rawRuntimeFlags, + "mode", derefString(parsed.mode), + "keep_alive", derefString(parsed.keepAlive), + "speculative_opts", parsed.specOpts, + "llamacpp", parsed.llamaCpp, + "vllm", parsed.vllm, + ) // Skip model configuration when generating titles to avoid reconfiguring the model // with different settings (e.g., smaller max_tokens) that would affect the main agent. if !globalOptions.GeneratingTitle() { - if err := configureModel(ctx, httpClient, baseURL, cfg.Model, contextSize, finalFlags, specOpts); err != nil { + if err := configureModel(ctx, httpClient, baseURL, cfg.Model, backendCfg, parsed.mode, parsed.rawRuntimeFlags); err != nil { slog.Debug("model configure via API skipped or failed", "error", err) } } @@ -129,6 +141,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt client: openai.NewClient(clientOptions...), baseURL: baseURL, httpClient: httpClient, + engine: engine, }, nil } @@ -214,6 +227,43 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat } } + // Collect per-request extra JSON fields. SetExtraFields replaces the map + // wholesale, so merge all contributors before a single Set call. + extraFields := map[string]any{} + + // NoThinking: disable reasoning at the chat-template level. llama.cpp and + // vLLM both honor chat_template_kwargs.enable_thinking=false for Qwen3 / + // Hermes / DeepSeek-R1 style templates; other engines ignore unknown keys. + // + // When the caller has also set a small MaxTokens (e.g. session title + // generation sets max_tokens=20), raise it to noThinkingMinOutputTokens + // so any residual reasoning tokens the engine/template still emits can't + // starve the visible output. The nil-guard is intentional: if MaxTokens + // is unset the caller has imposed no cap, so there is nothing to floor + // and we leave max_tokens off the request (letting the engine use its + // own output budget). Mirrors the OpenAI provider (see + // pkg/model/provider/openai/client.go). + if c.ModelOptions.NoThinking() { + extraFields["chat_template_kwargs"] = map[string]any{"enable_thinking": false} + if c.ModelConfig.MaxTokens != nil && *c.ModelConfig.MaxTokens < noThinkingMinOutputTokens { + params.MaxTokens = openai.Int(noThinkingMinOutputTokens) + slog.Debug("DMR NoThinking: bumped max_tokens floor", + "from", *c.ModelConfig.MaxTokens, "to", noThinkingMinOutputTokens) + } + } + + // vLLM-specific per-request fields (e.g. thinking_token_budget). + if c.engine == engineVLLM { + if fields := buildVLLMRequestFields(&c.ModelConfig); fields != nil { + maps.Copy(extraFields, fields) + } + } + + if len(extraFields) > 0 { + params.SetExtraFields(extraFields) + slog.Debug("DMR extra request fields applied", "fields", extraFields) + } + // Log the request in JSON format for debugging if requestJSON, err := json.Marshal(params); err == nil { slog.Debug("DMR chat completion request", "request", string(requestJSON)) @@ -222,7 +272,7 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat } if structuredOutput := c.ModelOptions.StructuredOutput(); structuredOutput != nil { - slog.Debug("Adding structured output to DMR request", "structured_output", structuredOutput) + slog.Debug("Adding structured output to DMR request", "name", structuredOutput.Name, "strict", structuredOutput.Strict) params.ResponseFormat.OfJSONSchema = &openai.ResponseFormatJSONSchemaParam{ JSONSchema: openai.ResponseFormatJSONSchemaJSONSchemaParam{ diff --git a/pkg/model/provider/dmr/client_test.go b/pkg/model/provider/dmr/client_test.go index cfe9de28c..12beb63e6 100644 --- a/pkg/model/provider/dmr/client_test.go +++ b/pkg/model/provider/dmr/client_test.go @@ -8,12 +8,15 @@ import ( "os" "path/filepath" "runtime" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" + "github.com/docker/docker-agent/pkg/model/provider/options" ) func TestNewClientWithExplicitBaseURL(t *testing.T) { @@ -201,30 +204,40 @@ func TestBuildConfigureRequest(t *testing.T) { acceptanceRate: 0.8, } contextSize := int64(8192) + backendCfg := buildConfigureBackendConfig(&contextSize, []string{"--threads", "8"}, specOpts, nil, nil, nil) - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", &contextSize, []string{"--temp", "0.7", "--top-p", "0.9"}, specOpts) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg, nil, "") assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) require.NotNil(t, req.ContextSize) assert.Equal(t, int32(8192), *req.ContextSize) - assert.Equal(t, []string{"--temp", "0.7", "--top-p", "0.9"}, req.RuntimeFlags) + assert.Equal(t, []string{"--threads", "8"}, req.RuntimeFlags) require.NotNil(t, req.Speculative) assert.Equal(t, "ai/qwen3:1B", req.Speculative.DraftModel) assert.Equal(t, 5, req.Speculative.NumTokens) assert.InEpsilon(t, 0.8, req.Speculative.MinAcceptanceRate, 0.001) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) }) t.Run("without speculative options", func(t *testing.T) { t.Parallel() contextSize := int64(4096) + backendCfg := buildConfigureBackendConfig(&contextSize, []string{"--threads", "8"}, nil, nil, nil, nil) - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", &contextSize, []string{"--threads", "8"}, nil) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg, nil, "") assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) require.NotNil(t, req.ContextSize) assert.Equal(t, int32(4096), *req.ContextSize) assert.Equal(t, []string{"--threads", "8"}, req.RuntimeFlags) assert.Nil(t, req.Speculative) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) }) t.Run("without context size", func(t *testing.T) { @@ -233,8 +246,9 @@ func TestBuildConfigureRequest(t *testing.T) { draftModel: "ai/qwen3:1B", numTokens: 5, } + backendCfg := buildConfigureBackendConfig(nil, nil, specOpts, nil, nil, nil) - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", nil, nil, specOpts) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg, nil, "") assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) assert.Nil(t, req.ContextSize) @@ -242,16 +256,41 @@ func TestBuildConfigureRequest(t *testing.T) { require.NotNil(t, req.Speculative) assert.Equal(t, "ai/qwen3:1B", req.Speculative.DraftModel) assert.Equal(t, 5, req.Speculative.NumTokens) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) }) t.Run("minimal config", func(t *testing.T) { t.Parallel() - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", nil, nil, nil) + backendCfg := buildConfigureBackendConfig(nil, nil, nil, nil, nil, nil) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg, nil, "") assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) assert.Nil(t, req.ContextSize) assert.Nil(t, req.RuntimeFlags) assert.Nil(t, req.Speculative) + assert.Nil(t, req.LlamaCpp) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) + }) + + t.Run("with llama.cpp reasoning budget", func(t *testing.T) { + t.Parallel() + rb := int32(16384) + llama := &llamaCppConfig{ReasoningBudget: &rb} + backendCfg := buildConfigureBackendConfig(nil, nil, nil, llama, nil, nil) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg, nil, "") + require.NotNil(t, req.LlamaCpp) + require.NotNil(t, req.LlamaCpp.ReasoningBudget) + assert.Equal(t, int32(16384), *req.LlamaCpp.ReasoningBudget) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) }) } @@ -289,15 +328,16 @@ func TestConfigureModelViaAPI(t *testing.T) { numTokens: 5, acceptanceRate: 0.8, } + backendCfg := buildConfigureBackendConfig(&contextSize, []string{"--threads", "8"}, specOpts, nil, nil, nil) - err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", &contextSize, []string{"--temp", "0.7"}, specOpts) + err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", backendCfg, nil, "") require.NoError(t, err) // Verify request body assert.Equal(t, "ai/qwen3:14B", receivedRequest.Model) require.NotNil(t, receivedRequest.ContextSize) assert.Equal(t, int32(8192), *receivedRequest.ContextSize) - assert.Equal(t, []string{"--temp", "0.7"}, receivedRequest.RuntimeFlags) + assert.Equal(t, []string{"--threads", "8"}, receivedRequest.RuntimeFlags) require.NotNil(t, receivedRequest.Speculative) assert.Equal(t, "ai/qwen3:1B", receivedRequest.Speculative.DraftModel) assert.Equal(t, 5, receivedRequest.Speculative.NumTokens) @@ -314,7 +354,7 @@ func TestConfigureModelViaAPI(t *testing.T) { defer server.Close() baseURL := server.URL + "/engines/v1/" - err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", nil, nil, nil) + err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", buildConfigureBackendConfig(nil, nil, nil, nil, nil, nil), nil, "") require.Error(t, err) assert.Contains(t, err.Error(), "500") assert.Contains(t, err.Error(), "internal error") @@ -330,71 +370,20 @@ func TestConfigureModelViaAPI(t *testing.T) { defer server.Close() baseURL := server.URL + "/engines/v1/" - err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", nil, nil, nil) + err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", buildConfigureBackendConfig(nil, nil, nil, nil, nil, nil), nil, "") require.Error(t, err) assert.Contains(t, err.Error(), "409") assert.Contains(t, err.Error(), "runner already active") }) } -func TestBuildRuntimeFlagsFromModelConfig_LlamaCpp(t *testing.T) { - t.Parallel() - - flags := buildRuntimeFlagsFromModelConfig("llama.cpp", &latest.ModelConfig{ - Temperature: new(0.6), - TopP: new(0.95), - FrequencyPenalty: new(0.2), - PresencePenalty: new(0.1), - }) - - assert.Equal(t, []string{"--temp", "0.6", "--top-p", "0.95", "--frequency-penalty", "0.2", "--presence-penalty", "0.1"}, flags) -} - -func TestIntegrateFlagsWithProviderOptsOrder(t *testing.T) { - t.Parallel() - - cfg := &latest.ModelConfig{ - Temperature: new(0.6), - TopP: new(0.9), - MaxTokens: new(int64(4096)), - ProviderOpts: map[string]any{ - "runtime_flags": []string{"--threads", "6"}, - }, - } - // derive config flags first, then merge provider opts (simulating NewClient path) - derived := buildRuntimeFlagsFromModelConfig("llama.cpp", cfg) - // provider opts should be appended after derived flags so they can override by order - merged := append(derived, []string{"--threads", "6"}...) - - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", cfg.MaxTokens, merged, nil) - assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) - require.NotNil(t, req.ContextSize) - assert.Equal(t, int32(4096), *req.ContextSize) - assert.Equal(t, []string{"--temp", "0.6", "--top-p", "0.9", "--threads", "6"}, req.RuntimeFlags) -} - -func TestMergeRuntimeFlagsPreferUser_WarnsAndPrefersUser(t *testing.T) { - t.Parallel() - - // Derived suggests temp/top-p, user overrides both and adds threads - derived := []string{"--temp", "0.5", "--top-p", "0.8"} - user := []string{"--temp", "0.7", "--threads", "8"} - - merged, warnings := mergeRuntimeFlagsPreferUser(derived, user) - - // Expect 1 warnings for --temp overriding - require.Len(t, warnings, 1) - - // Derived conflicting flags should be dropped, user ones kept and appended - assert.Equal(t, []string{"--top-p", "0.8", "--temp", "0.7", "--threads", "8"}, merged) -} - func TestParseDMRProviderOptsWithSpeculativeDecoding(t *testing.T) { t.Parallel() cfg := &latest.ModelConfig{ MaxTokens: new(int64(4096)), ProviderOpts: map[string]any{ + "context_size": int64(16384), "speculative_draft_model": "ai/qwen3:1B", "speculative_num_tokens": "5", "speculative_acceptance_rate": "0.75", @@ -402,14 +391,18 @@ func TestParseDMRProviderOptsWithSpeculativeDecoding(t *testing.T) { }, } - contextSize, runtimeFlags, specOpts := parseDMRProviderOpts(cfg) + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + contextSize, runtimeFlags, specOpts, llamaCpp := res.contextSize, res.runtimeFlags, res.specOpts, res.llamaCpp - assert.Equal(t, int64(4096), *contextSize) + require.NotNil(t, contextSize) + assert.Equal(t, int64(16384), *contextSize) assert.Equal(t, []string{"--threads", "8"}, runtimeFlags) require.NotNil(t, specOpts) assert.Equal(t, "ai/qwen3:1B", specOpts.draftModel) assert.Equal(t, 5, specOpts.numTokens) assert.InEpsilon(t, 0.75, specOpts.acceptanceRate, 0.001) + assert.Nil(t, llamaCpp) } func TestParseDMRProviderOptsWithoutSpeculativeDecoding(t *testing.T) { @@ -422,11 +415,461 @@ func TestParseDMRProviderOptsWithoutSpeculativeDecoding(t *testing.T) { }, } - contextSize, runtimeFlags, specOpts := parseDMRProviderOpts(cfg) + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + contextSize, runtimeFlags, specOpts, llamaCpp := res.contextSize, res.runtimeFlags, res.specOpts, res.llamaCpp - assert.Equal(t, int64(4096), *contextSize) + assert.Nil(t, contextSize, "context_size not in provider_opts, should be nil regardless of max_tokens") assert.Equal(t, []string{"--threads", "8"}, runtimeFlags) assert.Nil(t, specOpts) + assert.Nil(t, llamaCpp) +} + +func TestParseDMRProviderOptsContextSizeFromProviderOpts(t *testing.T) { + t.Parallel() + + cfg := &latest.ModelConfig{ + MaxTokens: new(int64(4096)), + ProviderOpts: map[string]any{ + "context_size": int64(32768), + }, + } + + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + contextSize, rf, spec, ll := res.contextSize, res.runtimeFlags, res.specOpts, res.llamaCpp + require.NotNil(t, contextSize) + assert.Equal(t, int64(32768), *contextSize) + assert.Nil(t, rf) + assert.Nil(t, spec) + assert.Nil(t, ll) +} + +func TestParseDMRProviderOptsContextSizeNeitherSet(t *testing.T) { + t.Parallel() + + cfg := &latest.ModelConfig{ + Provider: "dmr", + Model: "ai/qwen3", + } + + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + contextSize, rf, spec, ll := res.contextSize, res.runtimeFlags, res.specOpts, res.llamaCpp + assert.Nil(t, contextSize) + assert.Nil(t, rf) + assert.Nil(t, spec) + assert.Nil(t, ll) +} + +func TestParseDMRProviderOptsThinkingBudget(t *testing.T) { + t.Parallel() + + t.Run("llama.cpp: effort maps to token budget", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "medium"}, + } + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + llamaCpp := res.llamaCpp + require.NotNil(t, llamaCpp) + require.NotNil(t, llamaCpp.ReasoningBudget) + assert.Equal(t, int32(8192), *llamaCpp.ReasoningBudget) + }) + + t.Run("llama.cpp: explicit tokens", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Tokens: 2048}, + } + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + llamaCpp := res.llamaCpp + require.NotNil(t, llamaCpp) + require.NotNil(t, llamaCpp.ReasoningBudget) + assert.Equal(t, int32(2048), *llamaCpp.ReasoningBudget) + }) + + t.Run("llama.cpp: disabled", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "none"}, + } + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + llamaCpp := res.llamaCpp + require.NotNil(t, llamaCpp) + require.NotNil(t, llamaCpp.ReasoningBudget) + assert.Equal(t, int32(0), *llamaCpp.ReasoningBudget) + }) + + t.Run("empty engine defaults to llama.cpp", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Tokens: 4096}, + } + res, err := parseDMRProviderOpts("", cfg) + require.NoError(t, err) + llamaCpp := res.llamaCpp + require.NotNil(t, llamaCpp) + require.NotNil(t, llamaCpp.ReasoningBudget) + assert.Equal(t, int32(4096), *llamaCpp.ReasoningBudget) + }) + + t.Run("vllm engine: no llamacpp config (thinking handled per-request)", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "high"}, + } + res, err := parseDMRProviderOpts("vllm", cfg) + require.NoError(t, err) + llamaCpp := res.llamaCpp + assert.Nil(t, llamaCpp, "vllm engine should not produce llamacpp config; thinking_budget is sent per-request instead") + }) +} + +func TestParseVLLMConfig(t *testing.T) { + t.Parallel() + + t.Run("nil opts returns nil", func(t *testing.T) { + t.Parallel() + got, err := parseVLLMConfig(nil) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("empty opts returns nil", func(t *testing.T) { + t.Parallel() + got, err := parseVLLMConfig(map[string]any{}) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("unrelated opts returns nil", func(t *testing.T) { + t.Parallel() + got, err := parseVLLMConfig(map[string]any{"foo": "bar"}) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("gpu_memory_utilization as float", func(t *testing.T) { + t.Parallel() + got, err := parseVLLMConfig(map[string]any{"gpu_memory_utilization": 0.9}) + require.NoError(t, err) + require.NotNil(t, got) + require.NotNil(t, got.GPUMemoryUtilization) + assert.InEpsilon(t, 0.9, *got.GPUMemoryUtilization, 0.001) + assert.Nil(t, got.HFOverrides) + }) + + t.Run("gpu_memory_utilization as string", func(t *testing.T) { + t.Parallel() + got, err := parseVLLMConfig(map[string]any{"gpu_memory_utilization": "0.75"}) + require.NoError(t, err) + require.NotNil(t, got) + require.NotNil(t, got.GPUMemoryUtilization) + assert.InEpsilon(t, 0.75, *got.GPUMemoryUtilization, 0.001) + }) + + t.Run("gpu_memory_utilization with invalid type is ignored", func(t *testing.T) { + t.Parallel() + got, err := parseVLLMConfig(map[string]any{"gpu_memory_utilization": []int{1, 2}}) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("gpu_memory_utilization out of range returns error", func(t *testing.T) { + t.Parallel() + for _, val := range []float64{-0.1, 1.5, 2.0} { + _, err := parseVLLMConfig(map[string]any{"gpu_memory_utilization": val}) + assert.Error(t, err, "expected error for gpu_memory_utilization=%v", val) + } + }) + + t.Run("hf_overrides as map", func(t *testing.T) { + t.Parallel() + overrides := map[string]any{"max_model_len": 4096, "dtype": "bfloat16"} + got, err := parseVLLMConfig(map[string]any{"hf_overrides": overrides}) + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, overrides, got.HFOverrides) + assert.Nil(t, got.GPUMemoryUtilization) + }) + + t.Run("hf_overrides with non-map value is ignored", func(t *testing.T) { + t.Parallel() + got, err := parseVLLMConfig(map[string]any{"hf_overrides": "not-a-map"}) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("hf_overrides with invalid key returns error", func(t *testing.T) { + t.Parallel() + _, err := parseVLLMConfig(map[string]any{ + "hf_overrides": map[string]any{"--malicious": "bad"}, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid hf_overrides key") + }) + + t.Run("hf_overrides with invalid nested key returns error", func(t *testing.T) { + t.Parallel() + _, err := parseVLLMConfig(map[string]any{ + "hf_overrides": map[string]any{ + "good_key": map[string]any{"--bad": 1}, + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid hf_overrides nested key") + }) + + t.Run("hf_overrides with valid nested values is accepted", func(t *testing.T) { + t.Parallel() + overrides := map[string]any{ + "rope_scaling": map[string]any{ + "type": "yarn", + "factor": 2.0, + }, + "tags": []any{"v1", "v2"}, + } + got, err := parseVLLMConfig(map[string]any{"hf_overrides": overrides}) + require.NoError(t, err) + require.NotNil(t, got) + assert.Equal(t, overrides, got.HFOverrides) + }) + + t.Run("both options together", func(t *testing.T) { + t.Parallel() + got, err := parseVLLMConfig(map[string]any{ + "gpu_memory_utilization": 0.85, + "hf_overrides": map[string]any{"dtype": "float16"}, + }) + require.NoError(t, err) + require.NotNil(t, got) + require.NotNil(t, got.GPUMemoryUtilization) + assert.InEpsilon(t, 0.85, *got.GPUMemoryUtilization, 0.001) + assert.Equal(t, "float16", got.HFOverrides["dtype"]) + }) +} + +func TestParseDMRProviderOptsVLLMEngine(t *testing.T) { + t.Parallel() + + t.Run("vllm engine populates vllm config", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "gpu_memory_utilization": 0.9, + "hf_overrides": map[string]any{"max_model_len": 8192}, + }, + } + res, err := parseDMRProviderOpts("vllm", cfg) + require.NoError(t, err) + llamaCpp, vllm := res.llamaCpp, res.vllm + assert.Nil(t, llamaCpp, "llamacpp config should not be set for vllm engine") + require.NotNil(t, vllm) + require.NotNil(t, vllm.GPUMemoryUtilization) + assert.InEpsilon(t, 0.9, *vllm.GPUMemoryUtilization, 0.001) + assert.Equal(t, 8192, vllm.HFOverrides["max_model_len"]) + }) + + t.Run("llama.cpp engine ignores vllm opts", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "gpu_memory_utilization": 0.9, + "hf_overrides": map[string]any{"dtype": "float16"}, + }, + } + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + vllm := res.vllm + assert.Nil(t, vllm, "vllm config should not be set for llama.cpp engine") + }) + + t.Run("vllm engine flows end-to-end into configure request JSON", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "gpu_memory_utilization": 0.85, + "hf_overrides": map[string]any{"test_key": "test-value"}, + }, + } + res, err := parseDMRProviderOpts("vllm", cfg) + require.NoError(t, err) + contextSize, runtimeFlags, specOpts, llamaCpp, vllm := res.contextSize, res.runtimeFlags, res.specOpts, res.llamaCpp, res.vllm + backendCfg := buildConfigureBackendConfig(contextSize, runtimeFlags, specOpts, llamaCpp, vllm, nil) + req := buildConfigureRequest("ai/vllm-model", backendCfg, nil, "") + + data, err := json.Marshal(req) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + vllmParsed, ok := parsed["vllm"].(map[string]any) + require.True(t, ok, "vllm key should be present in JSON") + assert.InEpsilon(t, 0.85, vllmParsed["gpu-memory-utilization"].(float64), 0.001) + hfOverrides, ok := vllmParsed["hf-overrides"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "test-value", hfOverrides["test_key"]) + }) +} + +func TestBuildVLLMRequestFields(t *testing.T) { + t.Parallel() + + t.Run("nil config returns nil", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(nil) + assert.Nil(t, fields) + }) + + t.Run("nil budget returns nil", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{}) + assert.Nil(t, fields) + }) + + t.Run("disabled (effort none) returns 0", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "none"}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(0), fields["thinking_token_budget"]) + }) + + t.Run("explicit token count", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Tokens: 4096}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(4096), fields["thinking_token_budget"]) + }) + + t.Run("effort medium maps to 8192", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "medium"}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(8192), fields["thinking_token_budget"]) + }) + + t.Run("effort high maps to 16384", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "high"}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(16384), fields["thinking_token_budget"]) + }) + + t.Run("adaptive returns -1 (unlimited)", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "adaptive"}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(-1), fields["thinking_token_budget"]) + }) +} + +func TestResolveReasoningBudget(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input *latest.ThinkingBudget + wantBudget int64 + wantOK bool + }{ + { + name: "nil → (0, false)", + input: nil, + wantBudget: 0, + wantOK: false, + }, + { + name: "disabled via Tokens:0 → (0, true)", + input: &latest.ThinkingBudget{Tokens: 0}, + wantBudget: 0, + wantOK: true, + }, + { + name: "disabled via Effort:none → (0, true)", + input: &latest.ThinkingBudget{Effort: "none"}, + wantBudget: 0, + wantOK: true, + }, + { + name: "explicit Tokens:4096 → (4096, true)", + input: &latest.ThinkingBudget{Tokens: 4096}, + wantBudget: 4096, + wantOK: true, + }, + { + name: "explicit Tokens:-1 (dynamic) → (-1, true)", + input: &latest.ThinkingBudget{Tokens: -1}, + wantBudget: -1, + wantOK: true, + }, + { + name: "Effort:minimal → (1024, true)", + input: &latest.ThinkingBudget{Effort: "minimal"}, + wantBudget: 1024, + wantOK: true, + }, + { + name: "Effort:low → (2048, true)", + input: &latest.ThinkingBudget{Effort: "low"}, + wantBudget: 2048, + wantOK: true, + }, + { + name: "Effort:medium → (8192, true)", + input: &latest.ThinkingBudget{Effort: "medium"}, + wantBudget: 8192, + wantOK: true, + }, + { + name: "Effort:high → (16384, true)", + input: &latest.ThinkingBudget{Effort: "high"}, + wantBudget: 16384, + wantOK: true, + }, + { + name: "Effort:adaptive → (-1, true)", + input: &latest.ThinkingBudget{Effort: "adaptive"}, + wantBudget: -1, + wantOK: true, + }, + { + name: "Effort:adaptive/low → (-1, true)", + input: &latest.ThinkingBudget{Effort: "adaptive/low"}, + wantBudget: -1, + wantOK: true, + }, + { + name: "Effort:unknown → (-1, true)", + input: &latest.ThinkingBudget{Effort: "unknown"}, + wantBudget: -1, + wantOK: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + gotBudget, gotOK := resolveReasoningBudget(tt.input) + assert.Equal(t, tt.wantBudget, gotBudget) + assert.Equal(t, tt.wantOK, gotOK) + }) + } } func TestConfigureRequestJSONSerialization(t *testing.T) { @@ -435,14 +878,18 @@ func TestConfigureRequestJSONSerialization(t *testing.T) { t.Run("full request serializes correctly", func(t *testing.T) { t.Parallel() contextSize := int32(8192) + reasoning := int32(-1) req := configureRequest{ - Model: "ai/qwen3:14B", - ContextSize: &contextSize, - RuntimeFlags: []string{"--temp", "0.7"}, - Speculative: &speculativeDecodingRequest{ - DraftModel: "ai/qwen3:1B", - NumTokens: 5, - MinAcceptanceRate: 0.8, + Model: "ai/qwen3:14B", + configureBackendConfig: configureBackendConfig{ + ContextSize: &contextSize, + RuntimeFlags: []string{"--keep-alive", "5m"}, + Speculative: &speculativeDecodingRequest{ + DraftModel: "ai/qwen3:1B", + NumTokens: 5, + MinAcceptanceRate: 0.8, + }, + LlamaCpp: &llamaCppConfig{ReasoningBudget: &reasoning}, }, } @@ -455,13 +902,17 @@ func TestConfigureRequestJSONSerialization(t *testing.T) { assert.Equal(t, "ai/qwen3:14B", parsed["model"]) assert.InEpsilon(t, float64(8192), parsed["context-size"].(float64), 0.001) - assert.Equal(t, []any{"--temp", "0.7"}, parsed["runtime-flags"]) + assert.Equal(t, []any{"--keep-alive", "5m"}, parsed["runtime-flags"]) spec, ok := parsed["speculative"].(map[string]any) require.True(t, ok) assert.Equal(t, "ai/qwen3:1B", spec["draft_model"]) assert.InEpsilon(t, float64(5), spec["num_tokens"].(float64), 0.001) assert.InEpsilon(t, 0.8, spec["min_acceptance_rate"].(float64), 0.001) + + llama, ok := parsed["llamacpp"].(map[string]any) + require.True(t, ok) + assert.InEpsilon(t, float64(-1), llama["reasoning-budget"].(float64), 0.001) }) t.Run("minimal request omits nil fields", func(t *testing.T) { @@ -484,5 +935,320 @@ func TestConfigureRequestJSONSerialization(t *testing.T) { assert.False(t, hasRuntimeFlags, "runtime-flags should be omitted when nil") _, hasSpeculative := parsed["speculative"] assert.False(t, hasSpeculative, "speculative should be omitted when nil") + _, hasLlamaCpp := parsed["llamacpp"] + assert.False(t, hasLlamaCpp, "llamacpp should be omitted when nil") + _, hasMode := parsed["mode"] + assert.False(t, hasMode, "mode should be omitted when nil") + _, hasRawRuntimeFlags := parsed["raw-runtime-flags"] + assert.False(t, hasRawRuntimeFlags, "raw-runtime-flags should be omitted when empty") + _, hasKeepAlive := parsed["keep_alive"] + assert.False(t, hasKeepAlive, "keep_alive should be omitted when nil") + _, hasVLLM := parsed["vllm"] + assert.False(t, hasVLLM, "vllm should be omitted when nil") + }) + + t.Run("schema parity fields serialize with expected keys", func(t *testing.T) { + t.Parallel() + mode := "completion" + keepAlive := "5m" + gpu := 0.9 + req := configureRequest{ + Model: "ai/qwen3:14B", + Mode: &mode, + RawRuntimeFlags: "--foo --bar", + configureBackendConfig: configureBackendConfig{ + KeepAlive: &keepAlive, + VLLM: &vllmConfig{ + HFOverrides: map[string]any{"foo": "bar"}, + GPUMemoryUtilization: &gpu, + }, + }, + } + + data, err := json.Marshal(req) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "completion", parsed["mode"]) + assert.Equal(t, "--foo --bar", parsed["raw-runtime-flags"]) + assert.Equal(t, "5m", parsed["keep_alive"]) + + vllm, ok := parsed["vllm"].(map[string]any) + require.True(t, ok) + hfOverrides, ok := vllm["hf-overrides"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "bar", hfOverrides["foo"]) + assert.InEpsilon(t, 0.9, vllm["gpu-memory-utilization"].(float64), 0.001) }) } + +func TestParseKeepAlive(t *testing.T) { + t.Parallel() + + t.Run("nil opts returns nil", func(t *testing.T) { + t.Parallel() + got, err := parseKeepAlive(nil) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("unset returns nil", func(t *testing.T) { + t.Parallel() + got, err := parseKeepAlive(map[string]any{"other": 1}) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("valid durations", func(t *testing.T) { + t.Parallel() + for _, v := range []string{"5m", "1h", "30s", "2h30m", "0", "-1"} { + got, err := parseKeepAlive(map[string]any{"keep_alive": v}) + require.NoErrorf(t, err, "value %q should be valid", v) + require.NotNil(t, got) + assert.Equal(t, v, *got) + } + }) + + t.Run("non-string rejected", func(t *testing.T) { + t.Parallel() + _, err := parseKeepAlive(map[string]any{"keep_alive": 300}) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be a string") + }) + + t.Run("empty string rejected", func(t *testing.T) { + t.Parallel() + _, err := parseKeepAlive(map[string]any{"keep_alive": " "}) + require.Error(t, err) + }) + + t.Run("bad duration rejected", func(t *testing.T) { + t.Parallel() + _, err := parseKeepAlive(map[string]any{"keep_alive": "not-a-duration"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid keep_alive") + }) +} + +func TestParseMode(t *testing.T) { + t.Parallel() + + t.Run("nil opts returns nil", func(t *testing.T) { + t.Parallel() + got, err := parseMode(nil) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("unset returns nil", func(t *testing.T) { + t.Parallel() + got, err := parseMode(map[string]any{"other": 1}) + require.NoError(t, err) + assert.Nil(t, got) + }) + + t.Run("valid modes accepted", func(t *testing.T) { + t.Parallel() + for _, m := range []string{"completion", "embedding", "reranking", "image-generation"} { + got, err := parseMode(map[string]any{"mode": m}) + require.NoErrorf(t, err, "mode %q should be valid", m) + require.NotNil(t, got) + assert.Equal(t, m, *got) + } + }) + + t.Run("unknown mode rejected", func(t *testing.T) { + t.Parallel() + _, err := parseMode(map[string]any{"mode": "nonsense"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid") + }) + + t.Run("non-string rejected", func(t *testing.T) { + t.Parallel() + _, err := parseMode(map[string]any{"mode": 1}) + require.Error(t, err) + assert.Contains(t, err.Error(), "must be a string") + }) +} + +func TestParseRawRuntimeFlags(t *testing.T) { + t.Parallel() + + t.Run("unset returns empty", func(t *testing.T) { + t.Parallel() + got, err := parseRawRuntimeFlags(nil) + require.NoError(t, err) + assert.Empty(t, got) + }) + + t.Run("set returns value", func(t *testing.T) { + t.Parallel() + got, err := parseRawRuntimeFlags(map[string]any{"raw_runtime_flags": "--foo bar"}) + require.NoError(t, err) + assert.Equal(t, "--foo bar", got) + }) + + t.Run("whitespace only returns empty", func(t *testing.T) { + t.Parallel() + got, err := parseRawRuntimeFlags(map[string]any{"raw_runtime_flags": " "}) + require.NoError(t, err) + assert.Empty(t, got) + }) + + t.Run("non-string rejected", func(t *testing.T) { + t.Parallel() + _, err := parseRawRuntimeFlags(map[string]any{"raw_runtime_flags": 123}) + require.Error(t, err) + }) +} + +func TestParseDMRProviderOptsKeepAliveAndMode(t *testing.T) { + t.Parallel() + + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "keep_alive": "10m", + "mode": "embedding", + }, + } + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + require.NotNil(t, res.keepAlive) + assert.Equal(t, "10m", *res.keepAlive) + require.NotNil(t, res.mode) + assert.Equal(t, "embedding", *res.mode) +} + +func TestParseDMRProviderOptsRejectsBothRuntimeFlagVariants(t *testing.T) { + t.Parallel() + + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "runtime_flags": []string{"--threads", "8"}, + "raw_runtime_flags": "--threads 8", + }, + } + _, err := parseDMRProviderOpts("llama.cpp", cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot set both") +} + +func TestParseDMRProviderOptsPropagatesValidationError(t *testing.T) { + t.Parallel() + + t.Run("bad keep_alive fails parseDMRProviderOpts", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{"keep_alive": "banana"}, + } + _, err := parseDMRProviderOpts("llama.cpp", cfg) + require.Error(t, err) + }) + + t.Run("bad mode fails parseDMRProviderOpts", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{"mode": "banana"}, + } + _, err := parseDMRProviderOpts("llama.cpp", cfg) + require.Error(t, err) + }) + + t.Run("bad hf_overrides fails parseDMRProviderOpts (vllm engine)", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "hf_overrides": map[string]any{"--bad": 1}, + }, + } + _, err := parseDMRProviderOpts("vllm", cfg) + require.Error(t, err) + }) +} + +func TestConfigureRequestTopLevelFieldsEndToEnd(t *testing.T) { + t.Parallel() + + cfg := &latest.ModelConfig{ + ProviderOpts: map[string]any{ + "keep_alive": "-1", + "mode": "completion", + "raw_runtime_flags": "--threads 8 --ctx 4096", + }, + } + res, err := parseDMRProviderOpts("llama.cpp", cfg) + require.NoError(t, err) + + backendCfg := buildConfigureBackendConfig(res.contextSize, res.runtimeFlags, res.specOpts, res.llamaCpp, res.vllm, res.keepAlive) + req := buildConfigureRequest("ai/qwen3", backendCfg, res.mode, res.rawRuntimeFlags) + + data, err := json.Marshal(req) + require.NoError(t, err) + + var parsed map[string]any + require.NoError(t, json.Unmarshal(data, &parsed)) + + assert.Equal(t, "ai/qwen3", parsed["model"]) + assert.Equal(t, "-1", parsed["keep_alive"]) + assert.Equal(t, "completion", parsed["mode"]) + assert.Equal(t, "--threads 8 --ctx 4096", parsed["raw-runtime-flags"]) +} + +func TestNoThinkingSetsChatTemplateKwargsAndBumpsMaxTokens(t *testing.T) { + t.Parallel() + + var captured []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/chat/completions") { + body, _ := io.ReadAll(r.Body) + captured = body + // Return a minimal streaming response so the SDK is happy. + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n")) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[]}`)) + })) + defer server.Close() + + maxTokens := int64(20) + cfg := &latest.ModelConfig{ + Provider: "dmr", + Model: "ai/qwen3", + BaseURL: server.URL + "/engines/v1/", + MaxTokens: &maxTokens, + } + client, err := NewClient(t.Context(), cfg, options.WithNoThinking()) + require.NoError(t, err) + + stream, err := client.CreateChatCompletionStream(t.Context(), []chat.Message{ + {Role: chat.MessageRoleUser, Content: "hi"}, + }, nil) + require.NoError(t, err) + for { + if _, err := stream.Recv(); err != nil { + break + } + } + stream.Close() + + require.NotEmpty(t, captured, "chat/completions should have been called") + + var req map[string]any + require.NoError(t, json.Unmarshal(captured, &req)) + + // max_tokens floor (20 -> 256). + assert.EqualValues(t, noThinkingMinOutputTokens, req["max_tokens"]) + + // chat_template_kwargs.enable_thinking=false on every engine. + ct, ok := req["chat_template_kwargs"].(map[string]any) + require.True(t, ok, "chat_template_kwargs must be present") + assert.Equal(t, false, ct["enable_thinking"]) +} diff --git a/pkg/model/provider/dmr/configure.go b/pkg/model/provider/dmr/configure.go index 17573013f..2c718626a 100644 --- a/pkg/model/provider/dmr/configure.go +++ b/pkg/model/provider/dmr/configure.go @@ -2,27 +2,61 @@ package dmr import ( "bytes" - "cmp" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" "net/http" "net/url" + "regexp" "strconv" "strings" + "time" "github.com/docker/docker-agent/pkg/config/latest" ) -// configureRequest mirrors the model-runner's scheduling.ConfigureRequest structure. -// It specifies per-model runtime configuration options sent via POST /engines/_configure. +// configureRequest mirrors model-runner's scheduling.ConfigureRequest. type configureRequest struct { - Model string `json:"model"` + configureBackendConfig + + Model string `json:"model"` + Mode *string `json:"mode,omitempty"` + RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"` +} + +// configureBackendConfig mirrors model-runner's inference.BackendConfiguration. +type configureBackendConfig struct { ContextSize *int32 `json:"context-size,omitempty"` RuntimeFlags []string `json:"runtime-flags,omitempty"` Speculative *speculativeDecodingRequest `json:"speculative,omitempty"` + KeepAlive *string `json:"keep_alive,omitempty"` + VLLM *vllmConfig `json:"vllm,omitempty"` + LlamaCpp *llamaCppConfig `json:"llamacpp,omitempty"` +} + +// vllmConfig mirrors model-runner's inference.VLLMConfig for POST /engines/_configure. +type vllmConfig struct { + HFOverrides map[string]any `json:"hf-overrides,omitempty"` + GPUMemoryUtilization *float64 `json:"gpu-memory-utilization,omitempty"` +} + +// llamaCppConfig mirrors model-runner's inference.LlamaCppConfig for POST /engines/_configure. +type llamaCppConfig struct { + ReasoningBudget *int32 `json:"reasoning-budget,omitempty"` +} + +func (c *llamaCppConfig) LogValue() slog.Value { + if c == nil { + return slog.AnyValue(nil) + } + var rb any + if c.ReasoningBudget != nil { + rb = *c.ReasoningBudget + } + return slog.GroupValue(slog.Any("reasoning-budget", rb)) } // speculativeDecodingRequest mirrors model-runner's inference.SpeculativeDecodingConfig. @@ -38,14 +72,25 @@ type speculativeDecodingOpts struct { acceptanceRate float64 } +func (so *speculativeDecodingOpts) LogValue() slog.Value { + if so == nil { + return slog.AnyValue(nil) + } + return slog.GroupValue( + slog.String("draft-model", so.draftModel), + slog.Int("num-tokens", so.numTokens), + slog.Float64("acceptance-rate", so.acceptanceRate), + ) +} + // configureModel sends model configuration to Model Runner via POST /engines/_configure. -func configureModel(ctx context.Context, httpClient *http.Client, baseURL, model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) error { +func configureModel(ctx context.Context, httpClient *http.Client, baseURL, model string, backend configureBackendConfig, mode *string, rawRuntimeFlags string) error { if httpClient == nil { httpClient = &http.Client{} } configureURL := buildConfigureURL(baseURL) - reqData, err := json.Marshal(buildConfigureRequest(model, contextSize, runtimeFlags, specOpts)) + reqData, err := json.Marshal(buildConfigureRequest(model, backend, mode, rawRuntimeFlags)) if err != nil { return fmt.Errorf("failed to marshal configure request: %w", err) } @@ -62,9 +107,14 @@ func configureModel(ctx context.Context, httpClient *http.Client, baseURL, model slog.Debug("Sending model configure request", "model", model, "url", configureURL, - "context_size", contextSize, - "runtime_flags", runtimeFlags, - "speculative_opts", specOpts) + "context_size", derefInt32(backend.ContextSize), + "runtime_flags", backend.RuntimeFlags, + "raw_runtime_flags", rawRuntimeFlags, + "mode", derefString(mode), + "speculative_opts", backend.Speculative, + "llamacpp", backend.LlamaCpp, + "keep_alive", derefString(backend.KeepAlive), + "vllm", backend.VLLM) resp, err := httpClient.Do(req) if err != nil { @@ -97,116 +147,37 @@ func buildConfigureURL(baseURL string) string { return u.String() } -// buildConfigureRequest constructs the JSON request body for POST /engines/_configure. -func buildConfigureRequest(model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) configureRequest { - req := configureRequest{ - Model: model, +func buildConfigureBackendConfig(contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts, llamaCpp *llamaCppConfig, vllm *vllmConfig, keepAlive *string) configureBackendConfig { + cfg := configureBackendConfig{ RuntimeFlags: runtimeFlags, + LlamaCpp: llamaCpp, + VLLM: vllm, + KeepAlive: keepAlive, } - if contextSize != nil { cs := int32(*contextSize) - req.ContextSize = &cs + cfg.ContextSize = &cs } - if specOpts != nil { - req.Speculative = &speculativeDecodingRequest{ + cfg.Speculative = &speculativeDecodingRequest{ DraftModel: specOpts.draftModel, NumTokens: specOpts.numTokens, MinAcceptanceRate: specOpts.acceptanceRate, } } - - return req + return cfg } -// mergeRuntimeFlagsPreferUser merges derived engine flags (from model config fields like -// `temperature`) and user-provided runtime flags (from `provider_opts.runtime_flags`). -// When both specify the same flag key (e.g. --temp), the user value wins and a warning -// is returned. Order: non-conflicting derived flags first, then all user flags. -func mergeRuntimeFlagsPreferUser(derived, user []string) (merged, warnings []string) { - // parsedFlag holds a parsed flag token (e.g. "--temp 0.5" → key="--temp", tokens=["--temp","0.5"]). - type parsedFlag struct { - key string - tokens []string - } - - parse := func(args []string) []parsedFlag { - var out []parsedFlag - for i := 0; i < len(args); i++ { - tok := args[i] - if !strings.HasPrefix(tok, "-") { - out = append(out, parsedFlag{key: tok, tokens: []string{tok}}) - continue - } - // --key=value - if k, _, found := strings.Cut(tok, "="); found { - out = append(out, parsedFlag{key: k, tokens: []string{tok}}) - continue - } - // --key value (next token is the value if it doesn't start with -) - if i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { - out = append(out, parsedFlag{key: tok, tokens: []string{tok, args[i+1]}}) - i++ - } else { - out = append(out, parsedFlag{key: tok, tokens: []string{tok}}) - } - } - return out - } - - derFlags := parse(derived) - usrFlags := parse(user) - - // Build a set of flag keys the user explicitly provides. - userKeys := make(map[string]bool, len(usrFlags)) - for _, f := range usrFlags { - if strings.HasPrefix(f.key, "-") { - userKeys[f.key] = true - } - } - - // Emit non-conflicting derived flags; warn on conflicts. - for _, f := range derFlags { - if strings.HasPrefix(f.key, "-") && userKeys[f.key] { - warnings = append(warnings, "Overriding runtime flag "+f.key+" with value from provider_opts.runtime_flags") - continue - } - merged = append(merged, f.tokens...) - } - for _, f := range usrFlags { - merged = append(merged, f.tokens...) - } - return merged, warnings -} - -// buildRuntimeFlagsFromModelConfig converts standard ModelConfig fields into backend-specific -// runtime flags that the model-runner understands when launching the engine. -// Currently supports "llama.cpp". Unknown engines produce no flags. -func buildRuntimeFlagsFromModelConfig(engine string, cfg *latest.ModelConfig) []string { - if cfg == nil { - return nil - } - - eng := cmp.Or(strings.TrimSpace(engine), "llama.cpp") - if eng != "llama.cpp" { - return nil - } - - var flags []string - if cfg.Temperature != nil { - flags = append(flags, "--temp", strconv.FormatFloat(*cfg.Temperature, 'f', -1, 64)) - } - if cfg.TopP != nil { - flags = append(flags, "--top-p", strconv.FormatFloat(*cfg.TopP, 'f', -1, 64)) - } - if cfg.FrequencyPenalty != nil { - flags = append(flags, "--frequency-penalty", strconv.FormatFloat(*cfg.FrequencyPenalty, 'f', -1, 64)) - } - if cfg.PresencePenalty != nil { - flags = append(flags, "--presence-penalty", strconv.FormatFloat(*cfg.PresencePenalty, 'f', -1, 64)) +// buildConfigureRequest constructs the JSON request body for POST /engines/_configure. +// mode and rawRuntimeFlags are top-level ConfigureRequest fields (not part of +// BackendConfiguration); pass nil / "" to omit. +func buildConfigureRequest(model string, backend configureBackendConfig, mode *string, rawRuntimeFlags string) configureRequest { + return configureRequest{ + Model: model, + Mode: mode, + RawRuntimeFlags: rawRuntimeFlags, + configureBackendConfig: backend, } - return flags } // parseFloat64 attempts to parse a value as float64 from various types. @@ -240,25 +211,388 @@ func parseInt(v any) (int, bool) { return 0, false } -// parseDMRProviderOpts extracts DMR-specific provider options from the model config: -// context size, runtime flags, and speculative decoding settings. -func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) { +// parseInt64Value parses an int64 from YAML/JSON-decoded values (int, float64, string). +func parseInt64Value(v any) (int64, bool) { + switch t := v.(type) { + case int64: + return t, true + case int: + return int64(t), true + case float64: + return int64(t), true + case string: + s := strings.TrimSpace(t) + if s == "" { + return 0, false + } + n, err := strconv.ParseInt(s, 10, 64) + return n, err == nil + default: + return 0, false + } +} + +// parseContextSize extracts context_size from provider_opts. +// Returns nil when unset, letting model-runner use its default. +func parseContextSize(opts map[string]any) *int64 { + if len(opts) == 0 { + return nil + } + v, ok := opts["context_size"] + if !ok { + return nil + } + if n, ok := parseInt64Value(v); ok { + return &n + } + return nil +} + +// resolveReasoningBudget normalizes a ThinkingBudget to a token count understood by model-runner backends: +// - nil → (0, false) — budget unset, caller should omit the field entirely +// - disabled → (0, true) — budget explicitly disabled, caller should send 0 +// - tokens > 0 → (n, true) — explicit token count +// - adaptive / unknown effort → (-1, true) — unlimited +// - named effort → mapped token count +func resolveReasoningBudget(tb *latest.ThinkingBudget) (budget int64, ok bool) { + if tb == nil { + return 0, false + } + if tb.IsDisabled() { + return 0, true + } + if tb.Tokens != 0 || tb.Effort == "" { + return int64(tb.Tokens), true + } + if tb.IsAdaptive() { + return -1, true + } + if tok, ok := tb.EffortTokens(); ok { + return int64(tok), true + } + return -1, true // unknown effort → unlimited +} + +// buildLlamaCppConfig constructs the llamacpp engine configuration from the model config. +// Currently maps thinking_budget to model-runner's llamacpp.reasoning-budget. +// Returns nil when no relevant config is set. +func buildLlamaCppConfig(cfg *latest.ModelConfig) *llamaCppConfig { + if cfg == nil { + return nil + } + budget, ok := resolveReasoningBudget(cfg.ThinkingBudget) + if !ok { + return nil + } + v := int32(budget) + return &llamaCppConfig{ReasoningBudget: &v} +} + +// buildVLLMRequestFields constructs per-request extra fields for the vLLM engine. +// Currently maps thinking_budget to vLLM's thinking_token_budget sampling parameter. +// Returns nil when no extra fields are needed. +func buildVLLMRequestFields(cfg *latest.ModelConfig) map[string]any { + if cfg == nil { + return nil + } + budget, ok := resolveReasoningBudget(cfg.ThinkingBudget) + if !ok { + return nil + } + return map[string]any{"thinking_token_budget": budget} +} + +func derefInt32(p *int32) any { + if p == nil { + return nil + } + return *p +} + +// derefString safely dereferences a *string for logging. +func derefString(p *string) any { + if p == nil { + return nil + } + return *p +} + +// derefInt64 safely dereferences a *int64 for logging. Returns nil for nil pointers +// so slog renders "" instead of a memory address. +func derefInt64(p *int64) any { + if p == nil { + return nil + } + return *p +} + +const ( + engineLlamaCpp = "llama.cpp" + engineVLLM = "vllm" +) + +// noThinkingMinOutputTokens is the floor we enforce for NoThinking requests +// that also supply a small MaxTokens cap (e.g. session title generation sets +// max_tokens=20). Even with chat_template_kwargs.enable_thinking=false, some +// engines/templates still emit a few reasoning tokens before visible output, +// so a tiny cap can leave the visible text starved. The floor only raises a +// user-supplied cap; if MaxTokens is unset the caller has imposed no cap and +// there is nothing to floor (see client.go for the nil-guarded application +// site). Mirrors the OpenAI provider's 256-token floor (see +// pkg/model/provider/openai/client.go). +const noThinkingMinOutputTokens int64 = 256 + +// dmrParseResult bundles every piece of model-runner configuration that can be +// derived from a ModelConfig. Returning a struct (rather than 5+ positional +// values) keeps the public surface ergonomic as we add more fields. +type dmrParseResult struct { + contextSize *int64 + runtimeFlags []string + rawRuntimeFlags string + specOpts *speculativeDecodingOpts + llamaCpp *llamaCppConfig + vllm *vllmConfig + keepAlive *string + mode *string +} + +// parseDMRProviderOpts extracts DMR-specific provider options from the model +// config: context size, runtime flags, speculative decoding settings, +// backend-specific structured options, and top-level ConfigureRequest fields +// (mode, keep_alive, raw_runtime_flags). +// +// engine is the active model-runner backend (e.g. "llama.cpp", "vllm", "mlx", +// "sglang"). +// +// Any validation error on a user-supplied field is returned so the caller can +// fail fast rather than round-tripping the server and reading the 4xx body. +func parseDMRProviderOpts(engine string, cfg *latest.ModelConfig) (dmrParseResult, error) { + var res dmrParseResult if cfg == nil { - return nil, nil, nil + return res, nil + } + + res.contextSize = parseContextSize(cfg.ProviderOpts) + + if engine == "" || engine == engineLlamaCpp { + res.llamaCpp = buildLlamaCppConfig(cfg) + } + + if engine == engineVLLM { + vllm, err := parseVLLMConfig(cfg.ProviderOpts) + if err != nil { + return res, err + } + res.vllm = vllm + } + + ka, err := parseKeepAlive(cfg.ProviderOpts) + if err != nil { + return res, err + } + res.keepAlive = ka + + mode, err := parseMode(cfg.ProviderOpts) + if err != nil { + return res, err } + res.mode = mode - contextSize = cfg.MaxTokens + if raw, err := parseRawRuntimeFlags(cfg.ProviderOpts); err != nil { + return res, err + } else { + res.rawRuntimeFlags = raw + } - slog.Debug("DMR provider opts", "provider_opts", cfg.ProviderOpts) + slog.Debug("DMR provider opts", "provider_opts", cfg.ProviderOpts, "engine", engine) if len(cfg.ProviderOpts) == 0 { - return contextSize, nil, nil + return res, nil + } + + res.runtimeFlags = parseRuntimeFlags(cfg.ProviderOpts) + res.specOpts = parseSpeculativeOpts(cfg.ProviderOpts) + + if len(res.runtimeFlags) > 0 && res.rawRuntimeFlags != "" { + return res, errors.New("provider_opts: cannot set both runtime_flags and raw_runtime_flags; pick one") + } + + return res, nil +} + +// parseVLLMConfig extracts vLLM-specific configuration from provider_opts. +// Currently supports "gpu_memory_utilization" and "hf_overrides" keys. +// Returns nil when none of the keys are present or all values are invalid. +// hf_overrides is validated client-side with the same key rules model-runner +// enforces (see ../model-runner/pkg/inference/hf_overrides.go). +func parseVLLMConfig(opts map[string]any) (*vllmConfig, error) { + if len(opts) == 0 { + return nil, nil + } + + var vllm *vllmConfig + + if gpuMem, ok := opts["gpu_memory_utilization"]; ok { + if val, ok := parseFloat64(gpuMem); ok { + if val < 0 || val > 1 { + return nil, fmt.Errorf("provider_opts.gpu_memory_utilization must be between 0.0 and 1.0, got %v", val) + } + if vllm == nil { + vllm = &vllmConfig{} + } + vllm.GPUMemoryUtilization = &val + } + } + + if hfOverrides, ok := opts["hf_overrides"]; ok { + if overrides, ok := hfOverrides.(map[string]any); ok { + if err := validateHFOverrides(overrides); err != nil { + return nil, err + } + if vllm == nil { + vllm = &vllmConfig{} + } + vllm.HFOverrides = overrides + } + } + + return vllm, nil +} + +// validHFOverridesKeyRegex mirrors model-runner's regex: keys must be valid Go +// identifier-ish tokens to prevent injection via keys like "--malicious-flag". +var validHFOverridesKeyRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +// validateHFOverrides mirrors inference.HFOverrides.Validate() from model-runner +// so the client can fail fast on bad input instead of waiting for a 400. +func validateHFOverrides(overrides map[string]any) error { + for key, value := range overrides { + if !validHFOverridesKeyRegex.MatchString(key) { + return fmt.Errorf("invalid hf_overrides key %q: must contain only alphanumeric characters and underscores, and start with a letter or underscore", key) + } + if err := validateHFOverridesValue(key, value); err != nil { + return err + } } + return nil +} - runtimeFlags = parseRuntimeFlags(cfg.ProviderOpts) - specOpts = parseSpeculativeOpts(cfg.ProviderOpts) +func validateHFOverridesValue(key string, value any) error { + switch v := value.(type) { + case string, bool, float64, float32, nil, + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64: + return nil + case []any: + for i, elem := range v { + if err := validateHFOverridesValue(fmt.Sprintf("%s[%d]", key, i), elem); err != nil { + return err + } + } + return nil + case map[string]any: + for nestedKey, nestedValue := range v { + if !validHFOverridesKeyRegex.MatchString(nestedKey) { + return fmt.Errorf("invalid hf_overrides nested key %q in %q: must contain only alphanumeric characters and underscores, and start with a letter or underscore", nestedKey, key) + } + if err := validateHFOverridesValue(key+"."+nestedKey, nestedValue); err != nil { + return err + } + } + return nil + default: + return fmt.Errorf("invalid hf_overrides value for key %q: unsupported type %T", key, value) + } +} - return contextSize, runtimeFlags, specOpts +// parseKeepAlive extracts keep_alive from provider_opts and validates it using +// the same rules as model-runner's inference.ParseKeepAlive: +// - Go duration strings: "5m", "1h", "30s" +// - "0" to unload immediately +// - Any negative value ("-1", "-1m") to keep loaded forever +// +// Returns nil when unset, letting model-runner use its default (5 minutes). +func parseKeepAlive(opts map[string]any) (*string, error) { + if len(opts) == 0 { + return nil, nil + } + v, ok := opts["keep_alive"] + if !ok { + return nil, nil + } + s, ok := v.(string) + if !ok { + return nil, fmt.Errorf(`provider_opts.keep_alive must be a string (e.g. "5m", "1h", "-1"), got %T`, v) + } + s = strings.TrimSpace(s) + if s == "" { + return nil, errors.New("provider_opts.keep_alive must not be empty") + } + if err := validateKeepAlive(s); err != nil { + return nil, err + } + return &s, nil +} + +// validateKeepAlive enforces the same rules as model-runner's inference.ParseKeepAlive. +func validateKeepAlive(s string) error { + if s == "0" || s == "-1" { + return nil + } + if _, err := time.ParseDuration(s); err != nil { + return fmt.Errorf("invalid keep_alive duration %q: %w", s, err) + } + return nil +} + +// validModes mirrors the set accepted by model-runner's ParseBackendMode. +var validModes = map[string]struct{}{ + "completion": {}, + "embedding": {}, + "reranking": {}, + "image-generation": {}, +} + +// parseMode extracts mode from provider_opts. When unset the scheduler auto- +// detects mode from the request path, so nil is the safe default. +func parseMode(opts map[string]any) (*string, error) { + if len(opts) == 0 { + return nil, nil + } + v, ok := opts["mode"] + if !ok { + return nil, nil + } + s, ok := v.(string) + if !ok { + return nil, fmt.Errorf("provider_opts.mode must be a string, got %T", v) + } + s = strings.TrimSpace(s) + if _, ok := validModes[s]; !ok { + return nil, fmt.Errorf("provider_opts.mode %q is invalid; must be one of: completion, embedding, reranking, image-generation", s) + } + return &s, nil +} + +// parseRawRuntimeFlags extracts raw_runtime_flags as a single shell-style string. +// Model-runner parses this via shellwords; keep user validation minimal and +// reject empty/whitespace-only values. +func parseRawRuntimeFlags(opts map[string]any) (string, error) { + if len(opts) == 0 { + return "", nil + } + v, ok := opts["raw_runtime_flags"] + if !ok { + return "", nil + } + s, ok := v.(string) + if !ok { + return "", fmt.Errorf("provider_opts.raw_runtime_flags must be a string, got %T", v) + } + if strings.TrimSpace(s) == "" { + return "", nil + } + return s, nil } // parseRuntimeFlags extracts the "runtime_flags" key from provider opts. diff --git a/pkg/model/provider/openai/client.go b/pkg/model/provider/openai/client.go index 381350ac6..c9c24e57e 100644 --- a/pkg/model/provider/openai/client.go +++ b/pkg/model/provider/openai/client.go @@ -280,6 +280,11 @@ func (c *Client) CreateChatCompletionStream( // reasoning tokens don't exhaust the max_completion_tokens budget. // We use "low" instead of "minimal" because older models (o3-mini, o1) // only accept low/medium/high. + // + // If the caller also supplied a small MaxTokens cap, raise it to + // noThinkingMinOutputTokens so residual hidden reasoning can't starve + // visible output. The nil-guard is intentional: when MaxTokens is unset + // the caller has imposed no cap, so there is nothing to floor. if isOpenAIReasoningModel(c.ModelConfig.Model) { if c.ModelOptions.NoThinking() { params.ReasoningEffort = shared.ReasoningEffort("low") @@ -408,6 +413,11 @@ func (c *Client) CreateResponseStream( // Those hidden reasoning tokens still count against max_output_tokens, // so with a small budget (e.g. title generation) the model can exhaust // all tokens on reasoning and return empty visible text. + // + // If the caller also supplied a small MaxTokens cap, raise it to + // noThinkingMinOutputTokens so residual hidden reasoning can't starve + // visible output. The nil-guard is intentional: when MaxTokens is unset + // the caller has imposed no cap, so there is nothing to floor. if isOpenAIReasoningModel(c.ModelConfig.Model) { if c.ModelOptions.NoThinking() { // Use low effort so the model spends as few output tokens as @@ -1074,12 +1084,14 @@ func isOpenAIReasoningModel(model string) bool { strings.HasPrefix(m, "gpt-5") } -// noThinkingMinOutputTokens is the minimum max-output-token budget for -// reasoning models when NoThinking is set. Even with low reasoning effort -// the model still produces hidden reasoning tokens that count against -// max_output_tokens / max_completion_tokens. A small budget (e.g. 20) -// gets entirely consumed by reasoning, leaving nothing for visible text. -// 256 tokens is enough for low-effort reasoning plus a short visible response. +// noThinkingMinOutputTokens is the minimum output-token budget we enforce for +// reasoning models when NoThinking is set and the caller has also supplied a +// smaller MaxTokens cap. Even with low reasoning effort the model still +// produces hidden reasoning tokens that count against max_output_tokens / +// max_completion_tokens, so a tiny cap (e.g. 20) can get entirely consumed +// by reasoning and leave nothing for visible text. The floor only raises an +// explicit cap; if MaxTokens is unset the caller has imposed no cap and there +// is nothing to floor. const noThinkingMinOutputTokens int64 = 256 // openAIReasoningEffort validates a ThinkingBudget effort string for the