Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion sky/client/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,14 @@ def upload_mounts_to_api_server(
"""

if server_common.is_api_server_local():
return dag, None
# Propagate blob id if any
preserved_blob_id: Optional[str] = None
for task_ in dag.tasks:
tid = getattr(task_, 'file_mounts_blob_id', None)
if tid is not None:
preserved_blob_id = tid
break
return dag, preserved_blob_id

def _full_path(src: str) -> str:
return os.path.abspath(os.path.expanduser(src))
Expand Down
4 changes: 4 additions & 0 deletions sky/server/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,10 @@ def _get_client_file_mounts_path(
yaml_utils.dump_yaml(str(translated_client_task_path), task_configs)

dag = dag_utils.load_dag_from_yaml(str(translated_client_task_path))
# Propagate blob id incase of resubmission in job controller
if file_mounts_blob_id is not None:
for task_in_dag in dag.tasks:
task_in_dag.file_mounts_blob_id = file_mounts_blob_id
Comment thread
aylei marked this conversation as resolved.
Outdated
return dag


Expand Down
6 changes: 6 additions & 0 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def __init__(
_volume_mounts: Optional[List[volume_lib.VolumeMount]] = None,
_metadata: Optional[Dict[str, Any]] = None,
_user_specified_yaml: Optional[str] = None,
_file_mounts_blob_id: Optional[str] = None,
):
"""Initializes a Task.

Expand Down Expand Up @@ -469,6 +470,9 @@ def _concat(commands: Optional[Union[str, List[str]]]) -> Optional[str]:
# For internal use only.
self.file_mounts_mapping: Optional[Dict[str,
str]] = _file_mounts_mapping
# Blob id of the uploaded file_mounts to handle server-internal task
# re-submissions (e.g. the jobs controller re-launch).
self.file_mounts_blob_id: Optional[str] = _file_mounts_blob_id
self.volume_mounts: Optional[List[volume_lib.VolumeMount]] = (
_volume_mounts)

Expand Down Expand Up @@ -759,6 +763,7 @@ def from_yaml_config(
event_callback=config.pop('event_callback', None),
api_server_access=config.pop('api_server_access', True),
_file_mounts_mapping=config.pop('file_mounts_mapping', None),
_file_mounts_blob_id=config.pop('file_mounts_blob_id', None),
_metadata=config.pop('_metadata', None),
_user_specified_yaml=user_specified_yaml,
)
Expand Down Expand Up @@ -1978,6 +1983,7 @@ def add_if_not_none(key, value, no_empty: bool = False):
})

add_if_not_none('file_mounts_mapping', self.file_mounts_mapping)
add_if_not_none('file_mounts_blob_id', self.file_mounts_blob_id)
add_if_not_none('volumes', self.volumes)
if self.volume_mounts is not None:
config['volume_mounts'] = [
Expand Down
3 changes: 3 additions & 0 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,9 @@ def get_task_schema():
'file_mounts_mapping': {
'type': 'object',
},
'file_mounts_blob_id': {
'type': 'string',
},
'config': _filter_schema(
get_config_schema(),
constants.OVERRIDEABLE_CONFIG_KEYS_IN_TASK),
Expand Down
37 changes: 37 additions & 0 deletions tests/unit_tests/test_sky/client/test_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Unit tests for sky/client/common.py."""
import os
import re
from unittest import mock

import sky
from sky.client import common as client_common
from sky.client.common import _compute_zip_blob_id
from sky.data import storage_utils

Expand Down Expand Up @@ -109,3 +112,37 @@ def test_blob_id_empty_input(tmp_path):
hash_val = _compute_zip_blob_id(zip_path)
assert len(hash_val) == 64
assert re.fullmatch(r'[0-9a-f]{64}', hash_val)


@mock.patch('sky.client.common.server_common.is_api_server_local',
return_value=True)
def test_upload_mounts_local_preserves_stashed_blob_id(_mock_is_local):
"""Local server short-circuit must return a blob id stashed on the task.

This is what keeps the inner request (produced by a server-internal
re-submission, e.g. consolidation-mode jobs controller) carrying the
original blob id, so a replica dequeuing that request can re-resolve
the blob directory locally.
"""
blob_id = 'f' * 64
dag = sky.Dag()
t = sky.Task(name='t', run='echo hi')
t.file_mounts_blob_id = blob_id
dag.add(t)

out_dag, out_blob_id = client_common.upload_mounts_to_api_server(dag)
assert out_dag is dag
assert out_blob_id == blob_id


@mock.patch('sky.client.common.server_common.is_api_server_local',
return_value=True)
def test_upload_mounts_local_no_stash_returns_none(_mock_is_local):
"""User-authored Tasks never set file_mounts_blob_id; the local
short-circuit must return None, preserving pre-fix behavior."""
dag = sky.Dag()
dag.add(sky.Task(name='t', run='echo hi'))

out_dag, out_blob_id = client_common.upload_mounts_to_api_server(dag)
assert out_dag is dag
assert out_blob_id is None
68 changes: 68 additions & 0 deletions tests/unit_tests/test_sky/server/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,71 @@ def test_process_mounts_without_mapping(tmp_path, monkeypatch):
# Verify the dag was created successfully
assert dag is not None
assert len(dag.tasks) == 1


def test_process_mounts_stashes_blob_id_on_tasks(tmp_path, monkeypatch):
"""Blob id passed in should be stashed onto each task in the returned dag.

This lets server-internal re-submissions (e.g. the jobs controller in
consolidation mode calling sky.launch against the local API server)
propagate the blob id onto the new request's DB row, so a replica that
dequeues the inner request can re-resolve the blob directory locally.
"""
from sky.skylet import constants as skylet_constants

api_server_dir = tmp_path / 'api_server_clients'
monkeypatch.setattr('sky.server.common.API_SERVER_CLIENT_DIR',
api_server_dir)

blob_id = 'a' * 64
blob_dir = tmp_path / 'blobs' / blob_id
blob_dir.mkdir(parents=True)

monkeypatch.setattr('sky.server.common.resolve_blob_dir',
lambda bid, uh: str(blob_dir))

task_yaml = '''
name: test-task
resources:
cloud: aws
run: echo "hello"
'''

env_vars = {skylet_constants.USER_ID_ENV_VAR: 'test-user'}

dag = common.process_mounts_in_task_on_api_server(
task=task_yaml,
env_vars=env_vars,
workdir_only=False,
file_mounts_blob_id=blob_id)

assert len(dag.tasks) >= 1
for task in dag.tasks:
assert task.file_mounts_blob_id == blob_id


def test_process_mounts_no_blob_id_leaves_tasks_unstashed(
tmp_path, monkeypatch):
"""When called without a blob id, tasks in the returned dag must not have
``file_mounts_blob_id`` set. Regression guard for the standalone path."""
from sky.skylet import constants as skylet_constants

api_server_dir = tmp_path / 'api_server_clients'
monkeypatch.setattr('sky.server.common.API_SERVER_CLIENT_DIR',
api_server_dir)

task_yaml = '''
name: test-task
resources:
cloud: aws
run: echo "hello"
'''

env_vars = {skylet_constants.USER_ID_ENV_VAR: 'test-user'}

dag = common.process_mounts_in_task_on_api_server(task=task_yaml,
env_vars=env_vars,
workdir_only=False)

for task in dag.tasks:
assert task.file_mounts_blob_id is None
22 changes: 22 additions & 0 deletions tests/unit_tests/test_sky/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,3 +1336,25 @@ def test_multinode_rwx_volume_passes():
}
t.resolve_and_validate_volumes()
assert len(t.volume_mounts) == 1


def test_file_mounts_blob_id_roundtrip():
"""``file_mounts_blob_id`` should survive YAML round-trips."""
blob_id = 'a' * 64
t = task.Task(name='t', run='echo hi', _file_mounts_blob_id=blob_id)
assert t.file_mounts_blob_id == blob_id

config = t.to_yaml_config()
assert config.get('file_mounts_blob_id') == blob_id

t2 = task.Task.from_yaml_config(config)
assert t2.file_mounts_blob_id == blob_id


def test_file_mounts_blob_id_default_none():
"""User-authored Tasks have ``file_mounts_blob_id`` unset; the field must
not appear in the serialized YAML."""
t = task.Task(name='t', run='echo hi')
assert t.file_mounts_blob_id is None
config = t.to_yaml_config()
assert 'file_mounts_blob_id' not in config
Loading