Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion internal/guard/app/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ func TestProcessHookEventEnsuresDaemonObservedSession(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if session.Source != "daemon_observed" || session.Status != "open" || session.Agent != "claude" {
if session.Source != "daemon_observed" ||
session.Status != "open" ||
session.AgentProvider != "anthropic" ||
session.Agent != "claude_code" {
t.Fatalf("session = %+v, want daemon-observed local session", session)
}
events, err := store.Events(context.Background(), "local")
Expand Down
14 changes: 8 additions & 6 deletions internal/guard/store/sqlite/githubdryrun.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ func githubDryRunActionValues(actionID, sessionID string, event risk.HookEvent,
"request_summary": riskEvent.RequestSummary,
"path_class": riskEvent.PathClass,
})
// The managed endpoint's trusted identity is the service account +
// installation; hook payloads are session telemetry, not human identity.
identityJSON, identityHash := mustHashJSON(map[string]any{
"agent": event.Agent,
"principal_kind": "service_account",
})
agentProvider, canonicalAgent := hostedAgentIdentity(event.Agent)
identityPayload := map[string]any{
"agent": canonicalAgent,
}
if agentProvider != "" {
identityPayload["agent_provider"] = agentProvider
}
identityJSON, identityHash := mustHashJSON(identityPayload)

githubContext := map[string]any{}
if owner, repo, ok := splitRepoSlug(evaluation.Request.Resource); ok {
Expand Down
16 changes: 14 additions & 2 deletions internal/guard/store/sqlite/githubdryrun_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func TestSaveDecisionRecordsGithubDryRunRows(t *testing.T) {
}

var dryRun LedgerRecord
var runtimeDecision LedgerRecord
decidedRows := 0
for _, action := range batch.Actions {
if action["canonical_event_type"] != "request.decided" {
Expand All @@ -63,6 +64,8 @@ func TestSaveDecisionRecordsGithubDryRunRows(t *testing.T) {
decidedRows++
if action["decision_category"] == "dry_run" {
dryRun = action
} else {
runtimeDecision = action
}
}
if decidedRows != 2 {
Expand All @@ -71,6 +74,9 @@ func TestSaveDecisionRecordsGithubDryRunRows(t *testing.T) {
if dryRun == nil {
t.Fatal("no dry_run request.decided row found")
}
if runtimeDecision == nil {
t.Fatal("no runtime request.decided row found")
}

expectations := map[string]any{
"decision_result": "deny",
Expand Down Expand Up @@ -115,8 +121,14 @@ func TestSaveDecisionRecordsGithubDryRunRows(t *testing.T) {
}

identityJSON, _ := dryRun["identity_context_json"].(map[string]any)
if identityJSON["principal_kind"] != "service_account" {
t.Fatalf("identity_context_json = %v, want service_account principal", identityJSON)
if identityJSON["agent_provider"] != "anthropic" || identityJSON["agent"] != "claude_code" {
t.Fatalf("identity_context_json = %v, want canonical Claude identity", identityJSON)
}
if identityJSON["principal_kind"] != nil {
t.Fatalf("identity_context_json = %v, want agent identity only", identityJSON)
}
if dryRun["identity_hash"] != runtimeDecision["identity_hash"] {
t.Fatalf("dry-run identity_hash = %v, runtime identity_hash = %v", dryRun["identity_hash"], runtimeDecision["identity_hash"])
}

if decisionAt, _ := dryRun["decision_at"].(string); decisionAt == "" {
Expand Down
2 changes: 1 addition & 1 deletion internal/guard/store/sqlite/ledger.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (s *Store) AgentSessions(ctx context.Context, ids []string) ([]LedgerRecord
args = append(args, id)
}
return queryLedgerRecords(ctx, s.db, fmt.Sprintf(`
select id, runtime_kind, runtime_instance_id, adapter_kind, adapter_version, agent,
select id, runtime_kind, runtime_instance_id, adapter_kind, adapter_version, agent_provider, agent,
conversation_id, trace_id, principal_id, identity_context_json, identity_hash,
policy_version, policy_hash, cwd, source, status, external_id, closed_at,
mode, created_at, updated_at
Expand Down
72 changes: 49 additions & 23 deletions internal/guard/store/sqlite/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,17 @@ type SessionSummary struct {
}

type SessionRecord struct {
ID string `json:"id"`
Agent string `json:"agent,omitempty"`
CWD string `json:"cwd,omitempty"`
Source string `json:"source,omitempty"`
Status string `json:"status,omitempty"`
ExternalID string `json:"external_id,omitempty"`
Mode string `json:"mode,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ClosedAt *time.Time `json:"closed_at,omitempty"`
ID string `json:"id"`
AgentProvider string `json:"agent_provider,omitempty"`
Agent string `json:"agent,omitempty"`
CWD string `json:"cwd,omitempty"`
Source string `json:"source,omitempty"`
Status string `json:"status,omitempty"`
ExternalID string `json:"external_id,omitempty"`
Mode string `json:"mode,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ClosedAt *time.Time `json:"closed_at,omitempty"`
}

func OpenStore(path string) (*Store, error) {
Expand Down Expand Up @@ -116,6 +117,7 @@ func (s *Store) migrate(ctx context.Context) error {
runtime_instance_id text,
adapter_kind text,
adapter_version text,
agent_provider text,
agent text,
conversation_id text,
trace_id text,
Expand Down Expand Up @@ -278,6 +280,7 @@ func (s *Store) migrate(ctx context.Context) error {
{name: "runtime_instance_id", def: "text"},
{name: "adapter_kind", def: "text"},
{name: "adapter_version", def: "text"},
{name: "agent_provider", def: "text"},
{name: "conversation_id", def: "text"},
{name: "trace_id", def: "text"},
{name: "principal_id", def: "text"},
Expand Down Expand Up @@ -612,13 +615,15 @@ func (s *Store) OpenSession(ctx context.Context, sessionID, agent, cwd, source,
func (s *Store) OpenSessionWithMode(ctx context.Context, sessionID, agent, cwd, source, externalID, mode string) (SessionRecord, error) {
now := time.Now().UTC()
sessionID = normalizeSessionID(sessionID)
agentProvider, canonicalAgent := hostedAgentIdentity(agent)
if source == "" {
source = "daemon_observed"
}
_, err := s.db.ExecContext(ctx, `
insert into agent_sessions(id, agent, cwd, source, status, external_id, mode, closed_at, created_at, updated_at)
values(?, ?, ?, ?, 'open', ?, ?, null, ?, ?)
insert into agent_sessions(id, agent_provider, agent, cwd, source, status, external_id, mode, closed_at, created_at, updated_at)
values(?, ?, ?, ?, ?, 'open', ?, ?, null, ?, ?)
on conflict(id) do update set
agent_provider = coalesce(nullif(excluded.agent_provider, ''), agent_sessions.agent_provider),
agent = coalesce(nullif(excluded.agent, ''), agent_sessions.agent),
cwd = coalesce(nullif(excluded.cwd, ''), agent_sessions.cwd),
source = case
Expand All @@ -631,7 +636,7 @@ on conflict(id) do update set
mode = coalesce(nullif(excluded.mode, ''), agent_sessions.mode),
closed_at = null,
updated_at = excluded.updated_at
`, sessionID, agent, cwd, source, externalID, mode, now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano))
`, sessionID, agentProvider, canonicalAgent, cwd, source, externalID, mode, now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano))
if err != nil {
return SessionRecord{}, err
}
Expand All @@ -645,10 +650,12 @@ func (s *Store) EnsureObservedSession(ctx context.Context, sessionID, agent, cwd
func (s *Store) EnsureObservedSessionWithMode(ctx context.Context, sessionID, agent, cwd, mode string) (SessionRecord, error) {
now := time.Now().UTC()
sessionID = normalizeSessionID(sessionID)
agentProvider, canonicalAgent := hostedAgentIdentity(agent)
_, err := s.db.ExecContext(ctx, `
insert into agent_sessions(id, agent, cwd, source, status, mode, created_at, updated_at)
values(?, ?, ?, 'daemon_observed', 'open', ?, ?, ?)
insert into agent_sessions(id, agent_provider, agent, cwd, source, status, mode, created_at, updated_at)
values(?, ?, ?, ?, 'daemon_observed', 'open', ?, ?, ?)
on conflict(id) do update set
agent_provider = coalesce(nullif(excluded.agent_provider, ''), agent_sessions.agent_provider),
agent = coalesce(nullif(excluded.agent, ''), agent_sessions.agent),
cwd = coalesce(nullif(excluded.cwd, ''), agent_sessions.cwd),
mode = coalesce(nullif(excluded.mode, ''), agent_sessions.mode),
Expand All @@ -661,7 +668,7 @@ on conflict(id) do update set
else null
end,
updated_at = excluded.updated_at
`, sessionID, agent, cwd, mode, now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano))
`, sessionID, agentProvider, canonicalAgent, cwd, mode, now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano))
if err != nil {
return SessionRecord{}, err
}
Expand Down Expand Up @@ -693,7 +700,7 @@ where source = 'daemon_observed'

func (s *Store) Session(ctx context.Context, sessionID string) (SessionRecord, error) {
row := s.db.QueryRowContext(ctx, `
select id, coalesce(agent, ''), coalesce(cwd, ''), source, status, coalesce(external_id, ''),
select id, coalesce(agent_provider, ''), coalesce(agent, ''), coalesce(cwd, ''), source, status, coalesce(external_id, ''),
coalesce(mode, ''), created_at, updated_at, closed_at
from agent_sessions
where id = ?
Expand All @@ -712,14 +719,16 @@ func (s *Store) SaveDecision(ctx context.Context, event risk.HookEvent, decision
defer func() {
_ = tx.Rollback()
}()
agentProvider, canonicalAgent := hostedAgentIdentity(event.Agent)
_, err = tx.ExecContext(ctx, `
insert into agent_sessions(id, agent, cwd, source, status, created_at, updated_at)
values(?, ?, ?, 'daemon_observed', 'open', ?, ?)
insert into agent_sessions(id, agent_provider, agent, cwd, source, status, created_at, updated_at)
values(?, ?, ?, ?, 'daemon_observed', 'open', ?, ?)
on conflict(id) do update set
agent_provider = coalesce(nullif(excluded.agent_provider, ''), agent_sessions.agent_provider),
agent = coalesce(nullif(excluded.agent, ''), agent_sessions.agent),
cwd = coalesce(nullif(excluded.cwd, ''), agent_sessions.cwd),
updated_at = excluded.updated_at
`, sessionID, event.Agent, event.CWD, now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano))
`, sessionID, agentProvider, canonicalAgent, event.CWD, now.Format(time.RFC3339Nano), now.Format(time.RFC3339Nano))
if err != nil {
return DecisionRecord{}, err
}
Expand Down Expand Up @@ -828,9 +837,14 @@ func actionValues(actionID, sessionID string, event risk.HookEvent, decision ris
"request_summary": riskEvent.RequestSummary,
"path_class": riskEvent.PathClass,
})
identityJSON, identityHash := mustHashJSON(map[string]any{
"agent": event.Agent,
})
agentProvider, canonicalAgent := hostedAgentIdentity(event.Agent)
identityPayload := map[string]any{
"agent": canonicalAgent,
}
if agentProvider != "" {
identityPayload["agent_provider"] = agentProvider
}
identityJSON, identityHash := mustHashJSON(identityPayload)
Comment thread
michiosw marked this conversation as resolved.
contextPayload := map[string]any{
"cwd": event.CWD,
"hook_event_name": event.HookEventName,
Expand Down Expand Up @@ -1573,6 +1587,7 @@ func scanSession(scanner interface{ Scan(...any) error }) (SessionRecord, error)
var closed sql.NullString
if err := scanner.Scan(
&record.ID,
&record.AgentProvider,
&record.Agent,
&record.CWD,
&record.Source,
Expand Down Expand Up @@ -1605,6 +1620,17 @@ func scanSession(scanner interface{ Scan(...any) error }) (SessionRecord, error)
return record, nil
}

func hostedAgentIdentity(agent string) (string, string) {
switch strings.ToLower(strings.TrimSpace(agent)) {
case "claude", "claude-code", "claude_code":
return "anthropic", "claude_code"
case "cowork", "claude-cowork", "claude_cowork":
return "anthropic", "claude_cowork"
default:
return "", strings.TrimSpace(agent)
}
}

func parseSessionSummaryTimes(item *SessionSummary, latest, created, updated string, closed sql.NullString) error {
latestAt, err := parseStoredTime("session latest_at", latest)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion internal/guard/store/sqlite/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,9 @@ func TestLedgerBatchExportsSessionsActionsAndReceipts(t *testing.T) {
if batch.Sessions[0]["mode"] != "observe" {
t.Fatalf("session export mode = %q, want observe", batch.Sessions[0]["mode"])
}
if batch.Sessions[0]["agent_provider"] != "anthropic" || batch.Sessions[0]["agent"] != "claude_code" {
t.Fatalf("session export identity = provider %q agent %q, want anthropic claude_code", batch.Sessions[0]["agent_provider"], batch.Sessions[0]["agent"])
}
decided := ledgerRecordByID(batch.Actions, record.ID)
if decided == nil ||
decided["canonical_event_type"] != "request.decided" ||
Expand Down Expand Up @@ -1378,7 +1381,8 @@ func TestOpenAndCloseSessionRecordsLifecycle(t *testing.T) {
t.Fatal(err)
}
if opened.ID != "session-123" ||
opened.Agent != "claude" ||
opened.AgentProvider != "anthropic" ||
opened.Agent != "claude_code" ||
opened.CWD != "/tmp/project" ||
opened.Source != "wrapper_owned" ||
opened.Status != "open" ||
Expand Down
Loading