Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions pkg/graphql/resolver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,31 @@ const HeadersKey string = "headers"

func contextToHeaders(ctx context.Context, request *http.Request) {
if ctx.Value(HeadersKey) != nil {
headers, ok := ctx.Value(HeadersKey).(map[string]string)
headers, ok := ctx.Value(HeadersKey).(map[string]interface{})
if ok {
for key, value := range headers {
if value != "" {
request.Header.Add(key, value)
switch v := value.(type) {
case string:
if v != "" {
request.Header.Add(key, v)
}
case []string:
// Handle multiple values for the same header (e.g., Impersonate-Group)
for _, val := range v {
if val != "" {
request.Header.Add(key, val)
}
}
}
}
}
}
}

type initPayload struct {
ImpersonateUser string `json:"Impersonate-User"`
ImpersonateGroup string `json:"Impersonate-Group"`
ImpersonateUser string `json:"Impersonate-User"`
ImpersonateGroup string `json:"Impersonate-Group"`
ImpersonateGroups []string `json:"Impersonate-Groups"`
}

func InitPayload(ctx context.Context, payload json.RawMessage) context.Context {
Expand All @@ -33,10 +44,19 @@ func InitPayload(ctx context.Context, payload json.RawMessage) context.Context {
if err != nil {
return ctx
}
headers, ok := ctx.Value(HeadersKey).(map[string]string)
headers, ok := ctx.Value(HeadersKey).(map[string]interface{})
if ok {
headers["Impersonate-User"] = initPayload.ImpersonateUser
headers["Impersonate-Group"] = initPayload.ImpersonateGroup
if initPayload.ImpersonateUser != "" {
headers["Impersonate-User"] = initPayload.ImpersonateUser
}
// Support both single group (backward compatibility) and multiple groups
if len(initPayload.ImpersonateGroups) > 0 {
groups := initPayload.ImpersonateGroups
groups = append(groups, "system:authenticated")
headers["Impersonate-Group"] = groups
} else if initPayload.ImpersonateGroup != "" {
headers["Impersonate-Group"] = []string{initPayload.ImpersonateGroup, "system:authenticated"}
}
ctx = context.WithValue(ctx, HeadersKey, headers)
}
return ctx
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func (s *Server) HTTPHandler() (http.Handler, error) {
handler.InitPayload = resolver.InitPayload
graphQLHandler := handler.NewHandlerFunc(schema, gql.NewHttpHandler(schema))
handle("/api/graphql", authHandlerWithUser(func(user *auth.User, w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(context.Background(), resolver.HeadersKey, map[string]string{
ctx := context.WithValue(context.Background(), resolver.HeadersKey, map[string]interface{}{
"Authorization": fmt.Sprintf("Bearer %s", user.Token),
})
graphQLHandler(w, r.WithContext(ctx))
Expand Down