77 "crypto/tls"
88 "crypto/x509"
99 "fmt"
10- "net"
1110 "net/http"
1211 "net/url"
1312 "reflect"
@@ -24,6 +23,33 @@ import (
2423 "github.com/go-git/go-git/v5/utils/ioutil"
2524)
2625
26+ type contextKey int
27+
28+ const initialRequestKey contextKey = iota
29+
30+ // RedirectPolicy controls how the HTTP transport follows redirects.
31+ //
32+ // The values mirror Git's http.followRedirects config:
33+ // "true" follows redirects for all requests, "false" treats redirects as
34+ // errors, and "initial" follows redirects only for the initial
35+ // /info/refs discovery request. The zero value defaults to "initial".
36+ type RedirectPolicy string
37+
38+ const (
39+ FollowInitialRedirects RedirectPolicy = "initial"
40+ FollowRedirects RedirectPolicy = "true"
41+ NoFollowRedirects RedirectPolicy = "false"
42+ )
43+
44+ func withInitialRequest (ctx context.Context ) context.Context {
45+ return context .WithValue (ctx , initialRequestKey , true )
46+ }
47+
48+ func isInitialRequest (req * http.Request ) bool {
49+ v , _ := req .Context ().Value (initialRequestKey ).(bool )
50+ return v
51+ }
52+
2753// it requires a bytes.Buffer, because we need to know the length
2854func applyHeadersToRequest (req * http.Request , content * bytes.Buffer , host string , requestType string ) {
2955 req .Header .Add ("User-Agent" , capability .DefaultAgent ())
@@ -54,12 +80,15 @@ func advertisedReferences(ctx context.Context, s *session, serviceName string) (
5480
5581 s .ApplyAuthToRequest (req )
5682 applyHeadersToRequest (req , nil , s .endpoint .Host , serviceName )
57- res , err := s .client .Do (req .WithContext (ctx ))
83+ res , err := s .client .Do (req .WithContext (withInitialRequest ( ctx ) ))
5884 if err != nil {
5985 return nil , err
6086 }
6187
62- s .ModifyEndpointIfRedirect (res )
88+ if err := s .ModifyEndpointIfRedirect (res ); err != nil {
89+ _ = res .Body .Close ()
90+ return nil , err
91+ }
6392 defer ioutil .CheckClose (res .Body , & err )
6493
6594 if err = NewErr (res ); err != nil {
@@ -96,6 +125,7 @@ type client struct {
96125 client * http.Client
97126 transports * lru.Cache
98127 mutex sync.RWMutex
128+ follow RedirectPolicy
99129}
100130
101131// ClientOptions holds user configurable options for the client.
@@ -106,6 +136,11 @@ type ClientOptions struct {
106136 // size, will result in the least recently used transport getting deleted
107137 // before the provided transport is added to the cache.
108138 CacheMaxEntries int
139+
140+ // RedirectPolicy controls redirect handling. Supported values are
141+ // "true", "false", and "initial". The zero value defaults to
142+ // "initial", matching Git's http.followRedirects default.
143+ RedirectPolicy RedirectPolicy
109144}
110145
111146var (
@@ -150,12 +185,16 @@ func NewClientWithOptions(c *http.Client, opts *ClientOptions) transport.Transpo
150185 }
151186 cl := & client {
152187 client : c ,
188+ follow : FollowInitialRedirects ,
153189 }
154190
155191 if opts != nil {
156192 if opts .CacheMaxEntries > 0 {
157193 cl .transports = lru .New (opts .CacheMaxEntries )
158194 }
195+ if opts .RedirectPolicy != "" {
196+ cl .follow = opts .RedirectPolicy
197+ }
159198 }
160199 return cl
161200}
@@ -289,14 +328,9 @@ func newSession(c *client, ep *transport.Endpoint, auth transport.AuthMethod) (*
289328 }
290329 }
291330
292- httpClient = & http.Client {
293- Transport : transport ,
294- CheckRedirect : c .client .CheckRedirect ,
295- Jar : c .client .Jar ,
296- Timeout : c .client .Timeout ,
297- }
331+ httpClient = c .cloneHTTPClient (transport )
298332 } else {
299- httpClient = c .client
333+ httpClient = c .cloneHTTPClient ( c . client . Transport )
300334 }
301335
302336 s := & session {
@@ -324,30 +358,122 @@ func (s *session) ApplyAuthToRequest(req *http.Request) {
324358 s .auth .SetAuth (req )
325359}
326360
327- func (s * session ) ModifyEndpointIfRedirect (res * http.Response ) {
361+ func (s * session ) ModifyEndpointIfRedirect (res * http.Response ) error {
328362 if res .Request == nil {
329- return
363+ return nil
364+ }
365+ if s .endpoint == nil {
366+ return fmt .Errorf ("http redirect: nil endpoint" )
330367 }
331368
332369 r := res .Request
333370 if ! strings .HasSuffix (r .URL .Path , infoRefsPath ) {
334- return
371+ return fmt .Errorf ("http redirect: target %q does not end with %s" , r .URL .Path , infoRefsPath )
372+ }
373+ if r .URL .Scheme != "http" && r .URL .Scheme != "https" {
374+ return fmt .Errorf ("http redirect: unsupported scheme %q" , r .URL .Scheme )
375+ }
376+ if r .URL .Scheme != s .endpoint .Protocol &&
377+ ! (s .endpoint .Protocol == "http" && r .URL .Scheme == "https" ) {
378+ return fmt .Errorf ("http redirect: changes scheme from %q to %q" , s .endpoint .Protocol , r .URL .Scheme )
335379 }
336380
337- h , p , err := net .SplitHostPort (r .URL .Host )
381+ host := endpointHost (r .URL .Hostname ())
382+ port , err := endpointPort (r .URL .Port ())
338383 if err != nil {
339- h = r . URL . Host
384+ return err
340385 }
341- if p != "" {
342- port , err := strconv . Atoi ( p )
343- if err == nil {
344- s .endpoint .Port = port
345- }
386+
387+ if host != s . endpoint . Host || effectivePort ( r . URL . Scheme , port ) != effectivePort ( s . endpoint . Protocol , s . endpoint . Port ) {
388+ s . endpoint . User = ""
389+ s .endpoint .Password = ""
390+ s . auth = nil
346391 }
347- s .endpoint .Host = h
392+
393+ s .endpoint .Host = host
394+ s .endpoint .Port = port
348395
349396 s .endpoint .Protocol = r .URL .Scheme
350397 s .endpoint .Path = r .URL .Path [:len (r .URL .Path )- len (infoRefsPath )]
398+ return nil
399+ }
400+
401+ func endpointHost (host string ) string {
402+ if strings .Contains (host , ":" ) {
403+ return "[" + host + "]"
404+ }
405+
406+ return host
407+ }
408+
409+ func endpointPort (port string ) (int , error ) {
410+ if port == "" {
411+ return 0 , nil
412+ }
413+
414+ parsed , err := strconv .Atoi (port )
415+ if err != nil {
416+ return 0 , fmt .Errorf ("http redirect: invalid port %q" , port )
417+ }
418+
419+ return parsed , nil
420+ }
421+
422+ func effectivePort (scheme string , port int ) int {
423+ if port != 0 {
424+ return port
425+ }
426+
427+ switch strings .ToLower (scheme ) {
428+ case "http" :
429+ return 80
430+ case "https" :
431+ return 443
432+ default :
433+ return 0
434+ }
435+ }
436+
437+ func (c * client ) cloneHTTPClient (transport http.RoundTripper ) * http.Client {
438+ return & http.Client {
439+ Transport : transport ,
440+ CheckRedirect : wrapCheckRedirect (c .follow , c .client .CheckRedirect ),
441+ Jar : c .client .Jar ,
442+ Timeout : c .client .Timeout ,
443+ }
444+ }
445+
446+ func wrapCheckRedirect (policy RedirectPolicy , next func (* http.Request , []* http.Request ) error ) func (* http.Request , []* http.Request ) error {
447+ return func (req * http.Request , via []* http.Request ) error {
448+ if err := checkRedirect (req , via , policy ); err != nil {
449+ return err
450+ }
451+ if next != nil {
452+ return next (req , via )
453+ }
454+ return nil
455+ }
456+ }
457+
458+ func checkRedirect (req * http.Request , via []* http.Request , policy RedirectPolicy ) error {
459+ switch policy {
460+ case FollowRedirects :
461+ case NoFollowRedirects :
462+ return fmt .Errorf ("http redirect: redirects disabled to %s" , req .URL )
463+ case "" , FollowInitialRedirects :
464+ if ! isInitialRequest (req ) {
465+ return fmt .Errorf ("http redirect: redirect on non-initial request to %s" , req .URL )
466+ }
467+ default :
468+ return fmt .Errorf ("http redirect: invalid redirect policy %q" , policy )
469+ }
470+ if req .URL .Scheme != "http" && req .URL .Scheme != "https" {
471+ return fmt .Errorf ("http redirect: unsupported scheme %q" , req .URL .Scheme )
472+ }
473+ if len (via ) >= 10 {
474+ return fmt .Errorf ("http redirect: too many redirects" )
475+ }
476+ return nil
351477}
352478
353479func (* session ) Close () error {
0 commit comments