Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions modules/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
statusCodeWare := pipeline.NewStatusCodeAdjustWare()
traceIDStatusCodeWare := pipeline.NewStatusCodeAdjustWareWithAllowedCode(http.StatusNotFound)
urlDenyListWare := pipeline.NewURLDenyListWare(cfg.URLDenyList)
queryValidatorWare := pipeline.NewQueryValidatorWare()

tracePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
Expand All @@ -106,6 +107,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
searchPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
urlDenyListWare,
queryValidatorWare,
multiTenantMiddleware(cfg, logger),
newAsyncSearchSharder(reader, o, cfg.Search.Sharder, logger),
},
Expand Down Expand Up @@ -143,6 +145,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
queryRangePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
urlDenyListWare,
queryValidatorWare,
multiTenantMiddleware(cfg, logger),
newAsyncQueryRangeSharder(reader, o, cfg.Metrics.Sharder, logger),
},
Expand Down
50 changes: 50 additions & 0 deletions modules/frontend/pipeline/query_validator_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package pipeline

import (
"fmt"
"net/url"

"github.com/grafana/tempo/modules/frontend/combiner"
"github.com/grafana/tempo/pkg/traceql"
)

type queryValidatorWare struct {
next AsyncRoundTripper[combiner.PipelineResponse]
}

func NewQueryValidatorWare() AsyncMiddleware[combiner.PipelineResponse] {
return AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] {
return &queryValidatorWare{
next: next,
}
})
}

func (c queryValidatorWare) RoundTrip(req Request) (Responses[combiner.PipelineResponse], error) {
query := req.HTTPRequest().URL.Query()
err := c.validateTraceQLQuery(query)
if err != nil {
return NewBadRequest(err), nil
}
return c.next.RoundTrip(req)
}

func (c queryValidatorWare) validateTraceQLQuery(queryParams url.Values) error {
var traceQLQuery string
if queryParams.Has("q") {
traceQLQuery = queryParams.Get("q")
}
if queryParams.Has("query") {
traceQLQuery = queryParams.Get("query")
}
if traceQLQuery != "" {
expr, err := traceql.Parse(traceQLQuery)
if err == nil {
err = traceql.Validate(expr)
}
if err != nil {
return fmt.Errorf("invalid TraceQL query: %w", err)
}
}
return nil
}
52 changes: 52 additions & 0 deletions modules/frontend/pipeline/query_validator_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package pipeline

import (
"bytes"
"context"
"io"
"net/http"
"testing"

"github.com/grafana/tempo/modules/frontend/combiner"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var nextFunc = AsyncRoundTripperFunc[combiner.PipelineResponse](func(_ Request) (Responses[combiner.PipelineResponse], error) {
return NewHTTPToAsyncResponse(&http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte{})),
}), nil
})

func TestQueryValidator(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search", roundTrip)
assert.Equal(t, 200, statusCode)
}

func TestQueryValidatorForAValidQuery(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search&q={}", roundTrip)
assert.Equal(t, 200, statusCode)
}

func TestQueryValidatorForAnInvalidTraceQLQuery(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search?q={. hi}", roundTrip)
assert.Equal(t, 400, statusCode)
}

func TestQueryValidatorForAnInvalidTraceQlQueryRegex(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search?query={span.a =~ \".*((?<!(-test))(?<!(-uat)))$\"}", roundTrip)
assert.Equal(t, 400, statusCode)
}

func doRequest(t *testing.T, url string, rt AsyncRoundTripper[combiner.PipelineResponse]) int {
req, _ := http.NewRequest(http.MethodGet, url, nil)
resp, _ := rt.RoundTrip(NewHTTPRequest(req))
httpResponse, _, err := resp.Next(context.Background())
require.NoError(t, err)
return httpResponse.HTTPResponse().StatusCode
}
7 changes: 0 additions & 7 deletions pkg/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,6 @@ func ParseSearchRequest(r *http.Request) (*tempopb.SearchRequest, error) {

query, queryFound := extractQueryParam(vals, urlParamQuery)
if queryFound {
// TODO hacky fix: we don't validate {} since this isn't handled correctly yet
if query != "{}" {
_, err := traceql.Parse(query)
if err != nil {
return nil, fmt.Errorf("invalid TraceQL query: %w", err)
}
}
req.Query = query
}

Expand Down
5 changes: 0 additions & 5 deletions pkg/api/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,6 @@ func TestQuerierParseSearchRequest(t *testing.T) {
SpansPerSpanSet: defaultSpansPerSpanSet,
},
},
{
name: "invalid traceql query",
urlQuery: "q=" + url.QueryEscape(`{ .foo="bar" `),
err: "invalid TraceQL query: parse error at line 1, col 14: syntax error: unexpected $end",
},
{
name: "traceql query and tags",
urlQuery: "q=" + url.QueryEscape(`{ .foo="bar" }`) + "&tags=" + url.QueryEscape("service.name=foo"),
Expand Down
5 changes: 5 additions & 0 deletions pkg/traceql/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package traceql

func Validate(expr *RootExpr) error {
return expr.validate()
}