@@ -2,8 +2,10 @@ package server
22
33import (
44 "bytes"
5+ "context"
56 "encoding/base64"
67 "encoding/json"
8+ "errors"
79 "log/slog"
810 "net/http"
911 "net/http/httptest"
@@ -16,6 +18,7 @@ import (
1618 "github.com/stretchr/testify/assert"
1719 "github.com/stretchr/testify/require"
1820
21+ "github.com/dexidp/dex/connector"
1922 "github.com/dexidp/dex/server/internal"
2023 "github.com/dexidp/dex/storage"
2124)
@@ -347,6 +350,147 @@ func TestRefreshTokenAuthTime(t *testing.T) {
347350 }
348351}
349352
353+ // failingRefreshConnector implements connector.CallbackConnector and connector.RefreshConnector
354+ // but always returns an error on Refresh, simulating an upstream provider failure.
355+ type failingRefreshConnector struct {
356+ identity connector.Identity
357+ }
358+
359+ func (f * failingRefreshConnector ) LoginURL (_ connector.Scopes , callbackURL , state string ) (string , []byte , error ) {
360+ u , _ := url .Parse (callbackURL )
361+ v := u .Query ()
362+ v .Set ("state" , state )
363+ u .RawQuery = v .Encode ()
364+ return u .String (), nil , nil
365+ }
366+
367+ func (f * failingRefreshConnector ) HandleCallback (_ connector.Scopes , _ []byte , _ * http.Request ) (connector.Identity , error ) {
368+ return f .identity , nil
369+ }
370+
371+ func (f * failingRefreshConnector ) Refresh (_ context.Context , _ connector.Scopes , _ connector.Identity ) (connector.Identity , error ) {
372+ return connector.Identity {}, errors .New ("upstream: refresh token expired" )
373+ }
374+
375+ func TestUpstreamRefreshFailureFallsBackToUserIdentity (t * testing.T ) {
376+ t0 := time .Now ().UTC ().Round (time .Second )
377+ loginTime := t0 .Add (- 10 * time .Minute )
378+
379+ tests := []struct {
380+ name string
381+ sessionsEnabled bool
382+ createUserIdentity bool
383+ wantOK bool
384+ }{
385+ {
386+ name : "sessions enabled with user identity - fallback succeeds" ,
387+ sessionsEnabled : true ,
388+ createUserIdentity : true ,
389+ wantOK : true ,
390+ },
391+ {
392+ name : "sessions enabled without user identity - fallback fails" ,
393+ sessionsEnabled : true ,
394+ createUserIdentity : false ,
395+ wantOK : false ,
396+ },
397+ {
398+ name : "sessions disabled - no fallback, error returned" ,
399+ sessionsEnabled : false ,
400+ createUserIdentity : false ,
401+ wantOK : false ,
402+ },
403+ }
404+
405+ for _ , tc := range tests {
406+ t .Run (tc .name , func (t * testing.T ) {
407+ setSessionsEnabled (t , tc .sessionsEnabled )
408+
409+ httpServer , s := newTestServer (t , func (c * Config ) {
410+ c .Now = func () time.Time { return t0 }
411+ })
412+ defer httpServer .Close ()
413+
414+ if tc .sessionsEnabled {
415+ s .sessionConfig = & SessionConfig {
416+ CookieName : "dex_session" ,
417+ AbsoluteLifetime : 24 * time .Hour ,
418+ }
419+ }
420+
421+ mockRefreshTokenTestStorage (t , s .storage , false )
422+
423+ // Replace the connector with one that always fails on Refresh.
424+ // ResourceVersion must match the storage connector (empty by default in
425+ // mockRefreshTokenTestStorage) to prevent getConnector from re-opening it.
426+ s .mu .Lock ()
427+ s .connectors ["test" ] = Connector {
428+ Connector : & failingRefreshConnector {
429+ identity : connector.Identity {
430+ UserID : "0-385-28089-0" ,
431+ Username : "Kilgore Trout" ,
432+ Email : "kilgore@kilgore.trout" ,
433+ },
434+ },
435+ }
436+ s .mu .Unlock ()
437+
438+ if tc .createUserIdentity {
439+ err := s .storage .CreateUserIdentity (t .Context (), storage.UserIdentity {
440+ UserID : "1" ,
441+ ConnectorID : "test" ,
442+ Claims : storage.Claims {
443+ UserID : "1" ,
444+ Username : "jane" ,
445+ Email : "jane.doe@example.com" ,
446+ EmailVerified : true ,
447+ Groups : []string {"a" , "b" },
448+ },
449+ CreatedAt : loginTime ,
450+ LastLogin : loginTime ,
451+ })
452+ require .NoError (t , err )
453+ }
454+
455+ u , err := url .Parse (s .issuerURL .String ())
456+ require .NoError (t , err )
457+
458+ tokenData , err := internal .Marshal (& internal.RefreshToken {RefreshId : "test" , Token : "bar" })
459+ require .NoError (t , err )
460+
461+ u .Path = path .Join (u .Path , "/token" )
462+ v := url.Values {}
463+ v .Add ("grant_type" , "refresh_token" )
464+ v .Add ("refresh_token" , tokenData )
465+
466+ req , _ := http .NewRequest ("POST" , u .String (), bytes .NewBufferString (v .Encode ()))
467+ req .Header .Set ("Content-Type" , "application/x-www-form-urlencoded; param=value" )
468+ req .SetBasicAuth ("test" , "barfoo" )
469+
470+ rr := httptest .NewRecorder ()
471+ s .ServeHTTP (rr , req )
472+
473+ if tc .wantOK {
474+ require .Equal (t , http .StatusOK , rr .Code , "body: %s" , rr .Body .String ())
475+
476+ var resp struct {
477+ IDToken string `json:"id_token"`
478+ }
479+ err = json .Unmarshal (rr .Body .Bytes (), & resp )
480+ require .NoError (t , err )
481+
482+ // Verify the returned claims match UserIdentity, not the connector.
483+ claims := decodeJWTClaims (t , resp .IDToken )
484+ assert .Equal (t , "jane.doe@example.com" , claims ["email" ])
485+ assert .Equal (t , "jane" , claims ["name" ])
486+ } else {
487+ require .NotEqual (t , http .StatusOK , rr .Code ,
488+ "expected error when upstream fails without fallback" )
489+ }
490+ })
491+ }
492+ }
493+
350494func TestRefreshTokenPolicy (t * testing.T ) {
351495 lastTime := time .Now ()
352496 l := slog .New (slog .DiscardHandler )
0 commit comments