@@ -28,7 +28,7 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
2828 }
2929 serviceTimeout := time .Duration (85 * endpointConfig .Timeout .Nanoseconds ()/ 100 ) * time .Nanosecond
3030 combiner := getResponseCombiner (endpointConfig .ExtraConfig )
31- isSequential , propagatedParams := sequentialMergerConfig (endpointConfig )
31+ isSequential , sequentialReplacements := sequentialMergerConfig (endpointConfig )
3232
3333 logger .Debug (
3434 fmt .Sprintf (
@@ -40,6 +40,8 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
4040 ),
4141 )
4242
43+ bfFactory := backendFiltererFactory .filtererFactory
44+
4345 return func (next ... Proxy ) Proxy {
4446 if len (next ) != totalBackends {
4547 // we leave the panic here, because we do not want to continue
@@ -50,66 +52,61 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
5052 }
5153 reqClone := func (r * Request ) * Request { res := r .Clone (); return & res }
5254
55+ filters , err := bfFactory (endpointConfig )
56+ if err != nil {
57+ logger .Error (fmt .Sprintf ("[ENDPOINT: %s]%s %s" , endpointConfig .Endpoint , backendFiltererFactory .logPrefix , err ))
58+ return func (_ context.Context , _ * Request ) (* Response , error ) { return nil , err }
59+ }
60+
5361 if hasUnsafeBackends (endpointConfig ) {
5462 reqClone = CloneRequest
5563 }
5664
5765 if ! isSequential {
58- return parallelMerge (reqClone , serviceTimeout , combiner , next ... )
66+ return parallelMerge (reqClone , serviceTimeout , combiner , filters , next ... )
5967 }
6068
61- sequentialReplacements := make ([][]sequentialBackendReplacement , totalBackends )
69+ return sequentialMerge (reqClone , serviceTimeout , combiner , sequentialReplacements , filters , next ... )
70+ }
71+ }
6272
63- var rePropagatedParams = regexp .MustCompile (`[Rr]esp(\d+)_?([\w-.]+)?` )
64- var reUrlPatterns = regexp .MustCompile (`\{\{\.Resp(\d+)_([\w-.]+)\}\}` )
65- destKeyGenerator := func (i string , t string ) string {
66- key := "Resp" + i
67- if t != "" {
68- key += "_" + t
69- }
70- return key
71- }
73+ // BackendFiltererFactory is a factory function that returns a list of BackendFilterer
74+ // based on the provided EndpointConfig.
75+ // The returned list must be sorted by the backend index.
76+ // The list can contain nil values, which means that the backend in that index is untouched.
77+ type BackendFiltererFactory func (* config.EndpointConfig ) ([]BackendFilterer , error )
7278
73- for i , b := range endpointConfig .Backend {
74- for _ , match := range reUrlPatterns .FindAllStringSubmatch (b .URLPattern , - 1 ) {
75- if len (match ) > 1 {
76- backendIndex , err := strconv .Atoi (match [1 ])
77- if err != nil {
78- continue
79- }
79+ // BackendFilterer evalutes the request and returns true if the backend should be used,
80+ // otherwise the backend is skipped in both normal and sequential merging.
81+ // If the backend is skipped, the response will not be merged into the final response.
82+ type BackendFilterer func (* Request ) bool
8083
81- sequentialReplacements [i ] = append (sequentialReplacements [i ], sequentialBackendReplacement {
82- backendIndex : backendIndex ,
83- destination : destKeyGenerator (match [1 ], match [2 ]),
84- source : strings .Split (match [2 ], "." ),
85- fullResponse : match [2 ] == "" ,
86- })
87- }
88- }
84+ func defaultBackendFiltererFactory (_ * config.EndpointConfig ) ([]BackendFilterer , error ) {
85+ return []BackendFilterer {}, nil
86+ }
8987
90- if i > 0 {
91- for _ , p := range propagatedParams {
92- for _ , match := range rePropagatedParams .FindAllStringSubmatch (p , - 1 ) {
93- if len (match ) > 1 {
94- backendIndex , err := strconv .Atoi (match [1 ])
95- if err != nil || backendIndex >= totalBackends {
96- continue
97- }
88+ type backendFiltererRegistry struct {
89+ logPrefix string
90+ filtererFactory BackendFiltererFactory
91+ }
9892
99- sequentialReplacements [i ] = append (sequentialReplacements [i ], sequentialBackendReplacement {
100- backendIndex : backendIndex ,
101- destination : destKeyGenerator (match [1 ], match [2 ]),
102- source : strings .Split (match [2 ], "." ),
103- fullResponse : match [2 ] == "" ,
104- })
105- }
106- }
107- }
108- }
109- }
93+ var backendFiltererFactory = backendFiltererRegistry {
94+ filtererFactory : defaultBackendFiltererFactory ,
95+ }
11096
111- return sequentialMerge (reqClone , sequentialReplacements , serviceTimeout , combiner , next ... )
112- }
97+ // RegisterBackendFiltererFactory registers a new backend filterer factory
98+ // to be used by the merging middleware.
99+ // This factory is used to create a list of BackendFilterer
100+ // functions that will be used to filter backends based on the request.
101+ // Important: this function should be called everytime the middleware is created.
102+ func RegisterBackendFiltererFactory (logPrefix string , f BackendFiltererFactory ) {
103+ backendFiltererFactory .logPrefix = logPrefix
104+ backendFiltererFactory .filtererFactory = f
105+ }
106+
107+ func ResetBackendFiltererFactory () {
108+ backendFiltererFactory .logPrefix = ""
109+ backendFiltererFactory .filtererFactory = defaultBackendFiltererFactory
113110}
114111
115112type sequentialBackendReplacement struct {
@@ -119,9 +116,12 @@ type sequentialBackendReplacement struct {
119116 fullResponse bool
120117}
121118
122- func sequentialMergerConfig (cfg * config.EndpointConfig ) (bool , []string ) {
119+ func sequentialMergerConfig (cfg * config.EndpointConfig ) (bool , [][] sequentialBackendReplacement ) { // skipcq: GO-R1005
123120 enabled := false
121+ totalBackends := len (cfg .Backend )
122+ sequentialReplacements := make ([][]sequentialBackendReplacement , totalBackends )
124123 var propagatedParams []string
124+
125125 if v , ok := cfg .ExtraConfig [Namespace ]; ok {
126126 if e , ok := v .(map [string ]interface {}); ok {
127127 if v , ok := e [isSequentialKey ]; ok {
@@ -137,7 +137,54 @@ func sequentialMergerConfig(cfg *config.EndpointConfig) (bool, []string) {
137137 }
138138 }
139139 }
140- return enabled , propagatedParams
140+ var rePropagatedParams = regexp .MustCompile (`[Rr]esp(\d+)_?([\w-.]+)?` )
141+ var reUrlPatterns = regexp .MustCompile (`\{\{\.Resp(\d+)_([\w-.]+)\}\}` )
142+ destKeyGenerator := func (i string , t string ) string {
143+ key := "Resp" + i
144+ if t != "" {
145+ key += "_" + t
146+ }
147+ return key
148+ }
149+
150+ for i , b := range cfg .Backend {
151+ for _ , match := range reUrlPatterns .FindAllStringSubmatch (b .URLPattern , - 1 ) {
152+ if len (match ) > 1 {
153+ backendIndex , err := strconv .Atoi (match [1 ])
154+ if err != nil {
155+ continue
156+ }
157+
158+ sequentialReplacements [i ] = append (sequentialReplacements [i ], sequentialBackendReplacement {
159+ backendIndex : backendIndex ,
160+ destination : destKeyGenerator (match [1 ], match [2 ]),
161+ source : strings .Split (match [2 ], "." ),
162+ fullResponse : match [2 ] == "" ,
163+ })
164+ }
165+ }
166+
167+ if i > 0 {
168+ for _ , p := range propagatedParams {
169+ for _ , match := range rePropagatedParams .FindAllStringSubmatch (p , - 1 ) {
170+ if len (match ) > 1 {
171+ backendIndex , err := strconv .Atoi (match [1 ])
172+ if err != nil || backendIndex >= totalBackends {
173+ continue
174+ }
175+
176+ sequentialReplacements [i ] = append (sequentialReplacements [i ], sequentialBackendReplacement {
177+ backendIndex : backendIndex ,
178+ destination : destKeyGenerator (match [1 ], match [2 ]),
179+ source : strings .Split (match [2 ], "." ),
180+ fullResponse : match [2 ] == "" ,
181+ })
182+ }
183+ }
184+ }
185+ }
186+ }
187+ return enabled , sequentialReplacements
141188}
142189
143190func hasUnsafeBackends (cfg * config.EndpointConfig ) bool {
@@ -154,19 +201,32 @@ func hasUnsafeBackends(cfg *config.EndpointConfig) bool {
154201 return false
155202}
156203
157- func parallelMerge (reqCloner func (* Request ) * Request , timeout time.Duration , rc ResponseCombiner , next ... Proxy ) Proxy {
204+ func parallelMerge (
205+ reqCloner func (* Request ) * Request ,
206+ timeout time.Duration ,
207+ rc ResponseCombiner ,
208+ filters []BackendFilterer ,
209+ next ... Proxy ,
210+ ) Proxy {
158211 return func (ctx context.Context , request * Request ) (* Response , error ) {
159212 localCtx , cancel := context .WithTimeout (ctx , timeout )
160213
161- parts := make (chan * Response , len (next ))
162- failed := make (chan error , len (next ))
214+ proxyCount := len (next )
215+ filterCount := len (filters )
216+
217+ parts := make (chan * Response , proxyCount )
218+ failed := make (chan error , proxyCount )
163219
164- for _ , n := range next {
220+ for i , n := range next {
221+ if (i < filterCount ) && (filters [i ] != nil ) && ! filters [i ](request ) {
222+ proxyCount --
223+ continue
224+ }
165225 go requestPart (localCtx , n , reqCloner (request ), parts , failed )
166226 }
167227
168- acc := newIncrementalMergeAccumulator (len ( next ) , rc )
169- for i := 0 ; i < len ( next ) ; i ++ {
228+ acc := newIncrementalMergeAccumulator (proxyCount , rc )
229+ for i := 0 ; i < proxyCount ; i ++ {
170230 select {
171231 case err := <- failed :
172232 acc .Merge (nil , err )
@@ -181,10 +241,18 @@ func parallelMerge(reqCloner func(*Request) *Request, timeout time.Duration, rc
181241 }
182242}
183243
184- func sequentialMerge (reqCloner func (* Request ) * Request , sequentialReplacements [][]sequentialBackendReplacement , timeout time.Duration , rc ResponseCombiner , next ... Proxy ) Proxy { // skipcq: GO-R1005
244+ func sequentialMerge ( // skipcq: GO-R1005
245+ reqCloner func (* Request ) * Request ,
246+ timeout time.Duration ,
247+ rc ResponseCombiner ,
248+ sequentialReplacements [][]sequentialBackendReplacement ,
249+ filters []BackendFilterer ,
250+ next ... Proxy ,
251+ ) Proxy {
185252 return func (ctx context.Context , request * Request ) (* Response , error ) {
186253 localCtx , cancel := context .WithTimeout (ctx , timeout )
187254
255+ filterCount := len (filters )
188256 parts := make ([]* Response , len (next ))
189257 out := make (chan * Response , 1 )
190258 errCh := make (chan error , 1 )
@@ -270,6 +338,12 @@ func sequentialMerge(reqCloner func(*Request) *Request, sequentialReplacements [
270338 }
271339 }
272340
341+ if (i < filterCount ) && (filters [i ] != nil ) && ! filters [i ](request ) {
342+ parts [i ] = & Response {IsComplete : true , Data : make (map [string ]interface {})}
343+ acc .pending --
344+ continue
345+ }
346+
273347 sequentialRequestPart (localCtx , n , reqCloner (request ), out , errCh )
274348
275349 select {
@@ -335,7 +409,7 @@ func (i *incrementalMergeAccumulator) Result() (*Response, error) {
335409 return nil , newMergeError (i .errs )
336410 }
337411
338- if i .pending != 0 || len (i .errs ) != 0 {
412+ if i .pending > 0 || len (i .errs ) > 0 {
339413 i .data .IsComplete = false
340414 }
341415 return i .data , newMergeError (i .errs )
0 commit comments