Skip to content

Commit 8986fdd

Browse files
Add digest functions flag
1 parent 74d7b17 commit 8986fdd

File tree

11 files changed

+184
-26
lines changed

11 files changed

+184
-26
lines changed

cache/grpcproxy/grpcproxy_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ func newFixture(t *testing.T, proxy cache.Proxy, storageMode string) *fixture {
237237
}
238238
grpcServer := grpc.NewServer()
239239
go func() {
240-
err := server.ServeGRPC(listener, grpcServer, false, false, true, diskCache, logger, logger)
240+
err := server.ServeGRPC(listener, grpcServer, false, false, true, diskCache, logger, logger, hashing.DigestFunctions())
241241
if err != nil {
242242
logger.Printf(err.Error())
243243
}

config/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ go_library(
1717
"//cache/azblobproxy:go_default_library",
1818
"//cache/gcsproxy:go_default_library",
1919
"//cache/grpcproxy:go_default_library",
20+
"//cache/hashing:go_default_library",
2021
"//cache/httpproxy:go_default_library",
2122
"//cache/s3proxy:go_default_library",
23+
"//genproto/build/bazel/remote/execution/v2:go_default_library",
2224
"@com_github_azure_azure_sdk_for_go_sdk_azcore//:go_default_library",
2325
"@com_github_azure_azure_sdk_for_go_sdk_azidentity//:go_default_library",
2426
"@com_github_grpc_ecosystem_go_grpc_prometheus//:go_default_library",

config/config.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ import (
1717

1818
"github.com/buchgr/bazel-remote/v2/cache"
1919
"github.com/buchgr/bazel-remote/v2/cache/azblobproxy"
20+
"github.com/buchgr/bazel-remote/v2/cache/hashing"
2021
"github.com/buchgr/bazel-remote/v2/cache/s3proxy"
22+
pb "github.com/buchgr/bazel-remote/v2/genproto/build/bazel/remote/execution/v2"
2123

2224
"github.com/urfave/cli/v2"
2325
yaml "gopkg.in/yaml.v3"
@@ -114,6 +116,7 @@ type Config struct {
114116
LogTimezone string `yaml:"log_timezone"`
115117
MaxBlobSize int64 `yaml:"max_blob_size"`
116118
MaxProxyBlobSize int64 `yaml:"max_proxy_blob_size"`
119+
DigestFunctions []pb.DigestFunction_Value
117120

118121
// Fields that are created by combinations of the flags above.
119122
ProxyBackend cache.Proxy
@@ -125,6 +128,9 @@ type Config struct {
125128
type YamlConfig struct {
126129
Config `yaml:",inline"`
127130

131+
// Complext types that are converted later
132+
DigestFunctionNames []string `yaml:"digest_functions"`
133+
128134
// Deprecated fields, retained for backwards compatibility when
129135
// parsing config files.
130136

@@ -169,7 +175,8 @@ func newFromArgs(dir string, maxSize int, storageMode string, zstdImplementation
169175
accessLogLevel string,
170176
logTimezone string,
171177
maxBlobSize int64,
172-
maxProxyBlobSize int64) (*Config, error) {
178+
maxProxyBlobSize int64,
179+
digestFunctions []pb.DigestFunction_Value) (*Config, error) {
173180

174181
c := Config{
175182
HTTPAddress: httpAddress,
@@ -205,6 +212,7 @@ func newFromArgs(dir string, maxSize int, storageMode string, zstdImplementation
205212
LogTimezone: logTimezone,
206213
MaxBlobSize: maxBlobSize,
207214
MaxProxyBlobSize: maxProxyBlobSize,
215+
DigestFunctions: digestFunctions,
208216
}
209217

210218
err := validateConfig(&c)
@@ -234,6 +242,7 @@ func newFromYamlFile(path string) (*Config, error) {
234242

235243
func newFromYaml(data []byte) (*Config, error) {
236244
yc := YamlConfig{
245+
DigestFunctionNames: []string{"sha256"},
237246
Config: Config{
238247
StorageMode: "zstd",
239248
ZstdImplementation: "go",
@@ -270,6 +279,16 @@ func newFromYaml(data []byte) (*Config, error) {
270279
sort.Float64s(c.MetricsDurationBuckets)
271280
}
272281

282+
dfs := make([]pb.DigestFunction_Value, 0)
283+
for _, dfn := range yc.DigestFunctionNames {
284+
df := hashing.DigestFunction(dfn)
285+
if df == pb.DigestFunction_UNKNOWN {
286+
return nil, fmt.Errorf("unknown digest function %s", dfn)
287+
}
288+
dfs = append(dfs, hashing.DigestFunction(dfn))
289+
}
290+
c.DigestFunctions = dfs
291+
273292
err = validateConfig(&c)
274293
if err != nil {
275294
return nil, err
@@ -462,6 +481,15 @@ func validateConfig(c *Config) error {
462481
return errors.New("'log_timezone' must be set to either \"UTC\", \"local\" or \"none\"")
463482
}
464483

484+
if c.DigestFunctions == nil {
485+
return errors.New("at least on digest function must be supported")
486+
}
487+
for _, df := range c.DigestFunctions {
488+
if !hashing.Supported(df) {
489+
return fmt.Errorf("unsupported hashing function %s", df)
490+
}
491+
}
492+
465493
return nil
466494
}
467495

@@ -590,6 +618,17 @@ func get(ctx *cli.Context) (*Config, error) {
590618
}
591619
}
592620

621+
dfs := make([]pb.DigestFunction_Value, 0)
622+
if ctx.String("digest_functions") != "" {
623+
for _, dfn := range strings.Split(ctx.String("digest_functions"), ",") {
624+
df := hashing.DigestFunction(dfn)
625+
if df == pb.DigestFunction_UNKNOWN {
626+
return nil, fmt.Errorf("unknown digest function %s", dfn)
627+
}
628+
dfs = append(dfs, df)
629+
}
630+
}
631+
593632
return newFromArgs(
594633
ctx.String("dir"),
595634
ctx.Int("max_size"),
@@ -623,5 +662,6 @@ func get(ctx *cli.Context) (*Config, error) {
623662
ctx.String("log_timezone"),
624663
ctx.Int64("max_blob_size"),
625664
ctx.Int64("max_proxy_blob_size"),
665+
dfs,
626666
)
627667
}

config/config_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"testing"
1010
"time"
1111

12+
pb "github.com/buchgr/bazel-remote/v2/genproto/build/bazel/remote/execution/v2"
13+
1214
"github.com/google/go-cmp/cmp"
1315
)
1416

@@ -60,6 +62,7 @@ log_timezone: local
6062
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
6163
AccessLogLevel: "none",
6264
LogTimezone: "local",
65+
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
6366
}
6467

6568
if !reflect.DeepEqual(config, expectedConfig) {
@@ -103,6 +106,7 @@ gcs_proxy:
103106
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
104107
AccessLogLevel: "all",
105108
LogTimezone: "UTC",
109+
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
106110
}
107111

108112
if !cmp.Equal(config, expectedConfig) {
@@ -147,6 +151,7 @@ http_proxy:
147151
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
148152
AccessLogLevel: "all",
149153
LogTimezone: "UTC",
154+
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
150155
}
151156

152157
if !cmp.Equal(config, expectedConfig) {
@@ -224,6 +229,7 @@ s3_proxy:
224229
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
225230
AccessLogLevel: "all",
226231
LogTimezone: "UTC",
232+
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
227233
}
228234

229235
if !cmp.Equal(config, expectedConfig) {
@@ -258,6 +264,7 @@ profile_address: :7070
258264
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
259265
AccessLogLevel: "all",
260266
LogTimezone: "UTC",
267+
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
261268
}
262269

263270
if !cmp.Equal(config, expectedConfig) {
@@ -306,6 +313,7 @@ endpoint_metrics_duration_buckets: [.005, .1, 5]
306313
MetricsDurationBuckets: []float64{0.005, 0.1, 5},
307314
AccessLogLevel: "all",
308315
LogTimezone: "UTC",
316+
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
309317
}
310318

311319
if !cmp.Equal(config, expectedConfig) {
@@ -438,6 +446,7 @@ storage_mode: zstd
438446
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
439447
AccessLogLevel: "all",
440448
LogTimezone: "UTC",
449+
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
441450
}
442451

443452
if !cmp.Equal(config, expectedConfig) {
@@ -472,6 +481,7 @@ storage_mode: zstd
472481
MetricsDurationBuckets: []float64{.5, 1, 2.5, 5, 10, 20, 40, 80, 160, 320},
473482
AccessLogLevel: "all",
474483
LogTimezone: "UTC",
484+
DigestFunctions: []pb.DigestFunction_Value{pb.DigestFunction_SHA256},
475485
}
476486

477487
if !cmp.Equal(config, expectedConfig) {
@@ -495,3 +505,74 @@ func TestSocketPathMissing(t *testing.T) {
495505
t.Fatal("Expected the error message to mention the missing 'http_address' key/flag")
496506
}
497507
}
508+
509+
func TestDigestFunctions(t *testing.T) {
510+
t.Run("Default", func(t *testing.T) {
511+
yaml := `dir: /opt/cache-dir
512+
max_size: 42
513+
`
514+
config, err := newFromYaml([]byte(yaml))
515+
if err != nil {
516+
t.Fatal(err)
517+
}
518+
if len(config.DigestFunctions) != 1 {
519+
t.Fatal("Expected exactly one digest function")
520+
}
521+
if config.DigestFunctions[0] != pb.DigestFunction_SHA256 {
522+
t.Fatal("Expected sha256 digest function")
523+
}
524+
err = validateConfig(config)
525+
if err != nil {
526+
t.Fatal(err)
527+
}
528+
})
529+
530+
t.Run("Success", func(t *testing.T) {
531+
yaml := `dir: /opt/cache-dir
532+
max_size: 42
533+
digest_functions: [sha256]
534+
`
535+
config, err := newFromYaml([]byte(yaml))
536+
if err != nil {
537+
t.Fatal(err)
538+
}
539+
if len(config.DigestFunctions) != 1 {
540+
t.Fatal("Expected exactly one digest function")
541+
}
542+
if config.DigestFunctions[0] != pb.DigestFunction_SHA256 {
543+
t.Fatal("Expected sha256 digest function")
544+
}
545+
err = validateConfig(config)
546+
if err != nil {
547+
t.Fatal(err)
548+
}
549+
})
550+
551+
t.Run("UnknownFunction", func(t *testing.T) {
552+
yaml := `dir: /opt/cache-dir
553+
max_size: 42
554+
digest_functions: [sha256, foo]
555+
`
556+
_, err := newFromYaml([]byte(yaml))
557+
if err == nil {
558+
t.Fatal("Expected error")
559+
}
560+
if !strings.Contains(err.Error(), "unknown") {
561+
t.Fatalf("Unexpected error: %s", err.Error())
562+
}
563+
})
564+
565+
t.Run("UnsupportedFunction", func(t *testing.T) {
566+
yaml := `dir: /opt/cache-dir
567+
max_size: 42
568+
digest_functions: [md5]
569+
`
570+
_, err := newFromYaml([]byte(yaml))
571+
if err == nil {
572+
t.Fatal("Expected error")
573+
}
574+
if !strings.Contains(err.Error(), "unsupported") {
575+
t.Fatalf("Unexpected error: %s", err.Error())
576+
}
577+
})
578+
}

main.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ func startHttpServer(c *config.Config, httpServer **http.Server,
239239
checkClientCertForWrites := c.TLSCaFile != ""
240240
validateAC := !c.DisableHTTPACValidation
241241
h := server.NewHTTPCache(diskCache, c.AccessLogger, c.ErrorLogger, validateAC,
242-
c.EnableACKeyInstanceMangling, checkClientCertForReads, checkClientCertForWrites, gitCommit)
242+
c.EnableACKeyInstanceMangling, checkClientCertForReads, checkClientCertForWrites, gitCommit, c.DigestFunctions)
243243

244244
cacheHandler := h.CacheHandler
245245
var basicAuthenticator auth.BasicAuth
@@ -429,7 +429,8 @@ func startGrpcServer(c *config.Config, grpcServer **grpc.Server,
429429
validateAC,
430430
c.EnableACKeyInstanceMangling,
431431
enableRemoteAssetAPI,
432-
diskCache, c.AccessLogger, c.ErrorLogger)
432+
diskCache, c.AccessLogger, c.ErrorLogger,
433+
c.DigestFunctions)
433434
}
434435

435436
// A http.HandlerFunc wrapper which requires successful basic

server/grpc.go

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package server
22

33
import (
44
"context"
5+
"fmt"
56
"net"
67
"net/http"
78

@@ -30,11 +31,12 @@ import (
3031
const grpcHealthServiceName = "/grpc.health.v1.Health/Check"
3132

3233
type grpcServer struct {
33-
cache disk.Cache
34-
accessLogger cache.Logger
35-
errorLogger cache.Logger
36-
depsCheck bool
37-
mangleACKeys bool
34+
cache disk.Cache
35+
accessLogger cache.Logger
36+
errorLogger cache.Logger
37+
depsCheck bool
38+
mangleACKeys bool
39+
digestFunctions map[pb.DigestFunction_Value]bool
3840
}
3941

4042
var readOnlyMethods = map[string]struct{}{
@@ -55,26 +57,33 @@ func ListenAndServeGRPC(
5557
validateACDeps bool,
5658
mangleACKeys bool,
5759
enableRemoteAssetAPI bool,
58-
c disk.Cache, a cache.Logger, e cache.Logger) error {
60+
c disk.Cache, a cache.Logger, e cache.Logger,
61+
digestFunctions []pb.DigestFunction_Value) error {
5962

6063
listener, err := net.Listen(network, addr)
6164
if err != nil {
6265
return err
6366
}
6467

65-
return ServeGRPC(listener, srv, validateACDeps, mangleACKeys, enableRemoteAssetAPI, c, a, e)
68+
return ServeGRPC(listener, srv, validateACDeps, mangleACKeys, enableRemoteAssetAPI, c, a, e, digestFunctions)
6669
}
6770

6871
func ServeGRPC(l net.Listener, srv *grpc.Server,
6972
validateACDepsCheck bool,
7073
mangleACKeys bool,
7174
enableRemoteAssetAPI bool,
72-
c disk.Cache, a cache.Logger, e cache.Logger) error {
75+
c disk.Cache, a cache.Logger, e cache.Logger,
76+
digestFunctions []pb.DigestFunction_Value) error {
7377

78+
dfs := make(map[pb.DigestFunction_Value]bool)
79+
for _, df := range digestFunctions {
80+
dfs[df] = true
81+
}
7482
s := &grpcServer{
7583
cache: c, accessLogger: a, errorLogger: e,
76-
depsCheck: validateACDepsCheck,
77-
mangleACKeys: mangleACKeys,
84+
depsCheck: validateACDepsCheck,
85+
mangleACKeys: mangleACKeys,
86+
digestFunctions: dfs,
7887
}
7988
pb.RegisterActionCacheServer(srv, s)
8089
pb.RegisterCapabilitiesServer(srv, s)
@@ -129,10 +138,14 @@ func (s *grpcServer) GetCapabilities(ctx context.Context,
129138
func (s *grpcServer) getHasher(df pb.DigestFunction_Value) (hashing.Hasher, error) {
130139
var err error
131140
var hasher hashing.Hasher
141+
132142
switch df {
133143
case pb.DigestFunction_UNKNOWN:
134144
hasher, err = hashing.Get(hashing.LegacyFn)
135145
default:
146+
if _, ok := s.digestFunctions[df]; !ok {
147+
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("unsupported digest function %s", df))
148+
}
136149
hasher, err = hashing.Get(df)
137150
}
138151
if err != nil {

server/grpc_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func grpcTestSetupInternal(t *testing.T, mangleACKeys bool) (tc grpcTestFixture)
105105
validateAC,
106106
mangleACKeys,
107107
enableRemoteAssetAPI,
108-
diskCache, accessLogger, errorLogger)
108+
diskCache, accessLogger, errorLogger, hashing.DigestFunctions())
109109
if err2 != nil {
110110
fmt.Println(err2)
111111
os.Exit(1)

0 commit comments

Comments
 (0)