diff --git a/contributing/PIPELINES.md b/contributing/PIPELINES.md index 89037bb6b..e41949b4f 100644 --- a/contributing/PIPELINES.md +++ b/contributing/PIPELINES.md @@ -37,6 +37,55 @@ Brief checklist for implementing a new pipeline: 8. Register the pipeline in `PipelineManager` and hint fetch from services after commit via `pipeline_hinter.hint_fetch(Model.__name__)`. 9. Add minimum tests: fetch eligibility/order, successful unlock path, stale lock token path, and related lock contention retry path. +## Typical worker structure + +Most workers are easiest to reason about when `process()` is split into three phases: + +1. Load/refetch: open a short DB session, refetch the locked main row by `id + lock_token`, lock any required related rows, and gather any extra data needed for processing. +2. Process: do the heavy work outside DB sessions and build result objects or update maps instead of mutating detached ORM models. +3. Apply: open a short DB session, guard the main update by `id + lock_token`, resolve time placeholders, apply related updates, emit events, and unlock rows. + +A dedicated context object is often useful for the load step when the worker needs multiple loaded models, related lock metadata, or derived values that should be passed cleanly into processing and apply. For very small pipelines, a direct load -> process -> apply flow may still be clearer. + +Workers can share one context type and one apply function across all states even if the processing logic differs by state: + +```python +async def process(item): + context = await _load_process_context(item) + if context is None: + return + result = await _process_item(context) + await _apply_process_result(item, context, result) +``` + +Sometimes state-specific helpers are still the cleanest option, but they can still share a common apply phase if all states write results in the same general shape: + +```python +async def process(item): + if item.status == Status.PENDING: + context = await _load_pending_context(item) + elif item.status == Status.RUNNING: + context = await _load_running_context(item) + else: + return + if context is None: + return + result = await _process_item(context) + await _apply_process_result(item, context, result) +``` + +If different states have materially different write-side behavior, different apply paths are fine as well. This commonly happens when one state does a normal guarded update while another does delete-or-cleanup work with different related updates: + +```python +async def process(item): + if item.to_be_deleted: + await _process_to_be_deleted_item(item) + elif item.status == Status.SUBMITTED: + await _process_submitted_item(item) +``` + +It's ok not to force all pipelines into one exact shape. + ## Implementation patterns **Guarded apply by lock token** diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 7b2d79047..3d7efae47 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -5,6 +5,7 @@ from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline from dstack._internal.server.background.pipeline_tasks.instances import InstancePipeline +from dstack._internal.server.background.pipeline_tasks.jobs_running import JobRunningPipeline from dstack._internal.server.background.pipeline_tasks.jobs_terminating import ( JobTerminatingPipeline, ) @@ -23,6 +24,7 @@ def __init__(self) -> None: ComputeGroupPipeline(), FleetPipeline(), GatewayPipeline(), + JobRunningPipeline(), JobTerminatingPipeline(), InstancePipeline(), PlacementGroupPipeline(), diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 3065c1e09..51f4230a8 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -196,106 +196,23 @@ def __init__( @sentry_utils.instrument_named_task("pipeline_tasks.FleetWorker.process") async def process(self, item: PipelineItem): - async with get_session_ctx() as session: - res = await session.execute( - select(FleetModel) - .where( - FleetModel.id == item.id, - FleetModel.lock_token == item.lock_token, - ) - .options(joinedload(FleetModel.project)) - .options( - selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)) - .joinedload(InstanceModel.jobs) - .load_only(JobModel.id), - ) - .options( - selectinload( - FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) - ).load_only(RunModel.status) - ) - ) - fleet_model = res.unique().scalar_one_or_none() - if fleet_model is None: - log_lock_token_mismatch(logger, item) - return - - # Lock instance only if consolidation is needed. - locked_instance_ids: set[uuid.UUID] = set() - consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model) - consolidation_instances = None - if consolidation_fleet_spec is not None: - consolidation_instances = await _lock_fleet_instances_for_consolidation( - session=session, - item=item, - ) - if consolidation_instances is None: - return - locked_instance_ids = {instance.id for instance in consolidation_instances} - + process_context = await _load_process_context(item) + if process_context is None: + return result = await _process_fleet( - fleet_model, - consolidation_fleet_spec=consolidation_fleet_spec, - consolidation_instances=consolidation_instances, - ) - fleet_update_map = _FleetUpdateMap() - fleet_update_map.update(result.fleet_update_map) - set_processed_update_map_fields(fleet_update_map) - set_unlock_update_map_fields(fleet_update_map) - instance_update_rows = _build_instance_update_rows( - result.instance_id_to_update_map, - unlock_instance_ids=locked_instance_ids, + process_context.fleet_model, + consolidation_fleet_spec=process_context.consolidation_fleet_spec, + consolidation_instances=process_context.consolidation_instances, ) + await _apply_process_result(item, process_context, result) - async with get_session_ctx() as session: - now = get_current_datetime() - resolve_now_placeholders(fleet_update_map, now=now) - resolve_now_placeholders(instance_update_rows, now=now) - res = await session.execute( - update(FleetModel) - .where( - FleetModel.id == fleet_model.id, - FleetModel.lock_token == fleet_model.lock_token, - ) - .values(**fleet_update_map) - .returning(FleetModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: - log_lock_token_changed_after_processing(logger, item) - if locked_instance_ids: - await _unlock_fleet_locked_instances( - session=session, - item=item, - locked_instance_ids=locked_instance_ids, - ) - # TODO: Clean up fleet. - return - - if fleet_update_map.get("deleted"): - await session.execute( - update(PlacementGroupModel) - .where(PlacementGroupModel.fleet_id == item.id) - .values(fleet_deleted=True) - ) - if instance_update_rows: - await session.execute( - update(InstanceModel), - instance_update_rows, - ) - if len(result.new_instance_creates) > 0: - await _create_missing_fleet_instances( - session=session, - fleet_model=fleet_model, - new_instance_creates=result.new_instance_creates, - ) - emit_fleet_status_change_event( - session=session, - fleet_model=fleet_model, - old_status=fleet_model.status, - new_status=fleet_update_map.get("status", fleet_model.status), - status_message=fleet_update_map.get("status_message", fleet_model.status_message), - ) + +@dataclass +class _ProcessContext: + fleet_model: FleetModel + consolidation_fleet_spec: Optional[FleetSpec] + consolidation_instances: Optional[list[InstanceModel]] + locked_instance_ids: set[uuid.UUID] = field(default_factory=set) class _FleetUpdateMap(ItemUpdateMap, total=False): @@ -318,6 +235,83 @@ class _InstanceUpdateMap(ItemUpdateMap, total=False): id: uuid.UUID +@dataclass +class _ProcessResult: + fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) + new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list) + + +class _NewInstanceCreate(TypedDict): + id: uuid.UUID + instance_num: int + + +@dataclass +class _MaintainNodesResult: + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) + new_instance_creates: list[_NewInstanceCreate] = field(default_factory=list) + changes_required: bool = False + + @property + def has_changes(self) -> bool: + return len(self.instance_id_to_update_map) > 0 or len(self.new_instance_creates) > 0 + + +async def _load_process_context(item: PipelineItem) -> Optional[_ProcessContext]: + async with get_session_ctx() as session: + fleet_model = await _refetch_locked_fleet(session=session, item=item) + if fleet_model is None: + log_lock_token_mismatch(logger, item) + return None + + consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model) + consolidation_instances = None + if consolidation_fleet_spec is not None: + consolidation_instances = await _lock_fleet_instances_for_consolidation( + session=session, + item=item, + ) + if consolidation_instances is None: + return None + + return _ProcessContext( + fleet_model=fleet_model, + consolidation_fleet_spec=consolidation_fleet_spec, + consolidation_instances=consolidation_instances, + locked_instance_ids=( + set() + if consolidation_instances is None + else {i.id for i in consolidation_instances} + ), + ) + + +async def _refetch_locked_fleet( + session: AsyncSession, + item: PipelineItem, +) -> Optional[FleetModel]: + res = await session.execute( + select(FleetModel) + .where( + FleetModel.id == item.id, + FleetModel.lock_token == item.lock_token, + ) + .options(joinedload(FleetModel.project)) + .options( + selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)) + .joinedload(InstanceModel.jobs) + .load_only(JobModel.id), + ) + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.status) + ) + ) + return res.unique().scalar_one_or_none() + + def _get_fleet_spec_if_ready_for_consolidation(fleet_model: FleetModel) -> Optional[FleetSpec]: if fleet_model.status == FleetStatus.TERMINATING: return None @@ -398,27 +392,71 @@ async def _lock_fleet_instances_for_consolidation( return locked_instance_models -@dataclass -class _ProcessResult: - fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) - instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) - new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list) - - -class _NewInstanceCreate(TypedDict): - id: uuid.UUID - instance_num: int - - -@dataclass -class _MaintainNodesResult: - instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) - new_instance_creates: list[_NewInstanceCreate] = field(default_factory=list) - changes_required: bool = False +async def _apply_process_result( + item: PipelineItem, + context: _ProcessContext, + result: "_ProcessResult", +) -> None: + fleet_update_map = _FleetUpdateMap() + fleet_update_map.update(result.fleet_update_map) + set_processed_update_map_fields(fleet_update_map) + set_unlock_update_map_fields(fleet_update_map) + instance_update_rows = _build_instance_update_rows( + result.instance_id_to_update_map, + unlock_instance_ids=context.locked_instance_ids, + ) - @property - def has_changes(self) -> bool: - return len(self.instance_id_to_update_map) > 0 or len(self.new_instance_creates) > 0 + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(fleet_update_map, now=now) + resolve_now_placeholders(instance_update_rows, now=now) + res = await session.execute( + update(FleetModel) + .where( + FleetModel.id == context.fleet_model.id, + FleetModel.lock_token == context.fleet_model.lock_token, + ) + .values(**fleet_update_map) + .returning(FleetModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + if context.locked_instance_ids: + await _unlock_fleet_locked_instances( + session=session, + item=item, + locked_instance_ids=context.locked_instance_ids, + ) + # TODO: Clean up fleet. + return + + if fleet_update_map.get("deleted"): + await session.execute( + update(PlacementGroupModel) + .where(PlacementGroupModel.fleet_id == context.fleet_model.id) + .values(fleet_deleted=True) + ) + if instance_update_rows: + await session.execute( + update(InstanceModel), + instance_update_rows, + ) + if len(result.new_instance_creates) > 0: + await _create_missing_fleet_instances( + session=session, + fleet_model=context.fleet_model, + new_instance_creates=result.new_instance_creates, + ) + emit_fleet_status_change_event( + session=session, + fleet_model=context.fleet_model, + old_status=context.fleet_model.status, + new_status=fleet_update_map.get("status", context.fleet_model.status), + status_message=fleet_update_map.get( + "status_message", context.fleet_model.status_message + ), + ) async def _process_fleet( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index b5289e05e..67f017619 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -238,8 +238,7 @@ async def process(self, item: InstancePipelineItem): process_context = await _process_terminating_item(item) if process_context is None: return - set_processed_update_map_fields(process_context.result.instance_update_map) - set_unlock_update_map_fields(process_context.result.instance_update_map) + await _apply_process_result( item=item, instance_model=process_context.instance_model, @@ -376,6 +375,9 @@ async def _apply_process_result( instance_model: InstanceModel, result: ProcessResult, ) -> None: + set_processed_update_map_fields(result.instance_update_map) + set_unlock_update_map_fields(result.instance_update_map) + async with get_session_ctx() as session: if result.health_check_create is not None: session.add(InstanceHealthCheckModel(**result.health_check_create)) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py new file mode 100644 index 000000000..eda36d811 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -0,0 +1,1608 @@ +import asyncio +import enum +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Dict, Iterable, Literal, Optional, Sequence, Union + +import httpx +from sqlalchemy import and_, func, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only + +from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import GatewayError, SSHError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import NetworkMode, RegistryAuth +from dstack._internal.core.models.configurations import DevEnvironmentConfiguration +from dstack._internal.core.models.files import FileArchiveMapping +from dstack._internal.core.models.instances import InstanceStatus, SSHConnectionParams +from dstack._internal.core.models.metrics import Metric +from dstack._internal.core.models.profiles import StartupOrder +from dstack._internal.core.models.repos import RemoteRepoCreds +from dstack._internal.core.models.runs import ( + ClusterInfo, + Job, + JobProvisioningData, + JobRuntimeData, + JobSpec, + JobStatus, + JobSubmission, + JobTerminationReason, + Run, + RunSpec, + RunStatus, +) +from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + ItemUpdateMap, + Pipeline, + PipelineItem, + Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, +) +from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + FleetModel, + InstanceModel, + JobModel, + ProbeModel, + ProjectModel, + RepoModel, + RunModel, + UserModel, +) +from dstack._internal.server.schemas.runner import TaskStatus +from dstack._internal.server.services import events +from dstack._internal.server.services import files as files_services +from dstack._internal.server.services import logs as logs_services +from dstack._internal.server.services.backends.provisioning import ( + get_instance_specific_gpu_devices, + get_instance_specific_mounts, + resolve_provisioning_image_name, +) +from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.instances import ( + get_instance_remote_connection_info, + get_instance_ssh_private_keys, +) +from dstack._internal.server.services.jobs import ( + emit_job_status_change_event, + find_job, + get_job_attached_volumes, + get_job_runtime_data, + is_master_job, + job_model_to_job_submission, +) +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.metrics import get_job_metrics +from dstack._internal.server.services.repos import ( + get_code_model, + get_repo_creds, + repo_model_to_repo_head_with_creds, +) +from dstack._internal.server.services.runner import client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.services.runs import is_job_ready, run_model_to_run +from dstack._internal.server.services.secrets import get_project_secrets_mapping +from dstack._internal.server.services.storage import get_default_storage +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async +from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2) +"""`The minimum time before terminating active job in case of connectivity issues.""" + + +@dataclass +class JobRunningPipelineItem(PipelineItem): + status: JobStatus + + +class JobRunningPipeline(Pipeline[JobRunningPipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=10), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[JobRunningPipelineItem]( + model_type=JobModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = JobRunningFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + JobRunningWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return JobModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[JobRunningPipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[JobRunningPipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["JobRunningWorker"]: + return self.__workers + + +class JobRunningFetcher(Fetcher[JobRunningPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[JobRunningPipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[JobRunningPipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.JobRunningFetcher.fetch") + async def fetch(self, limit: int) -> list[JobRunningPipelineItem]: + job_lock, _ = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__) + async with job_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(JobModel) + .join(JobModel.run) + .where( + JobModel.status.in_( + [JobStatus.PROVISIONING, JobStatus.PULLING, JobStatus.RUNNING] + ), + RunModel.status.not_in([RunStatus.TERMINATING]), + JobModel.last_processed_at <= now - self._min_processing_interval, + or_( + JobModel.lock_expires_at.is_(None), + JobModel.lock_expires_at < now, + ), + or_( + JobModel.lock_owner.is_(None), + JobModel.lock_owner == JobRunningPipeline.__name__, + ), + ) + .order_by(JobModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=JobModel) + .options( + load_only( + JobModel.id, + JobModel.lock_token, + JobModel.lock_expires_at, + JobModel.status, + ) + ) + ) + job_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for job_model in job_models: + prev_lock_expired = job_model.lock_expires_at is not None + job_model.lock_expires_at = lock_expires_at + job_model.lock_token = lock_token + job_model.lock_owner = JobRunningPipeline.__name__ + items.append( + JobRunningPipelineItem( + __tablename__=JobModel.__tablename__, + id=job_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=job_model.status, + ) + ) + await session.commit() + return items + + +class JobRunningWorker(Worker[JobRunningPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[JobRunningPipelineItem], + heartbeater: Heartbeater[JobRunningPipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.JobRunningWorker.process") + async def process(self, item: JobRunningPipelineItem): + context = await _load_process_context(item=item) + if context is None: + log_lock_token_mismatch(logger, item) + return + + result = await _process_running_job(context=context) + await _apply_process_result( + item=item, + job_model=context.job_model, + result=result, + ) + + +@dataclass +class _ProcessContext: + job_model: JobModel + run_model: RunModel + run: Run + job: Job + job_submission: JobSubmission + job_provisioning_data: Optional[JobProvisioningData] + server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None + + @property + def repo_model(self) -> RepoModel: + return self.run_model.repo + + @property + def project(self) -> ProjectModel: + return self.run_model.project + + +class _JobUpdateMap(ItemUpdateMap, total=False): + status: JobStatus + termination_reason: Optional[JobTerminationReason] + termination_reason_message: Optional[str] + job_provisioning_data: Optional[str] + job_runtime_data: Optional[str] + runner_timestamp: Optional[int] + disconnected_at: Optional[datetime] + inactivity_secs: Optional[int] + exit_status: Optional[int] + registered: bool + + +@dataclass +class _RegisterReplicaResult: + gateway_target: Optional[events.Target] # None = no gateway + + +@dataclass +class _ProcessResult: + job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) + new_probe_models: list[ProbeModel] = field(default_factory=list) + replica_registration: Optional[_RegisterReplicaResult] = None # None = not registered yet + + +@dataclass +class _StartupContext: + cluster_info: ClusterInfo + volumes: list[Volume] + secrets: dict[str, str] + repo_creds: Optional[RemoteRepoCreds] + + +async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_ProcessContext]: + async with get_session_ctx() as session: + job_model = await _refetch_locked_job_model(session=session, item=item) + if job_model is None: + return None + run_model = await _fetch_run_model(session=session, run_id=job_model.run_id) + run = run_model_to_run(run_model, include_sensitive=True) + job_submission = job_model_to_job_submission(job_model) + server_ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance)) + return _ProcessContext( + job_model=job_model, + run_model=run_model, + run=run, + job=find_job(run.jobs, job_model.replica_num, job_model.job_num), + job_submission=job_submission, + job_provisioning_data=job_submission.job_provisioning_data, + server_ssh_private_keys=server_ssh_private_keys, + ) + + +async def _process_running_job(context: _ProcessContext) -> _ProcessResult: + result = _ProcessResult() + if context.job_provisioning_data is None: + logger.error("%s: job_provisioning_data of an active job is None", fmt(context.job_model)) + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message=( + "Unexpected server error: job_provisioning_data of an active job is None" + ), + ) + return result + + if context.job_model.status == JobStatus.PROVISIONING: + startup_context = await _prepare_startup_context(context=context, result=result) + if startup_context is None: + return result + await _process_provisioning_status( + context=context, startup_context=startup_context, result=result + ) + elif context.job_model.status == JobStatus.PULLING: + startup_context = await _prepare_startup_context(context=context, result=result) + if startup_context is None: + return result + await _process_pulling_status( + context=context, startup_context=startup_context, result=result + ) + elif context.job_model.status == JobStatus.RUNNING: + await _process_running_status(context=context, result=result) + + if _get_result_status(context.job_model, result) == JobStatus.RUNNING: + if context.job_model.status != JobStatus.RUNNING: + _initialize_running_job_probes( + job_model=context.job_model, + job=context.job, + result=result, + ) + await _maybe_register_replica(context=context, result=result) + await _check_gpu_utilization(context=context, result=result) + return result + + +async def _prepare_startup_context( + context: _ProcessContext, + result: _ProcessResult, +) -> Optional[_StartupContext]: + job_provisioning_data = get_or_error(context.job_provisioning_data) + + for other_job in context.run.jobs: + if ( + other_job.job_spec.replica_num == context.job.job_spec.replica_num + and other_job.job_submissions[-1].status == JobStatus.SUBMITTED + ): + logger.debug( + "%s: waiting for all jobs in the replica to be provisioned", + fmt(context.job_model), + ) + return None + + cluster_info = _get_cluster_info( + jobs=context.run.jobs, + replica_num=context.job.job_spec.replica_num, + job_provisioning_data=job_provisioning_data, + job_runtime_data=context.job_submission.job_runtime_data, + ) + + async with get_session_ctx() as session: + volumes = await get_job_attached_volumes( + session=session, + project=context.project, + run_spec=context.run.run_spec, + job_num=context.job.job_spec.job_num, + job_provisioning_data=job_provisioning_data, + ) + repo_creds_model = await get_repo_creds( + session=session, + repo=context.repo_model, + user=context.run_model.user, + ) + secrets = await get_project_secrets_mapping(session=session, project=context.project) + + repo_creds = repo_model_to_repo_head_with_creds( + context.repo_model, + repo_creds_model, + ).repo_creds + + try: + _interpolate_secrets(secrets, context.job.job_spec) + except InterpolatorError as e: + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message=f"Secrets interpolation error: {e.args[0]}", + ) + return None + + return _StartupContext( + cluster_info=cluster_info, + volumes=volumes, + secrets=secrets, + repo_creds=repo_creds, + ) + + +async def _refetch_locked_job_model( + session: AsyncSession, item: JobRunningPipelineItem +) -> Optional[JobModel]: + res = await session.execute( + select(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + .options(joinedload(JobModel.instance).joinedload(InstanceModel.project)) + .options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak)) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel: + latest_submissions_sq = ( + select( + JobModel.run_id.label("run_id"), + JobModel.replica_num.label("replica_num"), + JobModel.job_num.label("job_num"), + func.max(JobModel.submission_num).label("max_submission_num"), + ) + .where(JobModel.run_id == run_id) + .group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num) + .subquery() + ) + job_alias = aliased(JobModel) + res = await session.execute( + select(RunModel) + .where(RunModel.id == run_id) + .join(job_alias, job_alias.run_id == RunModel.id) + .join( + latest_submissions_sq, + onclause=and_( + job_alias.run_id == latest_submissions_sq.c.run_id, + job_alias.replica_num == latest_submissions_sq.c.replica_num, + job_alias.job_num == latest_submissions_sq.c.job_num, + job_alias.submission_num == latest_submissions_sq.c.max_submission_num, + ), + ) + .options(joinedload(RunModel.project)) + .options(joinedload(RunModel.user)) + .options(joinedload(RunModel.repo)) + .options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name)) + .options(contains_eager(RunModel.jobs, alias=job_alias)) + ) + return res.unique().scalar_one() + + +async def _process_provisioning_status( + context: _ProcessContext, + startup_context: _StartupContext, + result: _ProcessResult, +) -> None: + job_provisioning_data = get_or_error(context.job_provisioning_data) + server_ssh_private_keys = get_or_error(context.server_ssh_private_keys) + + if job_provisioning_data.hostname is None: + _wait_for_instance_provisioning_data(context.job_model, result) + return + if _should_wait_for_other_nodes(context.run, context.job, context.job_model): + return + + if job_provisioning_data.dockerized: + logger.debug( + "%s: process provisioning job with shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + ssh_user = job_provisioning_data.username + assert context.run.run_spec.ssh_key_pub is not None + user_ssh_key = context.run.run_spec.ssh_key_pub.strip() + public_keys = [context.project.ssh_public_key.strip(), user_ssh_key] + if job_provisioning_data.backend == BackendType.LOCAL: + user_ssh_key = "" + success = await run_async( + _process_provisioning_with_shim, + server_ssh_private_keys, + job_provisioning_data, + None, + run=context.run, + job_model=context.job_model, + jrd=get_job_runtime_data(context.job_model), + jpd=job_provisioning_data, + volumes=startup_context.volumes, + registry_auth=context.job.job_spec.registry_auth, + public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=user_ssh_key, + ) + if success: + _set_job_status(context.job_model, result, JobStatus.PULLING) + return + else: + logger.debug( + "%s: process provisioning job without shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + runner_availability = await run_async( + _get_runner_availability, + server_ssh_private_keys, + job_provisioning_data, + None, + ) + if runner_availability == _RunnerAvailability.AVAILABLE: + file_archives = await _get_job_file_archives( + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, + ) + code = await _get_job_code( + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), + ) + submit_result = await run_async( + _submit_job_to_runner, + server_ssh_private_keys, + job_provisioning_data, + None, + run=context.run, + job_model=context.job_model, + job=context.job, + jrd=get_job_runtime_data(context.job_model), + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, + success_if_not_available=False, + ) + if submit_result is not False: + _apply_submit_job_to_runner_result( + job_model=context.job_model, + result=result, + submit_result=submit_result, + ) + if submit_result is not False and submit_result.success: + return + + provisioning_timeout = get_provisioning_timeout( + backend_type=job_provisioning_data.get_base_backend(), + instance_type_name=job_provisioning_data.instance_type.name, + ) + if context.job_submission.age > provisioning_timeout: + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED, + termination_reason_message=( + f"Runner did not become available within {provisioning_timeout.total_seconds()}s." + f" Job submission age: {context.job_submission.age.total_seconds()}s)" + ), + ) + + +async def _process_pulling_status( + context: _ProcessContext, + startup_context: _StartupContext, + result: _ProcessResult, +) -> None: + job_provisioning_data = get_or_error(context.job_provisioning_data) + server_ssh_private_keys = get_or_error(context.server_ssh_private_keys) + + logger.debug( + "%s: process pulling job with shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + shim_state = await run_async( + _sync_shim_pulling_state, + server_ssh_private_keys, + job_provisioning_data, + None, + job_model=context.job_model, + jrd=_get_result_job_runtime_data(context.job_model, result), + ) + if shim_state is not False: + if shim_state.job_runtime_data is not None: + _set_job_runtime_data(result, shim_state.job_runtime_data) + + if shim_state.state == _ShimPullingState.WAITING: + _reset_disconnected_at(context.job_model, result) + return + + if shim_state.state == _ShimPullingState.FAILED: + logger.warning( + "%s: failed due to %s, age=%s", + fmt(context.job_model), + get_or_error(shim_state.termination_reason).value, + context.job_submission.age, + ) + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=get_or_error(shim_state.termination_reason), + termination_reason_message=get_or_error(shim_state.termination_reason_message), + ) + return + + # _ShimPullingState.READY + job_runtime_data = _get_result_job_runtime_data(context.job_model, result) + runner_availability = await run_async( + _get_runner_availability, + server_ssh_private_keys, + job_provisioning_data, + job_runtime_data, + ) + if runner_availability == _RunnerAvailability.UNAVAILABLE: + _reset_disconnected_at(context.job_model, result) + return + + if runner_availability == _RunnerAvailability.AVAILABLE: + file_archives = await _get_job_file_archives( + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, + ) + code = await _get_job_code( + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), + ) + submit_result = await run_async( + _submit_job_to_runner, + server_ssh_private_keys, + job_provisioning_data, + job_runtime_data, + run=context.run, + job_model=context.job_model, + job=context.job, + jrd=job_runtime_data, + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, + success_if_not_available=True, + ) + if submit_result is not False: + _apply_submit_job_to_runner_result( + job_model=context.job_model, + result=result, + submit_result=submit_result, + ) + if submit_result is not False and submit_result.success: + _reset_disconnected_at(context.job_model, result) + return + + # SSH tunnel failed or READY but runner submit failed — treat as disconnect + _handle_instance_unreachable(context, result, job_provisioning_data) + + +async def _process_running_status( + context: _ProcessContext, + result: _ProcessResult, +) -> None: + job_provisioning_data = get_or_error(context.job_provisioning_data) + server_ssh_private_keys = get_or_error(context.server_ssh_private_keys) + + logger.debug( + "%s: process running job, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + process_running_result = await run_async( + _process_running, + server_ssh_private_keys, + job_provisioning_data, + context.job_submission.job_runtime_data, + run_model=context.run_model, + job_model=context.job_model, + ) + if process_running_result is not False: + result.job_update_map.update(process_running_result.job_update_map) + _reset_disconnected_at(context.job_model, result) + return + + _handle_instance_unreachable(context, result, job_provisioning_data) + + +async def _apply_process_result( + item: JobRunningPipelineItem, + job_model: JobModel, + result: _ProcessResult, +) -> None: + set_processed_update_map_fields(result.job_update_map) + set_unlock_update_map_fields(result.job_update_map) + + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(result.job_update_map, now=now) + res = await session.execute( + update(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + .values(**result.job_update_map) + .returning(JobModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + return + + if result.new_probe_models: + session.add_all(result.new_probe_models) + + _emit_result_events(session=session, job_model=job_model, result=result) + + +def _emit_result_events( + session: AsyncSession, + job_model: JobModel, + result: _ProcessResult, +) -> None: + """Emit audit events for changes recorded in result..""" + emit_job_status_change_event( + session=session, + job_model=job_model, + old_status=job_model.status, + new_status=result.job_update_map.get("status", job_model.status), + termination_reason=result.job_update_map.get( + "termination_reason", job_model.termination_reason + ), + termination_reason_message=result.job_update_map.get( + "termination_reason_message", + job_model.termination_reason_message, + ), + ) + _emit_reachability_change_event( + session=session, + job_model=job_model, + old_disconnected_at=job_model.disconnected_at, + new_disconnected_at=result.job_update_map.get( + "disconnected_at", + job_model.disconnected_at, + ), + ) + if result.replica_registration is not None: + targets = [events.Target.from_model(job_model)] + if result.replica_registration.gateway_target is not None: + targets.append(result.replica_registration.gateway_target) + events.emit( + session, + "Service replica registered to receive requests", + actor=events.SystemActor(), + targets=targets, + ) + + +def _wait_for_instance_provisioning_data( + job_model: JobModel, + result: _ProcessResult, +) -> None: + if job_model.instance is None: + logger.error( + "%s: cannot update job_provisioning_data. job_model.instance is None.", + fmt(job_model), + ) + return + if job_model.instance.job_provisioning_data is None: + logger.error( + "%s: cannot update job_provisioning_data. job_model.job_provisioning_data is None.", + fmt(job_model), + ) + return + + if job_model.instance.status == InstanceStatus.TERMINATED: + _terminate_job( + job_model=job_model, + job_update_map=result.job_update_map, + termination_reason=JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED, + termination_reason_message="Instance is terminated", + ) + return + + result.job_update_map["job_provisioning_data"] = job_model.instance.job_provisioning_data + + +def _handle_instance_unreachable( + context: _ProcessContext, + result: _ProcessResult, + job_provisioning_data: JobProvisioningData, +) -> None: + _set_disconnected_at_now(context.job_model, result) + if not _should_terminate_job_due_to_disconnect( + _get_result_disconnected_at(context.job_model, result) + ): + logger.warning( + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + return + if job_provisioning_data.instance_type.resources.spot: + termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + else: + termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=termination_reason, + termination_reason_message="Instance is unreachable", + ) + + +def _initialize_running_job_probes( + job_model: JobModel, + job: Job, + result: _ProcessResult, +) -> None: + for probe_num in range(len(job.job_spec.probes)): + result.new_probe_models.append( + ProbeModel( + name=f"{job_model.job_name}-{probe_num}", + job_id=job_model.id, + probe_num=probe_num, + due=get_current_datetime(), + success_streak=0, + active=True, + ) + ) + + +async def _maybe_register_replica( + context: _ProcessContext, + result: _ProcessResult, +) -> None: + if ( + context.run.run_spec.configuration.type != "service" + or _get_result_registered(context.job_model, result) + or context.job_model.job_num != 0 + or result.new_probe_models + or not is_job_ready(context.job_model.probes, context.job.job_spec.probes) + ): + return + + ssh_head_proxy: Optional[SSHConnectionParams] = None + ssh_head_proxy_private_key: Optional[str] = None + instance = get_or_error(context.job_model.instance) + rci = get_instance_remote_connection_info(instance) + if rci is not None and rci.ssh_proxy is not None: + ssh_head_proxy = rci.ssh_proxy + ssh_head_proxy_keys = get_or_error(rci.ssh_proxy_keys) + ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private + + try: + gateway_target = await _register_service_replica( + context=context, + result=result, + ssh_head_proxy=ssh_head_proxy, + ssh_head_proxy_private_key=ssh_head_proxy_private_key, + ) + except GatewayError as e: + logger.warning("%s: failed to register service replica: %s", fmt(context.job_model), e) + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=JobTerminationReason.GATEWAY_ERROR, + termination_reason_message="Failed to register service replica", + ) + return + + result.job_update_map["registered"] = True + result.replica_registration = _RegisterReplicaResult(gateway_target=gateway_target) + + +async def _register_service_replica( + context: _ProcessContext, + result: _ProcessResult, + ssh_head_proxy: Optional[SSHConnectionParams], + ssh_head_proxy_private_key: Optional[str], +) -> Optional[events.Target]: + if context.run_model.gateway_id is None: + return None + + async with get_session_ctx() as session: + gateway_model, conn = await get_or_add_gateway_connection( + session, context.run_model.gateway_id + ) + gateway_target = events.Target.from_model(gateway_model) + assert context.job_model.instance is not None + instance_project_ssh_private_key = None + if context.job_model.project_id != context.job_model.instance.project_id: + instance_project_ssh_private_key = context.job_model.instance.project.ssh_private_key + # JobRuntimeData might change on PULLING -> RUNNING path + # so we must update job_submission with the result value. + job_submission = context.job_submission.copy(deep=True) + job_submission.job_runtime_data = _get_result_job_runtime_data(context.job_model, result) + try: + logger.debug( + "%s: registering replica for service %s", fmt(context.job_model), context.run.id.hex + ) + async with conn.client() as gateway_client: + await gateway_client.register_replica( + run=context.run, + job_spec=JobSpec.__response__.parse_raw(context.job_model.job_spec_data), + job_submission=job_submission, + instance_project_ssh_private_key=instance_project_ssh_private_key, + ssh_head_proxy=ssh_head_proxy, + ssh_head_proxy_private_key=ssh_head_proxy_private_key, + ) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + except GatewayError as e: + if "already exists in service" in e.msg: + logger.warning( + ( + "%s: could not register replica in gateway: %s." + " NOTE: if you just updated dstack from pre-0.19.25 to 0.19.25+," + " expect to see this warning once for every running service replica" + ), + fmt(context.job_model), + e.msg, + ) + else: + raise + return gateway_target + + +async def _check_gpu_utilization( + context: _ProcessContext, + result: _ProcessResult, +) -> None: + policy = context.job.job_spec.utilization_policy + if policy is None: + return + + after = get_current_datetime() - timedelta(seconds=policy.time_window) + async with get_session_ctx() as session: + job_metrics = await get_job_metrics(session, context.job_model, after=after) + gpus_util_metrics: list[Metric] = [] + for metric in job_metrics.metrics: + if metric.name.startswith("gpu_util_percent_gpu"): + gpus_util_metrics.append(metric) + if not gpus_util_metrics or gpus_util_metrics[0].timestamps[-1] > after + timedelta(minutes=1): + logger.debug("%s: GPU utilization check: not enough samples", fmt(context.job_model)) + return + if _should_terminate_due_to_low_gpu_util( + policy.min_gpu_utilization, [metric.values for metric in gpus_util_metrics] + ): + logger.debug("%s: GPU utilization check: terminating", fmt(context.job_model)) + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY, + termination_reason_message=( + f"The job GPU utilization below {policy.min_gpu_utilization}%" + f" for {policy.time_window} seconds" + ), + ) + else: + logger.debug("%s: GPU utilization check: OK", fmt(context.job_model)) + + +def _should_terminate_due_to_low_gpu_util( + min_util: int, gpus_util: Iterable[Iterable[int]] +) -> bool: + for gpu_util in gpus_util: + if all(util < min_util for util in gpu_util): + return True + return False + + +def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool: + for other_job in run.jobs: + if ( + other_job.job_spec.replica_num == job.job_spec.replica_num + and other_job.job_submissions[-1].status == JobStatus.PROVISIONING + and other_job.job_submissions[-1].job_provisioning_data is not None + and other_job.job_submissions[-1].job_provisioning_data.hostname is None + ): + logger.debug("%s: waiting for other job to have IP assigned", fmt(job_model)) + return True + master_job = find_job(run.jobs, job.job_spec.replica_num, 0) + if ( + job.job_spec.job_num != 0 + and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST + and master_job.job_submissions[-1].status != JobStatus.RUNNING + ): + logger.debug("%s: waiting for master job to become running", fmt(job_model)) + return True + if ( + is_master_job(job) + and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST + ): + for other_job in run.jobs: + if ( + other_job.job_spec.replica_num == job.job_spec.replica_num + and other_job.job_spec.job_num != job.job_spec.job_num + and other_job.job_submissions[-1].status != JobStatus.RUNNING + ): + logger.debug("%s: waiting for worker job to become running", fmt(job_model)) + return True + return False + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +def _process_provisioning_with_shim( + ports: Dict[int, int], + run: Run, + job_model: JobModel, + jrd: Optional[JobRuntimeData], + jpd: JobProvisioningData, + volumes: list[Volume], + registry_auth: Optional[RegistryAuth], + public_keys: list[str], + ssh_user: str, + ssh_key: str, +) -> bool: + job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + + resp = shim_client.healthcheck() + if resp is None: + logger.debug("%s: shim is not available yet", fmt(job_model)) + return False + + registry_username = "" + registry_password = "" + if registry_auth is not None: + registry_username = registry_auth.username + registry_password = registry_auth.password + + volume_mounts: list[VolumeMountPoint] = [] + instance_mounts: list[InstanceMountPoint] = [] + for mount in run.run_spec.configuration.volumes: + if isinstance(mount, VolumeMountPoint): + volume_mounts.append(mount.copy()) + elif isinstance(mount, InstanceMountPoint): + instance_mounts.append(mount) + else: + assert False, f"unexpected mount point: {mount!r}" + + for volume, volume_mount in zip(volumes, volume_mounts): + volume_mount.name = volume.name + + instance_mounts += get_instance_specific_mounts(jpd.backend, jpd.instance_type.name) + gpu_devices = get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name) + + container_user = "root" + if jrd is not None: + gpu = jrd.gpu + cpu = jrd.cpu + memory = jrd.memory + network_mode = jrd.network_mode + else: + gpu = None + cpu = None + memory = None + network_mode = NetworkMode.HOST + image_name = resolve_provisioning_image_name(job_spec, jpd) + if shim_client.is_api_v2_supported(): + shim_client.submit_task( + task_id=job_model.id, + name=job_model.job_name, + registry_username=registry_username, + registry_password=registry_password, + image_name=image_name, + container_user=container_user, + privileged=job_spec.privileged, + gpu=gpu, + cpu=cpu, + memory=memory, + shm_size=job_spec.requirements.resources.shm_size, + network_mode=network_mode, + volumes=volumes, + volume_mounts=volume_mounts, + instance_mounts=instance_mounts, + gpu_devices=gpu_devices, + host_ssh_user=ssh_user, + host_ssh_keys=[ssh_key] if ssh_key else [], + container_ssh_keys=public_keys, + instance_id=jpd.instance_id, + ) + else: + submitted = shim_client.submit( + username=registry_username, + password=registry_password, + image_name=image_name, + privileged=job_spec.privileged, + container_name=job_model.job_name, + container_user=container_user, + shm_size=job_spec.requirements.resources.shm_size, + public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=ssh_key, + mounts=volume_mounts, + volumes=volumes, + instance_mounts=instance_mounts, + instance_id=jpd.instance_id, + ) + if not submitted: + logger.warning( + "%s: failed to submit, shim is already running a job, stopping it now, retry later", + fmt(job_model), + ) + shim_client.stop(force=True) + return False + + return True + + +class _RunnerAvailability(enum.Enum): + AVAILABLE = "available" + UNAVAILABLE = "unavailable" + + +class _ShimPullingState(enum.Enum): + WAITING = "waiting" + READY = "ready" + FAILED = "failed" + + +@dataclass +class _SyncShimPullingStateResult: + state: _ShimPullingState + termination_reason: Optional[JobTerminationReason] = None + termination_reason_message: Optional[str] = None + job_runtime_data: Optional[JobRuntimeData] = None + + +@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability: + runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + if runner_client.healthcheck() is None: + return _RunnerAvailability.UNAVAILABLE + return _RunnerAvailability.AVAILABLE + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) +def _sync_shim_pulling_state( + ports: Dict[int, int], + job_model: JobModel, + jrd: Optional[JobRuntimeData] = None, +) -> Union[_SyncShimPullingStateResult, Literal[False]]: + shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + if shim_client.is_api_v2_supported(): + task = shim_client.get_task(job_model.id) + if task.status == TaskStatus.TERMINATED: + logger.warning( + "shim failed to execute job %s: %s (%s)", + job_model.job_name, + task.termination_reason, + task.termination_message, + ) + logger.debug("task status: %s", task.dict()) + return _SyncShimPullingStateResult( + state=_ShimPullingState.FAILED, + termination_reason=JobTerminationReason(task.termination_reason.lower()), + termination_reason_message=task.termination_message, + ) + + if task.status != TaskStatus.RUNNING: + return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) + + if jrd is not None: + if task.ports is None: + return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) + jrd = jrd.copy(update={"ports": {pm.container: pm.host for pm in task.ports}}) + else: + shim_status = shim_client.pull() + if ( + shim_status.state == "pending" + and shim_status.result is not None + and shim_status.result.reason != "" + ): + logger.warning( + "shim failed to execute job %s: %s (%s)", + job_model.job_name, + shim_status.result.reason, + shim_status.result.reason_message, + ) + logger.debug("shim status: %s", shim_status.dict()) + return _SyncShimPullingStateResult( + state=_ShimPullingState.FAILED, + termination_reason=JobTerminationReason(shim_status.result.reason.lower()), + termination_reason_message=shim_status.result.reason_message, + ) + + if shim_status.state in ("pulling", "creating"): + return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) + + return _SyncShimPullingStateResult( + state=_ShimPullingState.READY, + job_runtime_data=jrd, + ) + + +@dataclass +class _SubmitJobToRunnerResult: + success: bool + set_running_status: bool = False + job_runtime_data: Optional[JobRuntimeData] = None + + +@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +def _submit_job_to_runner( + ports: Dict[int, int], + run: Run, + job_model: JobModel, + job: Job, + jrd: Optional[JobRuntimeData], + cluster_info: ClusterInfo, + code: bytes, + file_archives: Iterable[tuple[uuid.UUID, bytes]], + secrets: Dict[str, str], + repo_credentials: Optional[RemoteRepoCreds], + success_if_not_available: bool, +) -> Union[_SubmitJobToRunnerResult, Literal[False]]: + logger.debug("%s: submitting job spec", fmt(job_model)) + logger.debug( + "%s: repo clone URL is %s", + fmt(job_model), + None if repo_credentials is None else repo_credentials.clone_url, + ) + instance = job_model.instance + if instance is not None and (rci := get_instance_remote_connection_info(instance)) is not None: + instance_env = rci.env + else: + instance_env = None + + runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + if runner_client.healthcheck() is None: + return _SubmitJobToRunnerResult(success=success_if_not_available) + + runner_client.submit_job( + run=run, + job=job, + cluster_info=cluster_info, + # Do not send all the secrets since interpolation is already done by the server. + # TODO: Passing secrets may be necessary for filtering out secret values from logs. + secrets={}, + repo_credentials=repo_credentials, + instance_env=instance_env, + ) + logger.debug("%s: uploading file archive(s)", fmt(job_model)) + for archive_id, archive in file_archives: + runner_client.upload_archive(archive_id, archive) + logger.debug("%s: uploading code", fmt(job_model)) + runner_client.upload_code(code) + logger.debug("%s: starting job", fmt(job_model)) + job_info = runner_client.run_job() + if job_info is not None: + if jrd is not None: + jrd = jrd.copy( + update={"working_dir": job_info.working_dir, "username": job_info.username} + ) + return _SubmitJobToRunnerResult( + success=True, + set_running_status=True, + job_runtime_data=jrd, + ) + + +@dataclass +class _ProcessRunningResult: + job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) + + +@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) +def _process_running( + ports: Dict[int, int], + run_model: RunModel, + job_model: JobModel, +) -> Union[_ProcessRunningResult, Literal[False]]: + runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + timestamp = job_model.runner_timestamp or 0 + resp = runner_client.pull(timestamp) + logs_services.write_logs( + project=run_model.project, + run_name=run_model.run_name, + job_submission_id=job_model.id, + runner_logs=resp.runner_logs, + job_logs=resp.job_logs, + ) + result = _ProcessRunningResult( + job_update_map=_JobUpdateMap(runner_timestamp=resp.last_updated) + ) + if len(resp.job_states) > 0: + latest_state_event = resp.job_states[-1] + latest_status = latest_state_event.state + if latest_status == JobStatus.DONE: + _terminate_job( + job_model=job_model, + job_update_map=result.job_update_map, + termination_reason=JobTerminationReason.DONE_BY_RUNNER, + termination_reason_message=None, + ) + elif latest_status in {JobStatus.FAILED, JobStatus.TERMINATED}: + termination_reason = JobTerminationReason.CONTAINER_EXITED_WITH_ERROR + if latest_state_event.termination_reason: + termination_reason = JobTerminationReason( + latest_state_event.termination_reason.lower() + ) + _terminate_job( + job_model=job_model, + job_update_map=result.job_update_map, + termination_reason=termination_reason, + termination_reason_message=latest_state_event.termination_message, + ) + if latest_state_event.exit_status is not None: + result.job_update_map["exit_status"] = latest_state_event.exit_status + if latest_state_event.exit_status != 0: + logger.info( + "%s: non-zero exit status %s", fmt(job_model), latest_state_event.exit_status + ) + else: + _terminate_if_inactivity_duration_exceeded( + run_model=run_model, + job_model=job_model, + job_update_map=result.job_update_map, + no_connections_secs=resp.no_connections_secs, + ) + return result + + +def _terminate_if_inactivity_duration_exceeded( + run_model: RunModel, + job_model: JobModel, + job_update_map: _JobUpdateMap, + no_connections_secs: Optional[int], +) -> None: + conf = RunSpec.__response__.parse_raw(run_model.run_spec).configuration + if not isinstance(conf, DevEnvironmentConfiguration) or not isinstance( + conf.inactivity_duration, int + ): + job_update_map["inactivity_secs"] = None + return + + logger.debug("%s: no SSH connections for %s seconds", fmt(job_model), no_connections_secs) + job_update_map["inactivity_secs"] = no_connections_secs + if no_connections_secs is None: + # TODO(0.19 or earlier): make no_connections_secs required + _terminate_job( + job_model=job_model, + job_update_map=job_update_map, + termination_reason=JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY, + termination_reason_message=( + "The selected instance was created before dstack 0.18.41" + " and does not support inactivity_duration" + ), + ) + elif no_connections_secs >= conf.inactivity_duration: + _terminate_job( + job_model=job_model, + job_update_map=job_update_map, + termination_reason=JobTerminationReason.INACTIVITY_DURATION_EXCEEDED, + termination_reason_message=( + f"The job was inactive for {no_connections_secs} seconds," + f" exceeding the inactivity_duration of {conf.inactivity_duration} seconds" + ), + ) + + +def _should_terminate_job_due_to_disconnect(disconnected_at: Optional[datetime]) -> bool: + if disconnected_at is None: + return False + return get_current_datetime() > disconnected_at + JOB_DISCONNECTED_RETRY_TIMEOUT + + +def _set_disconnected_at_now(job_model: JobModel, result: _ProcessResult) -> None: + if _get_result_disconnected_at(job_model, result) is None: + result.job_update_map["disconnected_at"] = get_current_datetime() + + +def _reset_disconnected_at(job_model: JobModel, result: _ProcessResult) -> None: + if _get_result_disconnected_at(job_model, result) is not None: + result.job_update_map["disconnected_at"] = None + + +def _get_cluster_info( + jobs: list[Job], + replica_num: int, + job_provisioning_data: JobProvisioningData, + job_runtime_data: Optional[JobRuntimeData], +) -> ClusterInfo: + job_ips = [] + for job in jobs: + if job.job_spec.replica_num == replica_num: + job_ips.append( + get_or_error(job.job_submissions[-1].job_provisioning_data).internal_ip or "" + ) + gpus_per_job = len(job_provisioning_data.instance_type.resources.gpus) + if job_runtime_data is not None and job_runtime_data.offer is not None: + gpus_per_job = len(job_runtime_data.offer.instance.resources.gpus) + return ClusterInfo( + job_ips=job_ips, + master_job_ip=job_ips[0], + gpus_per_job=gpus_per_job, + ) + + +def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]: + # TODO: drop this function when supporting jobs submitted before 0.19.17 is no longer relevant. + if ( + job.job_spec.repo_code_hash is None + and run.run_spec.repo_code_hash is not None + and job.job_submissions[-1].deployment_num == run.deployment_num + ): + return run.run_spec.repo_code_hash + return job.job_spec.repo_code_hash + + +async def _get_job_code(project: ProjectModel, repo: RepoModel, code_hash: Optional[str]) -> bytes: + if code_hash is None: + return b"" + async with get_session_ctx() as session: + code_model = await get_code_model(session=session, repo=repo, code_hash=code_hash) + if code_model is None: + return b"" + if code_model.blob is not None: + return code_model.blob + storage = get_default_storage() + if storage is None: + return b"" + blob = await run_async( + storage.get_code, + project.name, + repo.name, + code_hash, + ) + if blob is None: + logger.error( + "Failed to get repo code hash %s from storage for repo %s", code_hash, repo.name + ) + return b"" + return blob + + +async def _get_job_file_archives( + archive_mappings: Iterable[FileArchiveMapping], + user: UserModel, +) -> list[tuple[uuid.UUID, bytes]]: + archives: list[tuple[uuid.UUID, bytes]] = [] + for archive_mapping in archive_mappings: + archive_blob = await _get_job_file_archive(archive_id=archive_mapping.id, user=user) + archives.append((archive_mapping.id, archive_blob)) + return archives + + +async def _get_job_file_archive(archive_id: uuid.UUID, user: UserModel) -> bytes: + async with get_session_ctx() as session: + archive_model = await files_services.get_archive_model(session, id=archive_id, user=user) + if archive_model is None: + return b"" + if archive_model.blob is not None: + return archive_model.blob + storage = get_default_storage() + if storage is None: + return b"" + blob = await run_async( + storage.get_archive, + str(archive_model.user_id), + archive_model.blob_hash, + ) + if blob is None: + logger.error("Failed to get file archive %s from storage", archive_id) + return b"" + return blob + + +def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: + interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error + job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()} + if job_spec.registry_auth is not None: + job_spec.registry_auth = RegistryAuth( + username=interpolate(job_spec.registry_auth.username), + password=interpolate(job_spec.registry_auth.password), + ) + + +def _emit_reachability_change_event( + session: AsyncSession, + job_model: JobModel, + old_disconnected_at: Optional[datetime], + new_disconnected_at: Optional[datetime], +) -> None: + if old_disconnected_at is None and new_disconnected_at is not None: + events.emit( + session, + "Job became unreachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(job_model)], + ) + elif old_disconnected_at is not None and new_disconnected_at is None: + events.emit( + session, + "Job became reachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(job_model)], + ) + + +def _terminate_job( + job_model: JobModel, + job_update_map: _JobUpdateMap, + termination_reason: JobTerminationReason, + termination_reason_message: Optional[str], +) -> None: + job_update_map["termination_reason"] = termination_reason + job_update_map["termination_reason_message"] = termination_reason_message + _set_job_update_status(job_model, job_update_map, JobStatus.TERMINATING) + + +def _set_job_update_status( + job_model: JobModel, + job_update_map: _JobUpdateMap, + new_status: JobStatus, +) -> None: + if job_update_map.get("status", job_model.status) != new_status: + job_update_map["status"] = new_status + + +def _set_job_status(job_model: JobModel, result: _ProcessResult, new_status: JobStatus) -> None: + _set_job_update_status(job_model, result.job_update_map, new_status) + + +def _set_job_runtime_data(result: _ProcessResult, jrd: Optional[JobRuntimeData]) -> None: + result.job_update_map["job_runtime_data"] = None if jrd is None else jrd.json() + + +def _apply_submit_job_to_runner_result( + job_model: JobModel, + result: _ProcessResult, + submit_result: _SubmitJobToRunnerResult, +) -> None: + if submit_result.job_runtime_data is not None: + _set_job_runtime_data(result, submit_result.job_runtime_data) + if submit_result.set_running_status: + _set_job_status(job_model, result, JobStatus.RUNNING) + + +# Convention: _get_result_* helpers merge the loaded job_model state with any pending +# updates recorded in result.job_update_map. Always use these (not job_model.attr directly) +# when the field may have been updated earlier in the same processing cycle. + + +def _get_result_status(job_model: JobModel, result: _ProcessResult) -> JobStatus: + return result.job_update_map.get("status", job_model.status) + + +def _get_result_disconnected_at(job_model: JobModel, result: _ProcessResult) -> Optional[datetime]: + return result.job_update_map.get("disconnected_at", job_model.disconnected_at) + + +def _get_result_job_runtime_data( + job_model: JobModel, result: _ProcessResult +) -> Optional[JobRuntimeData]: + jrd = result.job_update_map.get("job_runtime_data", job_model.job_runtime_data) + if jrd is None: + return None + return JobRuntimeData.__response__.parse_raw(jrd) + + +def _get_result_registered(job_model: JobModel, result: _ProcessResult) -> bool: + return result.job_update_map.get("registered", job_model.registered) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 61c8adaee..727120841 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -251,14 +251,6 @@ async def process(self, item: JobTerminatingPipelineItem): instance_model=get_or_error(instance_model), ) - set_processed_update_map_fields(result.job_update_map) - set_unlock_update_map_fields(result.job_update_map) - if instance_model is not None: - if result.instance_update_map is None: - result.instance_update_map = _InstanceUpdateMap() - instance_update_map = result.instance_update_map - set_processed_update_map_fields(instance_update_map) - set_unlock_update_map_fields(instance_update_map) await _apply_process_result( item=item, job_model=job_model, @@ -289,6 +281,11 @@ class _VolumeUpdateRow(TypedDict): last_job_processed_at: UpdateMapDateTime +@dataclass +class _UnregisterReplicaResult: + gateway_target: Optional[events.Target] # None = no gateway + + @dataclass class _ProcessResult: job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) @@ -296,8 +293,9 @@ class _ProcessResult: volume_update_rows: list[_VolumeUpdateRow] = field(default_factory=list) detached_volume_ids: set[uuid.UUID] = field(default_factory=set) unassign_event_message: Optional[str] = None - emit_unregister_replica_event: bool = False - unregister_gateway_target: Optional[events.Target] = None + replica_unregistration: Optional[_UnregisterReplicaResult] = ( + None # None = not unregistered yet + ) @dataclass @@ -433,6 +431,14 @@ async def _apply_process_result( instance_model: Optional[InstanceModel], result: _ProcessResult, ) -> None: + set_processed_update_map_fields(result.job_update_map) + set_unlock_update_map_fields(result.job_update_map) + if instance_model is not None and result.instance_update_map is None: + result.instance_update_map = _InstanceUpdateMap() + if result.instance_update_map is not None: + set_processed_update_map_fields(result.instance_update_map) + set_unlock_update_map_fields(result.instance_update_map) + async with get_session_ctx() as session: now = get_current_datetime() related_instance_lock_owner = _get_related_instance_lock_owner(item.id) @@ -536,10 +542,10 @@ async def _apply_process_result( ], ) - if result.emit_unregister_replica_event: + if result.replica_unregistration is not None: targets = [events.Target.from_model(job_model)] - if result.unregister_gateway_target is not None: - targets.append(result.unregister_gateway_target) + if result.replica_unregistration.gateway_target is not None: + targets.append(result.replica_unregistration.gateway_target) events.emit( session, "Service replica unregistered from receiving requests", @@ -689,10 +695,10 @@ async def _detach_job_volumes( async def _unregister_replica_and_update_result( result: _ProcessResult, job_model: JobModel ) -> None: - result.unregister_gateway_target = await _unregister_replica(job_model=job_model) + gateway_target = await _unregister_replica(job_model=job_model) if job_model.registered: result.job_update_map["registered"] = False - result.emit_unregister_replica_event = True + result.replica_unregistration = _UnregisterReplicaResult(gateway_target=gateway_target) async def _unregister_replica( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index 89eea7d92..406c216c7 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass, field from datetime import timedelta -from typing import Sequence +from typing import Optional, Sequence from sqlalchemy import or_, select, update from sqlalchemy.orm import joinedload, load_only @@ -202,13 +202,22 @@ def __init__( @sentry_utils.instrument_named_task("pipeline_tasks.VolumeWorker.process") async def process(self, item: VolumePipelineItem): + volume_model = await _refetch_locked_volume(item) + if volume_model is None: + log_lock_token_mismatch(logger, item) + return + if item.to_be_deleted: - await _process_to_be_deleted_item(item) + result = await _process_to_be_deleted_volume(volume_model) elif item.status == VolumeStatus.SUBMITTED: - await _process_submitted_item(item) + result = await _process_submitted_volume(volume_model) + else: + return + await _apply_process_result(item=item, volume_model=volume_model, result=result) -async def _process_submitted_item(item: VolumePipelineItem): + +async def _refetch_locked_volume(item: VolumePipelineItem) -> Optional[VolumeModel]: async with get_session_ctx() as session: res = await session.execute( select(VolumeModel) @@ -217,7 +226,7 @@ async def _process_submitted_item(item: VolumePipelineItem): VolumeModel.lock_token == item.lock_token, ) .options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends)) - .options(joinedload(VolumeModel.user)) + .options(joinedload(VolumeModel.user).load_only(UserModel.name)) .options( joinedload(VolumeModel.attachments) .joinedload(VolumeAttachmentModel.instance) @@ -225,13 +234,16 @@ async def _process_submitted_item(item: VolumePipelineItem): .load_only(FleetModel.name) ) ) - volume_model = res.unique().scalar_one_or_none() - if volume_model is None: - log_lock_token_mismatch(logger, item) - return + return res.unique().scalar_one_or_none() + - result = await _process_submitted_volume(volume_model) - update_map = result.update_map +async def _apply_process_result( + item: VolumePipelineItem, + volume_model: VolumeModel, + result: "_ProcessResult", +): + update_map = _VolumeUpdateMap() + update_map.update(result.update_map) set_processed_update_map_fields(update_map) set_unlock_update_map_fields(update_map) @@ -249,15 +261,25 @@ async def _process_submitted_item(item: VolumePipelineItem): updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: log_lock_token_changed_after_processing(logger, item) - # TODO: Clean up volume. + if item.status == VolumeStatus.SUBMITTED: + # TODO: Clean up volume. + pass return - emit_volume_status_change_event( - session=session, - volume_model=volume_model, - old_status=volume_model.status, - new_status=update_map.get("status", volume_model.status), - status_message=update_map.get("status_message", volume_model.status_message), - ) + if item.to_be_deleted: + events.emit( + session, + "Volume deleted", + actor=events.SystemActor(), + targets=[events.Target.from_model(volume_model)], + ) + else: + emit_volume_status_change_event( + session=session, + volume_model=volume_model, + old_status=volume_model.status, + new_status=update_map.get("status", volume_model.status), + status_message=update_map.get("status_message", volume_model.status_message), + ) class _VolumeUpdateMap(ItemUpdateMap, total=False): @@ -269,11 +291,11 @@ class _VolumeUpdateMap(ItemUpdateMap, total=False): @dataclass -class _SubmittedResult: +class _ProcessResult: update_map: _VolumeUpdateMap = field(default_factory=_VolumeUpdateMap) -async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResult: +async def _process_submitted_volume(volume_model: VolumeModel) -> _ProcessResult: volume = volume_model_to_volume(volume_model) try: backend = await backends_services.get_project_backend_by_type_or_error( @@ -287,7 +309,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu volume.name, volume.configuration.backend.value, ) - return _SubmittedResult( + return _ProcessResult( update_map={ "status": VolumeStatus.FAILED, "status_message": "Backend not available", @@ -314,7 +336,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu status_message = f"Backend error: {repr(e)}" if len(e.args) > 0: status_message = str(e.args[0]) - return _SubmittedResult( + return _ProcessResult( update_map={ "status": VolumeStatus.FAILED, "status_message": status_message, @@ -322,7 +344,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu ) except Exception as e: logger.exception("Got exception when creating volume %s", volume_model.name) - return _SubmittedResult( + return _ProcessResult( update_map={ "status": VolumeStatus.FAILED, "status_message": f"Unexpected error: {repr(e)}", @@ -332,7 +354,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu logger.info("Added new volume %s", volume_model.name) # Provisioned volumes marked as active since they become available almost immediately in AWS # TODO: Consider checking volume state - return _SubmittedResult( + return _ProcessResult( update_map={ "status": VolumeStatus.ACTIVE, "volume_provisioning_data": vpd.json(), @@ -340,63 +362,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu ) -async def _process_to_be_deleted_item(item: VolumePipelineItem): - async with get_session_ctx() as session: - res = await session.execute( - select(VolumeModel) - .where( - VolumeModel.id == item.id, - VolumeModel.lock_token == item.lock_token, - ) - .options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends)) - .options(joinedload(VolumeModel.user).load_only(UserModel.name)) - .options( - joinedload(VolumeModel.attachments) - .joinedload(VolumeAttachmentModel.instance) - .joinedload(InstanceModel.fleet) - .load_only(FleetModel.name) - ) - ) - volume_model = res.unique().scalar_one_or_none() - if volume_model is None: - log_lock_token_mismatch(logger, item) - return - - result = await _process_to_be_deleted_volume(volume_model) - update_map = _VolumeUpdateMap() - update_map.update(result.update_map) - set_processed_update_map_fields(update_map) - set_unlock_update_map_fields(update_map) - async with get_session_ctx() as session: - now = get_current_datetime() - resolve_now_placeholders(update_map, now=now) - res = await session.execute( - update(VolumeModel) - .where( - VolumeModel.id == volume_model.id, - VolumeModel.lock_token == volume_model.lock_token, - ) - .values(**update_map) - .returning(VolumeModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: - log_lock_token_changed_after_processing(logger, item) - return - events.emit( - session, - "Volume deleted", - actor=events.SystemActor(), - targets=[events.Target.from_model(volume_model)], - ) - - -@dataclass -class _ProcessToBeDeletedResult: - update_map: _VolumeUpdateMap = field(default_factory=_VolumeUpdateMap) - - -async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _ProcessToBeDeletedResult: +async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _ProcessResult: volume = volume_model_to_volume(volume_model) if volume.external: return _get_deleted_result() @@ -437,8 +403,8 @@ async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _ProcessTo return _get_deleted_result() -def _get_deleted_result() -> _ProcessToBeDeletedResult: - return _ProcessToBeDeletedResult( +def _get_deleted_result() -> _ProcessResult: + return _ProcessResult( update_map={ "deleted": True, "deleted_at": NOW_PLACEHOLDER, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index a0b97e7d5..c61d24953 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -128,12 +128,6 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 5}, max_instances=4 if replica == 0 else 1, ) - _scheduler.add_job( - process_running_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) _scheduler.add_job( process_runs, IntervalTrigger(seconds=2, jitter=1), @@ -153,6 +147,12 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 1}, max_instances=2 if replica == 0 else 1, ) + _scheduler.add_job( + process_running_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.add_job( process_terminating_jobs, IntervalTrigger(seconds=4, jitter=2), diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index ea3d53973..6ceebf9e3 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -102,6 +102,8 @@ JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2) +# NOTE: This scheduled task is going to be deprecated in favor of `JobRunningPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/jobs_running.py` in sync. async def process_running_jobs(batch_size: int = 1): tasks = [] for _ in range(batch_size): diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 56d9fea77..2a2405c4e 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -144,6 +144,10 @@ async def _process_next_run(): .where( JobModel.run_id == run_model.id, JobModel.id.not_in(job_lockset), + or_( + JobModel.lock_expires_at.is_(None), + JobModel.lock_expires_at < now, + ), ) .options( load_only(JobModel.id), diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index fc67e8024..87b44b50a 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta from typing import List, Optional, Union -from sqlalchemy import func, or_, select +from sqlalchemy import exists, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import ( contains_eager, @@ -105,7 +105,10 @@ ) from dstack._internal.server.services.locking import get_locker, string_to_lock_id from dstack._internal.server.services.logging import fmt -from dstack._internal.server.services.offers import get_offers_by_requirements +from dstack._internal.server.services.offers import ( + get_instance_offer_with_restricted_az, + get_offers_by_requirements, +) from dstack._internal.server.services.placement import ( find_or_create_suitable_placement_group, get_fleet_placement_group_models, @@ -409,13 +412,14 @@ async def _process_submitted_job( await session.commit() return - master_instance_provisioning_data = ( - await _fetch_fleet_with_master_instance_provisioning_data( - exit_stack=exit_stack, - session=session, - fleet_model=fleet_model, - job=job, - ) + ( + fleet_model, + master_instance_provisioning_data, + ) = await _fetch_fleet_with_master_instance_provisioning_data( + exit_stack=exit_stack, + session=session, + fleet_model=fleet_model, + job=job, ) master_provisioning_data = ( master_job_provisioning_data or master_instance_provisioning_data @@ -573,7 +577,7 @@ async def _fetch_fleet_with_master_instance_provisioning_data( session: AsyncSession, fleet_model: Optional[FleetModel], job: Job, -) -> Optional[JobProvisioningData]: +) -> tuple[Optional[FleetModel], Optional[JobProvisioningData]]: # TODO: When submitted-jobs provisioning moves to pipelines, stop inferring the # cluster master from loaded fleet instances here. Resolve the current master via # FleetModel.current_master_instance_id so jobs follow the same master election @@ -595,12 +599,11 @@ async def _fetch_fleet_with_master_instance_provisioning_data( await sqlite_commit(session) res = await session.execute( select(FleetModel) - .outerjoin(FleetModel.instances) .where( FleetModel.id == fleet_model.id, - or_( - InstanceModel.id.is_(None), - InstanceModel.deleted == True, + ~exists().where( + InstanceModel.fleet_id == fleet_model.id, + InstanceModel.deleted == False, ), ) .with_for_update(key_share=True, of=FleetModel) @@ -626,7 +629,7 @@ async def _fetch_fleet_with_master_instance_provisioning_data( fleet_model=fleet_model, fleet_spec=fleet_spec, ) - return master_instance_provisioning_data + return fleet_model, master_instance_provisioning_data async def _assign_job_to_fleet_instance( @@ -781,6 +784,13 @@ async def _run_jobs_on_new_instances( offer_volumes = _get_offer_volumes(volumes, offer) job_configurations = [JobConfiguration(job=j, volumes=offer_volumes) for j in jobs] compute = backend.compute() + if master_job_provisioning_data is not None: + # `get_offers_by_requirements()` already restricts backend and region from the master. + # Availability zone still has to be narrowed per offer. + offer = get_instance_offer_with_restricted_az( + instance_offer=offer, + master_job_provisioning_data=master_job_provisioning_data, + ) if ( fleet_model is not None and len(fleet_model.instances) == 0 diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index f8aa3f288..4655f600f 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -987,9 +987,8 @@ def _get_job_submission_cost(job_submission: JobSubmission) -> float: async def process_terminating_run(session: AsyncSession, run_model: RunModel): """ - Used by both `process_runs` and `stop_run` to process a TERMINATING run. Stops the jobs gracefully and marks them as TERMINATING. - Jobs should be terminated by `process_terminating_jobs`. + Jobs then should be terminated by `process_terminating_jobs`. When all jobs are terminated, assigns a finished status to the run. Caller must acquire the lock on run. """ diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index ac2f88a5d..7bb5f0713 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -209,6 +209,7 @@ async def list_project_volume_models( select(VolumeModel) .where(*filters) .options(joinedload(VolumeModel.user)) + .options(joinedload(VolumeModel.project)) .options( joinedload(VolumeModel.attachments) .joinedload(VolumeAttachmentModel.instance) @@ -245,6 +246,7 @@ async def get_project_volume_model_by_name( select(VolumeModel) .where(*filters) .options(joinedload(VolumeModel.user)) + .options(joinedload(VolumeModel.project)) .options( joinedload(VolumeModel.attachments) .joinedload(VolumeAttachmentModel.instance) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py new file mode 100644 index 000000000..a52924a55 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -0,0 +1,1884 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Optional +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch + +import pytest +from freezegun import freeze_time +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from dstack._internal import settings +from dstack._internal.core.errors import SSHError +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import NetworkMode +from dstack._internal.core.models.configurations import ( + DevEnvironmentConfiguration, + ProbeConfig, + ServiceConfiguration, +) +from dstack._internal.core.models.gateways import GatewayStatus +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy +from dstack._internal.core.models.runs import ( + JobRuntimeData, + JobStatus, + JobTerminationReason, + RunStatus, +) +from dstack._internal.core.models.volumes import InstanceMountPoint, VolumeMountPoint, VolumeStatus +from dstack._internal.server import settings as server_settings +from dstack._internal.server.background.pipeline_tasks.jobs_running import ( + JobRunningFetcher, + JobRunningPipeline, + JobRunningPipelineItem, + JobRunningWorker, + _RunnerAvailability, + _SubmitJobToRunnerResult, +) +from dstack._internal.server.models import JobModel, ProbeModel +from dstack._internal.server.schemas.runner import ( + HealthcheckResponse, + JobInfoResponse, + JobStateEvent, + PortMapping, + PullResponse, + TaskStatus, +) +from dstack._internal.server.services.runner.client import RunnerClient, ShimClient +from dstack._internal.server.services.runner.ssh import SSHTunnel +from dstack._internal.server.services.volumes import volume_model_to_volume +from dstack._internal.server.testing.common import ( + create_backend, + create_export, + create_fleet, + create_gateway, + create_gateway_compute, + create_instance, + create_job, + create_job_metrics_point, + create_probe, + create_project, + create_repo, + create_run, + create_user, + create_volume, + get_job_provisioning_data, + get_job_runtime_data, + get_run_spec, + get_volume_configuration, + list_events, +) +from dstack._internal.utils.common import get_current_datetime + +pytestmark = pytest.mark.usefixtures("image_config_mock") + + +@dataclass +class _ProbeSetup: + success_streak: int + ready_after: int + + +@pytest.fixture +def fetcher() -> JobRunningFetcher: + return JobRunningFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=10), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + +@pytest.fixture +def worker() -> JobRunningWorker: + return JobRunningWorker(queue=Mock(), heartbeater=Mock()) + + +@pytest.fixture +def ssh_tunnel_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = MagicMock(spec_set=SSHTunnel) + monkeypatch.setattr("dstack._internal.server.services.runner.ssh.SSHTunnel", mock) + return mock + + +@pytest.fixture +def shim_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(spec_set=ShimClient) + mock.healthcheck.return_value = HealthcheckResponse(service="dstack-shim", version="latest") + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.ShimClient", Mock(return_value=mock) + ) + return mock + + +@pytest.fixture +def runner_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock(spec_set=RunnerClient) + mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-runner", version="0.0.1.dev2" + ) + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.RunnerClient", Mock(return_value=mock) + ) + return mock + + +def _lock_job_foreign(job_model) -> None: + job_model.lock_expires_at = get_current_datetime() + timedelta(minutes=1) + job_model.lock_token = uuid.uuid4() + job_model.lock_owner = "OtherPipeline" + + +def _lock_job_expired_same_owner(job_model) -> None: + job_model.lock_expires_at = get_current_datetime() - timedelta(minutes=1) + job_model.lock_token = uuid.uuid4() + job_model.lock_owner = JobRunningPipeline.__name__ + + +def _lock_job(job_model) -> None: + job_model.lock_expires_at = get_current_datetime() + timedelta(seconds=30) + job_model.lock_token = uuid.uuid4() + job_model.lock_owner = JobRunningPipeline.__name__ + + +def _job_to_pipeline_item(job_model) -> JobRunningPipelineItem: + assert job_model.lock_token is not None + assert job_model.lock_expires_at is not None + return JobRunningPipelineItem( + __tablename__=job_model.__tablename__, + id=job_model.id, + lock_token=job_model.lock_token, + lock_expires_at=job_model.lock_expires_at, + prev_lock_expired=False, + status=job_model.status, + ) + + +async def _process_job( + session: AsyncSession, + worker: JobRunningWorker, + job_model, +) -> None: + _lock_job(job_model) + await session.commit() + await worker.process(_job_to_pipeline_item(job_model)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestJobRunningFetcher: + async def test_fetch_selects_eligible_jobs_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: JobRunningFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + provisioning = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + last_processed_at=stale - timedelta(seconds=4), + ) + pulling = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + last_processed_at=stale - timedelta(seconds=3), + ) + running = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=2), + ) + expired_same_owner = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=1), + ) + _lock_job_expired_same_owner(expired_same_owner) + + recent = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=now, + ) + foreign_locked = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=stale, + ) + _lock_job_foreign(foreign_locked) + finished = await create_job( + session=session, + run=run, + status=JobStatus.DONE, + last_processed_at=stale - timedelta(seconds=5), + ) + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [ + provisioning.id, + pulling.id, + running.id, + expired_same_owner.id, + ] + assert [item.status for item in items] == [ + JobStatus.PROVISIONING, + JobStatus.PULLING, + JobStatus.RUNNING, + JobStatus.RUNNING, + ] + + for job in [ + provisioning, + pulling, + running, + expired_same_owner, + recent, + foreign_locked, + finished, + ]: + await session.refresh(job) + + fetched_jobs = [provisioning, pulling, running, expired_same_owner] + assert all(job.lock_owner == JobRunningPipeline.__name__ for job in fetched_jobs) + assert all(job.lock_expires_at is not None for job in fetched_jobs) + assert all(job.lock_token is not None for job in fetched_jobs) + assert len({job.lock_token for job in fetched_jobs}) == 1 + + assert recent.lock_owner is None + assert foreign_locked.lock_owner == "OtherPipeline" + assert finished.lock_owner is None + + async def test_fetch_excludes_jobs_from_terminating_runs( + self, test_db, session: AsyncSession, fetcher: JobRunningFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + active_run = await create_run(session=session, project=project, repo=repo, user=user) + terminating_run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="terminating-run", + status=RunStatus.TERMINATING, + ) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + active_job = await create_job( + session=session, + run=active_run, + status=JobStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=1), + ) + terminating_run_job = await create_job( + session=session, + run=terminating_run, + status=JobStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=2), + ) + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [active_job.id] + + await session.refresh(active_job) + await session.refresh(terminating_run_job) + + assert active_job.lock_owner == JobRunningPipeline.__name__ + assert terminating_run_job.lock_owner is None + + async def test_fetch_returns_oldest_jobs_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: JobRunningFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + now = get_current_datetime() + + oldest = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == JobRunningPipeline.__name__ + assert middle.lock_owner == JobRunningPipeline.__name__ + assert newest.lock_owner is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestJobRunningWorker: + async def test_process_skips_when_lock_token_changes( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=False), + instance=instance, + instance_assigned=True, + ) + _lock_job(job) + await session.commit() + + item = _job_to_pipeline_item(job) + new_lock_token = uuid.uuid4() + job.lock_token = new_lock_token + await session.commit() + + await worker.process(item) + await session.refresh(job) + + assert job.lock_token == new_lock_token + assert job.status == JobStatus.PROVISIONING + assert job.lock_owner == JobRunningPipeline.__name__ + + async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=False), + instance=instance, + instance_assigned=True, + ) + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.healthcheck.return_value = None + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + runner_client_mock.healthcheck.assert_called_once() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() + + await session.refresh(job) + assert job.status == JobStatus.PROVISIONING + assert job.lock_token is None + assert job.lock_owner is None + + async def test_runs_provisioning_job( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=False), + job_runtime_data=get_job_runtime_data(), + instance=instance, + instance_assigned=True, + ) + before_processed_at = job.last_processed_at + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-runner", version="0.0.1.dev2" + ) + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 2 + assert runner_client_mock.healthcheck.call_count == 2 + runner_client_mock.submit_job.assert_called_once() + runner_client_mock.upload_code.assert_called_once() + runner_client_mock.run_job.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner is None + assert job.last_processed_at > before_processed_at + job_runtime_data = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert job_runtime_data.working_dir == "/dstack/run" + assert job_runtime_data.username == "dstack" + + @pytest.mark.parametrize("privileged", [False, True]) + async def test_provisioning_shim_with_volumes( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + privileged: bool, + ): + project_ssh_pub_key = "__project_ssh_pub_key__" + project = await create_project(session=session, ssh_public_key=project_ssh_pub_key) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + configuration=get_volume_configuration( + name="my-vol", backend=BackendType.AWS, region="us-east-1" + ), + backend=BackendType.AWS, + region="us-east-1", + ) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run_spec.configuration.privileged = privileged + run_spec.configuration.volumes = [ + VolumeMountPoint(name="my-vol", path="/volume"), + InstanceMountPoint(instance_path="/root/.cache", path="/cache"), + ] + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job_provisioning_data = get_job_provisioning_data(dockerized=True) + + with patch( + "dstack._internal.server.services.jobs.configurators.base.get_default_python_verison" + ) as py_version: + py_version.return_value = "3.13" + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, + ) + + await _process_job(session, worker, job) + + ssh_tunnel_mock.assert_called_once() + shim_client_mock.healthcheck.assert_called_once() + shim_client_mock.submit_task.assert_called_once_with( + task_id=job.id, + name="test-run-0-0", + registry_username="", + registry_password="", + image_name=( + f"dstackai/base:{settings.DSTACK_BASE_IMAGE_VERSION}-" + f"base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ), + container_user="root", + privileged=privileged, + gpu=None, + cpu=None, + memory=None, + shm_size=None, + network_mode=NetworkMode.HOST, + volumes=[volume_model_to_volume(volume)], + volume_mounts=[VolumeMountPoint(name="my-vol", path="/volume")], + instance_mounts=[InstanceMountPoint(instance_path="/root/.cache", path="/cache")], + gpu_devices=[], + host_ssh_user="ubuntu", + host_ssh_keys=["user_ssh_key"], + container_ssh_keys=[project_ssh_pub_key, "user_ssh_key"], + instance_id=job_provisioning_data.instance_id, + ) + await session.refresh(job) + assert job.status == JobStatus.PULLING + + async def test_pulling_shim( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) + + await _process_job(session, worker, job) + + assert ssh_tunnel_mock.call_count == 3 + shim_client_mock.get_task.assert_called_once() + assert runner_client_mock.healthcheck.call_count == 2 + runner_client_mock.submit_job.assert_called_once() + runner_client_mock.upload_code.assert_called_once() + runner_client_mock.run_job.assert_called_once() + await session.refresh(job) + assert job.status == JobStatus.RUNNING + job_runtime_data = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert job_runtime_data.ports == {10022: 32771, 10999: 32772} + assert job_runtime_data.working_dir == "/dstack/run" + assert job_runtime_data.username == "dstack" + + async def test_pulling_shim_port_mapping_not_ready( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = None + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + await _process_job(session, worker, job) + ssh_tunnel_mock.assert_called_once() + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_not_called() + runner_client_mock.submit_job.assert_not_called() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() + + await session.refresh(job) + assert job.status == JobStatus.PULLING + + async def test_pulling_shim_waiting_resets_disconnect_and_emits_reachable_event( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + disconnected_at=get_current_datetime() - timedelta(minutes=1), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = None + + await _process_job(session, worker, job) + + ssh_tunnel_mock.assert_called_once() + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_not_called() + await session.refresh(job) + events = await list_events(session) + assert job.status == JobStatus.PULLING + assert job.disconnected_at is None + assert len(events) == 1 + assert events[0].message == "Job became reachable" + + async def test_pulling_shim_runner_not_ready( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + runner_client_mock.healthcheck.return_value = None + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + await _process_job(session, worker, job) + assert ssh_tunnel_mock.call_count == 2 + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_called_once() + runner_client_mock.submit_job.assert_not_called() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() + + await session.refresh(job) + assert job.status == JobStatus.PULLING + + async def test_pulling_shim_uses_runtime_port_mapping_for_runner_calls( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + expected_ports = {10022: 32771, 10999: 32772} + + def assert_runner_availability(_, __, job_runtime_data): + assert job_runtime_data is not None + assert job_runtime_data.ports == expected_ports + return _RunnerAvailability.AVAILABLE + + def assert_submit_job_to_runner(_, __, job_runtime_data, **kwargs): + assert job_runtime_data is not None + assert job_runtime_data.ports == expected_ports + return _SubmitJobToRunnerResult(success=True) + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_runner_availability", + side_effect=assert_runner_availability, + ) as get_runner_availability_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._submit_job_to_runner", + side_effect=assert_submit_job_to_runner, + ) as submit_job_to_runner_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + return_value=b"", + ), + ): + await _process_job(session, worker, job) + ssh_tunnel_mock.assert_called_once() + get_runner_availability_mock.assert_called_once() + submit_job_to_runner_mock.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.PULLING + job_runtime_data = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert job_runtime_data.ports == expected_ports + + async def test_pulling_shim_failed( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.IDLE + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + ) + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.ssh.time.sleep"), + ): + from dstack._internal.core.errors import SSHError + + ssh_tunnel_cls.side_effect = SSHError + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 3 + + await session.refresh(job) + events = await list_events(session) + assert job.disconnected_at is not None + assert job.status == JobStatus.PULLING + assert len(events) == 1 + assert events[0].message == "Job became unreachable" + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.ssh.time.sleep"), + freeze_time(job.disconnected_at + timedelta(minutes=5)), + ): + from dstack._internal.core.errors import SSHError + + ssh_tunnel_cls.side_effect = SSHError + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 3 + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.INSTANCE_UNREACHABLE + assert job.remove_at is None + + async def test_provisioning_shim_force_stop_if_already_running_api_v1( + self, + monkeypatch: pytest.MonkeyPatch, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run_spec.configuration.image = "debian" + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + monkeypatch.setattr( + "dstack._internal.server.services.runner.ssh.SSHTunnel", Mock(return_value=MagicMock()) + ) + shim_client_mock = Mock() + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.ShimClient", + Mock(return_value=shim_client_mock), + ) + shim_client_mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-shim", version="0.0.1.dev2" + ) + shim_client_mock.is_api_v2_supported.return_value = False + shim_client_mock.submit.return_value = False + + await _process_job(session, worker, job) + + shim_client_mock.healthcheck.assert_called_once() + shim_client_mock.submit.assert_called_once() + shim_client_mock.stop.assert_called_once_with(force=True) + await session.refresh(job) + assert job.status == JobStatus.PROVISIONING + + async def test_master_job_waits_for_workers( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run_spec.configuration.startup_order = StartupOrder.WORKERS_FIRST + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + ) + instance1 = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + instance2 = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + master_job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=job_provisioning_data, + instance_assigned=True, + instance=instance1, + job_num=0, + last_processed_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ) + worker_job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=job_provisioning_data, + instance_assigned=True, + instance=instance2, + job_num=1, + last_processed_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + ) + + await _process_job(session, worker, master_job) + await session.refresh(master_job) + assert master_job.status == JobStatus.PROVISIONING + + worker_job.status = JobStatus.RUNNING + master_job.last_processed_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel"), + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-runner", version="0.0.1.dev2" + ) + await _process_job(session, worker, master_job) + + await session.refresh(master_job) + assert master_job.status == JobStatus.RUNNING + + async def test_apply_skips_when_lock_token_changes_after_processing( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=False), + job_runtime_data=get_job_runtime_data(), + instance=instance, + instance_assigned=True, + ) + _lock_job(job) + await session.commit() + original_lock_token = job.lock_token + replacement_lock_token = uuid.uuid4() + + async def invalidate_lock(*args, **kwargs): + job.lock_token = replacement_lock_token + await session.commit() + return b"" + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_runner_availability", + return_value=_RunnerAvailability.AVAILABLE, + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + side_effect=invalidate_lock, + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._submit_job_to_runner", + return_value=_SubmitJobToRunnerResult(success=True), + ), + ): + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.PROVISIONING + assert job.lock_token == replacement_lock_token + assert job.lock_token != original_lock_token + + async def test_updates_running_job( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + tmp_path: Path, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(dockerized=False), + instance=instance, + instance_assigned=True, + ) + last_processed_at = job.last_processed_at + + with ( + patch.object(server_settings, "SERVER_DIR_PATH", tmp_path), + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.pull.return_value = PullResponse( + job_states=[JobStateEvent(timestamp=1, state=JobStatus.RUNNING)], + job_logs=[], + runner_logs=[], + last_updated=1, + ) + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.runner_timestamp == 1 + + job.last_processed_at = last_processed_at + await session.commit() + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.pull.return_value = PullResponse( + job_states=[JobStateEvent(timestamp=1, state=JobStatus.DONE, exit_status=0)], + job_logs=[], + runner_logs=[], + last_updated=2, + ) + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.DONE_BY_RUNNER + assert job.exit_status == 0 + assert job.runner_timestamp == 2 + + async def test_running_job_disconnect_retries_then_terminates( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(dockerized=False), + instance=instance, + instance_assigned=True, + ) + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.ssh.time.sleep"), + ): + ssh_tunnel_cls.side_effect = SSHError + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 3 + + await session.refresh(job) + events = await list_events(session) + assert job.status == JobStatus.RUNNING + assert job.disconnected_at is not None + assert len(events) == 1 + assert events[0].message == "Job became unreachable" + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.ssh.time.sleep"), + freeze_time(job.disconnected_at + timedelta(minutes=5)), + ): + ssh_tunnel_cls.side_effect = SSHError + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 3 + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.INSTANCE_UNREACHABLE + + @pytest.mark.parametrize( + ( + "inactivity_duration", + "no_connections_secs", + "expected_status", + "expected_termination_reason", + "expected_inactivity_secs", + ), + [ + pytest.param( + "1h", + 60 * 60 - 1, + JobStatus.RUNNING, + None, + 60 * 60 - 1, + id="duration-not-exceeded", + ), + pytest.param( + "1h", + 60 * 60, + JobStatus.TERMINATING, + JobTerminationReason.INACTIVITY_DURATION_EXCEEDED, + 60 * 60, + id="duration-exceeded-exactly", + ), + pytest.param( + "1h", + 60 * 60 + 1, + JobStatus.TERMINATING, + JobTerminationReason.INACTIVITY_DURATION_EXCEEDED, + 60 * 60 + 1, + id="duration-exceeded", + ), + pytest.param("off", 60 * 60, JobStatus.RUNNING, None, None, id="duration-off"), + pytest.param(False, 60 * 60, JobStatus.RUNNING, None, None, id="duration-false"), + pytest.param(None, 60 * 60, JobStatus.RUNNING, None, None, id="duration-none"), + pytest.param( + "1h", + None, + JobStatus.TERMINATING, + JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY, + None, + id="legacy-runner", + ), + pytest.param( + None, + None, + JobStatus.RUNNING, + None, + None, + id="legacy-runner-without-duration", + ), + ], + ) + async def test_inactivity_duration( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + inactivity_duration, + no_connections_secs: Optional[int], + expected_status: JobStatus, + expected_termination_reason: Optional[JobTerminationReason], + expected_inactivity_secs: Optional[int], + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + run_name="test-run", + run_spec=get_run_spec( + run_name="test-run", + repo_id=repo.name, + configuration=DevEnvironmentConfiguration( + name="test-run", + ide="vscode", + inactivity_duration=inactivity_duration, + ), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(), + instance=instance, + instance_assigned=True, + ) + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + no_connections_secs=no_connections_secs, + ) + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + runner_client_mock.pull.assert_called_once() + + await session.refresh(job) + assert job.status == expected_status + assert job.termination_reason == expected_termination_reason + assert job.inactivity_secs == expected_inactivity_secs + + @pytest.mark.parametrize( + ["samples", "expected_status"], + [ + pytest.param( + [ + (datetime(2023, 1, 1, 12, 25, 20, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 25, 30, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 29, 50, tzinfo=timezone.utc), 40), + ], + JobStatus.RUNNING, + id="not-enough-points", + ), + pytest.param( + [ + (datetime(2023, 1, 1, 12, 20, 10, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 20, 20, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 29, 50, tzinfo=timezone.utc), 80), + ], + JobStatus.RUNNING, + id="any-above-min", + ), + pytest.param( + [ + (datetime(2023, 1, 1, 12, 10, 10, tzinfo=timezone.utc), 80), + (datetime(2023, 1, 1, 12, 10, 20, tzinfo=timezone.utc), 80), + (datetime(2023, 1, 1, 12, 20, 10, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 20, 20, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 29, 50, tzinfo=timezone.utc), 40), + ], + JobStatus.TERMINATING, + id="all-below-min", + ), + ], + ) + @freeze_time(datetime(2023, 1, 1, 12, 30, tzinfo=timezone.utc)) + async def test_gpu_utilization( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + samples: list[tuple[datetime, int]], + expected_status: JobStatus, + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + run_name="test-run", + run_spec=get_run_spec( + run_name="test-run", + repo_id=repo.name, + configuration=DevEnvironmentConfiguration( + name="test-run", + ide="vscode", + utilization_policy=UtilizationPolicy( + min_gpu_utilization=80, + time_window=600, + ), + ), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(), + instance=instance, + instance_assigned=True, + last_processed_at=datetime(2023, 1, 1, 11, 30, tzinfo=timezone.utc), + ) + for timestamp, gpu_util in samples: + await create_job_metrics_point( + session=session, + job_model=job, + timestamp=timestamp, + gpus_memory_usage_bytes=[1024, 1024], + gpus_util_percent=[gpu_util, 100], + ) + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + no_connections_secs=0, + ) + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + runner_client_mock.pull.assert_called_once() + + await session.refresh(job) + assert job.status == expected_status + if expected_status == JobStatus.TERMINATING: + assert ( + job.termination_reason == JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY + ) + assert job.termination_reason_message == ( + "The job GPU utilization below 80% for 600 seconds" + ) + else: + assert job.termination_reason is None + assert job.termination_reason_message is None + + @pytest.mark.parametrize("probe_count", [1, 2]) + async def test_creates_probe_models_and_not_registers_service_replica( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + probe_count: int, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, + image="ubuntu", + probes=[ProbeConfig(type="http", url=f"/{i}") for i in range(probe_count)], + ), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + assert len(job.probes) == 0 + await _process_job(session, worker, job) + + await session.refresh(job) + job = ( + await session.execute( + select(JobModel) + .where(JobModel.id == job.id) + .options(selectinload(JobModel.probes)) + ) + ).scalar_one() + assert job.status == JobStatus.RUNNING + assert [probe.probe_num for probe in job.probes] == list(range(probe_count)) + assert not job.registered + + async def test_registers_service_replica_immediately_if_no_probes( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration(port=80, image="ubuntu"), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + await _process_job(session, worker, job) + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.registered + events = await list_events(session) + assert {event.message for event in events} == { + "Job status changed PULLING -> RUNNING", + "Service replica registered to receive requests", + } + + @pytest.mark.parametrize( + ("probes", "expect_to_register"), + [ + ([_ProbeSetup(success_streak=0, ready_after=1)], False), + ([_ProbeSetup(success_streak=1, ready_after=1)], True), + ( + [ + _ProbeSetup(success_streak=1, ready_after=1), + _ProbeSetup(success_streak=1, ready_after=2), + ], + False, + ), + ( + [ + _ProbeSetup(success_streak=1, ready_after=1), + _ProbeSetup(success_streak=3, ready_after=2), + ], + True, + ), + ], + ) + async def test_registers_service_replica_only_after_probes_pass( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, + probes: list[_ProbeSetup], + expect_to_register: bool, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, + image="ubuntu", + probes=[ + ProbeConfig(type="http", url=f"/{i}", ready_after=probe.ready_after) + for i, probe in enumerate(probes) + ], + ), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + registered=False, + ) + for i, probe in enumerate(probes): + await create_probe( + session=session, + job=job, + probe_num=i, + success_streak=probe.success_streak, + ) + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + ) + + await _process_job(session, worker, job) + + await session.refresh(job) + events = await list_events(session) + if expect_to_register: + assert job.registered + assert len(events) == 1 + assert events[0].message == "Service replica registered to receive requests" + else: + assert not job.registered + assert not events + + async def test_registers_service_replica_in_gateway( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + mock_gateway_connection: AsyncMock, + ): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + name="test-gateway", + wildcard_domain="example.com", + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, image="ubuntu", gateway="test-gateway" + ), + ), + gateway=gateway, + ) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + fleet=fleet, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + await _process_job(session, worker, job) + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.registered + events = await list_events(session) + assert {event.message for event in events} == { + "Job status changed PULLING -> RUNNING", + "Service replica registered to receive requests", + } + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value.register_replica.assert_called_once_with( + run=ANY, + job_spec=ANY, + job_submission=ANY, + instance_project_ssh_private_key=None, + ssh_head_proxy=None, + ssh_head_proxy_private_key=None, + ) + + async def test_registers_service_replica_in_gateway_when_running_on_imported_instance( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + mock_gateway_connection: AsyncMock, + ): + user = await create_user(session=session) + exporter_project = await create_project( + session=session, name="exporter", owner=user, ssh_private_key="exporter-private-key" + ) + importer_project = await create_project(session=session, name="importer", owner=user) + fleet = await create_fleet(session=session, project=exporter_project) + instance = await create_instance( + session=session, + project=exporter_project, + status=InstanceStatus.BUSY, + fleet=fleet, + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + repo = await create_repo(session=session, project_id=importer_project.id) + backend = await create_backend(session=session, project_id=importer_project.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + gateway = await create_gateway( + session=session, + project_id=importer_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + name="test-gateway", + wildcard_domain="example.com", + ) + run = await create_run( + session=session, + project=importer_project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, image="ubuntu", gateway="test-gateway" + ), + ), + gateway=gateway, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + await _process_job(session, worker, job) + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.registered + events = await list_events(session) + assert {event.message for event in events} == { + "Job status changed PULLING -> RUNNING", + "Service replica registered to receive requests", + } + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value.register_replica.assert_called_once_with( + run=ANY, + job_spec=ANY, + job_submission=ANY, + instance_project_ssh_private_key="exporter-private-key", + ssh_head_proxy=None, + ssh_head_proxy_private_key=None, + ) + + async def test_apply_skips_probe_insert_when_lock_token_changes_after_processing( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, + image="ubuntu", + probes=[ProbeConfig(type="http", url="/health")], + ), + ), + ) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + _lock_job(job) + await session.commit() + replacement_lock_token = uuid.uuid4() + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + async def invalidate_lock(*args, **kwargs): + job.lock_token = replacement_lock_token + await session.commit() + return b"" + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + side_effect=invalidate_lock, + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._submit_job_to_runner", + return_value=_SubmitJobToRunnerResult( + success=True, + set_running_status=True, + ), + ), + ): + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.PULLING + assert job.lock_token == replacement_lock_token + probes = ( + (await session.execute(select(ProbeModel).where(ProbeModel.job_id == job.id))) + .scalars() + .all() + ) + assert probes == [] diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py index 63dfaaa45..ad1cbb7cf 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py @@ -9,6 +9,7 @@ from dstack._internal.core.errors import BackendError, BackendNotAvailable from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.volumes import VolumeProvisioningData, VolumeStatus +from dstack._internal.server.background.pipeline_tasks import volumes as volumes_pipeline from dstack._internal.server.background.pipeline_tasks.volumes import ( VolumeFetcher, VolumePipeline, @@ -349,6 +350,72 @@ async def test_marks_volume_failed_if_backend_returns_error( assert len(events) == 1 assert events[0].message == "Volume status changed SUBMITTED -> FAILED (Some error)" + async def test_skips_processing_if_lock_token_changed_before_refetch( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + item = _volume_to_pipeline_item(volume) + + volume.lock_token = uuid.uuid4() + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes._process_submitted_volume" + ) as process_volume_mock: + await worker.process(item) + process_volume_mock.assert_not_awaited() + + await session.refresh(volume) + assert volume.status == VolumeStatus.SUBMITTED + events = await list_events(session) + assert len(events) == 0 + + async def test_skips_apply_if_lock_token_changed_after_processing( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + async def _change_lock_token_and_return_result(_volume_model: VolumeModel): + volume.lock_token = uuid.uuid4() + await session.commit() + return volumes_pipeline._ProcessResult( + update_map={ + "status": VolumeStatus.ACTIVE, + } + ) + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes._process_submitted_volume", + side_effect=_change_lock_token_and_return_result, + ) as process_volume_mock: + await worker.process(_volume_to_pipeline_item(volume)) + process_volume_mock.assert_awaited_once() + + await session.refresh(volume) + assert volume.status == VolumeStatus.SUBMITTED + events = await list_events(session) + assert len(events) == 0 + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index 239cc265b..66b38f331 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -118,7 +118,7 @@ class TestProcessRunningJobs: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( - self, test_db, session: AsyncSession + self, test_db, session: AsyncSession, ssh_tunnel_mock: Mock, runner_client_mock: Mock ): project = await create_project(session=session) user = await create_user(session=session) @@ -147,37 +147,20 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( instance=instance, instance_assigned=True, ) - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - patch( - "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", - new_callable=AsyncMock, - ) as get_job_file_archives_mock, - patch( - "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", - new_callable=AsyncMock, - ) as get_job_code_mock, - patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, - ): + runner_client_mock.healthcheck.return_value = None + with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: datetime_mock.return_value = datetime(2023, 1, 2, 5, 12, 30, 10, tzinfo=timezone.utc) - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.healthcheck = Mock() - runner_client_mock.healthcheck.return_value = None await process_running_jobs() - SSHTunnelMock.assert_called_once() - runner_client_mock.healthcheck.assert_called_once() - get_job_file_archives_mock.assert_not_awaited() - get_job_code_mock.assert_not_awaited() + ssh_tunnel_mock.assert_called() + runner_client_mock.healthcheck.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PROVISIONING @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_runs_provisioning_job(self, test_db, session: AsyncSession): + async def test_runs_provisioning_job( + self, test_db, session: AsyncSession, ssh_tunnel_mock: Mock, runner_client_mock: Mock + ): project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo( @@ -205,27 +188,16 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): instance=instance, instance_assigned=True, ) - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.healthcheck.return_value = HealthcheckResponse( - service="dstack-runner", version="0.0.1.dev2" - ) - runner_client_mock.run_job.return_value = JobInfoResponse( - working_dir="/dstack/run", username="dstack" - ) - await process_running_jobs() - assert SSHTunnelMock.call_count == 2 - assert runner_client_mock.healthcheck.call_count == 2 - runner_client_mock.submit_job.assert_called_once() - runner_client_mock.upload_code.assert_called_once() - runner_client_mock.run_job.assert_called_once() + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) + await process_running_jobs() + ssh_tunnel_mock.assert_called() + assert runner_client_mock.healthcheck.call_count == 2 + runner_client_mock.submit_job.assert_called_once() + runner_client_mock.upload_code.assert_called_once() + runner_client_mock.run_job.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.RUNNING jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) assert jrd.working_dir == "/dstack/run" @@ -233,7 +205,14 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_updates_running_job(self, test_db, session: AsyncSession, tmp_path: Path): + async def test_running_job_updates_runner_timestamp( + self, + test_db, + session: AsyncSession, + tmp_path: Path, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, + ): project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo( @@ -260,46 +239,63 @@ async def test_updates_running_job(self, test_db, session: AsyncSession, tmp_pat instance=instance, instance_assigned=True, ) - last_processed_at = job.last_processed_at - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - patch.object(server_settings, "SERVER_DIR_PATH", tmp_path), - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.pull.return_value = PullResponse( - job_states=[JobStateEvent(timestamp=1, state=JobStatus.RUNNING)], - job_logs=[], - runner_logs=[], - last_updated=1, - ) + runner_client_mock.pull.return_value = PullResponse( + job_states=[JobStateEvent(timestamp=1, state=JobStatus.RUNNING)], + job_logs=[], + runner_logs=[], + last_updated=1, + ) + with patch.object(server_settings, "SERVER_DIR_PATH", tmp_path): await process_running_jobs() - SSHTunnelMock.assert_called_once() + ssh_tunnel_mock.assert_called() await session.refresh(job) - assert job is not None assert job.status == JobStatus.RUNNING assert job.runner_timestamp == 1 - job.last_processed_at = last_processed_at - await session.commit() - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.pull.return_value = PullResponse( - job_states=[JobStateEvent(timestamp=1, state=JobStatus.DONE, exit_status=0)], - job_logs=[], - runner_logs=[], - last_updated=2, - ) - await process_running_jobs() - SSHTunnelMock.assert_called_once() + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_running_job_terminates_when_done_by_runner( + self, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, + ) + runner_client_mock.pull.return_value = PullResponse( + job_states=[JobStateEvent(timestamp=1, state=JobStatus.DONE, exit_status=0)], + job_logs=[], + runner_logs=[], + last_updated=2, + ) + await process_running_jobs() + ssh_tunnel_mock.assert_called() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.DONE_BY_RUNNER assert job.exit_status == 0 @@ -395,7 +391,6 @@ async def test_provisioning_shim_with_volumes( instance_id=job_provisioning_data.instance_id, ) await session.refresh(job) - assert job is not None assert job.status == JobStatus.PULLING @pytest.mark.asyncio @@ -442,14 +437,13 @@ async def test_pulling_shim( await process_running_jobs() - assert ssh_tunnel_mock.call_count == 3 + ssh_tunnel_mock.assert_called() shim_client_mock.get_task.assert_called_once() assert runner_client_mock.healthcheck.call_count == 2 runner_client_mock.submit_job.assert_called_once() runner_client_mock.upload_code.assert_called_once() runner_client_mock.run_job.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.RUNNING jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) assert jrd.ports == { @@ -515,7 +509,6 @@ async def test_pulling_shim_port_mapping_not_ready( get_job_file_archives_mock.assert_not_awaited() get_job_code_mock.assert_not_awaited() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PULLING @pytest.mark.asyncio @@ -578,7 +571,6 @@ async def test_pulling_shim_runner_not_ready( get_job_code_mock.assert_not_awaited() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PULLING @pytest.mark.asyncio @@ -661,7 +653,6 @@ def assert_submit_job_to_runner(_, __, job_runtime_data, **kwargs): submit_job_to_runner_mock.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PULLING jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) assert jrd.ports == expected_ports @@ -700,10 +691,9 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession): ): SSHTunnelMock.side_effect = SSHError await process_running_jobs() - assert SSHTunnelMock.call_count == 3 + SSHTunnelMock.assert_called() await session.refresh(job) events = await list_events(session) - assert job is not None assert job.disconnected_at is not None assert job.status == JobStatus.PULLING assert len(events) == 1 @@ -715,7 +705,7 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession): ): SSHTunnelMock.side_effect = SSHError await process_running_jobs() - assert SSHTunnelMock.call_count == 3 + SSHTunnelMock.assert_called() await session.refresh(job) assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.INSTANCE_UNREACHABLE @@ -838,6 +828,8 @@ async def test_inactivity_duration( self, test_db, session: AsyncSession, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, inactivity_duration, no_connections_secs: Optional[int], expected_status: JobStatus, @@ -880,23 +872,16 @@ async def test_inactivity_duration( instance=instance, instance_assigned=True, ) - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.pull.return_value = PullResponse( - job_states=[], - job_logs=[], - runner_logs=[], - last_updated=0, - no_connections_secs=no_connections_secs, - ) - await process_running_jobs() - SSHTunnelMock.assert_called_once() - runner_client_mock.pull.assert_called_once() + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + no_connections_secs=no_connections_secs, + ) + await process_running_jobs() + ssh_tunnel_mock.assert_called() + runner_client_mock.pull.assert_called_once() await session.refresh(job) assert job.status == expected_status assert job.termination_reason == expected_termination_reason @@ -943,6 +928,8 @@ async def test_gpu_utilization( self, test_db, session: AsyncSession, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, samples: list[tuple[datetime, int]], expected_status: JobStatus, ) -> None: @@ -995,23 +982,16 @@ async def test_gpu_utilization( gpus_memory_usage_bytes=[1024, 1024], gpus_util_percent=[gpu_util, 100], ) - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.pull.return_value = PullResponse( - job_states=[], - job_logs=[], - runner_logs=[], - last_updated=0, - no_connections_secs=0, - ) - await process_running_jobs() - SSHTunnelMock.assert_called_once() - runner_client_mock.pull.assert_called_once() + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + no_connections_secs=0, + ) + await process_running_jobs() + ssh_tunnel_mock.assert_called() + runner_client_mock.pull.assert_called_once() await session.refresh(job) assert job.status == expected_status if expected_status == JobStatus.TERMINATING: @@ -1025,7 +1005,9 @@ async def test_gpu_utilization( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession): + async def test_master_job_waits_for_workers( + self, test_db, session: AsyncSession, ssh_tunnel_mock: Mock, runner_client_mock: Mock + ): project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo( @@ -1075,6 +1057,9 @@ async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession job_num=1, last_processed_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), ) + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) await process_running_jobs() await session.refresh(master_job) assert master_job.status == JobStatus.PROVISIONING @@ -1082,17 +1067,7 @@ async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession # To guarantee master_job is processed next master_job.last_processed_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel"), - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.healthcheck.return_value = HealthcheckResponse( - service="dstack-runner", version="0.0.1.dev2" - ) - await process_running_jobs() + await process_running_jobs() await session.refresh(master_job) assert master_job.status == JobStatus.RUNNING diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py index b0481ad4e..60780d06b 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py @@ -87,7 +87,6 @@ async def test_fails_job_when_no_backends(self, test_db, session: AsyncSession): ) await process_submitted_jobs() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY @@ -145,7 +144,6 @@ async def test_provisions_job( backend_mock.compute.return_value.run_job.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PROVISIONING @pytest.mark.asyncio @@ -194,7 +192,6 @@ async def test_fails_job_when_privileged_true_and_no_offers_with_create_instance backend_mock.compute.return_value.run_job.assert_not_called() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY @@ -244,7 +241,6 @@ async def test_fails_job_when_instance_mounts_and_no_offers_with_create_instance backend_mock.compute.return_value.run_job.assert_not_called() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY @@ -291,7 +287,6 @@ async def test_provisions_job_with_optional_instance_volume_not_attached( await process_submitted_jobs() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PROVISIONING @pytest.mark.asyncio @@ -328,13 +323,12 @@ async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession): await process_submitted_jobs() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_assignes_job_to_instance(self, test_db, session: AsyncSession): + async def test_assigns_job_to_instance(self, test_db, session: AsyncSession): project = await create_project(session) user = await create_user(session) repo = await create_repo( diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py index d2b4d2d31..da58e9708 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import timedelta, timezone from unittest.mock import Mock, patch import pytest @@ -57,7 +57,6 @@ async def test_terminates_job(self, session: AsyncSession): run=run, status=JobStatus.TERMINATING, termination_reason=JobTerminationReason.TERMINATED_BY_USER, - submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), job_provisioning_data=job_provisioning_data, instance=instance, ) @@ -70,7 +69,6 @@ async def test_terminates_job(self, session: AsyncSession): SSHTunnelMock.assert_called_once() shim_client_mock.healthcheck.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATED async def test_detaches_job_volumes(self, session: AsyncSession): @@ -103,7 +101,6 @@ async def test_detaches_job_volumes(self, session: AsyncSession): run=run, status=JobStatus.TERMINATING, termination_reason=JobTerminationReason.TERMINATED_BY_USER, - submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), job_provisioning_data=job_provisioning_data, instance=instance, ) @@ -149,7 +146,6 @@ async def test_force_detaches_job_volumes(self, session: AsyncSession): run=run, status=JobStatus.TERMINATING, termination_reason=JobTerminationReason.TERMINATED_BY_USER, - submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), job_provisioning_data=job_provisioning_data, instance=instance, ) @@ -304,7 +300,6 @@ async def test_detaches_job_volumes_on_shared_instance(self, session: AsyncSessi run=run, status=JobStatus.TERMINATING, termination_reason=JobTerminationReason.TERMINATED_BY_USER, - submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), job_provisioning_data=job_provisioning_data, job_runtime_data=get_job_runtime_data(volume_names=["vol-1"]), instance=instance,