@@ -134,7 +134,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
134134 }
135135}
136136
137- func (s * Server ) HandleRequestHeaders (ctx context.Context , reqeustID string , req * extProcPb.ProcessingRequest ) (* extProcPb.ProcessingResponse , string , string ) {
137+ func (s * Server ) HandleRequestHeaders (ctx context.Context , requestID string , req * extProcPb.ProcessingRequest ) (* extProcPb.ProcessingResponse , string , string ) {
138138 klog .Info ("--- In RequestHeaders processing ..." )
139139 var username , model , routingStrategy , targetPodIP string
140140 r := req .Request
@@ -155,52 +155,21 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
155155 }
156156 }
157157
158- user , err := utils .GetUser (utils.User {Name : username }, s .redisClient )
159- if err != nil {
160- return generateErrorResponse (
161- envoyTypePb .StatusCode_Forbidden ,
162- []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
163- Key : "x-user-missing" , RawValue : []byte ("true" ),
164- }}},
165- fmt .Sprintf ("pre query: username is missing: %v" , err .Error ())), username , targetPodIP
166- }
167-
168- if user .Rpm == 0 {
169- user .Rpm = int64 (defaultRPM )
170- }
171- if user .Tpm == 0 {
172- user .Tpm = user .Rpm * int64 (defaultTPMMultiplier )
173- }
174-
175- code , err := s .checkRPM (ctx , username , user .Rpm )
176- if err != nil {
177- return generateErrorResponse (
178- code ,
179- []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
180- Key : "x-rpm-exceeded" , RawValue : []byte ("true" ),
181- }}},
182- fmt .Sprintf ("pre query: error on checking rpm: %v" , err .Error ())), username , targetPodIP
183- }
184-
185- rpm , code , err := s .incrRPM (ctx , username )
186- if err != nil {
187- return generateErrorResponse (
188- code ,
189- []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
190- Key : "x-error-update-rpm" , RawValue : []byte ("true" ),
191- }}},
192- fmt .Sprintf ("pre query: error on updating rpm: %v" , err .Error ())), username , targetPodIP
193- }
194- klog .Infof ("RequestStart %s: RPM: %v for user: %v" , reqeustID , rpm , user .Name )
158+ if username != "" {
159+ user , err := utils .GetUser (utils.User {Name : username }, s .redisClient )
160+ if err != nil {
161+ return generateErrorResponse (
162+ envoyTypePb .StatusCode_Forbidden ,
163+ []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
164+ Key : "x-user-missing" , RawValue : []byte ("true" ),
165+ }}},
166+ fmt .Sprintf ("pre query: username is missing: %v" , err .Error ())), username , targetPodIP
167+ }
195168
196- code , err = s .checkTPM (ctx , username , user .Tpm )
197- if err != nil {
198- return generateErrorResponse (
199- code ,
200- []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
201- Key : "x-tpm-exceeded" , RawValue : []byte ("true" ),
202- }}},
203- fmt .Sprintf ("pre query: error on checking tpm: %v" , err .Error ())), username , targetPodIP
169+ errRes := s .checkLimits (ctx , requestID , user )
170+ if errRes != nil {
171+ return errRes , user .Name , targetPodIP
172+ }
204173 }
205174
206175 headers := []* configPb.HeaderValueOption {
@@ -210,18 +179,19 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
210179 RawValue : []byte ("true" ),
211180 },
212181 },
213- {
214- Header : & configPb.HeaderValue {
215- Key : "x-updated-rpm" ,
216- RawValue : []byte (fmt .Sprintf ("%d" , rpm )),
217- },
218- },
182+ // TODO (varun): refactor this part with model name input from request body
183+ // {
184+ // Header: &configPb.HeaderValue{
185+ // Key: "x-updated-rpm",
186+ // RawValue: []byte(fmt.Sprintf("%d", rpm)),
187+ // },
188+ // },
219189 }
220190 if routingStrategy != "" {
221191 pods , err := s .cache .GetPodsForModel (model )
222192 if len (pods ) == 0 || err != nil {
223193 return generateErrorResponse (
224- code ,
194+ envoyTypePb . StatusCode_InternalServerError ,
225195 []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
226196 Key : "x-no-model-deployment" , RawValue : []byte ("true" ),
227197 }}},
@@ -231,7 +201,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
231201 targetPodIP , err = s .selectTargetPod (ctx , routingStrategy , pods )
232202 if targetPodIP == "" || err != nil {
233203 return generateErrorResponse (
234- code ,
204+ envoyTypePb . StatusCode_InternalServerError ,
235205 []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
236206 Key : "x-select-target-pod" , RawValue : []byte ("true" ),
237207 }}},
@@ -244,7 +214,7 @@ func (s *Server) HandleRequestHeaders(ctx context.Context, reqeustID string, req
244214 RawValue : []byte (targetPodIP ),
245215 },
246216 })
247- klog .Infof ("RequestStart %s: SelectedTargetPodIP: %s" , reqeustID , targetPodIP )
217+ klog .Infof ("RequestStart %s: SelectedTargetPodIP: %s" , requestID , targetPodIP )
248218 }
249219
250220 resp := & extProcPb.ProcessingResponse {
@@ -340,29 +310,32 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req *
340310 err .Error ())
341311 }
342312
343- tpm , err := s .ratelimiter .Incr (ctx , fmt .Sprintf ("%v_TPM_CURRENT" , user ), int64 (res .Usage .TotalTokens ))
344- if err != nil {
345- return generateErrorResponse (
346- envoyTypePb .StatusCode_InternalServerError ,
347- []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
348- Key : "x-error-update-tpm" , RawValue : []byte ("true" ),
349- }}},
350- fmt .Sprintf ("post query: error on updating tpm: %v" , err .Error ()))
313+ if user != "" {
314+ tpm , err := s .ratelimiter .Incr (ctx , fmt .Sprintf ("%v_TPM_CURRENT" , user ), int64 (res .Usage .TotalTokens ))
315+ if err != nil {
316+ return generateErrorResponse (
317+ envoyTypePb .StatusCode_InternalServerError ,
318+ []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
319+ Key : "x-error-update-tpm" , RawValue : []byte ("true" ),
320+ }}},
321+ fmt .Sprintf ("post query: error on updating tpm: %v" , err .Error ()))
322+ }
323+ klog .Infof ("RequestEnd %s: TPM: %v for user: %v" , reqeustID , tpm , user )
351324 }
352- klog .Infof ("RequestEnd %s: TPM: %v for user: %v" , reqeustID , tpm , user )
353325
354326 return & extProcPb.ProcessingResponse {
355327 Response : & extProcPb.ProcessingResponse_ResponseBody {
356328 ResponseBody : & extProcPb.BodyResponse {
357329 Response : & extProcPb.CommonResponse {
358330 HeaderMutation : & extProcPb.HeaderMutation {
359331 SetHeaders : []* configPb.HeaderValueOption {
360- {
361- Header : & configPb.HeaderValue {
362- Key : "x-updated-tpm" ,
363- RawValue : []byte (fmt .Sprintf ("%d" , tpm )),
364- },
365- },
332+ // TODO (varun): refactor with read model name from body
333+ // {
334+ // Header: &configPb.HeaderValue{
335+ // Key: "x-updated-tpm",
336+ // RawValue: []byte(fmt.Sprintf("%d", tpm)),
337+ // },
338+ // },
366339 },
367340 },
368341 },
@@ -371,6 +344,48 @@ func (s *Server) HandleResponseBody(ctx context.Context, reqeustID string, req *
371344 }
372345}
373346
347+ func (s * Server ) checkLimits (ctx context.Context , requestID string , user utils.User ) * extProcPb.ProcessingResponse {
348+ if user .Rpm == 0 {
349+ user .Rpm = int64 (defaultRPM )
350+ }
351+ if user .Tpm == 0 {
352+ user .Tpm = user .Rpm * int64 (defaultTPMMultiplier )
353+ }
354+
355+ code , err := s .checkRPM (ctx , user .Name , user .Rpm )
356+ if err != nil {
357+ return generateErrorResponse (
358+ code ,
359+ []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
360+ Key : "x-rpm-exceeded" , RawValue : []byte ("true" ),
361+ }}},
362+ fmt .Sprintf ("pre query: error on checking rpm: %v" , err .Error ()))
363+ }
364+
365+ rpm , code , err := s .incrRPM (ctx , user .Name )
366+ if err != nil {
367+ return generateErrorResponse (
368+ code ,
369+ []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
370+ Key : "x-error-update-rpm" , RawValue : []byte ("true" ),
371+ }}},
372+ fmt .Sprintf ("pre query: error on updating rpm: %v" , err .Error ()))
373+ }
374+ klog .Infof ("RequestStart %s: RPM: %v for user: %v" , requestID , rpm , user .Name )
375+
376+ code , err = s .checkTPM (ctx , user .Name , user .Tpm )
377+ if err != nil {
378+ return generateErrorResponse (
379+ code ,
380+ []* configPb.HeaderValueOption {{Header : & configPb.HeaderValue {
381+ Key : "x-tpm-exceeded" , RawValue : []byte ("true" ),
382+ }}},
383+ fmt .Sprintf ("pre query: error on checking tpm: %v" , err .Error ()))
384+ }
385+
386+ return nil
387+ }
388+
374389func (s * Server ) checkRPM (ctx context.Context , user string , rpmLimit int64 ) (envoyTypePb.StatusCode , error ) {
375390 rpmCurrent , err := s .ratelimiter .Get (ctx , fmt .Sprintf ("%v_RPM_CURRENT" , user ))
376391 if err != nil {
0 commit comments