Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions integration/e2e/query_range_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,13 @@ sendLoop:
callQueryRange(t, tempo.Endpoint(3200), query, debugMode)
})
}

res := doRequest(t, tempo.Endpoint(3200), "{. a}")
require.Equal(t, 400, res.StatusCode)
}

func callQueryRange(t *testing.T, endpoint, query string, printBody bool) {
url := buildURL(endpoint, fmt.Sprintf("%s with(exemplars=true)", query))
req, err := http.NewRequest(http.MethodGet, url, nil)
require.NoError(t, err)

res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
res := doRequest(t, endpoint, query)
require.Equal(t, http.StatusOK, res.StatusCode)

// Read body and print it
Expand All @@ -89,6 +87,16 @@ func callQueryRange(t *testing.T, endpoint, query string, printBody bool) {
require.GreaterOrEqual(t, exemplarCount, 1)
}

func doRequest(t *testing.T, endpoint, query string) *http.Response {
url := buildURL(endpoint, fmt.Sprintf("%s with(exemplars=true)", query))
req, err := http.NewRequest(http.MethodGet, url, nil)
require.NoError(t, err)

res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
return res
}

func buildURL(endpoint, query string) string {
return fmt.Sprintf(
"http://%s/api/metrics/query_range?query=%s&start=%d&end=%d&step=%s",
Expand Down
4 changes: 4 additions & 0 deletions modules/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
statusCodeWare := pipeline.NewStatusCodeAdjustWare()
traceIDStatusCodeWare := pipeline.NewStatusCodeAdjustWareWithAllowedCode(http.StatusNotFound)
urlDenyListWare := pipeline.NewURLDenyListWare(cfg.URLDenyList)
queryValidatorWare := pipeline.NewQueryValidatorWare()

tracePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
Expand All @@ -106,6 +107,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
searchPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
urlDenyListWare,
queryValidatorWare,
multiTenantMiddleware(cfg, logger),
newAsyncSearchSharder(reader, o, cfg.Search.Sharder, logger),
},
Expand Down Expand Up @@ -134,6 +136,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
metricsPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
urlDenyListWare,
queryValidatorWare,
multiTenantUnsupportedMiddleware(cfg, logger),
},
[]pipeline.Middleware{statusCodeWare, retryWare},
Expand All @@ -143,6 +146,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
queryRangePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
urlDenyListWare,
queryValidatorWare,
multiTenantMiddleware(cfg, logger),
newAsyncQueryRangeSharder(reader, o, cfg.Metrics.Sharder, logger),
},
Expand Down
4 changes: 2 additions & 2 deletions modules/frontend/metrics_query_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func newMetricsQueryInstantHTTPHandler(cfg Config, next pipeline.AsyncRoundTripp
if err != nil {
level.Error(logger).Log("msg", "query instant: query range combiner failed", "err", err)
return &http.Response{
StatusCode: http.StatusInternalServerError,
Status: http.StatusText(http.StatusInternalServerError),
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest),
Body: io.NopCloser(strings.NewReader(err.Error())),
}, nil
}
Expand Down
4 changes: 2 additions & 2 deletions modules/frontend/metrics_query_range_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func newMetricsQueryRangeHTTPHandler(cfg Config, next pipeline.AsyncRoundTripper
if err != nil {
level.Error(logger).Log("msg", "query range: query range combiner failed", "err", err)
return &http.Response{
StatusCode: http.StatusInternalServerError,
Status: http.StatusText(http.StatusInternalServerError),
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest),
Body: io.NopCloser(strings.NewReader(err.Error())),
}, nil
}
Expand Down
50 changes: 50 additions & 0 deletions modules/frontend/pipeline/async_query_validator_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package pipeline

import (
"fmt"
"net/url"

"github.com/grafana/tempo/modules/frontend/combiner"
"github.com/grafana/tempo/pkg/traceql"
)

type queryValidatorWare struct {
next AsyncRoundTripper[combiner.PipelineResponse]
}

func NewQueryValidatorWare() AsyncMiddleware[combiner.PipelineResponse] {
return AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] {
return &queryValidatorWare{
next: next,
}
})
}

func (c queryValidatorWare) RoundTrip(req Request) (Responses[combiner.PipelineResponse], error) {
query := req.HTTPRequest().URL.Query()
err := c.validateTraceQLQuery(query)
if err != nil {
return NewBadRequest(err), nil
}
return c.next.RoundTrip(req)
}

func (c queryValidatorWare) validateTraceQLQuery(queryParams url.Values) error {
var traceQLQuery string
if queryParams.Has("q") {
traceQLQuery = queryParams.Get("q")
}
if queryParams.Has("query") {
traceQLQuery = queryParams.Get("query")
}
if traceQLQuery != "" {
expr, err := traceql.Parse(traceQLQuery)
if err == nil {
err = traceql.Validate(expr)
}
if err != nil {
return fmt.Errorf("invalid TraceQL query: %w", err)
}
}
return nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package pipeline

import (
"bytes"
"context"
"io"
"net/http"
"testing"

"github.com/grafana/tempo/modules/frontend/combiner"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var nextFunc = AsyncRoundTripperFunc[combiner.PipelineResponse](func(_ Request) (Responses[combiner.PipelineResponse], error) {
return NewHTTPToAsyncResponse(&http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte{})),
}), nil
})

func TestQueryValidator(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search", roundTrip)
assert.Equal(t, 200, statusCode)
}

func TestQueryValidatorForAValidQuery(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search&q={}", roundTrip)
assert.Equal(t, 200, statusCode)
}

func TestQueryValidatorForAnInvalidTraceQLQuery(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search?q={. hi}", roundTrip)
assert.Equal(t, 400, statusCode)
}

func TestQueryValidatorForAnInvalidTraceQlQueryRegex(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search?query={span.a =~ \"[\"}", roundTrip)
assert.Equal(t, 400, statusCode)
}

func doRequest(t *testing.T, url string, rt AsyncRoundTripper[combiner.PipelineResponse]) int {
req, _ := http.NewRequest(http.MethodGet, url, nil)
resp, _ := rt.RoundTrip(NewHTTPRequest(req))
httpResponse, _, err := resp.Next(context.Background())
require.NoError(t, err)
return httpResponse.HTTPResponse().StatusCode
}
7 changes: 0 additions & 7 deletions pkg/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,6 @@ func ParseSearchRequest(r *http.Request) (*tempopb.SearchRequest, error) {

query, queryFound := extractQueryParam(vals, urlParamQuery)
if queryFound {
// TODO hacky fix: we don't validate {} since this isn't handled correctly yet
if query != "{}" {
_, err := traceql.Parse(query)
if err != nil {
return nil, fmt.Errorf("invalid TraceQL query: %w", err)
}
}
req.Query = query
}

Expand Down
5 changes: 0 additions & 5 deletions pkg/api/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,6 @@ func TestQuerierParseSearchRequest(t *testing.T) {
SpansPerSpanSet: defaultSpansPerSpanSet,
},
},
{
name: "invalid traceql query",
urlQuery: "q=" + url.QueryEscape(`{ .foo="bar" `),
err: "invalid TraceQL query: parse error at line 1, col 14: syntax error: unexpected $end",
},
{
name: "traceql query and tags",
urlQuery: "q=" + url.QueryEscape(`{ .foo="bar" }`) + "&tags=" + url.QueryEscape("service.name=foo"),
Expand Down
5 changes: 5 additions & 0 deletions pkg/traceql/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package traceql

func Validate(expr *RootExpr) error {
return expr.validate()
}