diff --git a/internal/resolver/dns/dns_resolver.go b/internal/resolver/dns/dns_resolver.go index abab35e250ef..f3f52a59a863 100644 --- a/internal/resolver/dns/dns_resolver.go +++ b/internal/resolver/dns/dns_resolver.go @@ -41,18 +41,24 @@ import ( "google.golang.org/grpc/serviceconfig" ) -// EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB -// addresses from SRV records. Must not be changed after init time. -var EnableSRVLookups = false - -// ResolvingTimeout specifies the maximum duration for a DNS resolution request. -// If the timeout expires before a response is received, the request will be canceled. -// -// It is recommended to set this value at application startup. Avoid modifying this variable -// after initialization as it's not thread-safe for concurrent modification. -var ResolvingTimeout = 30 * time.Second - -var logger = grpclog.Component("dns") +var ( + // EnableSRVLookups controls whether the DNS resolver attempts to fetch gRPCLB + // addresses from SRV records. Must not be changed after init time. + EnableSRVLookups = false + + // MinResolutionInterval is the minimum interval at which re-resolutions are + // allowed. This helps to prevent excessive re-resolution. + MinResolutionInterval = 30 * time.Second + + // ResolvingTimeout specifies the maximum duration for a DNS resolution request. + // If the timeout expires before a response is received, the request will be canceled. + // + // It is recommended to set this value at application startup. Avoid modifying this variable + // after initialization as it's not thread-safe for concurrent modification. + ResolvingTimeout = 30 * time.Second + + logger = grpclog.Component("dns") +) func init() { resolver.Register(NewBuilder()) @@ -208,7 +214,7 @@ func (d *dnsResolver) watcher() { // Success resolving, wait for the next ResolveNow. However, also wait 30 // seconds at the very least to prevent constantly re-resolving. backoffIndex = 1 - waitTime = internal.MinResolutionRate + waitTime = MinResolutionInterval select { case <-d.ctx.Done(): return diff --git a/internal/resolver/dns/dns_resolver_test.go b/internal/resolver/dns/dns_resolver_test.go index 498cf5b83e27..95fd4b5eeeb5 100644 --- a/internal/resolver/dns/dns_resolver_test.go +++ b/internal/resolver/dns/dns_resolver_test.go @@ -68,11 +68,11 @@ func overrideNetResolver(t *testing.T, r *testNetResolver) { t.Cleanup(func() { dnsinternal.NewNetResolver = origNetResolver }) } -// Override the DNS Min Res Rate used by the resolver. -func overrideResolutionRate(t *testing.T, d time.Duration) { - origMinResRate := dnsinternal.MinResolutionRate - dnsinternal.MinResolutionRate = d - t.Cleanup(func() { dnsinternal.MinResolutionRate = origMinResRate }) +// Override the DNS minimum resolution interval used by the resolver. +func overrideResolutionInterval(t *testing.T, d time.Duration) { + origMinResInterval := dns.MinResolutionInterval + dnspublic.SetMinResolutionInterval(d) + t.Cleanup(func() { dnspublic.SetMinResolutionInterval(origMinResInterval) }) } // Override the timer used by the DNS resolver to fire after a duration of d. @@ -636,7 +636,7 @@ func (s) TestDNSResolver_ExponentialBackoff(t *testing.T) { func (s) TestDNSResolver_ResolveNow(t *testing.T) { const target = "foo.bar.com" - overrideResolutionRate(t, 0) + overrideResolutionInterval(t, 0) overrideTimeAfterFunc(t, 0) tr := &testNetResolver{ hostLookupTable: map[string][]string{ @@ -739,7 +739,7 @@ func (s) TestIPResolver(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - overrideResolutionRate(t, 0) + overrideResolutionInterval(t, 0) overrideTimeAfterFunc(t, 2*defaultTestTimeout) r, stateCh, _ := buildResolverWithTestClientConn(t, test.target) @@ -1258,3 +1258,35 @@ func (s) TestResolveTimeout(t *testing.T) { } } } + +// Test verifies that changing [MinResolutionInterval] variable correctly effects +// the resolution behaviour +func (s) TestMinResolutionInterval(t *testing.T) { + const target = "foo.bar.com" + + overrideResolutionInterval(t, 1*time.Millisecond) + tr := &testNetResolver{ + hostLookupTable: map[string][]string{ + "foo.bar.com": {"1.2.3.4", "5.6.7.8"}, + }, + txtLookupTable: map[string][]string{ + "_grpc_config.foo.bar.com": txtRecordServiceConfig(txtRecordGood), + }, + } + overrideNetResolver(t, tr) + + r, stateCh, _ := buildResolverWithTestClientConn(t, target) + + wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}} + wantSC := scJSON + + for i := 0; i < 5; i++ { + // set context timeout slightly higher than the min resolution interval to make sure resolutions + // happen successfully + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + verifyUpdateFromResolver(ctx, t, stateCh, wantAddrs, nil, wantSC) + r.ResolveNow(resolver.ResolveNowOptions{}) + } +} diff --git a/internal/resolver/dns/internal/internal.go b/internal/resolver/dns/internal/internal.go index c7fc557d00c1..9cb3b5c58a6a 100644 --- a/internal/resolver/dns/internal/internal.go +++ b/internal/resolver/dns/internal/internal.go @@ -50,10 +50,6 @@ var ( // The following vars are overridden from tests. var ( - // MinResolutionRate is the minimum rate at which re-resolutions are - // allowed. This helps to prevent excessive re-resolution. - MinResolutionRate = 30 * time.Second - // TimeAfterFunc is used by the DNS resolver to wait for the given duration // to elapse. In non-test code, this is implemented by time.After. In test // code, this can be used to control the amount of time the resolver is diff --git a/resolver/dns/dns_resolver.go b/resolver/dns/dns_resolver.go index b54a3a3225d4..f1320b9f44c1 100644 --- a/resolver/dns/dns_resolver.go +++ b/resolver/dns/dns_resolver.go @@ -52,3 +52,9 @@ func SetResolvingTimeout(timeout time.Duration) { func NewBuilder() resolver.Builder { return dns.NewBuilder() } + +// SetMinResolutionInterval sets the default minimum interval at which DNS re-resolutions are +// allowed. This helps to prevent excessive re-resolution. +func SetMinResolutionInterval(d time.Duration) { + dns.MinResolutionInterval = d +}