Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pkg/cli/codemod_dependabot_permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package cli

import (
"fmt"
"sort"
"strings"

"github.com/github/gh-aw/pkg/logger"
Expand Down Expand Up @@ -234,10 +233,9 @@ func findPermissionsInsertIndex(lines []string) int {

func sortedMissingPermissionKeys(missing map[workflow.PermissionScope]workflow.PermissionLevel) []string {
keys := make([]string, 0, len(missing))
for scope := range missing {
for _, scope := range sliceutil.SortedKeys(missing) {
keys = append(keys, string(scope))
}
sort.Strings(keys)
return keys
}

Expand Down
38 changes: 3 additions & 35 deletions pkg/cli/compile_update_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,31 +127,7 @@ func shouldRunCompileUpdateCheck(noCheckUpdate bool) bool {
}

lastCheckFile := getCompileUpdateCheckFilePath()
if lastCheckFile == "" {
compileUpdateCheckLog.Print("Could not determine compile update check file path")
return false
}

data, err := os.ReadFile(lastCheckFile)
if err != nil {
if !os.IsNotExist(err) {
compileUpdateCheckLog.Printf("Error reading compile update check file: %v", err)
}
return true
}

lastCheck, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data)))
if err != nil {
compileUpdateCheckLog.Printf("Error parsing compile update check time: %v", err)
return true
}

elapsed := time.Since(lastCheck)
if elapsed < compileUpdateCheckInterval {
compileUpdateCheckLog.Printf("Last compile update check was %v ago, skipping", elapsed)
return false
}
return true
return shouldRunUpdateCheckAtPath(lastCheckFile, compileUpdateCheckInterval, "compile update check", compileUpdateCheckLog)
}

func waitForCompileUpdateNotification(ctx context.Context, results <-chan *compileUpdateNotification, timeout time.Duration) *compileUpdateNotification {
Expand Down Expand Up @@ -292,19 +268,11 @@ func getCompileUpdateCheckFilePath() string {
}

func getCompileUpdateCheckFilePathImpl() string {
return getLastCheckFilePathFor(compileUpdateCheckFileName)
return getUpdateCheckFilePathFor(compileUpdateCheckFileName, compileUpdateCheckLog)
}

func updateCompileUpdateCheckTime() {
lastCheckFile := getCompileUpdateCheckFilePath()
if lastCheckFile == "" {
return
}

timestamp := time.Now().Format(time.RFC3339)
if err := os.WriteFile(lastCheckFile, []byte(timestamp), constants.FilePermSensitive); err != nil {
compileUpdateCheckLog.Printf("Error writing compile update check time: %v", err)
}
writeUpdateCheckTime(getCompileUpdateCheckFilePath(), constants.FilePermSensitive, "compile update check", compileUpdateCheckLog)
}

func isMinorVersionBehind(currentVersion string, latestVersion string) bool {
Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/logs_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/github/gh-aw/pkg/console"
"github.com/github/gh-aw/pkg/constants"
"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/repoutil"
"github.com/github/gh-aw/pkg/workflow"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -448,7 +449,7 @@ Downloaded artifacts include (when using --artifacts all):
// to the same repository that is checked out locally.
func repoIsLocal(repo string) bool {
// Strip optional HOST/ prefix (e.g. "github.com/owner/repo" → "owner/repo")
ownerRepo, _ := normalizeRepoForAPI(repo)
ownerRepo, _ := repoutil.NormalizeRepoForAPI(repo)

// Fast path: GITHUB_REPOSITORY is always the current repo in MCP server containers.
if envRepo := os.Getenv("GITHUB_REPOSITORY"); envRepo != "" {
Expand Down
21 changes: 5 additions & 16 deletions pkg/cli/outcome_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/github/gh-aw/pkg/github"
"github.com/github/gh-aw/pkg/intent"
"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/repoutil"
"github.com/github/gh-aw/pkg/workflow"
)

Expand Down Expand Up @@ -220,21 +221,9 @@ func ComputeOutcomeSummary(reports []OutcomeReport, mapping *github.ObjectiveMap
return s
}

// normalizeRepoForAPI splits a repo string of the form "[HOST/]owner/repo" into
// the owner/repo portion and an optional host. Most callers pass plain "owner/repo",
// but GHES and Proxima installs may supply "HOST/owner/repo".
func normalizeRepoForAPI(repo string) (ownerRepo string, host string) {
parts := strings.SplitN(repo, "/", 3)
if len(parts) == 3 {
// HOST/owner/repo
return parts[1] + "/" + parts[2], parts[0]
}
return repo, ""
}

// ghAPIGet calls the GitHub REST API via gh cli and returns the parsed JSON.
func ghAPIGet(endpoint string, repo string) (map[string]any, error) {
ownerRepo, host := normalizeRepoForAPI(repo)
ownerRepo, host := repoutil.NormalizeRepoForAPI(repo)
outcomeEvalLog.Printf("gh api GET: repo=%s, endpoint=%s, host=%q", ownerRepo, endpoint, host)
args := []string{"api", fmt.Sprintf("repos/%s/%s", ownerRepo, endpoint)}
var output []byte
Expand All @@ -257,7 +246,7 @@ func ghAPIGet(endpoint string, repo string) (map[string]any, error) {

// ghAPIGetArray calls the GitHub REST API and returns a JSON array.
func ghAPIGetArray(endpoint string, repo string) ([]map[string]any, error) {
ownerRepo, host := normalizeRepoForAPI(repo)
ownerRepo, host := repoutil.NormalizeRepoForAPI(repo)
args := []string{"api", fmt.Sprintf("repos/%s/%s", ownerRepo, endpoint)}
var output []byte
var err error
Expand All @@ -278,7 +267,7 @@ func ghAPIGetArray(endpoint string, repo string) ([]map[string]any, error) {

// ghAPIGraphQL calls the GitHub GraphQL API via gh cli and returns the parsed JSON.
func ghAPIGraphQL(query string, repo string) (map[string]any, error) {
ownerRepo, host := normalizeRepoForAPI(repo)
ownerRepo, host := repoutil.NormalizeRepoForAPI(repo)
args := []string{"api", "graphql", "-f", "query=" + query}
var output []byte
var err error
Expand Down Expand Up @@ -474,7 +463,7 @@ func resolvePullRequestIntent(report OutcomeReport, repo string, resolver intent

func loadPullRequestIntentData(report OutcomeReport, repo string) (intent.PullRequestData, error) {
prNumber := report.ObjectNumber
ownerRepo, _ := normalizeRepoForAPI(repo)
ownerRepo, _ := repoutil.NormalizeRepoForAPI(repo)
owner, name, found := strings.Cut(ownerRepo, "/")
if !found || owner == "" || name == "" {
return intent.PullRequestData{}, fmt.Errorf("invalid repo for root tracing: %s", repo)
Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/outcome_eval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"

"github.com/github/gh-aw/pkg/github"
"github.com/github/gh-aw/pkg/repoutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -114,7 +115,7 @@ func TestNormalizeRepoForAPI(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ownerRepo, host := normalizeRepoForAPI(tt.repo)
ownerRepo, host := repoutil.NormalizeRepoForAPI(tt.repo)
assert.Equal(t, tt.wantOwnerRepo, ownerRepo, "owner/repo portion")
assert.Equal(t, tt.wantHost, host, "host portion")
})
Expand Down
59 changes: 4 additions & 55 deletions pkg/cli/update_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"

"github.com/github/gh-aw/pkg/constants"
"golang.org/x/mod/semver"

"github.com/cli/go-gh/v2/pkg/api"
"github.com/github/gh-aw/pkg/console"
"github.com/github/gh-aw/pkg/constants"
"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/workflow"
)
Expand Down Expand Up @@ -64,31 +63,7 @@ func shouldCheckForUpdate(noCheckUpdate bool) bool {

// Check if we've already checked recently
lastCheckFile := getLastCheckFilePath()
if lastCheckFile == "" {
updateCheckLog.Print("Could not determine last check file path")
return false
}

// Read last check time
data, err := os.ReadFile(lastCheckFile)
if err != nil {
if !os.IsNotExist(err) {
updateCheckLog.Printf("Error reading last check file: %v", err)
}
// File doesn't exist or error reading - perform check
return true
}

lastCheck, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data)))
if err != nil {
updateCheckLog.Printf("Error parsing last check time: %v", err)
// Invalid timestamp - perform check
return true
}

// Check if enough time has passed
if time.Since(lastCheck) < checkInterval {
updateCheckLog.Printf("Last check was %v ago, skipping", time.Since(lastCheck))
if !shouldRunUpdateCheckAtPath(lastCheckFile, checkInterval, "update check", updateCheckLog) {
return false
}

Expand All @@ -115,38 +90,12 @@ func getLastCheckFilePath() string {

// getLastCheckFilePathImpl is the actual implementation
func getLastCheckFilePathImpl() string {
return getLastCheckFilePathFor(lastCheckFileName)
}

func getLastCheckFilePathFor(fileName string) string {
// Use OS temp directory for cross-platform compatibility
tmpDir := os.TempDir()
if tmpDir == "" {
updateCheckLog.Print("Could not determine temp directory")
return ""
}

// Create a gh-aw subdirectory in temp
ghAwTmpDir := filepath.Join(tmpDir, "gh-aw")
if err := os.MkdirAll(ghAwTmpDir, constants.DirPermPublic); err != nil {
updateCheckLog.Printf("Error creating gh-aw temp directory: %v", err)
return ""
}

return filepath.Join(ghAwTmpDir, fileName)
return getUpdateCheckFilePathFor(lastCheckFileName, updateCheckLog)
}

// updateLastCheckTime updates the timestamp of the last update check
func updateLastCheckTime() {
lastCheckFile := getLastCheckFilePath()
if lastCheckFile == "" {
return
}

timestamp := time.Now().Format(time.RFC3339)
if err := os.WriteFile(lastCheckFile, []byte(timestamp), constants.FilePermPublic); err != nil {
updateCheckLog.Printf("Error writing last check time: %v", err)
}
writeUpdateCheckTime(getLastCheckFilePath(), constants.FilePermPublic, "update check", updateCheckLog)
}

// checkForUpdates checks if a newer version of gh-aw is available
Expand Down
67 changes: 67 additions & 0 deletions pkg/cli/update_check_state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package cli

import (
"os"
"path/filepath"
"strings"
"time"

"github.com/github/gh-aw/pkg/constants"
"github.com/github/gh-aw/pkg/logger"
)

func getUpdateCheckFilePathFor(fileName string, log *logger.Logger) string {
tmpDir := os.TempDir()
if tmpDir == "" {
log.Print("Could not determine temp directory")
return ""
}

ghAwTmpDir := filepath.Join(tmpDir, "gh-aw")
if err := os.MkdirAll(ghAwTmpDir, constants.DirPermPublic); err != nil {
log.Printf("Error creating gh-aw temp directory: %v", err)
return ""
}

return filepath.Join(ghAwTmpDir, fileName)
}

func shouldRunUpdateCheckAtPath(lastCheckFile string, interval time.Duration, label string, log *logger.Logger) bool {
if lastCheckFile == "" {
log.Printf("Could not determine %s file path", label)
return false
}

data, err := os.ReadFile(lastCheckFile)
if err != nil {
if !os.IsNotExist(err) {
log.Printf("Error reading %s file: %v", label, err)
}
return true
}

lastCheck, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data)))
if err != nil {
log.Printf("Error parsing %s time: %v", label, err)
return true
}

elapsed := time.Since(lastCheck)
if elapsed < interval {
log.Printf("Last %s was %v ago, skipping", label, elapsed)
return false
}

return true
}

func writeUpdateCheckTime(path string, perm os.FileMode, label string, log *logger.Logger) {
if path == "" {
return
}

timestamp := time.Now().Format(time.RFC3339)
if err := os.WriteFile(path, []byte(timestamp), perm); err != nil {
log.Printf("Error writing %s time: %v", label, err)
}
}
3 changes: 2 additions & 1 deletion pkg/repoutil/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ The `repoutil` package provides utility functions for working with GitHub reposi

## Overview

This package offers a single focused helper for parsing and validating `owner/repo` repository slug strings, which are used throughout the codebase wherever GitHub repositories are referenced.
This package offers focused helpers for parsing and normalizing repository identifiers, which are used throughout the codebase wherever GitHub repositories are referenced.

## Public API

Expand All @@ -13,6 +13,7 @@ This package offers a single focused helper for parsing and validating `owner/re
| Function | Signature | Description |
|----------|-----------|-------------|
| `SplitRepoSlug` | `func(slug string) (owner, repo string, err error)` | Splits a repository slug of the form `owner/repo` into its two components; returns an error when the slug does not contain exactly one `/` or when either component is empty |
| `NormalizeRepoForAPI` | `func(repo string) (ownerRepo string, host string)` | Splits a repository string of the form `[HOST/]owner/repo` into the `owner/repo` portion and an optional host name for GHES/Proxima API calls |

## Usage Examples

Expand Down
11 changes: 11 additions & 0 deletions pkg/repoutil/repoutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,14 @@ func SplitRepoSlug(slug string) (owner, repo string, err error) {
repoutilLog.Printf("Split result: owner=%s, repo=%s", parts[0], parts[1])
return parts[0], parts[1], nil
}

// NormalizeRepoForAPI splits a repo string of the form "[HOST/]owner/repo" into
// the owner/repo portion and an optional host. Most callers pass plain
// "owner/repo", but GHES and Proxima installs may supply "HOST/owner/repo".
func NormalizeRepoForAPI(repo string) (ownerRepo string, host string) {
parts := strings.SplitN(repo, "/", 3)
if len(parts) == 3 {
return parts[1] + "/" + parts[2], parts[0]
}
return repo, ""
}
21 changes: 21 additions & 0 deletions pkg/repoutil/repoutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,24 @@ func BenchmarkSplitRepoSlug_Invalid(b *testing.B) {
_, _, _ = SplitRepoSlug(slug)
}
}

func TestNormalizeRepoForAPI(t *testing.T) {
tests := []struct {
name string
repo string
wantOwnerRepo string
wantHost string
}{
{"plain owner/repo", "owner/repo", "owner/repo", ""},
{"GHES HOST/owner/repo", "myhost.com/owner/repo", "owner/repo", "myhost.com"},
{"github.com/owner/repo treated as host prefix", "github.com/owner/repo", "owner/repo", "github.com"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ownerRepo, host := NormalizeRepoForAPI(tt.repo)
assert.Equal(t, tt.wantOwnerRepo, ownerRepo, "owner/repo portion")
assert.Equal(t, tt.wantHost, host, "host portion")
})
}
}
Loading
Loading