From 20a607bb287605945a55f7bf0284640e240e21fc Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sun, 26 Apr 2026 09:38:09 +0200 Subject: [PATCH 1/6] feat(skills): allow fork skills to override the model Fork skills (context: fork) previously inherited the parent agent's model with no way to switch. Add a model field to SKILL.md frontmatter that overrides the model used while the sub-session runs and is restored when the skill completes. - Parse `model:` from SKILL.md frontmatter into Skill.Model. - Add Agent.ModelOverrides() returning a defensive copy so callers can save and restore the active override around a sub-session. - In handleRunSkill, resolve the skill's model ref (named, inline provider/model, or comma-separated alloy) via the runtime's model switcher, apply it as the agent's override, and defer restoration. Falls back to the parent's model with a warning if resolution fails or the runtime has no model switcher configured. - Document the new field and add an example in the skills guide. - Cover parsing and the not-configured runtime path with tests, and extend TestModelOverride for the new ModelOverrides() getter and save/restore round-trip. Assisted-By: docker-agent --- docs/features/skills/index.md | 32 +++++++++++++++++ pkg/agent/agent.go | 12 +++++++ pkg/agent/agent_test.go | 17 +++++++++ pkg/runtime/skill_runner.go | 62 ++++++++++++++++++++++++++++++++ pkg/runtime/skill_runner_test.go | 25 +++++++++++++ pkg/skills/skills.go | 7 ++++ pkg/skills/skills_test.go | 37 +++++++++++++++++++ 7 files changed, 192 insertions(+) create mode 100644 pkg/runtime/skill_runner_test.go diff --git a/docs/features/skills/index.md b/docs/features/skills/index.md index 81feb516b..73df9a625 100644 --- a/docs/features/skills/index.md +++ b/docs/features/skills/index.md @@ -98,6 +98,7 @@ When asked to create a Dockerfile: | `name` | Yes | Unique skill identifier | | `description` | Yes | Short description shown to the agent for skill matching | | `context` | No | Set to `fork` to run the skill as an isolated sub-agent (see below) | +| `model` | No | Override the model used while running the skill as a sub-agent (fork only) | | `allowed-tools` | No | List of tools the skill needs (YAML list or comma-separated string) | | `license` | No | License identifier (e.g. `Apache-2.0`) | | `compatibility` | No | Free-text compatibility notes | @@ -138,6 +139,37 @@ When the agent encounters a task that matches a `context: fork` skill, it uses t +### Overriding the model for a fork skill + +Fork skills can declare a `model` field in their frontmatter to use a +different model than the parent agent for the duration of the sub-session. +This is useful when a skill is best handled by a faster, cheaper, or more +specialised model — for example a powerful reasoning model for refactors, +or a fast model for routine bookkeeping work. The override only applies +while the skill is running; the parent agent keeps its own model. + +The `model` value accepts either a named model from the agent config or +an inline `provider/model` reference (and the same comma-separated alloy +syntax as the rest of the agent config): + + +```yaml +--- +name: bump-go-dependencies +description: Update Go module dependencies one by one +context: fork +model: openai/gpt-4o-mini +--- + +# Bump Dependencies + +1. ... +``` + +If the model reference cannot be resolved (unknown name, missing +credentials, runtime not configured for model switching, …) the skill +falls back to the parent agent's default model and a warning is logged. + ## Search Paths Skills are discovered from these locations (later overrides earlier): diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 154f104fd..d33f2036d 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -180,6 +180,18 @@ func (a *Agent) HasModelOverride() bool { return overrides != nil && len(*overrides) > 0 } +// ModelOverrides returns the currently active model override providers, +// or nil when no override is set. The returned slice is a copy so it can +// be safely retained by the caller (e.g. to save/restore the override +// around a sub-session). +func (a *Agent) ModelOverrides() []provider.Provider { + overrides := a.modelOverrides.Load() + if overrides == nil || len(*overrides) == 0 { + return nil + } + return append([]provider.Provider(nil), (*overrides)...) +} + // ConfiguredModels returns the originally configured models for this agent. // This is useful for listing available models in the TUI picker. func (a *Agent) ConfiguredModels() []provider.Provider { diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 63790795b..a93c92fcb 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -148,10 +148,13 @@ func TestModelOverride(t *testing.T) { model := a.Model() assert.Equal(t, "openai/gpt-4o", model.ID()) assert.False(t, a.HasModelOverride()) + assert.Nil(t, a.ModelOverrides()) // Set an override a.SetModelOverride(overrideModel) assert.True(t, a.HasModelOverride()) + require.Len(t, a.ModelOverrides(), 1) + assert.Equal(t, "anthropic/claude-sonnet-4-0", a.ModelOverrides()[0].ID()) // Now Model() should return the override model = a.Model() @@ -162,9 +165,23 @@ func TestModelOverride(t *testing.T) { require.Len(t, configuredModels, 1) assert.Equal(t, "openai/gpt-4o", configuredModels[0].ID()) + // Mutating the slice returned by ModelOverrides must not affect the agent + snapshot := a.ModelOverrides() + snapshot[0] = defaultModel + require.Len(t, a.ModelOverrides(), 1) + assert.Equal(t, "anthropic/claude-sonnet-4-0", a.ModelOverrides()[0].ID()) + + // Save / restore round-trip using ModelOverrides + prev := a.ModelOverrides() + a.SetModelOverride(defaultModel) + assert.Equal(t, "openai/gpt-4o", a.Model().ID()) + a.SetModelOverride(prev...) + assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model().ID()) + // Clear the override a.SetModelOverride(nil) assert.False(t, a.HasModelOverride()) + assert.Nil(t, a.ModelOverrides()) // Model() should return the default again model = a.Model() diff --git a/pkg/runtime/skill_runner.go b/pkg/runtime/skill_runner.go index d900e010b..5687b3ea1 100644 --- a/pkg/runtime/skill_runner.go +++ b/pkg/runtime/skill_runner.go @@ -3,12 +3,16 @@ package runtime import ( "context" "encoding/json" + "errors" "fmt" "log/slog" + "strings" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/builtin" @@ -67,6 +71,24 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session "task", params.Task, ) + // If the skill declares a model override, apply it for the duration of + // the sub-session and restore the previous override when done. The + // parent agent loop is blocked while the sub-session runs, so this + // save/restore is safe. + if skill.Model != "" { + restore, err := r.applySkillModelOverride(ctx, a, skill.Model) + if err != nil { + slog.Warn("Failed to apply skill model override; using default model", + "agent", ca, + "skill", params.Name, + "model", skill.Model, + "error", err, + ) + } else { + defer restore() + } + } + cfg := SubSessionConfig{ Task: params.Task, SystemMessage: skillContent, @@ -80,3 +102,43 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session s := newSubSession(sess, cfg, a) return r.runSubSessionForwarding(ctx, sess, s, span, evts, ca) } + +// applySkillModelOverride resolves modelRef and applies it as the model +// override on a, returning a restore func that reinstates the previous +// override (or clears it if none was set). modelRef can be a named model +// from the config, an inline "provider/model" spec, or an inline alloy +// (comma-separated). The runtime must be configured with model switching +// for this to succeed. +func (r *LocalRuntime) applySkillModelOverride(ctx context.Context, a *agent.Agent, modelRef string) (func(), error) { + if r.modelSwitcherCfg == nil { + return nil, errors.New("model switching is not configured for this runtime") + } + + var providers []provider.Provider + var err error + if strings.Contains(modelRef, ",") { + providers, err = r.resolveModelRefs(ctx, modelRef) + } else { + var p provider.Provider + p, err = r.resolveModelRef(ctx, modelRef) + if err == nil { + providers = []provider.Provider{p} + } + } + if err != nil { + return nil, fmt.Errorf("resolve model %q: %w", modelRef, err) + } + + prev := a.ModelOverrides() + a.SetModelOverride(providers...) + slog.Debug("Applied skill model override", "agent", a.Name(), "model", modelRef, "count", len(providers)) + + return func() { + if len(prev) == 0 { + a.SetModelOverride() + } else { + a.SetModelOverride(prev...) + } + slog.Debug("Restored skill model override", "agent", a.Name()) + }, nil +} diff --git a/pkg/runtime/skill_runner_test.go b/pkg/runtime/skill_runner_test.go new file mode 100644 index 000000000..f31e17045 --- /dev/null +++ b/pkg/runtime/skill_runner_test.go @@ -0,0 +1,25 @@ +package runtime + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/agent" +) + +// TestApplySkillModelOverride_NotConfigured verifies that the helper +// returns an error (and does not modify the agent state) when the runtime +// has no model switcher configured. +func TestApplySkillModelOverride_NotConfigured(t *testing.T) { + t.Parallel() + + r := &LocalRuntime{} + a := agent.New("root", "test") + + restore, err := r.applySkillModelOverride(t.Context(), a, "openai/gpt-4o") + require.Error(t, err) + assert.Nil(t, restore) + assert.False(t, a.HasModelOverride(), "agent override must not be applied on error") +} diff --git a/pkg/skills/skills.go b/pkg/skills/skills.go index 4d61179aa..264a7fc6c 100644 --- a/pkg/skills/skills.go +++ b/pkg/skills/skills.go @@ -26,6 +26,11 @@ type Skill struct { Metadata map[string]string AllowedTools []string Context string // "fork" to run the skill as an isolated sub-agent + // Model is an optional model override applied while the skill runs as + // a sub-agent (context: fork). It accepts either a named model from the + // agent config or an inline "provider/model" reference (e.g. + // "openai/gpt-4o-mini"). It is ignored for non-fork skills. + Model string } // IsFork returns true when the skill should be executed in an isolated @@ -358,6 +363,8 @@ func parseFrontmatter(content string) (Skill, bool) { skill.Compatibility = unquote(value) case "context": skill.Context = unquote(value) + case "model": + skill.Model = unquote(value) case "metadata": currentKey = "metadata" case "allowed-tools": diff --git a/pkg/skills/skills_test.go b/pkg/skills/skills_test.go index 289f7b387..6de0734e4 100644 --- a/pkg/skills/skills_test.go +++ b/pkg/skills/skills_test.go @@ -159,6 +159,42 @@ Body`, }, wantOK: true, }, + { + name: "model override (named)", + content: `--- +name: model-skill +description: A skill that overrides the model +context: fork +model: my_fast_model +--- + +Body`, + want: Skill{ + Name: "model-skill", + Description: "A skill that overrides the model", + Context: "fork", + Model: "my_fast_model", + }, + wantOK: true, + }, + { + name: "model override (inline provider/model)", + content: `--- +name: inline-model-skill +description: Skill with inline provider/model override +context: fork +model: openai/gpt-4o-mini +--- + +Body`, + want: Skill{ + Name: "inline-model-skill", + Description: "Skill with inline provider/model override", + Context: "fork", + Model: "openai/gpt-4o-mini", + }, + wantOK: true, + }, { name: "allowed-tools list with quoted items", content: "---\nname: quoted-tools\ndescription: Skill with quoted tool items\nallowed-tools:\n - \"Bash(git:*)\"\n - 'Read'\n---\n\nBody", @@ -192,6 +228,7 @@ Body`, assert.Equal(t, tt.want.Metadata, got.Metadata) assert.Equal(t, tt.want.AllowedTools, got.AllowedTools) assert.Equal(t, tt.want.Context, got.Context) + assert.Equal(t, tt.want.Model, got.Model) } }) } From 7e027927d36959b4b88ef1221337bcbaae24bb8b Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sun, 26 Apr 2026 09:41:17 +0200 Subject: [PATCH 2/6] refactor(runtime): reuse SetAgentModel for skill model overrides The applySkillModelOverride helper duplicated a less complete subset of the resolution logic already implemented by SetAgentModel (named model, alloy from config, inline provider/model, inline alloy spec). Replace the 40-line helper with a 10-line save/restore around a direct SetAgentModel call: capture the previous override via ModelOverrides(), apply the new one, and on success defer SetModelOverride(prev...). SetModelOverride filters nil providers and clears the override when the resulting slice is empty, so the same call handles both "restore previous override" and "restore default" without an explicit branch. This drops the now-unused agent / provider / errors / strings imports and the standalone helper test (the no-switcher error path is already covered by SetAgentModel itself). Assisted-By: docker-agent --- pkg/runtime/skill_runner.go | 54 ++++---------------------------- pkg/runtime/skill_runner_test.go | 25 --------------- 2 files changed, 6 insertions(+), 73 deletions(-) delete mode 100644 pkg/runtime/skill_runner_test.go diff --git a/pkg/runtime/skill_runner.go b/pkg/runtime/skill_runner.go index 5687b3ea1..6c721722c 100644 --- a/pkg/runtime/skill_runner.go +++ b/pkg/runtime/skill_runner.go @@ -3,16 +3,12 @@ package runtime import ( "context" "encoding/json" - "errors" "fmt" "log/slog" - "strings" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "github.com/docker/docker-agent/pkg/agent" - "github.com/docker/docker-agent/pkg/model/provider" "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/tools" "github.com/docker/docker-agent/pkg/tools/builtin" @@ -74,10 +70,12 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session // If the skill declares a model override, apply it for the duration of // the sub-session and restore the previous override when done. The // parent agent loop is blocked while the sub-session runs, so this - // save/restore is safe. + // save/restore is safe. SetAgentModel handles every accepted form + // (named model, alloy, inline provider/model, inline alloy); on + // failure we just log a warning and fall back to the parent's model. if skill.Model != "" { - restore, err := r.applySkillModelOverride(ctx, a, skill.Model) - if err != nil { + prev := a.ModelOverrides() + if err := r.SetAgentModel(ctx, ca, skill.Model); err != nil { slog.Warn("Failed to apply skill model override; using default model", "agent", ca, "skill", params.Name, @@ -85,7 +83,7 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session "error", err, ) } else { - defer restore() + defer a.SetModelOverride(prev...) } } @@ -102,43 +100,3 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session s := newSubSession(sess, cfg, a) return r.runSubSessionForwarding(ctx, sess, s, span, evts, ca) } - -// applySkillModelOverride resolves modelRef and applies it as the model -// override on a, returning a restore func that reinstates the previous -// override (or clears it if none was set). modelRef can be a named model -// from the config, an inline "provider/model" spec, or an inline alloy -// (comma-separated). The runtime must be configured with model switching -// for this to succeed. -func (r *LocalRuntime) applySkillModelOverride(ctx context.Context, a *agent.Agent, modelRef string) (func(), error) { - if r.modelSwitcherCfg == nil { - return nil, errors.New("model switching is not configured for this runtime") - } - - var providers []provider.Provider - var err error - if strings.Contains(modelRef, ",") { - providers, err = r.resolveModelRefs(ctx, modelRef) - } else { - var p provider.Provider - p, err = r.resolveModelRef(ctx, modelRef) - if err == nil { - providers = []provider.Provider{p} - } - } - if err != nil { - return nil, fmt.Errorf("resolve model %q: %w", modelRef, err) - } - - prev := a.ModelOverrides() - a.SetModelOverride(providers...) - slog.Debug("Applied skill model override", "agent", a.Name(), "model", modelRef, "count", len(providers)) - - return func() { - if len(prev) == 0 { - a.SetModelOverride() - } else { - a.SetModelOverride(prev...) - } - slog.Debug("Restored skill model override", "agent", a.Name()) - }, nil -} diff --git a/pkg/runtime/skill_runner_test.go b/pkg/runtime/skill_runner_test.go deleted file mode 100644 index f31e17045..000000000 --- a/pkg/runtime/skill_runner_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package runtime - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/docker/docker-agent/pkg/agent" -) - -// TestApplySkillModelOverride_NotConfigured verifies that the helper -// returns an error (and does not modify the agent state) when the runtime -// has no model switcher configured. -func TestApplySkillModelOverride_NotConfigured(t *testing.T) { - t.Parallel() - - r := &LocalRuntime{} - a := agent.New("root", "test") - - restore, err := r.applySkillModelOverride(t.Context(), a, "openai/gpt-4o") - require.Error(t, err) - assert.Nil(t, restore) - assert.False(t, a.HasModelOverride(), "agent override must not be applied on error") -} From 2d1ffd986012691f313c8fe71a9d87b9dfb15ccc Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sun, 26 Apr 2026 09:52:50 +0200 Subject: [PATCH 3/6] fix(skills): preserve concurrent model picks via CAS restore The previous save/restore around a fork-skill sub-session was racy against the TUI model picker: a user switching model mid-skill had their choice silently reverted by the deferred restore. Replace the naive ModelOverrides() snapshot + SetModelOverride(prev...) restore with a pointer-identity compare-and-swap: - Add an opaque ModelOverrideSnapshot type plus SnapshotModelOverride and RestoreModelOverride methods on Agent. RestoreModelOverride CAS-swaps on the underlying atomic.Pointer, so the restore is a no-op if any other caller mutated the override since it was captured. - handleRunSkill now snapshots before and after applying the skill's override and defers a CAS-restore. If the user switches model via the TUI (or any other caller wins the race), their choice is kept and the skill's override scope ends cleanly. - Document that ModelOverrides() is for read-only inspection only and that the snapshot/restore primitive must be used for scoped overrides. - Clarify the docs: fallback is to the agent's currently-active model (configured default OR a previously-set override), and concurrent TUI model switches during a fork skill are preserved. Validated with race-detector tests covering the happy path, restore to a pre-existing override, and both concurrent-change and concurrent-clear scenarios. Assisted-By: docker-agent --- docs/features/skills/index.md | 10 ++++- pkg/agent/agent.go | 43 +++++++++++++++++++- pkg/agent/agent_test.go | 74 +++++++++++++++++++++++++++++++++++ pkg/runtime/skill_runner.go | 21 ++++++---- 4 files changed, 137 insertions(+), 11 deletions(-) diff --git a/docs/features/skills/index.md b/docs/features/skills/index.md index 73df9a625..362e305de 100644 --- a/docs/features/skills/index.md +++ b/docs/features/skills/index.md @@ -168,7 +168,15 @@ model: openai/gpt-4o-mini If the model reference cannot be resolved (unknown name, missing credentials, runtime not configured for model switching, …) the skill -falls back to the parent agent's default model and a warning is logged. +falls back to the agent's currently-active model (its configured +default, or any override the user previously set via the model picker) +and a warning is logged. + +When the skill completes, the agent's previous model is restored — but +only if no one else changed the model in the meantime. If the user +switches the model via the TUI model picker while the fork skill is +running, their choice is preserved (the deferred restore becomes a +no-op). ## Search Paths diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index d33f2036d..539bc6f24 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -182,8 +182,14 @@ func (a *Agent) HasModelOverride() bool { // ModelOverrides returns the currently active model override providers, // or nil when no override is set. The returned slice is a copy so it can -// be safely retained by the caller (e.g. to save/restore the override -// around a sub-session). +// be safely retained by the caller for read-only inspection. +// +// Do NOT use this for save/restore around a temporary override: the +// returned slice is a snapshot of the contents at call time, so naive +// save+restore can clobber a concurrent change made by another caller +// (e.g. the TUI model picker switching the model while a skill +// sub-session is running). Use SnapshotModelOverride / RestoreModelOverride +// for safe scoped overrides instead. func (a *Agent) ModelOverrides() []provider.Provider { overrides := a.modelOverrides.Load() if overrides == nil || len(*overrides) == 0 { @@ -192,6 +198,39 @@ func (a *Agent) ModelOverrides() []provider.Provider { return append([]provider.Provider(nil), (*overrides)...) } +// ModelOverrideSnapshot is an opaque token that captures the agent's model +// override at a point in time. Pass it to RestoreModelOverride to undo a +// scoped override safely. +type ModelOverrideSnapshot struct { + // ptr is the raw atomic pointer value at snapshot time. It is used for + // pointer-identity compare-and-swap, never dereferenced by callers. + ptr *[]provider.Provider +} + +// SnapshotModelOverride captures the agent's current model override. The +// returned snapshot is opaque; pass it to RestoreModelOverride later to +// restore the captured value. +func (a *Agent) SnapshotModelOverride() ModelOverrideSnapshot { + return ModelOverrideSnapshot{ptr: a.modelOverrides.Load()} +} + +// RestoreModelOverride atomically restores the override to the value +// captured by `prev`, but only if the current override is still the one +// captured by `current` (pointer identity). If another caller has changed +// the override since `current` was captured, the restore is a no-op so +// that the concurrent change wins. +// +// This is the safe primitive for applying a temporary override around a +// scope (e.g. a skill sub-session) without clobbering changes made by +// concurrent callers such as the TUI model picker. +func (a *Agent) RestoreModelOverride(prev, current ModelOverrideSnapshot) { + if a.modelOverrides.CompareAndSwap(current.ptr, prev.ptr) { + slog.Debug("Restored model override", "agent", a.name) + } else { + slog.Debug("Model override changed concurrently; skipping restore", "agent", a.name) + } +} + // ConfiguredModels returns the originally configured models for this agent. // This is useful for listing available models in the TUI picker. func (a *Agent) ConfiguredModels() []provider.Provider { diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index a93c92fcb..796cff797 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -188,6 +188,80 @@ func TestModelOverride(t *testing.T) { assert.Equal(t, "openai/gpt-4o", model.ID()) } +func TestSnapshotAndRestoreModelOverride(t *testing.T) { + t.Parallel() + + defaultModel := &mockProvider{id: "openai/gpt-4o"} + skillModel := &mockProvider{id: "openai/gpt-4o-mini"} + userModel := &mockProvider{id: "anthropic/claude-sonnet-4-0"} + + t.Run("restores when no concurrent change", func(t *testing.T) { + t.Parallel() + a := New("root", "test", WithModel(defaultModel)) + + prev := a.SnapshotModelOverride() + a.SetModelOverride(skillModel) + ours := a.SnapshotModelOverride() + assert.Equal(t, "openai/gpt-4o-mini", a.Model().ID()) + + a.RestoreModelOverride(prev, ours) + assert.False(t, a.HasModelOverride()) + assert.Equal(t, "openai/gpt-4o", a.Model().ID()) + }) + + t.Run("restores back to a pre-existing override", func(t *testing.T) { + t.Parallel() + a := New("root", "test", WithModel(defaultModel)) + a.SetModelOverride(userModel) + + prev := a.SnapshotModelOverride() + a.SetModelOverride(skillModel) + ours := a.SnapshotModelOverride() + assert.Equal(t, "openai/gpt-4o-mini", a.Model().ID()) + + a.RestoreModelOverride(prev, ours) + assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model().ID()) + }) + + t.Run("keeps a concurrent change instead of restoring", func(t *testing.T) { + // This is the TUI-while-skill-runs scenario: another caller + // changes the override between SnapshotModelOverride and + // RestoreModelOverride. The deferred restore must NOT clobber + // that change. + t.Parallel() + a := New("root", "test", WithModel(defaultModel)) + + prev := a.SnapshotModelOverride() + a.SetModelOverride(skillModel) + ours := a.SnapshotModelOverride() + + // Simulate concurrent TUI model switch. + a.SetModelOverride(userModel) + + a.RestoreModelOverride(prev, ours) + require.True(t, a.HasModelOverride(), "user's model choice must be preserved") + assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model().ID()) + }) + + t.Run("keeps a concurrent clear instead of restoring", func(t *testing.T) { + // Same as above but the concurrent caller clears the override + // (e.g. user revert via TUI). The restore must respect that. + t.Parallel() + a := New("root", "test", WithModel(defaultModel)) + a.SetModelOverride(userModel) + + prev := a.SnapshotModelOverride() + a.SetModelOverride(skillModel) + ours := a.SnapshotModelOverride() + + // Simulate concurrent TUI revert. + a.SetModelOverride() + + a.RestoreModelOverride(prev, ours) + assert.False(t, a.HasModelOverride(), "user's revert must be preserved") + }) +} + func TestModel_LogsSelection(t *testing.T) { t.Parallel() diff --git a/pkg/runtime/skill_runner.go b/pkg/runtime/skill_runner.go index 6c721722c..20f13f5e5 100644 --- a/pkg/runtime/skill_runner.go +++ b/pkg/runtime/skill_runner.go @@ -68,22 +68,27 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session ) // If the skill declares a model override, apply it for the duration of - // the sub-session and restore the previous override when done. The - // parent agent loop is blocked while the sub-session runs, so this - // save/restore is safe. SetAgentModel handles every accepted form - // (named model, alloy, inline provider/model, inline alloy); on - // failure we just log a warning and fall back to the parent's model. + // the sub-session and restore the previous override when done. + // SetAgentModel handles every accepted form (named model, alloy, inline + // provider/model, inline alloy); on failure we just log a warning and + // fall back to the agent's currently-active model. + // + // We snapshot before and after the apply so the deferred restore can + // CAS back to the previous value only if no concurrent caller (e.g. the + // TUI model picker) changed the override in the meantime — otherwise + // the user's choice would be silently reverted on skill completion. if skill.Model != "" { - prev := a.ModelOverrides() + prev := a.SnapshotModelOverride() if err := r.SetAgentModel(ctx, ca, skill.Model); err != nil { - slog.Warn("Failed to apply skill model override; using default model", + slog.Warn("Failed to apply skill model override; using current model", "agent", ca, "skill", params.Name, "model", skill.Model, "error", err, ) } else { - defer a.SetModelOverride(prev...) + ours := a.SnapshotModelOverride() + defer a.RestoreModelOverride(prev, ours) } } From d9a3f2b5e6ae3e5dbcd3265d512551db1ac2bc9d Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 12:56:27 +0200 Subject: [PATCH 4/6] refactor(runtime): introduce WithAgentModel, drop dead ModelOverrides MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two follow-up simplifications now that the skill model-override feature is in place. - Encapsulate the snapshot+SetAgentModel+snapshot+restore-closure idiom in a new LocalRuntime.WithAgentModel(ctx, name, ref) (restore, err) helper. The skill runner shrinks from a dozen lines (with two manual SnapshotModelOverride calls and a hand-written defer) to a 6-line if/else: get a restore closure or log and skip. The CAS semantics live in one place. - Drop the now-unused Agent.ModelOverrides() getter. It was added for the original naive save/restore in handleRunSkill, but the CAS refactor moved off slice-copy save/restore. The only remaining caller was the unit test, which now asserts via Model().ID() and HasModelOverride() — same coverage with less ceremony. Snapshot / restore is still tested by TestSnapshotAndRestoreModelOverride. No feature change. Validated with build, vet, golangci-lint, and the race-detector test suite for pkg/agent, pkg/runtime, pkg/skills, pkg/tools/builtin, pkg/tools/builtin/agent. Assisted-By: docker-agent --- pkg/agent/agent.go | 18 ------------------ pkg/agent/agent_test.go | 32 ++++---------------------------- pkg/runtime/model_switcher.go | 23 +++++++++++++++++++++++ pkg/runtime/skill_runner.go | 19 ++++++------------- 4 files changed, 33 insertions(+), 59 deletions(-) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 539bc6f24..bfb1061d6 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -180,24 +180,6 @@ func (a *Agent) HasModelOverride() bool { return overrides != nil && len(*overrides) > 0 } -// ModelOverrides returns the currently active model override providers, -// or nil when no override is set. The returned slice is a copy so it can -// be safely retained by the caller for read-only inspection. -// -// Do NOT use this for save/restore around a temporary override: the -// returned slice is a snapshot of the contents at call time, so naive -// save+restore can clobber a concurrent change made by another caller -// (e.g. the TUI model picker switching the model while a skill -// sub-session is running). Use SnapshotModelOverride / RestoreModelOverride -// for safe scoped overrides instead. -func (a *Agent) ModelOverrides() []provider.Provider { - overrides := a.modelOverrides.Load() - if overrides == nil || len(*overrides) == 0 { - return nil - } - return append([]provider.Provider(nil), (*overrides)...) -} - // ModelOverrideSnapshot is an opaque token that captures the agent's model // override at a point in time. Pass it to RestoreModelOverride to undo a // scoped override safely. diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 796cff797..c37b7042f 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -145,47 +145,23 @@ func TestModelOverride(t *testing.T) { a := New("root", "test", WithModel(defaultModel)) // Initially should return the default model - model := a.Model() - assert.Equal(t, "openai/gpt-4o", model.ID()) + assert.Equal(t, "openai/gpt-4o", a.Model().ID()) assert.False(t, a.HasModelOverride()) - assert.Nil(t, a.ModelOverrides()) // Set an override a.SetModelOverride(overrideModel) assert.True(t, a.HasModelOverride()) - require.Len(t, a.ModelOverrides(), 1) - assert.Equal(t, "anthropic/claude-sonnet-4-0", a.ModelOverrides()[0].ID()) - - // Now Model() should return the override - model = a.Model() - assert.Equal(t, "anthropic/claude-sonnet-4-0", model.ID()) + assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model().ID()) - // ConfiguredModels should still return the original models + // ConfiguredModels still reflects the originally configured models configuredModels := a.ConfiguredModels() require.Len(t, configuredModels, 1) assert.Equal(t, "openai/gpt-4o", configuredModels[0].ID()) - // Mutating the slice returned by ModelOverrides must not affect the agent - snapshot := a.ModelOverrides() - snapshot[0] = defaultModel - require.Len(t, a.ModelOverrides(), 1) - assert.Equal(t, "anthropic/claude-sonnet-4-0", a.ModelOverrides()[0].ID()) - - // Save / restore round-trip using ModelOverrides - prev := a.ModelOverrides() - a.SetModelOverride(defaultModel) - assert.Equal(t, "openai/gpt-4o", a.Model().ID()) - a.SetModelOverride(prev...) - assert.Equal(t, "anthropic/claude-sonnet-4-0", a.Model().ID()) - // Clear the override a.SetModelOverride(nil) assert.False(t, a.HasModelOverride()) - assert.Nil(t, a.ModelOverrides()) - - // Model() should return the default again - model = a.Model() - assert.Equal(t, "openai/gpt-4o", model.ID()) + assert.Equal(t, "openai/gpt-4o", a.Model().ID()) } func TestSnapshotAndRestoreModelOverride(t *testing.T) { diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 88fad7bc1..bc4b0208e 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -154,6 +154,29 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st return nil } +// WithAgentModel applies modelRef as a model override on the named agent +// and returns a function that restores the previous override safely. +// +// The returned restore func uses pointer-identity compare-and-swap on the +// agent's override, so a concurrent change made between the apply and the +// restore (e.g. by the TUI model picker) is preserved instead of being +// clobbered. +// +// If modelRef cannot be resolved, the agent is left untouched, restore is +// nil, and an error is returned. +func (r *LocalRuntime) WithAgentModel(ctx context.Context, agentName, modelRef string) (restore func(), err error) { + a, err := r.team.Agent(agentName) + if err != nil { + return nil, fmt.Errorf("agent not found: %w", err) + } + prev := a.SnapshotModelOverride() + if err := r.SetAgentModel(ctx, agentName, modelRef); err != nil { + return nil, err + } + ours := a.SnapshotModelOverride() + return func() { a.RestoreModelOverride(prev, ours) }, nil +} + // resolveModelRef resolves a model reference to a single provider. // The reference can be a named model from the config or an inline // "provider/model" spec (e.g. "openai/gpt-4o-mini"). diff --git a/pkg/runtime/skill_runner.go b/pkg/runtime/skill_runner.go index 20f13f5e5..358de724d 100644 --- a/pkg/runtime/skill_runner.go +++ b/pkg/runtime/skill_runner.go @@ -68,18 +68,12 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session ) // If the skill declares a model override, apply it for the duration of - // the sub-session and restore the previous override when done. - // SetAgentModel handles every accepted form (named model, alloy, inline - // provider/model, inline alloy); on failure we just log a warning and - // fall back to the agent's currently-active model. - // - // We snapshot before and after the apply so the deferred restore can - // CAS back to the previous value only if no concurrent caller (e.g. the - // TUI model picker) changed the override in the meantime — otherwise - // the user's choice would be silently reverted on skill completion. + // the sub-session. WithAgentModel handles every accepted form (named + // model, alloy, inline provider/model, inline alloy) and returns a + // CAS-safe restore func; on failure we log a warning and fall back to + // the agent's currently-active model. if skill.Model != "" { - prev := a.SnapshotModelOverride() - if err := r.SetAgentModel(ctx, ca, skill.Model); err != nil { + if restore, err := r.WithAgentModel(ctx, ca, skill.Model); err != nil { slog.Warn("Failed to apply skill model override; using current model", "agent", ca, "skill", params.Name, @@ -87,8 +81,7 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session "error", err, ) } else { - ours := a.SnapshotModelOverride() - defer a.RestoreModelOverride(prev, ours) + defer restore() } } From 61bc8ba13e5415d7ac114e2e33820387b22f980b Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:04:59 +0200 Subject: [PATCH 5/6] fix(runtime): always-non-nil restore from WithAgentModel + add tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address two reviewer findings on the WithAgentModel helper. - Make the returned restore closure always non-nil. On error it is a no-op closure; on success it performs the CAS-restore. Callers can now defer restore() unconditionally without nil-checking, removing a footgun for future callers. The skill runner is simplified accordingly: a single defer restore() right after the call instead of an if/else around it. - Add dedicated unit tests for WithAgentModel covering: * agent not found — error + safe no-op restore * nil modelSwitcherCfg — error + safe no-op restore + agent untouched * invalid model ref — error + safe no-op restore + agent untouched * apply clears existing override; restore puts it back * restore is idempotent — second call is a CAS no-op * concurrent change is preserved by restore (TUI scenario at the runtime layer) The "apply" tests deliberately use the empty-string model ref form, which short-circuits inside SetAgentModel to a.SetModelOverride() and therefore needs no provider resolution. This keeps the tests hermetic while still exercising the snapshot+apply+CAS-restore composition. Validated with go build, go vet, golangci-lint (0 offenses), and the full pkg/agent + pkg/runtime + pkg/skills + pkg/tools/builtin test suites under -race. Assisted-By: docker-agent --- pkg/runtime/model_switcher.go | 18 ++-- pkg/runtime/skill_runner.go | 10 +- pkg/runtime/with_agent_model_test.go | 145 +++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 pkg/runtime/with_agent_model_test.go diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index bc4b0208e..2cb525019 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -157,21 +157,21 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st // WithAgentModel applies modelRef as a model override on the named agent // and returns a function that restores the previous override safely. // -// The returned restore func uses pointer-identity compare-and-swap on the -// agent's override, so a concurrent change made between the apply and the -// restore (e.g. by the TUI model picker) is preserved instead of being -// clobbered. -// -// If modelRef cannot be resolved, the agent is left untouched, restore is -// nil, and an error is returned. +// The returned restore func is always non-nil. On success it uses +// pointer-identity compare-and-swap on the agent's override, so a +// concurrent change made between the apply and the restore (e.g. by the +// TUI model picker) is preserved instead of being clobbered. On error +// the agent is left untouched and restore is a no-op, so callers can +// always defer it without nil-checking. func (r *LocalRuntime) WithAgentModel(ctx context.Context, agentName, modelRef string) (restore func(), err error) { + noop := func() {} a, err := r.team.Agent(agentName) if err != nil { - return nil, fmt.Errorf("agent not found: %w", err) + return noop, fmt.Errorf("agent not found: %w", err) } prev := a.SnapshotModelOverride() if err := r.SetAgentModel(ctx, agentName, modelRef); err != nil { - return nil, err + return noop, err } ours := a.SnapshotModelOverride() return func() { a.RestoreModelOverride(prev, ours) }, nil diff --git a/pkg/runtime/skill_runner.go b/pkg/runtime/skill_runner.go index 358de724d..99de35b3b 100644 --- a/pkg/runtime/skill_runner.go +++ b/pkg/runtime/skill_runner.go @@ -70,18 +70,18 @@ func (r *LocalRuntime) handleRunSkill(ctx context.Context, sess *session.Session // If the skill declares a model override, apply it for the duration of // the sub-session. WithAgentModel handles every accepted form (named // model, alloy, inline provider/model, inline alloy) and returns a - // CAS-safe restore func; on failure we log a warning and fall back to - // the agent's currently-active model. + // CAS-safe restore func that is always non-nil; on failure we log a + // warning and fall back to the agent's currently-active model. if skill.Model != "" { - if restore, err := r.WithAgentModel(ctx, ca, skill.Model); err != nil { + restore, err := r.WithAgentModel(ctx, ca, skill.Model) + defer restore() + if err != nil { slog.Warn("Failed to apply skill model override; using current model", "agent", ca, "skill", params.Name, "model", skill.Model, "error", err, ) - } else { - defer restore() } } diff --git a/pkg/runtime/with_agent_model_test.go b/pkg/runtime/with_agent_model_test.go new file mode 100644 index 000000000..7b4ca44d1 --- /dev/null +++ b/pkg/runtime/with_agent_model_test.go @@ -0,0 +1,145 @@ +package runtime + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/agent" + "github.com/docker/docker-agent/pkg/team" +) + +// TestWithAgentModel covers the LocalRuntime.WithAgentModel helper. The +// helper is the public entry point for "apply a temporary model override +// for a scope, then CAS-restore it"; it composes resolution (SetAgentModel) +// with the agent-level snapshot/restore primitives. +func TestWithAgentModel(t *testing.T) { + t.Parallel() + + t.Run("agent not found returns no-op restore and error", func(t *testing.T) { + t.Parallel() + tm := team.New(team.WithAgents(agent.New("root", "test"))) + r := &LocalRuntime{ + team: tm, + modelSwitcherCfg: &ModelSwitcherConfig{}, + } + + restore, err := r.WithAgentModel(t.Context(), "missing", "openai/gpt-4o") + require.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") + require.NotNil(t, restore, "restore must always be non-nil") + assert.NotPanics(t, restore, "restore on error must be a safe no-op") + }) + + t.Run("nil modelSwitcherCfg returns no-op restore and error", func(t *testing.T) { + t.Parallel() + root := agent.New("root", "test") + tm := team.New(team.WithAgents(root)) + r := &LocalRuntime{team: tm} // modelSwitcherCfg is nil + + restore, err := r.WithAgentModel(t.Context(), "root", "openai/gpt-4o") + require.Error(t, err) + require.NotNil(t, restore) + assert.NotPanics(t, restore) + assert.False(t, root.HasModelOverride(), "agent state must not be touched on error") + }) + + t.Run("invalid model ref returns no-op restore and error", func(t *testing.T) { + t.Parallel() + root := agent.New("root", "test") + tm := team.New(team.WithAgents(root)) + r := &LocalRuntime{ + team: tm, + modelSwitcherCfg: &ModelSwitcherConfig{}, + } + + // "invalid" has no slash → not an inline spec, and no named config + // matches → SetAgentModel returns an error. + restore, err := r.WithAgentModel(t.Context(), "root", "invalid") + require.Error(t, err) + require.NotNil(t, restore) + assert.NotPanics(t, restore) + assert.False(t, root.HasModelOverride(), "agent state must not be touched on error") + }) + + t.Run("apply clears existing override; restore puts it back", func(t *testing.T) { + t.Parallel() + // Pre-existing override (e.g. set by the user via the model picker + // before the skill ran). + userPick := &mockProvider{id: "user/pick"} + root := agent.New("root", "test", agent.WithModel(&mockProvider{id: "default/model"})) + root.SetModelOverride(userPick) + require.Equal(t, "user/pick", root.Model().ID()) + + tm := team.New(team.WithAgents(root)) + r := &LocalRuntime{ + team: tm, + modelSwitcherCfg: &ModelSwitcherConfig{}, + } + + // Empty modelRef clears the override (handled inside SetAgentModel + // without requiring any provider resolution). + restore, err := r.WithAgentModel(t.Context(), "root", "") + require.NoError(t, err) + require.NotNil(t, restore) + + // Inside the scope: override is cleared. + assert.False(t, root.HasModelOverride()) + assert.Equal(t, "default/model", root.Model().ID()) + + // After restore: user's pick is back. + restore() + assert.True(t, root.HasModelOverride()) + assert.Equal(t, "user/pick", root.Model().ID()) + }) + + t.Run("restore is idempotent", func(t *testing.T) { + t.Parallel() + root := agent.New("root", "test", agent.WithModel(&mockProvider{id: "default/model"})) + userPick := &mockProvider{id: "user/pick"} + root.SetModelOverride(userPick) + + tm := team.New(team.WithAgents(root)) + r := &LocalRuntime{ + team: tm, + modelSwitcherCfg: &ModelSwitcherConfig{}, + } + + restore, err := r.WithAgentModel(t.Context(), "root", "") + require.NoError(t, err) + + restore() + assert.Equal(t, "user/pick", root.Model().ID()) + // Second call is a CAS no-op (the state is already restored). + assert.NotPanics(t, restore) + assert.Equal(t, "user/pick", root.Model().ID()) + }) + + t.Run("concurrent change is preserved by restore", func(t *testing.T) { + t.Parallel() + // This is the TUI-while-skill-runs scenario at the runtime layer: + // after the skill applies its override, another caller (e.g. the + // model picker) sets a different override before the deferred + // restore runs. The restore must NOT clobber that change. + root := agent.New("root", "test", agent.WithModel(&mockProvider{id: "default/model"})) + tm := team.New(team.WithAgents(root)) + r := &LocalRuntime{ + team: tm, + modelSwitcherCfg: &ModelSwitcherConfig{}, + } + + // Apply: clears any override (none was set). + restore, err := r.WithAgentModel(t.Context(), "root", "") + require.NoError(t, err) + + // Concurrent caller wins between apply and restore. + userPick := &mockProvider{id: "user/pick"} + root.SetModelOverride(userPick) + + // Restore must be a no-op because the override changed. + restore() + require.True(t, root.HasModelOverride(), "concurrent change must be preserved") + assert.Equal(t, "user/pick", root.Model().ID()) + }) +} From cd7f9aed21a6c04b322b997b4be1aa6f820f9d8a Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:18:30 +0200 Subject: [PATCH 6/6] fix(runtime): close race window in WithAgentModel snapshot capture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer flagged a real (if narrow) race in WithAgentModel: prev := a.SnapshotModelOverride() if err := r.SetAgentModel(ctx, name, ref); err != nil { ... } ours := a.SnapshotModelOverride() // <-- race window here Between SetAgentModel returning and the post-apply snapshot being captured, a concurrent caller (TUI model picker, etc.) could store its own override. `ours` would then refer to that caller's pointer instead of the one we just stored, and the deferred CAS-restore would incorrectly succeed and clobber the user's choice. Fix the race by capturing the snapshot atomically with the store itself: - SetModelOverride now returns ModelOverrideSnapshot — the pointer it just stored. Existing callers that don't care about the snapshot ignore the return value (Go allows this), so no source changes are needed beyond the agent itself. - Refactor SetAgentModel into a thin wrapper over a new setAgentModelInternal that returns (snapshot, error). Each branch now threads the snapshot from a.SetModelOverride(...) through the return. - WithAgentModel no longer calls SnapshotModelOverride after the apply: it uses the snapshot returned by setAgentModelInternal, which holds exactly the pointer the agent stored, regardless of any concurrent SetModelOverride call. The CAS-restore now correctly fails (preserving the concurrent change) instead of incorrectly succeeding. Tests: - TestSetModelOverride_ReturnsSnapshotOfStoredValue proves the new property: a concurrent SetModelOverride between our store and a later restore is preserved, because oursSnap holds the pointer we stored, not the load-after-store value. - TestSetModelOverride_ClearReturnsZeroSnapshot verifies the clear-and-restore round-trip via the returned snapshot. Validated with go build, go vet, golangci-lint (0 issues), and the race-detector test suite for pkg/agent, pkg/runtime, pkg/skills, pkg/tools/builtin, pkg/tools/builtin/agent, and the full mise test suite. Assisted-By: docker-agent --- pkg/agent/agent.go | 13 ++++++-- pkg/agent/agent_test.go | 50 +++++++++++++++++++++++++++++ pkg/runtime/model_switcher.go | 60 +++++++++++++++++++++++------------ 3 files changed, 100 insertions(+), 23 deletions(-) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index bfb1061d6..8c7bf354f 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -152,7 +152,13 @@ func (a *Agent) Model() provider.Provider { // The override(s) take precedence over the configured models. // For alloy models, multiple providers can be passed and one will be randomly selected. // Pass no arguments or nil providers to clear the override. -func (a *Agent) SetModelOverride(models ...provider.Provider) { +// +// SetModelOverride returns a snapshot of the value that was just stored. +// Callers performing a scoped override (apply now, restore later) should +// keep this snapshot and pass it as `current` to RestoreModelOverride so +// the deferred restore can detect concurrent changes via CAS. Callers +// that only need the side-effect can ignore the return value. +func (a *Agent) SetModelOverride(models ...provider.Provider) ModelOverrideSnapshot { // Filter out nil providers var validModels []provider.Provider for _, m := range models { @@ -161,17 +167,20 @@ func (a *Agent) SetModelOverride(models ...provider.Provider) { } } + var ptr *[]provider.Provider if len(validModels) == 0 { a.modelOverrides.Store(nil) slog.Debug("Cleared model override", "agent", a.name) } else { - a.modelOverrides.Store(&validModels) + ptr = &validModels + a.modelOverrides.Store(ptr) ids := make([]string, len(validModels)) for i, m := range validModels { ids[i] = m.ID() } slog.Debug("Set model override", "agent", a.name, "models", ids) } + return ModelOverrideSnapshot{ptr: ptr} } // HasModelOverride returns true if a model override is currently set. diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index c37b7042f..ce5b59f02 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -164,6 +164,56 @@ func TestModelOverride(t *testing.T) { assert.Equal(t, "openai/gpt-4o", a.Model().ID()) } +func TestSetModelOverride_ReturnsSnapshotOfStoredValue(t *testing.T) { + // SetModelOverride must return a snapshot of the value it just stored, + // not what a subsequent SnapshotModelOverride() would load. This is the + // guarantee that closes the race window for scoped overrides: if a + // concurrent caller stores a different override after our store but + // before we capture our snapshot, our snapshot must still refer to + // what we stored, so the deferred CAS-restore will fail (concurrent + // change wins) instead of incorrectly succeeding. + t.Parallel() + + defaultModel := &mockProvider{id: "default"} + oursModel := &mockProvider{id: "ours"} + othersModel := &mockProvider{id: "others"} + + a := New("root", "test", WithModel(defaultModel)) + + // Capture the snapshot returned by SetModelOverride. + prev := a.SnapshotModelOverride() + oursSnap := a.SetModelOverride(oursModel) + + // Simulate a concurrent caller storing a different override _after_ we + // stored ours but _before_ a hypothetical post-store SnapshotModelOverride. + a.SetModelOverride(othersModel) + require.Equal(t, "others", a.Model().ID()) + + // The deferred restore must be a no-op because oursSnap holds the + // pointer we stored, not the current pointer. + a.RestoreModelOverride(prev, oursSnap) + assert.Equal(t, "others", a.Model().ID(), + "concurrent override must be preserved; the snapshot returned by SetModelOverride captures the stored pointer") +} + +func TestSetModelOverride_ClearReturnsZeroSnapshot(t *testing.T) { + t.Parallel() + + a := New("root", "test", WithModel(&mockProvider{id: "default"})) + + // Calling SetModelOverride with no providers (or nil) clears the override. + // The returned snapshot should round-trip cleanly through RestoreModelOverride. + cleared := a.SetModelOverride() + assert.False(t, a.HasModelOverride()) + + // Now set an override and restore using `cleared` as `prev`. + oursSnap := a.SetModelOverride(&mockProvider{id: "ours"}) + require.True(t, a.HasModelOverride()) + + a.RestoreModelOverride(cleared, oursSnap) + assert.False(t, a.HasModelOverride(), "restoring a cleared snapshot must clear the override") +} + func TestSnapshotAndRestoreModelOverride(t *testing.T) { t.Parallel() diff --git a/pkg/runtime/model_switcher.go b/pkg/runtime/model_switcher.go index 2cb525019..43578bbd7 100644 --- a/pkg/runtime/model_switcher.go +++ b/pkg/runtime/model_switcher.go @@ -8,6 +8,7 @@ import ( "slices" "strings" + "github.com/docker/docker-agent/pkg/agent" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider" @@ -93,20 +94,34 @@ type ModelSwitcherConfig struct { // SetAgentModel implements ModelSwitcher for LocalRuntime. func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef string) error { + _, err := r.setAgentModelInternal(ctx, agentName, modelRef) + return err +} + +// setAgentModelInternal applies modelRef as the agent's model override and +// returns a snapshot of the value that was just stored. The snapshot is +// captured atomically with the store (it is the pointer returned by +// SetModelOverride itself), so there is no window where another caller +// could intervene and the snapshot would refer to a different value. +// +// SetAgentModel is a thin wrapper that discards the snapshot; callers that +// want to do a CAS-based restore (see WithAgentModel) use this method +// directly to keep the snapshot. +func (r *LocalRuntime) setAgentModelInternal(ctx context.Context, agentName, modelRef string) (agent.ModelOverrideSnapshot, error) { if r.modelSwitcherCfg == nil { - return errors.New("model switching not configured for this runtime") + return agent.ModelOverrideSnapshot{}, errors.New("model switching not configured for this runtime") } a, err := r.team.Agent(agentName) if err != nil { - return fmt.Errorf("agent not found: %w", err) + return agent.ModelOverrideSnapshot{}, fmt.Errorf("agent not found: %w", err) } // Empty modelRef means clear the override (use agent's default) if modelRef == "" { - a.SetModelOverride() + snap := a.SetModelOverride() slog.Info("Cleared agent model override (using default)", "agent", agentName) - return nil + return snap, nil } // Check if modelRef is a named model from config @@ -116,20 +131,20 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st if isAlloyModelConfig(modelConfig) { providers, err := r.resolveModelRefs(ctx, modelConfig.Model) if err != nil { - return fmt.Errorf("failed to create alloy model from config: %w", err) + return agent.ModelOverrideSnapshot{}, fmt.Errorf("failed to create alloy model from config: %w", err) } - a.SetModelOverride(providers...) + snap := a.SetModelOverride(providers...) slog.Info("Set agent model override (alloy)", "agent", agentName, "config_name", modelRef, "model_count", len(providers)) - return nil + return snap, nil } prov, err := r.createProviderFromConfig(ctx, &modelConfig) if err != nil { - return fmt.Errorf("failed to create model from config: %w", err) + return agent.ModelOverrideSnapshot{}, fmt.Errorf("failed to create model from config: %w", err) } - a.SetModelOverride(prov) + snap := a.SetModelOverride(prov) slog.Info("Set agent model override", "agent", agentName, "model", prov.ID(), "config_name", modelRef) - return nil + return snap, nil } // Check if this is an inline alloy spec (comma-separated provider/model specs) @@ -137,21 +152,21 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st if isInlineAlloySpec(modelRef) { providers, err := r.resolveModelRefs(ctx, modelRef) if err != nil { - return fmt.Errorf("failed to create inline alloy model: %w", err) + return agent.ModelOverrideSnapshot{}, fmt.Errorf("failed to create inline alloy model: %w", err) } - a.SetModelOverride(providers...) + snap := a.SetModelOverride(providers...) slog.Info("Set agent model override (inline alloy)", "agent", agentName, "model_count", len(providers)) - return nil + return snap, nil } // Try single inline spec (provider/model) prov, err := r.resolveModelRef(ctx, modelRef) if err != nil { - return fmt.Errorf("failed to resolve model %q: %w", modelRef, err) + return agent.ModelOverrideSnapshot{}, fmt.Errorf("failed to resolve model %q: %w", modelRef, err) } - a.SetModelOverride(prov) + snap := a.SetModelOverride(prov) slog.Info("Set agent model override (inline)", "agent", agentName, "model", prov.ID()) - return nil + return snap, nil } // WithAgentModel applies modelRef as a model override on the named agent @@ -160,9 +175,12 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st // The returned restore func is always non-nil. On success it uses // pointer-identity compare-and-swap on the agent's override, so a // concurrent change made between the apply and the restore (e.g. by the -// TUI model picker) is preserved instead of being clobbered. On error -// the agent is left untouched and restore is a no-op, so callers can -// always defer it without nil-checking. +// TUI model picker) is preserved instead of being clobbered. The post- +// apply snapshot is captured atomically with the store inside +// SetModelOverride, so there is no window where a concurrent change +// could be misattributed to this scope. On error the agent is left +// untouched and restore is a no-op, so callers can always defer it +// without nil-checking. func (r *LocalRuntime) WithAgentModel(ctx context.Context, agentName, modelRef string) (restore func(), err error) { noop := func() {} a, err := r.team.Agent(agentName) @@ -170,10 +188,10 @@ func (r *LocalRuntime) WithAgentModel(ctx context.Context, agentName, modelRef s return noop, fmt.Errorf("agent not found: %w", err) } prev := a.SnapshotModelOverride() - if err := r.SetAgentModel(ctx, agentName, modelRef); err != nil { + ours, err := r.setAgentModelInternal(ctx, agentName, modelRef) + if err != nil { return noop, err } - ours := a.SnapshotModelOverride() return func() { a.RestoreModelOverride(prev, ours) }, nil }