Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions internal/oauth/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ type Client interface {
Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error)
Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error)
Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error)
Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error)
}

type httpClient struct {
Expand Down Expand Up @@ -145,7 +146,7 @@ func (c *httpClient) Discover(ctx context.Context, endpoint string) (*OIDCConfig
func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) {
endpoint = strings.TrimRight(endpoint, "/")

// Discover OIDC configuration
// Discover OIDC configuration - caches on first call
config, err := c.Discover(ctx, endpoint)
if err != nil {
return nil, errors.Wrap(err, "OIDC discovery failed")
Expand Down Expand Up @@ -208,7 +209,7 @@ func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string
func (c *httpClient) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) {
endpoint = strings.TrimRight(endpoint, "/")

// Discover OIDC configuration (should be cached from Start)
// Discover OIDC configuration - caches on first call
config, err := c.Discover(ctx, endpoint)
if err != nil {
return nil, errors.Wrap(err, "OIDC discovery failed")
Expand Down Expand Up @@ -307,3 +308,55 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str

return &tokenResp, nil
}

// Refresh exchanges a refresh token for a new access token.
func (c *httpClient) Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error) {
endpoint = strings.TrimRight(endpoint, "/")

config, err := c.Discover(ctx, endpoint)
if err != nil {
return nil, errors.Wrap(err, "OIDC discovery failed")
}

if config.TokenEndpoint == "" {
return nil, errors.New("token endpoint not found in OIDC configuration")
}

data := url.Values{}
data.Set("client_id", c.clientID)
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)

req, err := http.NewRequestWithContext(ctx, "POST", config.TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, errors.Wrap(err, "creating refresh token request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")

resp, err := c.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "refresh token request failed")
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "reading refresh token response")
}

if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" {
return nil, errors.Newf("refresh token failed: %s: %s", errResp.Error, errResp.ErrorDescription)
}
return nil, errors.Newf("refresh token failed with status %d: %s", resp.StatusCode, string(body))
}

var tokenResp TokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, errors.Wrap(err, "parsing refresh token response")
}

return &tokenResp, nil
}
41 changes: 41 additions & 0 deletions internal/oauth/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,44 @@ func TestPoll_ContextCancellation(t *testing.T) {
t.Errorf("error = %v, want context.Canceled or wrapped context canceled error", err)
}
}

func TestRefresh_Success(t *testing.T) {
server := newTestServer(t, testServerOptions{
handlers: map[string]http.HandlerFunc{
testTokenPath: func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if got := r.FormValue("grant_type"); got != "refresh_token" {
t.Errorf("grant_type = %q, want %q", got, "refresh_token")
}
if got := r.FormValue("refresh_token"); got != "test-refresh-token" {
t.Errorf("refresh_token = %q, want %q", got, "test-refresh-token")
}

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(TokenResponse{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
})
},
},
})
defer server.Close()

client := NewClient(DefaultClientID)
resp, err := client.Refresh(context.Background(), server.URL, "test-refresh-token")
if err != nil {
t.Fatalf("Refresh() error = %v", err)
}

if resp.AccessToken != "new-access-token" {
t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "new-access-token")
}
if resp.RefreshToken != "new-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", resp.RefreshToken, "new-refresh-token")
}
}
Loading