From a57b7281fe00aac8d42a57d4912f70d3143f06aa Mon Sep 17 00:00:00 2001 From: Eoin O'Shaughnessy Date: Fri, 6 Jun 2025 11:59:53 +0100 Subject: [PATCH] feat: async support for process_ --- src/f5_ai_gateway_sdk/processor.py | 87 +++++++------ tests/contract/test_processor_exchanges.py | 116 ++++++------------ .../test_processor_routes_exchanges.py | 2 +- tests/libs/fakes/processors.py | 70 +++++++---- tests/test_processor.py | 42 +++++++ 5 files changed, 183 insertions(+), 134 deletions(-) diff --git a/src/f5_ai_gateway_sdk/processor.py b/src/f5_ai_gateway_sdk/processor.py index b8406b6..e1ca545 100644 --- a/src/f5_ai_gateway_sdk/processor.py +++ b/src/f5_ai_gateway_sdk/processor.py @@ -5,6 +5,7 @@ LICENSE file in the root directory of this source tree. """ +import inspect import json import logging from abc import ABC @@ -225,6 +226,13 @@ def __init_subclass__(cls, **kwargs): "The DEPRECATED 'process' method must not be implemented " "alongside 'process_input' or 'process_response'." ) + if is_process_overridden and inspect.iscoroutinefunction(cls.process): + # we don't want to add async capabilities to the deprecated function + raise TypeError( + f"Cannot create concrete class {cls.__name__}. " + "The DEPRECATED 'process' method does not support async. " + "Implement 'process_input' and/or 'process_response' instead." + ) return @@ -875,15 +883,18 @@ async def _parse_and_process(self, request: Request) -> Response: prompt_hash, response_hash = (None, None) if input_direction: prompt_hash = prompt.hash() - result: Result | Reject = self.process_input( + result = await self._handle_process_function( + self.process_input, metadata=metadata, parameters=parameters, prompt=prompt, request=request, ) + else: response_hash = response.hash() - result: Result | Reject = self.process_response( + result = await self._handle_process_function( + self.process_response, metadata=metadata, parameters=parameters, prompt=prompt, @@ -1014,7 +1025,16 @@ def _is_method_overridden(self, method_name: str) -> bool: # the method object directly from the Processor class, then it has been overridden. return instance_class_method_obj is not base_class_method_obj - def process_input( + async def _process_fallback(self, **kwargs) -> Result | Reject: + warnings.warn( + f"{type(self).__name__} uses the deprecated 'process' method. " + "Implement 'process_input' and/or 'process_response' instead.", + DeprecationWarning, + stacklevel=2, + ) + return await self._handle_process_function(self.process, **kwargs) + + async def process_input( self, prompt: PROMPT, metadata: Metadata, @@ -1043,26 +1063,20 @@ def process_input(self, prompt, response, metadata, parameters, request): return Result(processor_result=result) """ - if self._is_method_overridden("process"): - warnings.warn( - f"{type(self).__name__} uses the deprecated 'process' method for input. " - "Implement 'process_input' instead.", - DeprecationWarning, - stacklevel=2, # Points the warning to the caller of process_input + if not self._is_method_overridden("process"): + raise NotImplementedError( + f"{type(self).__name__} must implement 'process_input' or the " + "deprecated 'process' method to handle input." ) - return self.process( - prompt=prompt, - response=None, - metadata=metadata, - parameters=parameters, - request=request, - ) - raise NotImplementedError( - f"{type(self).__name__} must implement 'process_input' or the " - "deprecated 'process' method to handle input." + return await self._process_fallback( + prompt=prompt, + response=None, + metadata=metadata, + parameters=parameters, + request=request, ) - def process_response( + async def process_response( self, prompt: PROMPT | None, response: RESPONSE, @@ -1096,23 +1110,17 @@ def process_response(self, prompt, response, metadata, parameters, request): return Result(processor_result=result) """ - if self._is_method_overridden("process"): - warnings.warn( - f"{type(self).__name__} uses the deprecated 'process' method for response. " - "Implement 'process_response' instead.", - DeprecationWarning, - stacklevel=2, # Points the warning to the caller of process_input + if not self._is_method_overridden("process"): + raise NotImplementedError( + f"{type(self).__name__} must implement 'process_response' or the " + "deprecated 'process' method to handle input." ) - return self.process( - prompt=prompt, - response=response, - metadata=metadata, - parameters=parameters, - request=request, - ) - raise NotImplementedError( - f"{type(self).__name__} must implement 'process_response' or the " - "deprecated 'process' method to handle input." + return await self._process_fallback( + prompt=prompt, + response=response, + metadata=metadata, + parameters=parameters, + request=request, ) def process( @@ -1159,6 +1167,13 @@ def process(self, prompt, response, metadata, parameters, request): "'process_input'/'process_response'." ) + async def _handle_process_function(self, func, **kwargs) -> Result | Reject: + if inspect.iscoroutinefunction(func): + result = await func(**kwargs) + else: + result = func(**kwargs) + return result + def _validation_error_as_messages(err: ValidationError) -> list[str]: return [_error_details_to_str(e) for e in err.errors()] diff --git a/tests/contract/test_processor_exchanges.py b/tests/contract/test_processor_exchanges.py index 6114f7a..5b40fb0 100644 --- a/tests/contract/test_processor_exchanges.py +++ b/tests/contract/test_processor_exchanges.py @@ -42,6 +42,12 @@ SIGNATURE_PATH = f"signature/{PROCESSOR_NAMESPACE}/{PROCESSOR_NAME}" CONTENT_TYPE = "application/json" +TEST_PROCESSORS = [ + (fake_processors.JudgySync), + (fake_processors.JudgyAsync), + (fake_processors.DeprecatedJudgy), +] + def test_multipart_fields_breaking_change(): """Verify that the multipart_fields have not changed without this test failing. @@ -99,9 +105,7 @@ def test_multipart_fields_breaking_change(): ), f"{result - expected} and {expected - result} should be empty" -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_processor_response_parameters_a_prompt_mismatch( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -151,9 +155,7 @@ def test_processor_response_parameters_a_prompt_mismatch( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_processor_overload_both_parameters( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -205,9 +207,7 @@ def test_processor_overload_both_parameters( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_processor_500_raising( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -257,28 +257,26 @@ def test_processor_500_raising( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_processor_returns_none( data_loader, processor_client_loader, test_logger, judgy_class ): """Verify that with a stood up processor that the request will reject a request with no prompt.""" - if judgy_class == fake_processors.Judgy: + if judgy_class.uses_process_method(): class NoneReturningProcessor(judgy_class): - def process_input(*_, **__): + def process(*_, **__): """Return None as a matter of existence.""" return None + else: - def process_response(*_, **__): + class NoneReturningProcessor(judgy_class): + def process_input(*_, **__): """Return None as a matter of existence.""" return None - else: - class NoneReturningProcessor(judgy_class): - def process(*_, **__): + def process_response(*_, **__): """Return None as a matter of existence.""" return None @@ -331,9 +329,7 @@ def process(*_, **__): ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_processor_returns_bogus_class( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -347,12 +343,12 @@ class BogusClass: tags = Tags() """Bogus class placeholder that is not a valid response object.""" - if judgy_class == fake_processors.Judgy: + if judgy_class.uses_process_method(): class BogusClassReturningProcessor(judgy_class): """Bogus processor whose process method returns BogusClass type.""" - def process_input(*_, **__): + def process(*_, **__): """Return BogusClass type as a matter of existence.""" return BogusClass() else: @@ -360,7 +356,7 @@ def process_input(*_, **__): class BogusClassReturningProcessor(judgy_class): """Bogus processor whose process method returns BogusClass type.""" - def process(*_, **__): + def process_input(*_, **__): """Return BogusClass type as a matter of existence.""" return BogusClass() @@ -411,9 +407,7 @@ def process(*_, **__): ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_raising_processor( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -465,9 +459,7 @@ def test_raising_processor( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_no_prompt( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -516,9 +508,7 @@ def test_request_no_prompt( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_null_parameters( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -567,9 +557,7 @@ def test_request_null_parameters( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_empty_metadata( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -619,9 +607,7 @@ def test_request_empty_metadata( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_invalid_metadata( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -696,9 +682,7 @@ def test_request_invalid_metadata( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_string_metadata( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -742,9 +726,7 @@ def test_request_string_metadata( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_query_get_command(processor_client_loader, test_logger, judgy_class): """Verify that with a ?command=parameters we get the parameters back.""" expected_status_code = http_status_codes.HTTP_200_OK @@ -775,9 +757,7 @@ def test_request_query_get_command(processor_client_loader, test_logger, judgy_c assert response.json()["parameters"] == judgy.parameters_class.model_json_schema() -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_query_post_command_invalid_json( processor_client_loader, test_logger, judgy_class ): @@ -821,9 +801,7 @@ def test_request_query_post_command_invalid_json( ] -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_query_post_command_invalid_parameters( processor_client_loader, test_logger, judgy_class ): @@ -867,9 +845,7 @@ def test_request_query_post_command_invalid_parameters( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_invalid_parameters( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -919,9 +895,7 @@ def test_request_invalid_parameters( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_required_parameters_missing( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -970,9 +944,7 @@ def test_request_required_parameters_missing( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_required_parameters_present( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -1016,9 +988,7 @@ def test_request_required_parameters_present( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_required_metadata_response_fields( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -1054,9 +1024,7 @@ def test_request_required_metadata_response_fields( assert "processor_version" in response.text -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_request_required_parameters_missing_multipart( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -1100,9 +1068,7 @@ def test_request_required_parameters_missing_multipart( ) -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_modification_with_reject( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -1110,10 +1076,10 @@ def test_modification_with_reject( expected_valid_status_code = http_status_codes.HTTP_400_BAD_REQUEST method = "post" - if judgy_class == fake_processors.Judgy: + if judgy_class.uses_process_method(): class ModifyAndRejectProcessor(judgy_class): - def process_input(*_, **__): + def process(*_, **__): """Return None as a matter of existence.""" return Result( modified_prompt=RequestInput( @@ -1126,7 +1092,7 @@ def process_input(*_, **__): else: class ModifyAndRejectProcessor(judgy_class): - def process(*_, **__): + def process_input(*_, **__): """Return None as a matter of existence.""" return Result( modified_prompt=RequestInput( @@ -1171,9 +1137,7 @@ def process(*_, **__): assert "mutually exclusive" in response.text -@pytest.mark.parametrize( - "judgy_class", [(fake_processors.Judgy), (fake_processors.DeprecatedJudgy)] -) +@pytest.mark.parametrize("judgy_class", TEST_PROCESSORS) def test_get_signature_definition( data_loader, processor_client_loader, test_logger, judgy_class ): @@ -1267,7 +1231,7 @@ def retrieval(): return multipart(header, retrieval) -def fake_judgy(judgy_class=fake_processors.Judgy) -> fake_processors.Judgy: +def fake_judgy(judgy_class=fake_processors.JudgySync) -> fake_processors.JudgySync: return judgy_class( PROCESSOR_NAME, PROCESSOR_VERSION, diff --git a/tests/contract/test_processor_routes_exchanges.py b/tests/contract/test_processor_routes_exchanges.py index 744e73f..7e5ce02 100644 --- a/tests/contract/test_processor_routes_exchanges.py +++ b/tests/contract/test_processor_routes_exchanges.py @@ -156,7 +156,7 @@ def test_processor_routes_get_with_one_processor( width=width, background=background, ).strip() - judgy = fake_processors.Judgy( + judgy = fake_processors.JudgyAsync( processor_name, processor_version, processor_namespace, diff --git a/tests/libs/fakes/processors.py b/tests/libs/fakes/processors.py index c950d71..c8b89e4 100644 --- a/tests/libs/fakes/processors.py +++ b/tests/libs/fakes/processors.py @@ -7,6 +7,8 @@ SDK Processor that is dynamically judgy based upon the request and is self-aware for reporting. """ +import asyncio +import functools from pydantic import Field from starlette.requests import Request @@ -88,15 +90,19 @@ class JudgyRequiredParameters(JudgyParameters): required_message: str -class Judgy(Processor): +class JudgySync(Processor): """Complete processor that behaves differently depending on JudgyParameters settings.""" + @classmethod + def uses_process_method(cls): + return False + def __init__(self, *processor_args, **processor_kwargs): """Allow for exceptions to be raised from Judgy during process().""" self.raise_error = None super().__init__(signature=BOTH_SIGNATURE, *processor_args, **processor_kwargs) - def process_input( + async def process_input( self, prompt: RequestInput, metadata: Metadata, @@ -105,13 +111,12 @@ def process_input( ) -> Result | Reject: return self._internal_process( prompt=prompt, - response=None, metadata=metadata, parameters=parameters, request=request, ) - def process_response( + async def process_response( self, prompt: RequestInput, response: ResponseOutput, @@ -130,10 +135,10 @@ def process_response( def _internal_process( self, prompt: RequestInput, - response: ResponseOutput, metadata: Metadata, parameters: JudgyParameters, request: Request, + response: ResponseOutput | None = None, ) -> Result | Reject: """Respond dynamically based upon parameters given to the object initially by the test.""" if isinstance((raise_error := self.raise_error), Exception): @@ -167,32 +172,55 @@ def _internal_process( return Result(**my_response) +class JudgyAsync(Processor): + """ + Implementation using async methods for process_input and process_response + """ + + @classmethod + def uses_process_method(cls): + return False + + def __init__(self, *processor_args, **processor_kwargs): + """Allow for exceptions to be raised from Judgy during process().""" + self.raise_error = None + self._internal_judgy = JudgySync(*processor_args, **processor_kwargs) + super().__init__(signature=BOTH_SIGNATURE, *processor_args, **processor_kwargs) + + async def process_input(self, **kwargs) -> Result | Reject: + if isinstance((raise_error := self.raise_error), Exception): + raise raise_error + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, functools.partial(self._internal_judgy._internal_process, **kwargs) + ) + + async def process_response(self, **kwargs) -> Result | Reject: + if isinstance((raise_error := self.raise_error), Exception): + raise raise_error + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, functools.partial(self._internal_judgy._internal_process, **kwargs) + ) + + class DeprecatedJudgy(Processor): """ Implementation using the deprecated process method instead of process_input and process_response """ + @classmethod + def uses_process_method(cls): + return True + def __init__(self, *processor_args, **processor_kwargs): """Allow for exceptions to be raised from Judgy during process().""" self.raise_error = None - self._internal_judgy = Judgy(*processor_args, **processor_kwargs) + self._internal_judgy = JudgySync(*processor_args, **processor_kwargs) super().__init__(signature=BOTH_SIGNATURE, *processor_args, **processor_kwargs) - def process( - self, - prompt: RequestInput, - response: ResponseOutput, - metadata: Metadata, - parameters: JudgyParameters, - request: Request, - ) -> Result | Reject: + def process(self, **kwargs) -> Result | Reject: """Respond dynamically based upon parameters given to the object initially by the test.""" if isinstance((raise_error := self.raise_error), Exception): raise raise_error - return self._internal_judgy._internal_process( - prompt=prompt, - response=response, - metadata=metadata, - parameters=parameters, - request=request, - ) + return self._internal_judgy._internal_process(**kwargs) diff --git a/tests/test_processor.py b/tests/test_processor.py index 2586a60..09259a4 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -813,6 +813,48 @@ def process_response(self): AllImplementedProcessor() self.assertIn(expected_message, err.value.args, str(err.value.args)) + def test_async_implemented(self): + class AsyncImplementedProcessor(Processor): + def __init__(self): + super().__init__( + name="non-implemented-processor", + namespace="fake", + signature=BOTH_SIGNATURE, + version="v1", + ) + + async def process_input(self): + return Result() + + async def process_response(self): + return Result() + + self.assertIsNotNone(AsyncImplementedProcessor()) + + def test_async_process_implemented(self): + expected_message = ( + "Cannot create concrete class AsyncProcessImplementedProcessor. " + "The DEPRECATED 'process' method does not support async. " + "Implement 'process_input' and/or 'process_response' instead." + ) + + with pytest.raises(TypeError) as err: + + class AsyncProcessImplementedProcessor(Processor): + def __init__(self): + super().__init__( + name="non-implemented-processor", + namespace="fake", + signature=BOTH_SIGNATURE, + version="v1", + ) + + async def process(self): + return Result() + + AsyncProcessImplementedProcessor() + self.assertIn(expected_message, err.value.args, str(err.value.args)) + def test_input_signature_match(self): """Verify we can instantiate a correct input-only processor"""