Skip to content

Commit 780cbe1

Browse files
committed
feat: disconnect upstream refreshing
Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
1 parent f90a36c commit 780cbe1

File tree

2 files changed

+184
-2
lines changed

2 files changed

+184
-2
lines changed

server/refreshhandlers.go

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"time"
1111

1212
"github.com/dexidp/dex/connector"
13+
"github.com/dexidp/dex/pkg/featureflags"
1314
"github.com/dexidp/dex/server/internal"
1415
"github.com/dexidp/dex/storage"
1516
)
@@ -107,6 +108,10 @@ func newInternalServerError() *refreshError {
107108
return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
108109
}
109110

111+
func newUpstreamRefreshError(desc string) *refreshError {
112+
return &refreshError{msg: errInvalidGrant, desc: desc, code: http.StatusBadGateway}
113+
}
114+
110115
func newBadRequestError(desc string) *refreshError {
111116
return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest}
112117
}
@@ -271,7 +276,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext,
271276
newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident)
272277
if err != nil {
273278
s.logger.ErrorContext(ctx, "failed to refresh identity", "err", err)
274-
return ident, newInternalServerError()
279+
return ident, newUpstreamRefreshError(err.Error())
275280
}
276281

277282
return newIdent, nil
@@ -327,6 +332,20 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
327332
Groups: rCtx.storageToken.Claims.Groups,
328333
}
329334

335+
// Pre-fetch UserIdentity outside the storage transaction to avoid deadlocks with
336+
// storage backends that use a single lock (e.g., memory storage).
337+
// This is used as a fallback when the upstream connector refresh fails.
338+
var cachedIdentity *storage.UserIdentity
339+
if featureflags.SessionsEnabled.Enabled() {
340+
ui, err := s.storage.GetUserIdentity(ctx, rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID)
341+
if err != nil {
342+
s.logger.WarnContext(ctx, "failed to pre-fetch user identity for upstream refresh fallback",
343+
"user_id", rCtx.storageToken.Claims.UserID, "connector_id", rCtx.storageToken.ConnectorID, "err", err)
344+
} else {
345+
cachedIdentity = &ui
346+
}
347+
}
348+
330349
refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
331350
rotationEnabled := s.refreshTokenPolicy.RotationEnabled()
332351
reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed)
@@ -373,7 +392,26 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
373392
// Dex will call the connector's Refresh method only once if request is not in reuse interval.
374393
ident, rerr = s.refreshWithConnector(ctx, rCtx, ident)
375394
if rerr != nil {
376-
return old, rerr
395+
// When sessions are enabled and the upstream provider fails (e.g., expired upstream
396+
// refresh token), fall back to claims stored in UserIdentity instead of failing the
397+
// entire refresh. This matches the behavior of other identity brokers (Keycloak, Auth0)
398+
// that do not contact the upstream on every downstream refresh.
399+
if cachedIdentity != nil {
400+
s.logger.WarnContext(ctx, "upstream refresh failed, using cached identity from last login",
401+
"err", rerr, "user_id", cachedIdentity.Claims.UserID, "connector_id", rCtx.storageToken.ConnectorID)
402+
ident = connector.Identity{
403+
UserID: cachedIdentity.Claims.UserID,
404+
Username: cachedIdentity.Claims.Username,
405+
PreferredUsername: cachedIdentity.Claims.PreferredUsername,
406+
Email: cachedIdentity.Claims.Email,
407+
EmailVerified: cachedIdentity.Claims.EmailVerified,
408+
Groups: cachedIdentity.Claims.Groups,
409+
}
410+
rerr = nil
411+
}
412+
if rerr != nil {
413+
return old, rerr
414+
}
377415
}
378416

379417
// Update the claims of the refresh token.

server/refreshhandlers_test.go

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package server
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/base64"
67
"encoding/json"
8+
"errors"
79
"log/slog"
810
"net/http"
911
"net/http/httptest"
@@ -16,6 +18,7 @@ import (
1618
"github.com/stretchr/testify/assert"
1719
"github.com/stretchr/testify/require"
1820

21+
"github.com/dexidp/dex/connector"
1922
"github.com/dexidp/dex/server/internal"
2023
"github.com/dexidp/dex/storage"
2124
)
@@ -347,6 +350,147 @@ func TestRefreshTokenAuthTime(t *testing.T) {
347350
}
348351
}
349352

353+
// failingRefreshConnector implements connector.CallbackConnector and connector.RefreshConnector
354+
// but always returns an error on Refresh, simulating an upstream provider failure.
355+
type failingRefreshConnector struct {
356+
identity connector.Identity
357+
}
358+
359+
func (f *failingRefreshConnector) LoginURL(_ connector.Scopes, callbackURL, state string) (string, []byte, error) {
360+
u, _ := url.Parse(callbackURL)
361+
v := u.Query()
362+
v.Set("state", state)
363+
u.RawQuery = v.Encode()
364+
return u.String(), nil, nil
365+
}
366+
367+
func (f *failingRefreshConnector) HandleCallback(_ connector.Scopes, _ []byte, _ *http.Request) (connector.Identity, error) {
368+
return f.identity, nil
369+
}
370+
371+
func (f *failingRefreshConnector) Refresh(_ context.Context, _ connector.Scopes, _ connector.Identity) (connector.Identity, error) {
372+
return connector.Identity{}, errors.New("upstream: refresh token expired")
373+
}
374+
375+
func TestUpstreamRefreshFailureFallsBackToUserIdentity(t *testing.T) {
376+
t0 := time.Now().UTC().Round(time.Second)
377+
loginTime := t0.Add(-10 * time.Minute)
378+
379+
tests := []struct {
380+
name string
381+
sessionsEnabled bool
382+
createUserIdentity bool
383+
wantOK bool
384+
}{
385+
{
386+
name: "sessions enabled with user identity - fallback succeeds",
387+
sessionsEnabled: true,
388+
createUserIdentity: true,
389+
wantOK: true,
390+
},
391+
{
392+
name: "sessions enabled without user identity - fallback fails",
393+
sessionsEnabled: true,
394+
createUserIdentity: false,
395+
wantOK: false,
396+
},
397+
{
398+
name: "sessions disabled - no fallback, error returned",
399+
sessionsEnabled: false,
400+
createUserIdentity: false,
401+
wantOK: false,
402+
},
403+
}
404+
405+
for _, tc := range tests {
406+
t.Run(tc.name, func(t *testing.T) {
407+
setSessionsEnabled(t, tc.sessionsEnabled)
408+
409+
httpServer, s := newTestServer(t, func(c *Config) {
410+
c.Now = func() time.Time { return t0 }
411+
})
412+
defer httpServer.Close()
413+
414+
if tc.sessionsEnabled {
415+
s.sessionConfig = &SessionConfig{
416+
CookieName: "dex_session",
417+
AbsoluteLifetime: 24 * time.Hour,
418+
}
419+
}
420+
421+
mockRefreshTokenTestStorage(t, s.storage, false)
422+
423+
// Replace the connector with one that always fails on Refresh.
424+
// ResourceVersion must match the storage connector (empty by default in
425+
// mockRefreshTokenTestStorage) to prevent getConnector from re-opening it.
426+
s.mu.Lock()
427+
s.connectors["test"] = Connector{
428+
Connector: &failingRefreshConnector{
429+
identity: connector.Identity{
430+
UserID: "0-385-28089-0",
431+
Username: "Kilgore Trout",
432+
Email: "kilgore@kilgore.trout",
433+
},
434+
},
435+
}
436+
s.mu.Unlock()
437+
438+
if tc.createUserIdentity {
439+
err := s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{
440+
UserID: "1",
441+
ConnectorID: "test",
442+
Claims: storage.Claims{
443+
UserID: "1",
444+
Username: "jane",
445+
Email: "jane.doe@example.com",
446+
EmailVerified: true,
447+
Groups: []string{"a", "b"},
448+
},
449+
CreatedAt: loginTime,
450+
LastLogin: loginTime,
451+
})
452+
require.NoError(t, err)
453+
}
454+
455+
u, err := url.Parse(s.issuerURL.String())
456+
require.NoError(t, err)
457+
458+
tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"})
459+
require.NoError(t, err)
460+
461+
u.Path = path.Join(u.Path, "/token")
462+
v := url.Values{}
463+
v.Add("grant_type", "refresh_token")
464+
v.Add("refresh_token", tokenData)
465+
466+
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode()))
467+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
468+
req.SetBasicAuth("test", "barfoo")
469+
470+
rr := httptest.NewRecorder()
471+
s.ServeHTTP(rr, req)
472+
473+
if tc.wantOK {
474+
require.Equal(t, http.StatusOK, rr.Code, "body: %s", rr.Body.String())
475+
476+
var resp struct {
477+
IDToken string `json:"id_token"`
478+
}
479+
err = json.Unmarshal(rr.Body.Bytes(), &resp)
480+
require.NoError(t, err)
481+
482+
// Verify the returned claims match UserIdentity, not the connector.
483+
claims := decodeJWTClaims(t, resp.IDToken)
484+
assert.Equal(t, "jane.doe@example.com", claims["email"])
485+
assert.Equal(t, "jane", claims["name"])
486+
} else {
487+
require.NotEqual(t, http.StatusOK, rr.Code,
488+
"expected error when upstream fails without fallback")
489+
}
490+
})
491+
}
492+
}
493+
350494
func TestRefreshTokenPolicy(t *testing.T) {
351495
lastTime := time.Now()
352496
l := slog.New(slog.DiscardHandler)

0 commit comments

Comments
 (0)