Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions lib/galaxy/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Callable,
Optional,
TYPE_CHECKING,
Union,
)

import yaml
Expand Down Expand Up @@ -1019,7 +1020,7 @@ def __init__(
app: MinimalManagerApp,
use_persisted_destination: bool = False,
tool: Optional["Tool"] = None,
):
) -> None:
self.job_id = job.id
self.session_id = job.session_id
self.user_id = job.user_id
Expand All @@ -1029,7 +1030,7 @@ def __init__(
self.extra_filenames: list[str] = []
self.environment_variables: list[dict[str, str]] = []
self.interactivetools: list[dict[str, Any]] = []
self.command_line = None
self.command_line: Union[str, None] = None
self.version_command_line = None
self._dependency_shell_commands = None
# Tool versioning variables
Expand Down Expand Up @@ -2826,26 +2827,19 @@ class TaskWrapper(JobWrapper):

is_task = True

def __init__(self, task, queue):
def __init__(self, task: Task, queue: "BaseJobHandlerQueue") -> None:
self.task_id = task.id
super().__init__(task.job, queue)
if task.prepare_input_files_cmd is not None:
self.prepare_input_files_cmds = [task.prepare_input_files_cmd]
else:
self.prepare_input_files_cmds = None
self.prepare_input_files_cmds = (
[task.prepare_input_files_cmd] if task.prepare_input_files_cmd is not None else None
)
self.status = task.states.NEW

def can_split(self):
# Should the job handler split this job up? TaskWrapper should
# always return False as the job has already been split.
return False

def get_job(self):
if self.job_id:
return self.sa_session.get(Job, self.job_id)
else:
return None

def get_task(self):
return self.sa_session.get(Task, self.task_id)

Expand Down
54 changes: 36 additions & 18 deletions lib/galaxy/jobs/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
)
from typing import (
Any,
Generic,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)

Expand Down Expand Up @@ -325,7 +327,7 @@ def queue_job(self, job_wrapper: "MinimalJobWrapper") -> None:
def stop_job(self, job_wrapper):
raise NotImplementedError()

def recover(self, job, job_wrapper):
def recover(self, job: model.Job, job_wrapper: "MinimalJobWrapper") -> None:
raise NotImplementedError()

def build_command_line(
Expand Down Expand Up @@ -591,9 +593,15 @@ def _handle_runner_state(self, runner_state, job_state: "JobState"):
except Exception:
log.exception("Caught exception in runner state handler")

def fail_job(self, job_state: "JobState", exception=False, message="Job failed", full_status=None):
def fail_job(
self,
job_state: "JobState",
exception: bool = False,
message: str = "Job failed",
full_status: Union[dict[str, Any], None] = None,
) -> None:
job = job_state.job_wrapper.get_job()
if getattr(job_state, "stop_job", True) and job.state != model.Job.states.NEW:
if job_state.stop_job and job.state != model.Job.states.NEW:
self.stop_job(job_state.job_wrapper)
job_state.job_wrapper.reclaim_ownership()
self._handle_runner_state("failure", job_state)
Expand Down Expand Up @@ -705,13 +713,14 @@ class JobState:

runner_states = runner_states

def __init__(self, job_wrapper: "JobWrapper", job_destination: "JobDestination"):
def __init__(self, job_wrapper: "MinimalJobWrapper", job_destination: "JobDestination") -> None:
self.runner_state_handled = False
self.job_wrapper = job_wrapper
self.job_destination = job_destination
self.runner_state = None
self.redact_email_in_job_name = True
self._exit_code_file = None
self.stop_job = True
if self.job_wrapper:
self.redact_email_in_job_name = self.job_wrapper.app.config.redact_email_in_job_name

Expand Down Expand Up @@ -765,23 +774,26 @@ class AsynchronousJobState(JobState):
to communicate with distributed resource manager.
"""

old_state: Union["JobStateEnum", None]

def __init__(
self,
job_wrapper: "MinimalJobWrapper",
job_destination: "JobDestination",
*,
files_dir=None,
job_wrapper=None,
job_id: Union[str, None] = None,
job_file=None,
output_file=None,
error_file=None,
exit_code_file=None,
job_name=None,
job_destination=None,
):
) -> None:
super().__init__(job_wrapper, job_destination)
self.old_state: Union[JobStateEnum, None] = None
self.old_state = None
self._running = False
self.check_count = 0
self.start_time = None
self.start_time: Union[datetime.datetime, None] = None

# job_id is the DRM's job id, not the Galaxy job id
self.job_id = job_id
Expand All @@ -796,11 +808,11 @@ def __init__(
self.set_defaults(files_dir)

@property
def running(self):
def running(self) -> bool:
return self._running

@running.setter
def running(self, is_running):
def running(self, is_running: bool) -> None:
self._running = is_running
# This will be invalid for job recovery
if self.start_time is None:
Expand Down Expand Up @@ -834,22 +846,28 @@ def init_job_stream_files(self):
pass


class AsynchronousJobRunner(BaseJobRunner, Monitors):
T = TypeVar("T", bound=AsynchronousJobState)


class AsynchronousJobRunner(BaseJobRunner, Monitors, Generic[T]):
"""Parent class for any job runner that runs jobs asynchronously (e.g. via
a distributed resource manager). Provides general methods for having a
thread to monitor the state of asynchronous jobs and submitting those jobs
to the correct methods (queue, finish, cleanup) at appropriate times..
"""

monitor_queue: Queue[T]
watched: list[T]

def __init__(self, app: "GalaxyManagerApplication", nworkers: int, **kwargs) -> None:
super().__init__(app, nworkers, **kwargs)
# 'watched' and 'queue' are both used to keep track of jobs to watch.
# 'queue' is used to add new watched jobs, and can be called from
# any thread (usually by the 'queue_job' method). 'watched' must only
# be modified by the monitor thread, which will move items from 'queue'
# to 'watched' and then manage the watched jobs.
self.watched: list[AsynchronousJobState] = []
self.monitor_queue: Queue[AsynchronousJobState] = Queue()
self.watched = []
self.monitor_queue = Queue()

def _init_monitor_thread(self):
name = f"{self.runner_name}.monitor_thread"
Expand Down Expand Up @@ -892,7 +910,7 @@ def monitor(self):
# Sleep a bit before the next state check
time.sleep(self.app.config.job_runner_monitor_sleep)

def monitor_job(self, job_state: AsynchronousJobState) -> None:
def monitor_job(self, job_state: T) -> None:
self.monitor_queue.put(job_state)

def shutdown(self):
Expand All @@ -903,7 +921,7 @@ def shutdown(self):
self.shutdown_monitor()
super().shutdown()

def check_watched_items(self):
def check_watched_items(self) -> None:
"""
This method is responsible for iterating over self.watched and handling
state changes and updating self.watched with a new list of watched job
Expand All @@ -919,7 +937,7 @@ def check_watched_items(self):
self.watched = new_watched

# Subclasses should implement this unless they override check_watched_items all together.
def check_watched_item(self, job_state: AsynchronousJobState) -> Union[AsynchronousJobState, None]:
def check_watched_item(self, job_state: T) -> Union[T, None]:
raise NotImplementedError()

def _collect_job_output(self, job_id: int, external_job_id: Optional[str], job_state: JobState):
Expand All @@ -943,7 +961,7 @@ def _collect_job_output(self, job_id: int, external_job_id: Optional[str], job_s
which_try += 1
return collect_output_success, stdout, stderr

def finish_job(self, job_state: AsynchronousJobState):
def finish_job(self, job_state: T) -> None:
"""
Get the output/error for a finished job, pass to `job_wrapper.finish`
and cleanup all the job's temporary files.
Expand Down
22 changes: 12 additions & 10 deletions lib/galaxy/jobs/runners/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import time
from queue import Empty
from typing import (
Any,
TYPE_CHECKING,
Union,
)

from galaxy import model
Expand Down Expand Up @@ -81,7 +83,7 @@ def _add_resource_requirements(destination_params):
return rval


class AWSBatchJobRunner(AsynchronousJobRunner):
class AWSBatchJobRunner(AsynchronousJobRunner[AsynchronousJobState]):
"""
This runner uses container only. It requires that an AWS EFS is mounted as a local drive
and all Galaxy job-related paths, such as objects, job_directory, tool_directory and so
Expand Down Expand Up @@ -213,7 +215,7 @@ def __init__(self, app, nworkers, **kwargs):
)
self._batch_client = session.client("batch")

def queue_job(self, job_wrapper):
def queue_job(self, job_wrapper: "MinimalJobWrapper") -> None:
log.debug(f"Starting queue_job for job {job_wrapper.get_id_tag()}")
if not self.prepare_job(job_wrapper, include_metadata=False, modify_command_for_container=False):
log.debug(f"Not ready {job_wrapper.get_id_tag()}")
Expand All @@ -225,11 +227,11 @@ def queue_job(self, job_wrapper):
job_name, job_id = self._submit_job(job_def, job_wrapper, destination_params)
job_wrapper.set_external_id(job_id)
ajs = AsynchronousJobState(
files_dir=job_wrapper.working_directory,
job_wrapper=job_wrapper,
job_destination=job_destination,
files_dir=job_wrapper.working_directory,
job_name=job_name,
job_id=job_id,
job_destination=job_destination,
)
self.monitor_queue.put(ajs)

Expand Down Expand Up @@ -395,16 +397,16 @@ def stop_job(self, job_wrapper):
msg = "Job {name!r} is terminated"
log.debug(msg.format(name=job_name))

def recover(self, job, job_wrapper):
def recover(self, job: model.Job, job_wrapper: "MinimalJobWrapper") -> None:
msg = "(name!r/runner!r) is still in {state!s} state, adding to the runner monitor queue"
job_id = job.get_job_runner_external_id()
job_name = self.JOB_NAME_PREFIX + job_wrapper.get_id_tag()
ajs = AsynchronousJobState(
files_dir=job_wrapper.working_directory,
job_wrapper=job_wrapper,
job_destination=job_wrapper.job_destination,
files_dir=job_wrapper.working_directory,
job_id=str(job_id),
job_name=job_name,
job_destination=job_wrapper.job_destination,
)
if job.state in (model.Job.states.RUNNING, model.Job.states.STOPPED):
log.debug(msg.format(name=job.id, runner=job.job_runner_name, state=job.state))
Expand All @@ -417,9 +419,9 @@ def recover(self, job, job_wrapper):
ajs.running = False
self.monitor_queue.put(ajs)

def fail_job(self, job_state: JobState, exception=False, message="Job failed", full_status=None):
def fail_job(self, job_state: JobState, exception: bool = False, message: str = "Job failed", full_status: Union[dict[str, Any], None] = None) -> None:
job = job_state.job_wrapper.get_job()
if getattr(job_state, "stop_job", True) and job.state != model.Job.states.NEW:
if job_state.stop_job and job.state != model.Job.states.NEW:
self.stop_job(job_state.job_wrapper)
job_state.job_wrapper.reclaim_ownership()
self._handle_runner_state("failure", job_state)
Expand Down Expand Up @@ -460,7 +462,7 @@ def monitor(self):
# Sleep a bit before the next state check
time.sleep(max(self.app.config.job_runner_monitor_sleep, self.MIN_QUERY_INTERVAL))

def check_watched_items(self):
def check_watched_items(self) -> None:
done: set[str] = set()
self.check_watched_items_by_batch(0, len(self.watched), done)
self.watched = [ajs for ajs in self.watched if ajs.job_id not in done]
Expand Down
27 changes: 16 additions & 11 deletions lib/galaxy/jobs/runners/chronos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import functools
import logging
import os
from typing import Union
from typing import (
TYPE_CHECKING,
Union,
)

from galaxy import model
from galaxy.jobs.runners import (
Expand All @@ -10,6 +13,9 @@
)
from galaxy.util import unicodify

if TYPE_CHECKING:
from galaxy.jobs import MinimalJobWrapper

CHRONOS_IMPORT_MSG = (
"The Python 'chronos' package is required to use "
"this feature, please install it or correct the "
Expand Down Expand Up @@ -82,7 +88,7 @@ def _add_galaxy_environment_variables(cpus, memory):
return [{"name": "GALAXY_SLOTS", "value": cpus}, {"name": "GALAXY_MEMORY_MB", "value": memory}]


class ChronosJobRunner(AsynchronousJobRunner):
class ChronosJobRunner(AsynchronousJobRunner[AsynchronousJobState]):
runner_name = "ChronosRunner"
RUNNER_PARAM_SPEC_KEY = "runner_param_specs"
JOB_NAME_PREFIX = "galaxy-chronos-"
Expand Down Expand Up @@ -148,7 +154,7 @@ def __init__(self, app, nworkers, **kwargs):
)

@handle_exception_call
def queue_job(self, job_wrapper):
def queue_job(self, job_wrapper: "MinimalJobWrapper") -> None:
LOGGER.debug(f"Starting queue_job for job {job_wrapper.get_id_tag()}")
if not self.prepare_job(job_wrapper, include_metadata=False, modify_command_for_container=False):
LOGGER.debug(f"Not ready {job_wrapper.get_id_tag()}")
Expand All @@ -158,10 +164,10 @@ def queue_job(self, job_wrapper):
job_name = chronos_job_spec["name"]
self._chronos_client.add(chronos_job_spec)
ajs = AsynchronousJobState(
files_dir=job_wrapper.working_directory,
job_wrapper=job_wrapper,
job_id=job_name,
job_destination=job_destination,
files_dir=job_wrapper.working_directory,
job_id=job_name,
)
self.monitor_queue.put(ajs)

Expand All @@ -178,16 +184,15 @@ def stop_job(self, job_wrapper):
msg = "Job {name!r} not found. It cannot be terminated."
LOGGER.error(msg.format(name=job_name))

def recover(self, job, job_wrapper):
def recover(self, job: model.Job, job_wrapper: "MinimalJobWrapper") -> None:
msg = "(name!r/runner!r) is still in {state!s} state, adding to the runner monitor queue"
job_id = job.get_job_runner_external_id()
ajs = AsynchronousJobState(
files_dir=job_wrapper.working_directory,
job_wrapper=job_wrapper,
job_id=self.JOB_NAME_PREFIX + str(job_id),
job_destination=job_wrapper.job_destination,
files_dir=job_wrapper.working_directory,
job_id=self.JOB_NAME_PREFIX + str(job_id),
)
ajs.command_line = job.command_line
if job.state in (model.Job.states.RUNNING, model.Job.states.STOPPED):
LOGGER.debug(msg.format(name=job.id, runner=job.job_runner_external_id, state=job.state))
ajs.old_state = model.Job.states.RUNNING
Expand Down Expand Up @@ -241,14 +246,14 @@ def _mark_as_active(self, job_state: AsynchronousJobState) -> AsynchronousJobSta
def _mark_as_failed(self, job_state: AsynchronousJobState, reason: str) -> None:
_write_logfile(job_state.error_file, reason)
job_state.running = False
job_state.stop_job = True
job_state.stop_job = False
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assignment should be job_state.stop_job = True to properly stop the job, not False.

Suggested change
job_state.stop_job = False
job_state.stop_job = True

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this signals to the stop_job() method that the job on the cluster has already terminated and doesn't need to be stopped.

job_state.job_wrapper.change_state(model.Job.states.ERROR)
job_state.fail_message = reason
self.mark_as_failed(job_state)
return None

@handle_exception_call
def finish_job(self, job_state):
def finish_job(self, job_state: AsynchronousJobState) -> None:
super().finish_job(job_state)
self._chronos_client.delete(job_state.job_id)

Expand Down
Loading
Loading