|
10 | 10 | from typing import TYPE_CHECKING, TypeVar, cast
|
11 | 11 |
|
12 | 12 | from streamflow.core import utils
|
| 13 | +from streamflow.core.deployment import Connector, ExecutionLocation, Target |
13 | 14 | from streamflow.core.exception import WorkflowExecutionException
|
14 | 15 | from streamflow.core.persistence import (
|
15 | 16 | DatabaseLoadingContext,
|
16 | 17 | DependencyType,
|
17 | 18 | PersistableEntity,
|
18 | 19 | )
|
| 20 | +from streamflow.core.utils import get_class_from_name, get_class_fullname |
19 | 21 |
|
20 | 22 | if TYPE_CHECKING:
|
21 | 23 | from typing import Any
|
22 | 24 |
|
23 | 25 | from streamflow.core.context import StreamFlowContext
|
24 | 26 |
|
25 | 27 |
|
| 28 | +class Command(ABC): |
| 29 | + def __init__(self, step: Step): |
| 30 | + super().__init__() |
| 31 | + self.step: Step = step |
| 32 | + |
| 33 | + @abstractmethod |
| 34 | + async def execute(self, job: Job) -> CommandOutput: ... |
| 35 | + |
| 36 | + @classmethod |
| 37 | + async def load( |
| 38 | + cls, |
| 39 | + context: StreamFlowContext, |
| 40 | + row: MutableMapping[str, Any], |
| 41 | + loading_context: DatabaseLoadingContext, |
| 42 | + step: Step, |
| 43 | + ) -> Command: |
| 44 | + type_ = cast(type[Command], utils.get_class_from_name(row["type"])) |
| 45 | + return await type_._load(context, row["params"], loading_context, step) |
| 46 | + |
| 47 | + async def save(self, context: StreamFlowContext): |
| 48 | + return { |
| 49 | + "type": utils.get_class_fullname(type(self)), |
| 50 | + "params": await self._save_additional_params(context), |
| 51 | + } |
| 52 | + |
| 53 | + @classmethod |
| 54 | + async def _load( |
| 55 | + cls, |
| 56 | + context: StreamFlowContext, |
| 57 | + row: MutableMapping[str, Any], |
| 58 | + loading_context: DatabaseLoadingContext, |
| 59 | + step: Step, |
| 60 | + ): |
| 61 | + return cls(step=step) |
| 62 | + |
| 63 | + async def _save_additional_params( |
| 64 | + self, context: StreamFlowContext |
| 65 | + ) -> MutableMapping[str, Any]: |
| 66 | + return {} |
| 67 | + |
| 68 | + |
| 69 | +class CommandOptions(ABC): |
| 70 | + pass |
| 71 | + |
| 72 | + |
| 73 | +class CommandOutput: |
| 74 | + __slots__ = ("value", "status") |
| 75 | + |
| 76 | + def __init__(self, value: Any, status: Status): |
| 77 | + self.value: Any = value |
| 78 | + self.status: Status = status |
| 79 | + |
| 80 | + def update(self, value: Any): |
| 81 | + return CommandOutput(value=value, status=self.status) |
| 82 | + |
| 83 | + |
| 84 | +class CommandOutputProcessor(ABC): |
| 85 | + def __init__(self, name: str, workflow: Workflow, target: Target | None = None): |
| 86 | + self.name: str = name |
| 87 | + self.workflow: Workflow = workflow |
| 88 | + self.target: Target | None = target |
| 89 | + |
| 90 | + def _get_connector(self, connector: Connector | None, job: Job) -> Connector: |
| 91 | + return connector or self.workflow.context.scheduler.get_connector(job.name) |
| 92 | + |
| 93 | + async def _get_locations( |
| 94 | + self, connector: Connector | None, job: Job |
| 95 | + ) -> MutableSequence[ExecutionLocation]: |
| 96 | + if self.target: |
| 97 | + available_locations = await connector.get_available_locations( |
| 98 | + service=self.target.service |
| 99 | + ) |
| 100 | + return [loc.location for loc in available_locations.values()] |
| 101 | + else: |
| 102 | + return self.workflow.context.scheduler.get_locations(job.name) |
| 103 | + |
| 104 | + @classmethod |
| 105 | + async def _load( |
| 106 | + cls, |
| 107 | + context: StreamFlowContext, |
| 108 | + row: MutableMapping[str, Any], |
| 109 | + loading_context: DatabaseLoadingContext, |
| 110 | + ) -> CommandOutputProcessor: |
| 111 | + return cls( |
| 112 | + name=row["name"], |
| 113 | + workflow=await loading_context.load_workflow(context, row["workflow"]), |
| 114 | + target=( |
| 115 | + (await loading_context.load_target(context, row["workflow"])) |
| 116 | + if row["target"] |
| 117 | + else None |
| 118 | + ), |
| 119 | + ) |
| 120 | + |
| 121 | + async def _save_additional_params( |
| 122 | + self, context: StreamFlowContext |
| 123 | + ) -> MutableMapping[str, Any]: |
| 124 | + if self.target: |
| 125 | + await self.target.save(context) |
| 126 | + return { |
| 127 | + "name": self.name, |
| 128 | + "workflow": self.workflow.persistent_id, |
| 129 | + "target": self.target.persistent_id if self.target else None, |
| 130 | + } |
| 131 | + |
| 132 | + @classmethod |
| 133 | + async def load( |
| 134 | + cls, |
| 135 | + context: StreamFlowContext, |
| 136 | + row: MutableMapping[str, Any], |
| 137 | + loading_context: DatabaseLoadingContext, |
| 138 | + ) -> CommandOutputProcessor: |
| 139 | + type_ = cast( |
| 140 | + type[CommandOutputProcessor], utils.get_class_from_name(row["type"]) |
| 141 | + ) |
| 142 | + return await type_._load(context, row["params"], loading_context) |
| 143 | + |
| 144 | + @abstractmethod |
| 145 | + async def process( |
| 146 | + self, |
| 147 | + job: Job, |
| 148 | + command_output: CommandOutput, |
| 149 | + connector: Connector | None = None, |
| 150 | + ) -> Token | None: ... |
| 151 | + |
| 152 | + async def save(self, context: StreamFlowContext): |
| 153 | + return { |
| 154 | + "type": utils.get_class_fullname(type(self)), |
| 155 | + "params": await self._save_additional_params(context), |
| 156 | + } |
| 157 | + |
| 158 | + |
| 159 | +class CommandToken: |
| 160 | + __slots__ = ("name", "position", "value") |
| 161 | + |
| 162 | + def __init__(self, name: str | None, position: int | None, value: Any): |
| 163 | + self.name: str | None = name |
| 164 | + self.position: int | None = position |
| 165 | + self.value: Any = value |
| 166 | + |
| 167 | + |
| 168 | +class CommandTokenProcessor(ABC): |
| 169 | + def __init__(self, name: str): |
| 170 | + self.name: str = name |
| 171 | + |
| 172 | + @classmethod |
| 173 | + async def _load( |
| 174 | + cls, |
| 175 | + context: StreamFlowContext, |
| 176 | + row: MutableMapping[str, Any], |
| 177 | + loading_context: DatabaseLoadingContext, |
| 178 | + ): |
| 179 | + return cls(name=row["name"]) |
| 180 | + |
| 181 | + async def _save_additional_params( |
| 182 | + self, context: StreamFlowContext |
| 183 | + ) -> MutableMapping[str, Any]: |
| 184 | + return {"name": self.name} |
| 185 | + |
| 186 | + @abstractmethod |
| 187 | + def bind( |
| 188 | + self, |
| 189 | + token: Token | None, |
| 190 | + position: int | None, |
| 191 | + options: CommandOptions, |
| 192 | + ) -> CommandToken: ... |
| 193 | + |
| 194 | + @abstractmethod |
| 195 | + def check_type(self, token: Token) -> bool: ... |
| 196 | + |
| 197 | + @classmethod |
| 198 | + async def load( |
| 199 | + cls, |
| 200 | + context: StreamFlowContext, |
| 201 | + row: MutableMapping[str, Any], |
| 202 | + loading_context: DatabaseLoadingContext, |
| 203 | + ) -> CommandTokenProcessor: |
| 204 | + type_ = cast(type[CommandTokenProcessor], get_class_from_name(row["type"])) |
| 205 | + return await type_._load(context, row["params"], loading_context) |
| 206 | + |
| 207 | + async def save(self, context: StreamFlowContext): |
| 208 | + return { |
| 209 | + "type": get_class_fullname(type(self)), |
| 210 | + "params": await self._save_additional_params(context), |
| 211 | + } |
| 212 | + |
| 213 | + |
26 | 214 | class Executor(ABC):
|
27 | 215 | def __init__(self, workflow: Workflow):
|
28 | 216 | self.workflow: Workflow = workflow
|
@@ -514,6 +702,7 @@ def __init__(
|
514 | 702 | self.context: StreamFlowContext = context
|
515 | 703 | self.config: MutableMapping[str, Any] = config
|
516 | 704 | self.name: str = name if name is not None else str(uuid.uuid4())
|
| 705 | + self.input_ports: MutableMapping[str, str] = {} |
517 | 706 | self.output_ports: MutableMapping[str, str] = {}
|
518 | 707 | self.ports: MutableMapping[str, Port] = {}
|
519 | 708 | self.steps: MutableMapping[str, Step] = {}
|
@@ -548,6 +737,12 @@ def create_step(self, cls: type[S], name: str = None, **kwargs) -> S:
|
548 | 737 | self.steps[name] = step
|
549 | 738 | return step
|
550 | 739 |
|
| 740 | + def get_input_port(self, name: str) -> Port: |
| 741 | + return self.ports[self.input_ports[name]] |
| 742 | + |
| 743 | + def get_input_ports(self) -> MutableMapping[str, Port]: |
| 744 | + return {name: self.ports[p] for name, p in self.input_ports.items()} |
| 745 | + |
551 | 746 | def get_output_port(self, name: str) -> Port:
|
552 | 747 | return self.ports[self.output_ports[name]]
|
553 | 748 |
|
|
0 commit comments