Skip to content

Commit 3c9cdd0

Browse files
authored
Merge pull request #769 from luraproject/sc-691/customizable-backend-filtering
Customizable backend filtering
2 parents b36b91d + 393d87e commit 3c9cdd0

2 files changed

Lines changed: 274 additions & 73 deletions

File tree

proxy/merging.go

Lines changed: 132 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
2828
}
2929
serviceTimeout := time.Duration(85*endpointConfig.Timeout.Nanoseconds()/100) * time.Nanosecond
3030
combiner := getResponseCombiner(endpointConfig.ExtraConfig)
31-
isSequential, propagatedParams := sequentialMergerConfig(endpointConfig)
31+
isSequential, sequentialReplacements := sequentialMergerConfig(endpointConfig)
3232

3333
logger.Debug(
3434
fmt.Sprintf(
@@ -40,6 +40,8 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
4040
),
4141
)
4242

43+
bfFactory := backendFiltererFactory.filtererFactory
44+
4345
return func(next ...Proxy) Proxy {
4446
if len(next) != totalBackends {
4547
// we leave the panic here, because we do not want to continue
@@ -50,66 +52,61 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
5052
}
5153
reqClone := func(r *Request) *Request { res := r.Clone(); return &res }
5254

55+
filters, err := bfFactory(endpointConfig)
56+
if err != nil {
57+
logger.Error(fmt.Sprintf("[ENDPOINT: %s]%s %s", endpointConfig.Endpoint, backendFiltererFactory.logPrefix, err))
58+
return func(_ context.Context, _ *Request) (*Response, error) { return nil, err }
59+
}
60+
5361
if hasUnsafeBackends(endpointConfig) {
5462
reqClone = CloneRequest
5563
}
5664

5765
if !isSequential {
58-
return parallelMerge(reqClone, serviceTimeout, combiner, next...)
66+
return parallelMerge(reqClone, serviceTimeout, combiner, filters, next...)
5967
}
6068

61-
sequentialReplacements := make([][]sequentialBackendReplacement, totalBackends)
69+
return sequentialMerge(reqClone, serviceTimeout, combiner, sequentialReplacements, filters, next...)
70+
}
71+
}
6272

63-
var rePropagatedParams = regexp.MustCompile(`[Rr]esp(\d+)_?([\w-.]+)?`)
64-
var reUrlPatterns = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-.]+)\}\}`)
65-
destKeyGenerator := func(i string, t string) string {
66-
key := "Resp" + i
67-
if t != "" {
68-
key += "_" + t
69-
}
70-
return key
71-
}
73+
// BackendFiltererFactory is a factory function that returns a list of BackendFilterer
74+
// based on the provided EndpointConfig.
75+
// The returned list must be sorted by the backend index.
76+
// The list can contain nil values, which means that the backend in that index is untouched.
77+
type BackendFiltererFactory func(*config.EndpointConfig) ([]BackendFilterer, error)
7278

73-
for i, b := range endpointConfig.Backend {
74-
for _, match := range reUrlPatterns.FindAllStringSubmatch(b.URLPattern, -1) {
75-
if len(match) > 1 {
76-
backendIndex, err := strconv.Atoi(match[1])
77-
if err != nil {
78-
continue
79-
}
79+
// BackendFilterer evalutes the request and returns true if the backend should be used,
80+
// otherwise the backend is skipped in both normal and sequential merging.
81+
// If the backend is skipped, the response will not be merged into the final response.
82+
type BackendFilterer func(*Request) bool
8083

81-
sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
82-
backendIndex: backendIndex,
83-
destination: destKeyGenerator(match[1], match[2]),
84-
source: strings.Split(match[2], "."),
85-
fullResponse: match[2] == "",
86-
})
87-
}
88-
}
84+
func defaultBackendFiltererFactory(_ *config.EndpointConfig) ([]BackendFilterer, error) {
85+
return []BackendFilterer{}, nil
86+
}
8987

90-
if i > 0 {
91-
for _, p := range propagatedParams {
92-
for _, match := range rePropagatedParams.FindAllStringSubmatch(p, -1) {
93-
if len(match) > 1 {
94-
backendIndex, err := strconv.Atoi(match[1])
95-
if err != nil || backendIndex >= totalBackends {
96-
continue
97-
}
88+
type backendFiltererRegistry struct {
89+
logPrefix string
90+
filtererFactory BackendFiltererFactory
91+
}
9892

99-
sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
100-
backendIndex: backendIndex,
101-
destination: destKeyGenerator(match[1], match[2]),
102-
source: strings.Split(match[2], "."),
103-
fullResponse: match[2] == "",
104-
})
105-
}
106-
}
107-
}
108-
}
109-
}
93+
var backendFiltererFactory = backendFiltererRegistry{
94+
filtererFactory: defaultBackendFiltererFactory,
95+
}
11096

111-
return sequentialMerge(reqClone, sequentialReplacements, serviceTimeout, combiner, next...)
112-
}
97+
// RegisterBackendFiltererFactory registers a new backend filterer factory
98+
// to be used by the merging middleware.
99+
// This factory is used to create a list of BackendFilterer
100+
// functions that will be used to filter backends based on the request.
101+
// Important: this function should be called everytime the middleware is created.
102+
func RegisterBackendFiltererFactory(logPrefix string, f BackendFiltererFactory) {
103+
backendFiltererFactory.logPrefix = logPrefix
104+
backendFiltererFactory.filtererFactory = f
105+
}
106+
107+
func ResetBackendFiltererFactory() {
108+
backendFiltererFactory.logPrefix = ""
109+
backendFiltererFactory.filtererFactory = defaultBackendFiltererFactory
113110
}
114111

115112
type sequentialBackendReplacement struct {
@@ -119,9 +116,12 @@ type sequentialBackendReplacement struct {
119116
fullResponse bool
120117
}
121118

122-
func sequentialMergerConfig(cfg *config.EndpointConfig) (bool, []string) {
119+
func sequentialMergerConfig(cfg *config.EndpointConfig) (bool, [][]sequentialBackendReplacement) { // skipcq: GO-R1005
123120
enabled := false
121+
totalBackends := len(cfg.Backend)
122+
sequentialReplacements := make([][]sequentialBackendReplacement, totalBackends)
124123
var propagatedParams []string
124+
125125
if v, ok := cfg.ExtraConfig[Namespace]; ok {
126126
if e, ok := v.(map[string]interface{}); ok {
127127
if v, ok := e[isSequentialKey]; ok {
@@ -137,7 +137,54 @@ func sequentialMergerConfig(cfg *config.EndpointConfig) (bool, []string) {
137137
}
138138
}
139139
}
140-
return enabled, propagatedParams
140+
var rePropagatedParams = regexp.MustCompile(`[Rr]esp(\d+)_?([\w-.]+)?`)
141+
var reUrlPatterns = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-.]+)\}\}`)
142+
destKeyGenerator := func(i string, t string) string {
143+
key := "Resp" + i
144+
if t != "" {
145+
key += "_" + t
146+
}
147+
return key
148+
}
149+
150+
for i, b := range cfg.Backend {
151+
for _, match := range reUrlPatterns.FindAllStringSubmatch(b.URLPattern, -1) {
152+
if len(match) > 1 {
153+
backendIndex, err := strconv.Atoi(match[1])
154+
if err != nil {
155+
continue
156+
}
157+
158+
sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
159+
backendIndex: backendIndex,
160+
destination: destKeyGenerator(match[1], match[2]),
161+
source: strings.Split(match[2], "."),
162+
fullResponse: match[2] == "",
163+
})
164+
}
165+
}
166+
167+
if i > 0 {
168+
for _, p := range propagatedParams {
169+
for _, match := range rePropagatedParams.FindAllStringSubmatch(p, -1) {
170+
if len(match) > 1 {
171+
backendIndex, err := strconv.Atoi(match[1])
172+
if err != nil || backendIndex >= totalBackends {
173+
continue
174+
}
175+
176+
sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
177+
backendIndex: backendIndex,
178+
destination: destKeyGenerator(match[1], match[2]),
179+
source: strings.Split(match[2], "."),
180+
fullResponse: match[2] == "",
181+
})
182+
}
183+
}
184+
}
185+
}
186+
}
187+
return enabled, sequentialReplacements
141188
}
142189

143190
func hasUnsafeBackends(cfg *config.EndpointConfig) bool {
@@ -154,19 +201,32 @@ func hasUnsafeBackends(cfg *config.EndpointConfig) bool {
154201
return false
155202
}
156203

157-
func parallelMerge(reqCloner func(*Request) *Request, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy {
204+
func parallelMerge(
205+
reqCloner func(*Request) *Request,
206+
timeout time.Duration,
207+
rc ResponseCombiner,
208+
filters []BackendFilterer,
209+
next ...Proxy,
210+
) Proxy {
158211
return func(ctx context.Context, request *Request) (*Response, error) {
159212
localCtx, cancel := context.WithTimeout(ctx, timeout)
160213

161-
parts := make(chan *Response, len(next))
162-
failed := make(chan error, len(next))
214+
proxyCount := len(next)
215+
filterCount := len(filters)
216+
217+
parts := make(chan *Response, proxyCount)
218+
failed := make(chan error, proxyCount)
163219

164-
for _, n := range next {
220+
for i, n := range next {
221+
if (i < filterCount) && (filters[i] != nil) && !filters[i](request) {
222+
proxyCount--
223+
continue
224+
}
165225
go requestPart(localCtx, n, reqCloner(request), parts, failed)
166226
}
167227

168-
acc := newIncrementalMergeAccumulator(len(next), rc)
169-
for i := 0; i < len(next); i++ {
228+
acc := newIncrementalMergeAccumulator(proxyCount, rc)
229+
for i := 0; i < proxyCount; i++ {
170230
select {
171231
case err := <-failed:
172232
acc.Merge(nil, err)
@@ -181,10 +241,18 @@ func parallelMerge(reqCloner func(*Request) *Request, timeout time.Duration, rc
181241
}
182242
}
183243

184-
func sequentialMerge(reqCloner func(*Request) *Request, sequentialReplacements [][]sequentialBackendReplacement, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy { // skipcq: GO-R1005
244+
func sequentialMerge( // skipcq: GO-R1005
245+
reqCloner func(*Request) *Request,
246+
timeout time.Duration,
247+
rc ResponseCombiner,
248+
sequentialReplacements [][]sequentialBackendReplacement,
249+
filters []BackendFilterer,
250+
next ...Proxy,
251+
) Proxy {
185252
return func(ctx context.Context, request *Request) (*Response, error) {
186253
localCtx, cancel := context.WithTimeout(ctx, timeout)
187254

255+
filterCount := len(filters)
188256
parts := make([]*Response, len(next))
189257
out := make(chan *Response, 1)
190258
errCh := make(chan error, 1)
@@ -270,6 +338,12 @@ func sequentialMerge(reqCloner func(*Request) *Request, sequentialReplacements [
270338
}
271339
}
272340

341+
if (i < filterCount) && (filters[i] != nil) && !filters[i](request) {
342+
parts[i] = &Response{IsComplete: true, Data: make(map[string]interface{})}
343+
acc.pending--
344+
continue
345+
}
346+
273347
sequentialRequestPart(localCtx, n, reqCloner(request), out, errCh)
274348

275349
select {
@@ -335,7 +409,7 @@ func (i *incrementalMergeAccumulator) Result() (*Response, error) {
335409
return nil, newMergeError(i.errs)
336410
}
337411

338-
if i.pending != 0 || len(i.errs) != 0 {
412+
if i.pending > 0 || len(i.errs) > 0 {
339413
i.data.IsComplete = false
340414
}
341415
return i.data, newMergeError(i.errs)

0 commit comments

Comments
 (0)