Skip to content

Commit b055151

Browse files
authored
Add context to execution_date_fn in ExternalTaskSensor (apache#8702)
Co-authored-by: Ace Haidrey <ahaidrey@pinterest.com>
1 parent 9a4a2d1 commit b055151

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

airflow/sensors/external_task_sensor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def poke(self, context, session=None):
119119
if self.execution_delta:
120120
dttm = context['execution_date'] - self.execution_delta
121121
elif self.execution_date_fn:
122-
dttm = self.execution_date_fn(context['execution_date'])
122+
dttm = self._handle_execution_date_fn(context=context)
123123
else:
124124
dttm = context['execution_date']
125125

@@ -204,6 +204,26 @@ def get_count(self, dttm_filter, session, states):
204204
).scalar()
205205
return count
206206

207+
def _handle_execution_date_fn(self, context):
208+
"""
209+
This function is to handle backwards compatibility with how this operator was
210+
previously where it only passes the execution date, but also allow for the newer
211+
implementation to pass all context through as well, to allow for more sophisticated
212+
returns of dates to return.
213+
Namely, this function check the number of arguments in the execution_date_fn
214+
signature and if its 1, treat the legacy way, if it's 2, pass the context as
215+
the 2nd argument, and if its more, throw an exception.
216+
"""
217+
num_fxn_params = self.execution_date_fn.__code__.co_argcount
218+
if num_fxn_params == 1:
219+
return self.execution_date_fn(context['execution_date'])
220+
elif num_fxn_params == 2:
221+
return self.execution_date_fn(context['execution_date'], context)
222+
else:
223+
raise AirflowException(
224+
'execution_date_fn passed {} args but only allowed up to 2'.format(num_fxn_params)
225+
)
226+
207227

208228
class ExternalTaskMarker(DummyOperator):
209229
"""

tests/sensors/test_external_task_sensor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,28 @@ def test_external_task_sensor_fn(self):
308308
ignore_ti_state=True
309309
)
310310

311+
def test_external_task_sensor_fn_multiple_args(self):
312+
"""Check this task sensor passes multiple args with full context. If no failure, means clean run."""
313+
self.test_time_sensor()
314+
315+
def my_func(dt, context):
316+
assert context['execution_date'] == dt
317+
return dt + timedelta(0)
318+
319+
op1 = ExternalTaskSensor(
320+
task_id='test_external_task_sensor_multiple_arg_fn',
321+
external_dag_id=TEST_DAG_ID,
322+
external_task_id=TEST_TASK_ID,
323+
execution_date_fn=my_func,
324+
allowed_states=['success'],
325+
dag=self.dag
326+
)
327+
op1.run(
328+
start_date=DEFAULT_DATE,
329+
end_date=DEFAULT_DATE,
330+
ignore_ti_state=True
331+
)
332+
311333
def test_external_task_sensor_error_delta_and_fn(self):
312334
self.test_time_sensor()
313335
# Test that providing execution_delta and a function raises an error

0 commit comments

Comments
 (0)