Skip to content

Commit 9b8a7c3

Browse files
authored
Merge pull request #745 from luraproject/propagate_seq_merger_params
Global param propagation for sequential merger
2 parents fff63b0 + fb3d36f commit 9b8a7c3

3 files changed

Lines changed: 193 additions & 89 deletions

File tree

proxy/merging.go

Lines changed: 145 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package proxy
55
import (
66
"context"
77
"fmt"
8+
"io"
89
"net/http"
910
"regexp"
1011
"strconv"
@@ -16,7 +17,7 @@ import (
1617
)
1718

1819
// NewMergeDataMiddleware creates proxy middleware for merging responses from several backends
19-
func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.EndpointConfig) Middleware {
20+
func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.EndpointConfig) Middleware { // skipcq: GO-R1005
2021
totalBackends := len(endpointConfig.Backend)
2122
if totalBackends == 0 {
2223
logger.Fatal("all endpoints must have at least one backend: NewMergeDataMiddleware")
@@ -27,7 +28,7 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
2728
}
2829
serviceTimeout := time.Duration(85*endpointConfig.Timeout.Nanoseconds()/100) * time.Nanosecond
2930
combiner := getResponseCombiner(endpointConfig.ExtraConfig)
30-
isSequential := shouldRunSequentialMerger(endpointConfig)
31+
isSequential, propagatedParams := sequentialMergerConfig(endpointConfig)
3132

3233
logger.Debug(
3334
fmt.Sprintf(
@@ -57,24 +58,86 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
5758
return parallelMerge(reqClone, serviceTimeout, combiner, next...)
5859
}
5960

60-
patterns := make([]string, len(endpointConfig.Backend))
61+
sequentialReplacements := make([][]sequentialBackendReplacement, totalBackends)
62+
63+
var rePropagatedParams = regexp.MustCompile(`[Rr]esp(\d+)_?([\w-.]+)?`)
64+
var reUrlPatterns = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-.]+)\}\}`)
65+
destKeyGenerator := func(i string, t string) string {
66+
key := "Resp" + i
67+
if t != "" {
68+
key += "_" + t
69+
}
70+
return key
71+
}
72+
6173
for i, b := range endpointConfig.Backend {
62-
patterns[i] = b.URLPattern
74+
for _, match := range reUrlPatterns.FindAllStringSubmatch(b.URLPattern, -1) {
75+
if len(match) > 1 {
76+
backendIndex, err := strconv.Atoi(match[1])
77+
if err != nil {
78+
continue
79+
}
80+
81+
sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
82+
backendIndex: backendIndex,
83+
destination: destKeyGenerator(match[1], match[2]),
84+
source: strings.Split(match[2], "."),
85+
fullResponse: match[2] == "",
86+
})
87+
}
88+
}
89+
90+
if i > 0 {
91+
for _, p := range propagatedParams {
92+
for _, match := range rePropagatedParams.FindAllStringSubmatch(p, -1) {
93+
if len(match) > 1 {
94+
backendIndex, err := strconv.Atoi(match[1])
95+
if err != nil || backendIndex >= totalBackends {
96+
continue
97+
}
98+
99+
sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
100+
backendIndex: backendIndex,
101+
destination: destKeyGenerator(match[1], match[2]),
102+
source: strings.Split(match[2], "."),
103+
fullResponse: match[2] == "",
104+
})
105+
}
106+
}
107+
}
108+
}
63109
}
64-
return sequentialMerge(reqClone, patterns, serviceTimeout, combiner, next...)
110+
111+
return sequentialMerge(reqClone, sequentialReplacements, serviceTimeout, combiner, next...)
65112
}
66113
}
67114

68-
func shouldRunSequentialMerger(cfg *config.EndpointConfig) bool {
115+
type sequentialBackendReplacement struct {
116+
backendIndex int
117+
destination string
118+
source []string
119+
fullResponse bool
120+
}
121+
122+
func sequentialMergerConfig(cfg *config.EndpointConfig) (bool, []string) {
123+
enabled := false
124+
var propagatedParams []string
69125
if v, ok := cfg.ExtraConfig[Namespace]; ok {
70126
if e, ok := v.(map[string]interface{}); ok {
71127
if v, ok := e[isSequentialKey]; ok {
72128
c, ok := v.(bool)
73-
return ok && c
129+
enabled = ok && c
130+
}
131+
if v, ok := e[sequentialPropagateKey]; ok {
132+
if a, ok := v.([]interface{}); ok {
133+
for _, p := range a {
134+
propagatedParams = append(propagatedParams, p.(string))
135+
}
136+
}
74137
}
75138
}
76139
}
77-
return false
140+
return enabled, propagatedParams
78141
}
79142

80143
func hasUnsafeBackends(cfg *config.EndpointConfig) bool {
@@ -118,75 +181,92 @@ func parallelMerge(reqCloner func(*Request) *Request, timeout time.Duration, rc
118181
}
119182
}
120183

121-
var reMergeKey = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-\.]+)\}\}`)
122-
123-
func sequentialMerge(reqCloner func(*Request) *Request, patterns []string, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy {
184+
func sequentialMerge(reqCloner func(*Request) *Request, sequentialReplacements [][]sequentialBackendReplacement, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy { // skipcq: GO-R1005
124185
return func(ctx context.Context, request *Request) (*Response, error) {
125186
localCtx, cancel := context.WithTimeout(ctx, timeout)
126187

127188
parts := make([]*Response, len(next))
128189
out := make(chan *Response, 1)
129190
errCh := make(chan error, 1)
191+
sequentialMergeRegistry := map[string]string{}
130192

131193
acc := newIncrementalMergeAccumulator(len(next), rc)
132194
TxLoop:
133195
for i, n := range next {
134196
if i > 0 {
135-
for _, match := range reMergeKey.FindAllStringSubmatch(patterns[i], -1) {
136-
if len(match) > 1 {
137-
rNum, err := strconv.Atoi(match[1])
138-
if err != nil || rNum >= i || parts[rNum] == nil {
139-
continue
140-
}
141-
key := "Resp" + match[1] + "_" + match[2]
142-
143-
var v interface{}
144-
var ok bool
145-
146-
data := parts[rNum].Data
147-
keys := strings.Split(match[2], ".")
148-
if len(keys) > 1 {
149-
for _, k := range keys[:len(keys)-1] {
150-
v, ok = data[k]
151-
if !ok {
152-
break
153-
}
154-
clean, ok := v.(map[string]interface{})
155-
if !ok {
156-
break
157-
}
158-
data = clean
197+
for _, r := range sequentialReplacements[i] {
198+
if r.backendIndex >= i || parts[r.backendIndex] == nil {
199+
continue
200+
}
201+
202+
var v interface{}
203+
var ok bool
204+
205+
data := parts[r.backendIndex].Data
206+
if len(r.source) > 1 {
207+
for _, k := range r.source[:len(r.source)-1] {
208+
v, ok = data[k]
209+
if !ok {
210+
break
159211
}
212+
clean, ok := v.(map[string]interface{})
213+
if !ok {
214+
break
215+
}
216+
data = clean
160217
}
218+
}
161219

162-
v, ok = data[keys[len(keys)-1]]
163-
if !ok {
220+
if found := sequentialMergeRegistry[r.destination]; found != "" {
221+
request.Params[r.destination] = found
222+
continue
223+
}
224+
225+
if r.fullResponse {
226+
if parts[r.backendIndex].Io == nil {
164227
continue
165228
}
166-
switch clean := v.(type) {
167-
case []interface{}:
168-
if len(clean) == 0 {
169-
request.Params[key] = ""
170-
continue
171-
}
172-
var b strings.Builder
173-
for i := 0; i < len(clean)-1; i++ {
174-
fmt.Fprintf(&b, "%v,", clean[i])
175-
}
176-
fmt.Fprintf(&b, "%v", clean[len(clean)-1])
177-
request.Params[key] = b.String()
178-
case string:
179-
request.Params[key] = clean
180-
case int:
181-
request.Params[key] = strconv.Itoa(clean)
182-
case float64:
183-
request.Params[key] = strconv.FormatFloat(clean, 'E', -1, 32)
184-
case bool:
185-
request.Params[key] = strconv.FormatBool(clean)
186-
default:
187-
request.Params[key] = fmt.Sprintf("%v", v)
229+
buf, err := io.ReadAll(parts[r.backendIndex].Io)
230+
231+
if err == nil {
232+
request.Params[r.destination] = string(buf)
233+
sequentialMergeRegistry[r.destination] = string(buf)
188234
}
235+
continue
189236
}
237+
238+
v, ok = data[r.source[len(r.source)-1]]
239+
if !ok {
240+
continue
241+
}
242+
243+
var param string
244+
245+
switch clean := v.(type) {
246+
case []interface{}:
247+
if len(clean) == 0 {
248+
request.Params[r.destination] = ""
249+
break
250+
}
251+
var b strings.Builder
252+
for i := 0; i < len(clean)-1; i++ {
253+
fmt.Fprintf(&b, "%v,", clean[i])
254+
}
255+
fmt.Fprintf(&b, "%v", clean[len(clean)-1])
256+
param = b.String()
257+
case string:
258+
param = clean
259+
case int:
260+
param = strconv.Itoa(clean)
261+
case float64:
262+
param = strconv.FormatFloat(clean, 'E', -1, 32)
263+
case bool:
264+
param = strconv.FormatBool(clean)
265+
default:
266+
param = fmt.Sprintf("%v", v)
267+
}
268+
request.Params[r.destination] = param
269+
sequentialMergeRegistry[r.destination] = param
190270
}
191271
}
192272

@@ -284,30 +364,25 @@ func requestPart(ctx context.Context, next Proxy, request *Request, out chan<- *
284364
}
285365

286366
func sequentialRequestPart(ctx context.Context, next Proxy, request *Request, out chan<- *Response, failed chan<- error) {
287-
localCtx, cancel := context.WithCancel(ctx)
288-
289367
copyRequest := CloneRequest(request)
290368

291-
in, err := next(localCtx, request)
369+
in, err := next(ctx, request)
292370

293371
*request = *copyRequest
294372

295373
if err != nil {
296374
failed <- err
297-
cancel()
298375
return
299376
}
300377
if in == nil {
301378
failed <- errNullResult
302-
cancel()
303379
return
304380
}
305381
select {
306382
case out <- in:
307383
case <-ctx.Done():
308384
failed <- ctx.Err()
309385
}
310-
cancel()
311386
}
312387

313388
func newMergeError(errs []error) error {
@@ -342,9 +417,10 @@ func RegisterResponseCombiner(name string, f ResponseCombiner) {
342417
}
343418

344419
const (
345-
mergeKey = "combiner"
346-
isSequentialKey = "sequential"
347-
defaultCombinerName = "default"
420+
mergeKey = "combiner"
421+
isSequentialKey = "sequential"
422+
sequentialPropagateKey = "sequential_propagated_params"
423+
defaultCombinerName = "default"
348424
)
349425

350426
var responseCombiners = initResponseCombiners()
@@ -382,7 +458,7 @@ func combineData(total int, parts []*Response) *Response {
382458
}
383459
isComplete = isComplete && part.IsComplete
384460
if retResponse == nil {
385-
retResponse = part
461+
retResponse = &Response{Data: part.Data, IsComplete: isComplete}
386462
continue
387463
}
388464
for k, v := range part.Data {

0 commit comments

Comments
 (0)