From d4585ff90d4210afdbc23b6ea8e70ea0652010e3 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 11 Mar 2026 16:01:09 +0500 Subject: [PATCH 01/21] Add running jobs pipeline fetcher scaffold --- .../background/pipeline_tasks/jobs_running.py | 176 ++++++++++++++ .../pipeline_tasks/test_running_jobs.py | 224 ++++++++++++++++++ 2 files changed, 400 insertions(+) create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py new file mode 100644 index 000000000..2e5440783 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -0,0 +1,176 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import Sequence + +from sqlalchemy import or_, select +from sqlalchemy.orm import load_only + +from dstack._internal.core.models.runs import JobStatus, RunStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + Worker, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class JobRunningPipelineItem(PipelineItem): + status: JobStatus + + +class JobRunningPipeline(Pipeline[JobRunningPipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=10), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[JobRunningPipelineItem]( + model_type=JobModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = JobRunningFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + JobRunningWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return JobModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[JobRunningPipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[JobRunningPipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["JobRunningWorker"]: + return self.__workers + + +class JobRunningFetcher(Fetcher[JobRunningPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[JobRunningPipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[JobRunningPipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.JobRunningFetcher.fetch") + async def fetch(self, limit: int) -> list[JobRunningPipelineItem]: + job_lock, _ = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__) + async with job_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(JobModel) + .join(JobModel.run) + .where( + JobModel.status.in_( + [JobStatus.PROVISIONING, JobStatus.PULLING, JobStatus.RUNNING] + ), + RunModel.status.not_in([RunStatus.TERMINATING]), + JobModel.last_processed_at < now - self._min_processing_interval, + or_( + JobModel.lock_expires_at.is_(None), + JobModel.lock_expires_at < now, + ), + or_( + JobModel.lock_owner.is_(None), + JobModel.lock_owner == JobRunningPipeline.__name__, + ), + ) + .order_by(JobModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=JobModel) + .options( + load_only( + JobModel.id, + JobModel.lock_token, + JobModel.lock_expires_at, + JobModel.status, + ) + ) + ) + job_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for job_model in job_models: + prev_lock_expired = job_model.lock_expires_at is not None + job_model.lock_expires_at = lock_expires_at + job_model.lock_token = lock_token + job_model.lock_owner = JobRunningPipeline.__name__ + items.append( + JobRunningPipelineItem( + __tablename__=JobModel.__tablename__, + id=job_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + status=job_model.status, + ) + ) + await session.commit() + return items + + +class JobRunningWorker(Worker[JobRunningPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[JobRunningPipelineItem], + heartbeater: Heartbeater[JobRunningPipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.JobRunningWorker.process") + async def process(self, item: JobRunningPipelineItem): + raise NotImplementedError("JobRunningWorker.process() is implemented in a later step") diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py new file mode 100644 index 000000000..ea362f9c3 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -0,0 +1,224 @@ +import asyncio +import uuid +from datetime import timedelta +from unittest.mock import Mock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.runs import JobStatus, RunStatus +from dstack._internal.server.background.pipeline_tasks.jobs_running import ( + JobRunningFetcher, + JobRunningPipeline, +) +from dstack._internal.server.testing.common import ( + create_job, + create_project, + create_repo, + create_run, + create_user, +) +from dstack._internal.utils.common import get_current_datetime + +pytestmark = pytest.mark.usefixtures("image_config_mock") + + +@pytest.fixture +def fetcher() -> JobRunningFetcher: + return JobRunningFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=10), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + +def _lock_job_foreign(job_model): + job_model.lock_expires_at = get_current_datetime() + timedelta(minutes=1) + job_model.lock_token = uuid.uuid4() + job_model.lock_owner = "OtherPipeline" + + +def _lock_job_expired_same_owner(job_model): + job_model.lock_expires_at = get_current_datetime() - timedelta(minutes=1) + job_model.lock_token = uuid.uuid4() + job_model.lock_owner = JobRunningPipeline.__name__ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestJobRunningFetcher: + async def test_fetch_selects_eligible_jobs_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: JobRunningFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + provisioning = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + last_processed_at=stale - timedelta(seconds=4), + ) + pulling = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + last_processed_at=stale - timedelta(seconds=3), + ) + running = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=2), + ) + expired_same_owner = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=1), + ) + _lock_job_expired_same_owner(expired_same_owner) + + recent = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=now, + ) + foreign_locked = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=stale, + ) + _lock_job_foreign(foreign_locked) + finished = await create_job( + session=session, + run=run, + status=JobStatus.DONE, + last_processed_at=stale - timedelta(seconds=5), + ) + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [ + provisioning.id, + pulling.id, + running.id, + expired_same_owner.id, + ] + assert [item.status for item in items] == [ + JobStatus.PROVISIONING, + JobStatus.PULLING, + JobStatus.RUNNING, + JobStatus.RUNNING, + ] + + for job in [ + provisioning, + pulling, + running, + expired_same_owner, + recent, + foreign_locked, + finished, + ]: + await session.refresh(job) + + fetched_jobs = [provisioning, pulling, running, expired_same_owner] + assert all(job.lock_owner == JobRunningPipeline.__name__ for job in fetched_jobs) + assert all(job.lock_expires_at is not None for job in fetched_jobs) + assert all(job.lock_token is not None for job in fetched_jobs) + assert len({job.lock_token for job in fetched_jobs}) == 1 + + assert recent.lock_owner is None + assert foreign_locked.lock_owner == "OtherPipeline" + assert finished.lock_owner is None + + async def test_fetch_excludes_jobs_from_terminating_runs( + self, test_db, session: AsyncSession, fetcher: JobRunningFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + active_run = await create_run(session=session, project=project, repo=repo, user=user) + terminating_run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="terminating-run", + status=RunStatus.TERMINATING, + ) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + active_job = await create_job( + session=session, + run=active_run, + status=JobStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=1), + ) + terminating_run_job = await create_job( + session=session, + run=terminating_run, + status=JobStatus.RUNNING, + last_processed_at=stale - timedelta(seconds=2), + ) + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [active_job.id] + + await session.refresh(active_job) + await session.refresh(terminating_run_job) + + assert active_job.lock_owner == JobRunningPipeline.__name__ + assert terminating_run_job.lock_owner is None + + async def test_fetch_returns_oldest_jobs_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: JobRunningFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + now = get_current_datetime() + + oldest = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == JobRunningPipeline.__name__ + assert middle.lock_owner == JobRunningPipeline.__name__ + assert newest.lock_owner is None From 88a7de15cd729c4a96f7d863e141528ef6622987 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 11 Mar 2026 16:27:53 +0500 Subject: [PATCH 02/21] Prototype job pipeline for provisioning and pullings states --- .../background/pipeline_tasks/jobs_running.py | 994 +++++++++++++++++- .../_internal/server/services/volumes.py | 2 + .../pipeline_tasks/test_running_jobs.py | 789 +++++++++++++- 3 files changed, 1772 insertions(+), 13 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 2e5440783..35643a861 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -1,35 +1,143 @@ import asyncio +import enum import uuid from dataclasses import dataclass -from datetime import timedelta -from typing import Sequence +from datetime import datetime, timedelta +from typing import Dict, Iterable, Literal, Optional, Sequence, Union -from sqlalchemy import or_, select -from sqlalchemy.orm import load_only +from sqlalchemy import and_, func, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only -from dstack._internal.core.models.runs import JobStatus, RunStatus +from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import NetworkMode, RegistryAuth +from dstack._internal.core.models.files import FileArchiveMapping +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.profiles import StartupOrder +from dstack._internal.core.models.repos import RemoteRepoCreds +from dstack._internal.core.models.runs import ( + ClusterInfo, + Job, + JobProvisioningData, + JobRuntimeData, + JobSpec, + JobStatus, + JobSubmission, + JobTerminationReason, + Run, + RunStatus, +) +from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, Worker, + log_lock_token_changed_after_processing, + log_lock_token_mismatch, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) +from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout from dstack._internal.server.db import get_db, get_session_ctx -from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.models import ( + FleetModel, + InstanceModel, + JobModel, + ProbeModel, + ProjectModel, + RepoModel, + RunModel, + UserModel, +) +from dstack._internal.server.schemas.runner import TaskStatus +from dstack._internal.server.services import events +from dstack._internal.server.services import files as files_services +from dstack._internal.server.services.backends.provisioning import ( + get_instance_specific_gpu_devices, + get_instance_specific_mounts, + resolve_provisioning_image_name, +) +from dstack._internal.server.services.instances import ( + get_instance_remote_connection_info, + get_instance_ssh_private_keys, +) +from dstack._internal.server.services.jobs import ( + emit_job_status_change_event, + find_job, + get_job_attached_volumes, + get_job_runtime_data, + is_master_job, + job_model_to_job_submission, +) from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.repos import ( + get_code_model, + get_repo_creds, + repo_model_to_repo_head_with_creds, +) +from dstack._internal.server.services.runner import client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.services.runs import run_model_to_run +from dstack._internal.server.services.secrets import get_project_secrets_mapping +from dstack._internal.server.services.storage import get_default_storage from dstack._internal.server.utils import sentry_utils -from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async +from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) +JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2) +"""`JOB_DISCONNECTED_RETRY_TIMEOUT` is the minimum time before terminating active job in case of connectivity issues.""" + + @dataclass class JobRunningPipelineItem(PipelineItem): status: JobStatus +@dataclass +class _RunningJobContext: + job_model: JobModel + run_model: RunModel + repo_model: RepoModel + project: ProjectModel + run: Run + job: Job + job_submission: JobSubmission + job_provisioning_data: Optional[JobProvisioningData] + initial_status: JobStatus + initial_disconnected_at: Optional[datetime] + server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None + + +@dataclass +class _RunningJobStartupContext: + cluster_info: ClusterInfo + volumes: list[Volume] + secrets: dict[str, str] + repo_creds: Optional[RemoteRepoCreds] + + +class _JobUpdateMap(ItemUpdateMap, total=False): + status: JobStatus + termination_reason: Optional[JobTerminationReason] + termination_reason_message: Optional[str] + job_provisioning_data: Optional[str] + job_runtime_data: Optional[str] + runner_timestamp: Optional[int] + disconnected_at: Optional[datetime] + inactivity_secs: Optional[int] + exit_status: Optional[int] + + class JobRunningPipeline(Pipeline[JobRunningPipelineItem]): def __init__( self, @@ -173,4 +281,874 @@ def __init__( @sentry_utils.instrument_named_task("pipeline_tasks.JobRunningWorker.process") async def process(self, item: JobRunningPipelineItem): - raise NotImplementedError("JobRunningWorker.process() is implemented in a later step") + if item.status == JobStatus.RUNNING: + raise NotImplementedError("RUNNING-state migration is implemented in a later step") + + context = await _load_running_job_context(item=item) + if context is None: + log_lock_token_mismatch(logger, item) + return + + await _process_running_job(context=context) + + job_update_map = _build_job_update_map(context.job_model) + set_processed_update_map_fields(job_update_map) + set_unlock_update_map_fields(job_update_map) + await _apply_process_result( + item=item, + job_model=context.job_model, + initial_status=context.initial_status, + initial_disconnected_at=context.initial_disconnected_at, + job_update_map=job_update_map, + ) + + +async def _load_running_job_context(item: JobRunningPipelineItem) -> Optional[_RunningJobContext]: + async with get_session_ctx() as session: + job_model = await _refetch_locked_job_model(session=session, item=item) + if job_model is None: + return None + run_model = await _fetch_run_model(session=session, run_id=job_model.run_id) + run = run_model_to_run(run_model, include_sensitive=True) + job_submission = job_model_to_job_submission(job_model) + server_ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance)) + return _RunningJobContext( + job_model=job_model, + run_model=run_model, + repo_model=run_model.repo, + project=run_model.project, + run=run, + job=find_job(run.jobs, job_model.replica_num, job_model.job_num), + job_submission=job_submission, + job_provisioning_data=job_submission.job_provisioning_data, + initial_status=job_model.status, + initial_disconnected_at=job_model.disconnected_at, + server_ssh_private_keys=server_ssh_private_keys, + ) + + +async def _process_running_job(context: _RunningJobContext) -> None: + if context.job_provisioning_data is None: + logger.error("%s: job_provisioning_data of an active job is None", fmt(context.job_model)) + _terminate_running_job( + job_model=context.job_model, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message=( + "Unexpected server error: job_provisioning_data of an active job is None" + ), + ) + return + + startup_context = None + if context.initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]: + startup_context = await _prepare_running_job_startup_context(context=context) + if startup_context is None: + return + + if context.initial_status == JobStatus.PROVISIONING: + await _process_running_job_provisioning_state( + context=context, + startup_context=get_or_error(startup_context), + ) + elif context.initial_status == JobStatus.PULLING: + await _process_running_job_pulling_state( + context=context, + startup_context=get_or_error(startup_context), + ) + + +async def _prepare_running_job_startup_context( + context: _RunningJobContext, +) -> Optional[_RunningJobStartupContext]: + job_provisioning_data = get_or_error(context.job_provisioning_data) + + for other_job in context.run.jobs: + if ( + other_job.job_spec.replica_num == context.job.job_spec.replica_num + and other_job.job_submissions[-1].status == JobStatus.SUBMITTED + ): + logger.debug( + "%s: waiting for all jobs in the replica to be provisioned", + fmt(context.job_model), + ) + return None + + cluster_info = _get_cluster_info( + jobs=context.run.jobs, + replica_num=context.job.job_spec.replica_num, + job_provisioning_data=job_provisioning_data, + job_runtime_data=context.job_submission.job_runtime_data, + ) + + async with get_session_ctx() as session: + volumes = await get_job_attached_volumes( + session=session, + project=context.project, + run_spec=context.run.run_spec, + job_num=context.job.job_spec.job_num, + job_provisioning_data=job_provisioning_data, + ) + repo_creds_model = await get_repo_creds( + session=session, + repo=context.repo_model, + user=context.run_model.user, + ) + secrets = await get_project_secrets_mapping(session=session, project=context.project) + + repo_creds = repo_model_to_repo_head_with_creds( + context.repo_model, + repo_creds_model, + ).repo_creds + try: + _interpolate_secrets(secrets, context.job.job_spec) + except InterpolatorError as e: + _terminate_running_job( + job_model=context.job_model, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message=f"Secrets interpolation error: {e.args[0]}", + ) + return None + + return _RunningJobStartupContext( + cluster_info=cluster_info, + volumes=volumes, + secrets=secrets, + repo_creds=repo_creds, + ) + + +async def _refetch_locked_job_model( + session: AsyncSession, item: JobRunningPipelineItem +) -> Optional[JobModel]: + res = await session.execute( + select(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + .options(joinedload(JobModel.instance).joinedload(InstanceModel.project)) + .options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak)) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel: + latest_submissions_sq = ( + select( + JobModel.run_id.label("run_id"), + JobModel.replica_num.label("replica_num"), + JobModel.job_num.label("job_num"), + func.max(JobModel.submission_num).label("max_submission_num"), + ) + .where(JobModel.run_id == run_id) + .group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num) + .subquery() + ) + job_alias = aliased(JobModel) + res = await session.execute( + select(RunModel) + .where(RunModel.id == run_id) + .join(job_alias, job_alias.run_id == RunModel.id) + .join( + latest_submissions_sq, + onclause=and_( + job_alias.run_id == latest_submissions_sq.c.run_id, + job_alias.replica_num == latest_submissions_sq.c.replica_num, + job_alias.job_num == latest_submissions_sq.c.job_num, + job_alias.submission_num == latest_submissions_sq.c.max_submission_num, + ), + ) + .options(joinedload(RunModel.project)) + .options(joinedload(RunModel.user)) + .options(joinedload(RunModel.repo)) + .options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name)) + .options(contains_eager(RunModel.jobs, alias=job_alias)) + ) + return res.unique().scalar_one() + + +async def _process_running_job_provisioning_state( + context: _RunningJobContext, + startup_context: _RunningJobStartupContext, +) -> None: + job_provisioning_data = get_or_error(context.job_provisioning_data) + server_ssh_private_keys = get_or_error(context.server_ssh_private_keys) + + if job_provisioning_data.hostname is None: + _wait_for_instance_provisioning_data(context.job_model) + return + if _should_wait_for_other_nodes(context.run, context.job, context.job_model): + return + + success = False + if job_provisioning_data.dockerized: + logger.debug( + "%s: process provisioning job with shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + ssh_user = job_provisioning_data.username + assert context.run.run_spec.ssh_key_pub is not None + user_ssh_key = context.run.run_spec.ssh_key_pub.strip() + public_keys = [context.project.ssh_public_key.strip(), user_ssh_key] + if job_provisioning_data.backend == BackendType.LOCAL: + user_ssh_key = "" + success = await run_async( + _process_provisioning_with_shim, + server_ssh_private_keys, + job_provisioning_data, + None, + run=context.run, + job_model=context.job_model, + jpd=job_provisioning_data, + volumes=startup_context.volumes, + registry_auth=context.job.job_spec.registry_auth, + public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=user_ssh_key, + ) + else: + logger.debug( + "%s: process provisioning job without shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + runner_availability = await run_async( + _get_runner_availability, + server_ssh_private_keys, + job_provisioning_data, + None, + ) + if runner_availability == _RunnerAvailability.AVAILABLE: + file_archives = await _get_job_file_archives( + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, + ) + code = await _get_job_code( + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), + ) + success = await run_async( + _submit_job_to_runner, + server_ssh_private_keys, + job_provisioning_data, + None, + run=context.run, + job_model=context.job_model, + job=context.job, + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, + success_if_not_available=False, + ) + + if success: + return + + provisioning_timeout = get_provisioning_timeout( + backend_type=job_provisioning_data.get_base_backend(), + instance_type_name=job_provisioning_data.instance_type.name, + ) + if context.job_submission.age > provisioning_timeout: + context.job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED + context.job_model.termination_reason_message = ( + f"Runner did not become available within {provisioning_timeout.total_seconds()}s." + f" Job submission age: {context.job_submission.age.total_seconds()}s)" + ) + _switch_job_status(context.job_model, JobStatus.TERMINATING) + + +async def _process_running_job_pulling_state( + context: _RunningJobContext, + startup_context: _RunningJobStartupContext, +) -> None: + job_provisioning_data = get_or_error(context.job_provisioning_data) + server_ssh_private_keys = get_or_error(context.server_ssh_private_keys) + + logger.debug( + "%s: process pulling job with shim, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + shim_state = await run_async( + _sync_shim_pulling_state, + server_ssh_private_keys, + job_provisioning_data, + None, + job_model=context.job_model, + ) + if shim_state == _ShimPullingState.WAITING: + _reset_disconnected_at(context.job_model) + return + + if shim_state == _ShimPullingState.READY: + job_runtime_data = get_job_runtime_data(context.job_model) + runner_availability = await run_async( + _get_runner_availability, + server_ssh_private_keys, + job_provisioning_data, + job_runtime_data, + ) + if runner_availability == _RunnerAvailability.UNAVAILABLE: + _reset_disconnected_at(context.job_model) + return + + if runner_availability == _RunnerAvailability.AVAILABLE: + file_archives = await _get_job_file_archives( + archive_mappings=context.job.job_spec.file_archives, + user=context.run_model.user, + ) + code = await _get_job_code( + project=context.project, + repo=context.repo_model, + code_hash=_get_repo_code_hash(context.run, context.job), + ) + success = await run_async( + _submit_job_to_runner, + server_ssh_private_keys, + job_provisioning_data, + job_runtime_data, + run=context.run, + job_model=context.job_model, + job=context.job, + cluster_info=startup_context.cluster_info, + code=code, + file_archives=file_archives, + secrets=startup_context.secrets, + repo_credentials=startup_context.repo_creds, + success_if_not_available=True, + ) + if success: + _reset_disconnected_at(context.job_model) + return + + if context.job_model.termination_reason: + logger.warning( + "%s: failed due to %s, age=%s", + fmt(context.job_model), + context.job_model.termination_reason.value, + context.job_submission.age, + ) + _switch_job_status(context.job_model, JobStatus.TERMINATING) + return + + _set_disconnected_at_now(context.job_model) + if not _should_terminate_job_due_to_disconnect(context.job_model): + logger.warning( + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + return + + if job_provisioning_data.instance_type.resources.spot: + context.job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + else: + context.job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + context.job_model.termination_reason_message = "Instance is unreachable" + _switch_job_status(context.job_model, JobStatus.TERMINATING) + + +def _build_job_update_map(job_model: JobModel) -> _JobUpdateMap: + return _JobUpdateMap( + status=job_model.status, + termination_reason=job_model.termination_reason, + termination_reason_message=job_model.termination_reason_message, + job_provisioning_data=job_model.job_provisioning_data, + job_runtime_data=job_model.job_runtime_data, + runner_timestamp=job_model.runner_timestamp, + disconnected_at=job_model.disconnected_at, + inactivity_secs=job_model.inactivity_secs, + exit_status=job_model.exit_status, + ) + + +async def _apply_process_result( + item: JobRunningPipelineItem, + job_model: JobModel, + initial_status: JobStatus, + initial_disconnected_at: Optional[datetime], + job_update_map: _JobUpdateMap, +) -> None: + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(job_update_map, now=now) + res = await session.execute( + update(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + .values(**job_update_map) + .returning(JobModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + return + + emit_job_status_change_event( + session=session, + job_model=job_model, + old_status=initial_status, + new_status=job_update_map.get("status", initial_status), + termination_reason=job_update_map.get( + "termination_reason", job_model.termination_reason + ), + termination_reason_message=job_update_map.get( + "termination_reason_message", + job_model.termination_reason_message, + ), + ) + _emit_reachability_change_event( + session=session, + job_model=job_model, + initial_disconnected_at=initial_disconnected_at, + new_disconnected_at=job_update_map.get("disconnected_at", initial_disconnected_at), + ) + + +def _emit_reachability_change_event( + session: AsyncSession, + job_model: JobModel, + initial_disconnected_at: Optional[datetime], + new_disconnected_at: Optional[datetime], +) -> None: + if initial_disconnected_at is None and new_disconnected_at is not None: + events.emit( + session, + "Job became unreachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(job_model)], + ) + elif initial_disconnected_at is not None and new_disconnected_at is None: + events.emit( + session, + "Job became reachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(job_model)], + ) + + +def _terminate_running_job( + job_model: JobModel, + termination_reason: JobTerminationReason, + termination_reason_message: str, +) -> None: + job_model.termination_reason = termination_reason + job_model.termination_reason_message = termination_reason_message + _switch_job_status(job_model, JobStatus.TERMINATING) + + +def _wait_for_instance_provisioning_data(job_model: JobModel) -> None: + if job_model.instance is None: + logger.error( + "%s: cannot update job_provisioning_data. job_model.instance is None.", + fmt(job_model), + ) + return + if job_model.instance.job_provisioning_data is None: + logger.error( + "%s: cannot update job_provisioning_data. job_model.job_provisioning_data is None.", + fmt(job_model), + ) + return + + if job_model.instance.status == InstanceStatus.TERMINATED: + job_model.termination_reason = JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED + job_model.termination_reason_message = "Instance is terminated" + _switch_job_status(job_model, JobStatus.TERMINATING) + return + + job_model.job_provisioning_data = job_model.instance.job_provisioning_data + + +def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool: + for other_job in run.jobs: + if ( + other_job.job_spec.replica_num == job.job_spec.replica_num + and other_job.job_submissions[-1].status == JobStatus.PROVISIONING + and other_job.job_submissions[-1].job_provisioning_data is not None + and other_job.job_submissions[-1].job_provisioning_data.hostname is None + ): + logger.debug("%s: waiting for other job to have IP assigned", fmt(job_model)) + return True + master_job = find_job(run.jobs, job.job_spec.replica_num, 0) + if ( + job.job_spec.job_num != 0 + and run.run_spec.merged_profile.startup_order == StartupOrder.MASTER_FIRST + and master_job.job_submissions[-1].status != JobStatus.RUNNING + ): + logger.debug("%s: waiting for master job to become running", fmt(job_model)) + return True + if ( + is_master_job(job) + and run.run_spec.merged_profile.startup_order == StartupOrder.WORKERS_FIRST + ): + for other_job in run.jobs: + if ( + other_job.job_spec.replica_num == job.job_spec.replica_num + and other_job.job_spec.job_num != job.job_spec.job_num + and other_job.job_submissions[-1].status != JobStatus.RUNNING + ): + logger.debug("%s: waiting for worker job to become running", fmt(job_model)) + return True + return False + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT], retries=1) +def _process_provisioning_with_shim( + ports: Dict[int, int], + run: Run, + job_model: JobModel, + jpd: JobProvisioningData, + volumes: list[Volume], + registry_auth: Optional[RegistryAuth], + public_keys: list[str], + ssh_user: str, + ssh_key: str, +) -> bool: + job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + + resp = shim_client.healthcheck() + if resp is None: + logger.debug("%s: shim is not available yet", fmt(job_model)) + return False + + registry_username = "" + registry_password = "" + if registry_auth is not None: + registry_username = registry_auth.username + registry_password = registry_auth.password + + volume_mounts: list[VolumeMountPoint] = [] + instance_mounts: list[InstanceMountPoint] = [] + for mount in run.run_spec.configuration.volumes: + if isinstance(mount, VolumeMountPoint): + volume_mounts.append(mount.copy()) + elif isinstance(mount, InstanceMountPoint): + instance_mounts.append(mount) + else: + assert False, f"unexpected mount point: {mount!r}" + + for volume, volume_mount in zip(volumes, volume_mounts): + volume_mount.name = volume.name + + instance_mounts += get_instance_specific_mounts(jpd.backend, jpd.instance_type.name) + gpu_devices = get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name) + + container_user = "root" + job_runtime_data = get_job_runtime_data(job_model) + if job_runtime_data is not None: + gpu = job_runtime_data.gpu + cpu = job_runtime_data.cpu + memory = job_runtime_data.memory + network_mode = job_runtime_data.network_mode + else: + gpu = None + cpu = None + memory = None + network_mode = NetworkMode.HOST + image_name = resolve_provisioning_image_name(job_spec, jpd) + if shim_client.is_api_v2_supported(): + shim_client.submit_task( + task_id=job_model.id, + name=job_model.job_name, + registry_username=registry_username, + registry_password=registry_password, + image_name=image_name, + container_user=container_user, + privileged=job_spec.privileged, + gpu=gpu, + cpu=cpu, + memory=memory, + shm_size=job_spec.requirements.resources.shm_size, + network_mode=network_mode, + volumes=volumes, + volume_mounts=volume_mounts, + instance_mounts=instance_mounts, + gpu_devices=gpu_devices, + host_ssh_user=ssh_user, + host_ssh_keys=[ssh_key] if ssh_key else [], + container_ssh_keys=public_keys, + instance_id=jpd.instance_id, + ) + else: + submitted = shim_client.submit( + username=registry_username, + password=registry_password, + image_name=image_name, + privileged=job_spec.privileged, + container_name=job_model.job_name, + container_user=container_user, + shm_size=job_spec.requirements.resources.shm_size, + public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=ssh_key, + mounts=volume_mounts, + volumes=volumes, + instance_mounts=instance_mounts, + instance_id=jpd.instance_id, + ) + if not submitted: + logger.warning( + "%s: failed to submit, shim is already running a job, stopping it now, retry later", + fmt(job_model), + ) + shim_client.stop(force=True) + return False + + _switch_job_status(job_model, JobStatus.PULLING) + return True + + +class _RunnerAvailability(enum.Enum): + AVAILABLE = "available" + UNAVAILABLE = "unavailable" + + +class _ShimPullingState(enum.Enum): + WAITING = "waiting" + READY = "ready" + + +@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability: + runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + if runner_client.healthcheck() is None: + return _RunnerAvailability.UNAVAILABLE + return _RunnerAvailability.AVAILABLE + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) +def _sync_shim_pulling_state( + ports: Dict[int, int], + job_model: JobModel, +) -> Union[Literal[False], _ShimPullingState]: + shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + if shim_client.is_api_v2_supported(): + task = shim_client.get_task(job_model.id) + if task.status == TaskStatus.TERMINATED: + logger.warning( + "shim failed to execute job %s: %s (%s)", + job_model.job_name, + task.termination_reason, + task.termination_message, + ) + logger.debug("task status: %s", task.dict()) + job_model.termination_reason = JobTerminationReason(task.termination_reason.lower()) + job_model.termination_reason_message = task.termination_message + return False + + if task.status != TaskStatus.RUNNING: + return _ShimPullingState.WAITING + + job_runtime_data = get_job_runtime_data(job_model) + if job_runtime_data is not None: + if task.ports is None: + return _ShimPullingState.WAITING + job_runtime_data.ports = {pm.container: pm.host for pm in task.ports} + job_model.job_runtime_data = job_runtime_data.json() + else: + shim_status = shim_client.pull() + if ( + shim_status.state == "pending" + and shim_status.result is not None + and shim_status.result.reason != "" + ): + logger.warning( + "shim failed to execute job %s: %s (%s)", + job_model.job_name, + shim_status.result.reason, + shim_status.result.reason_message, + ) + logger.debug("shim status: %s", shim_status.dict()) + job_model.termination_reason = JobTerminationReason(shim_status.result.reason.lower()) + job_model.termination_reason_message = shim_status.result.reason_message + return False + + if shim_status.state in ("pulling", "creating"): + return _ShimPullingState.WAITING + + return _ShimPullingState.READY + + +def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool: + if job_model.disconnected_at is None: + return False + return get_current_datetime() > job_model.disconnected_at + JOB_DISCONNECTED_RETRY_TIMEOUT + + +def _set_disconnected_at_now(job_model: JobModel) -> None: + if job_model.disconnected_at is None: + job_model.disconnected_at = get_current_datetime() + + +def _reset_disconnected_at(job_model: JobModel) -> None: + if job_model.disconnected_at is not None: + job_model.disconnected_at = None + + +def _get_cluster_info( + jobs: list[Job], + replica_num: int, + job_provisioning_data: JobProvisioningData, + job_runtime_data: Optional[JobRuntimeData], +) -> ClusterInfo: + job_ips = [] + for job in jobs: + if job.job_spec.replica_num == replica_num: + job_ips.append( + get_or_error(job.job_submissions[-1].job_provisioning_data).internal_ip or "" + ) + gpus_per_job = len(job_provisioning_data.instance_type.resources.gpus) + if job_runtime_data is not None and job_runtime_data.offer is not None: + gpus_per_job = len(job_runtime_data.offer.instance.resources.gpus) + return ClusterInfo( + job_ips=job_ips, + master_job_ip=job_ips[0], + gpus_per_job=gpus_per_job, + ) + + +def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]: + if ( + job.job_spec.repo_code_hash is None + and run.run_spec.repo_code_hash is not None + and job.job_submissions[-1].deployment_num == run.deployment_num + ): + return run.run_spec.repo_code_hash + return job.job_spec.repo_code_hash + + +async def _get_job_code(project: ProjectModel, repo: RepoModel, code_hash: Optional[str]) -> bytes: + if code_hash is None: + return b"" + async with get_session_ctx() as session: + code_model = await get_code_model(session=session, repo=repo, code_hash=code_hash) + if code_model is None: + return b"" + if code_model.blob is not None: + return code_model.blob + storage = get_default_storage() + if storage is None: + return b"" + blob = await run_async( + storage.get_code, + project.name, + repo.name, + code_hash, + ) + if blob is None: + logger.error( + "Failed to get repo code hash %s from storage for repo %s", code_hash, repo.name + ) + return b"" + return blob + + +async def _get_job_file_archives( + archive_mappings: Iterable[FileArchiveMapping], + user: UserModel, +) -> list[tuple[uuid.UUID, bytes]]: + archives: list[tuple[uuid.UUID, bytes]] = [] + for archive_mapping in archive_mappings: + archive_blob = await _get_job_file_archive(archive_id=archive_mapping.id, user=user) + archives.append((archive_mapping.id, archive_blob)) + return archives + + +async def _get_job_file_archive(archive_id: uuid.UUID, user: UserModel) -> bytes: + async with get_session_ctx() as session: + archive_model = await files_services.get_archive_model(session, id=archive_id, user=user) + if archive_model is None: + return b"" + if archive_model.blob is not None: + return archive_model.blob + storage = get_default_storage() + if storage is None: + return b"" + blob = await run_async( + storage.get_archive, + str(archive_model.user_id), + archive_model.blob_hash, + ) + if blob is None: + logger.error("Failed to get file archive %s from storage", archive_id) + return b"" + return blob + + +@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +def _submit_job_to_runner( + ports: Dict[int, int], + run: Run, + job_model: JobModel, + job: Job, + cluster_info: ClusterInfo, + code: bytes, + file_archives: Iterable[tuple[uuid.UUID, bytes]], + secrets: Dict[str, str], + repo_credentials: Optional[RemoteRepoCreds], + success_if_not_available: bool, +) -> bool: + logger.debug("%s: submitting job spec", fmt(job_model)) + logger.debug( + "%s: repo clone URL is %s", + fmt(job_model), + None if repo_credentials is None else repo_credentials.clone_url, + ) + instance = job_model.instance + if instance is not None and (rci := get_instance_remote_connection_info(instance)) is not None: + instance_env = rci.env + else: + instance_env = None + + runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + if runner_client.healthcheck() is None: + return success_if_not_available + + runner_client.submit_job( + run=run, + job=job, + cluster_info=cluster_info, + secrets={}, + repo_credentials=repo_credentials, + instance_env=instance_env, + ) + logger.debug("%s: uploading file archive(s)", fmt(job_model)) + for archive_id, archive in file_archives: + runner_client.upload_archive(archive_id, archive) + logger.debug("%s: uploading code", fmt(job_model)) + runner_client.upload_code(code) + logger.debug("%s: starting job", fmt(job_model)) + job_info = runner_client.run_job() + if job_info is not None: + job_runtime_data = get_job_runtime_data(job_model) + if job_runtime_data is not None: + job_runtime_data.working_dir = job_info.working_dir + job_runtime_data.username = job_info.username + job_model.job_runtime_data = job_runtime_data.json() + + _switch_job_status(job_model, JobStatus.RUNNING) + return True + + +def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: + interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error + job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()} + if job_spec.registry_auth is not None: + job_spec.registry_auth = RegistryAuth( + username=interpolate(job_spec.registry_auth.username), + password=interpolate(job_spec.registry_auth.password), + ) + + +def _switch_job_status(job_model: JobModel, new_status: JobStatus) -> None: + if job_model.status != new_status: + job_model.status = new_status diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index ac2f88a5d..7bb5f0713 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -209,6 +209,7 @@ async def list_project_volume_models( select(VolumeModel) .where(*filters) .options(joinedload(VolumeModel.user)) + .options(joinedload(VolumeModel.project)) .options( joinedload(VolumeModel.attachments) .joinedload(VolumeAttachmentModel.instance) @@ -245,6 +246,7 @@ async def get_project_volume_model_by_name( select(VolumeModel) .where(*filters) .options(joinedload(VolumeModel.user)) + .options(joinedload(VolumeModel.project)) .options( joinedload(VolumeModel.attachments) .joinedload(VolumeAttachmentModel.instance) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index ea362f9c3..f50c60e5f 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -1,22 +1,51 @@ import asyncio import uuid -from datetime import timedelta -from unittest.mock import Mock +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest +from freezegun import freeze_time from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.runs import JobStatus, RunStatus +from dstack._internal import settings +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import NetworkMode +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.profiles import StartupOrder +from dstack._internal.core.models.runs import ( + JobRuntimeData, + JobStatus, + JobTerminationReason, + RunStatus, +) +from dstack._internal.core.models.volumes import InstanceMountPoint, VolumeMountPoint, VolumeStatus from dstack._internal.server.background.pipeline_tasks.jobs_running import ( JobRunningFetcher, JobRunningPipeline, + JobRunningPipelineItem, + JobRunningWorker, + _RunnerAvailability, +) +from dstack._internal.server.schemas.runner import ( + HealthcheckResponse, + JobInfoResponse, + PortMapping, + TaskStatus, ) +from dstack._internal.server.services.volumes import volume_model_to_volume from dstack._internal.server.testing.common import ( + create_instance, create_job, create_project, create_repo, create_run, create_user, + create_volume, + get_job_provisioning_data, + get_job_runtime_data, + get_run_spec, + get_volume_configuration, + list_events, ) from dstack._internal.utils.common import get_current_datetime @@ -34,18 +63,81 @@ def fetcher() -> JobRunningFetcher: ) -def _lock_job_foreign(job_model): +@pytest.fixture +def worker() -> JobRunningWorker: + return JobRunningWorker(queue=Mock(), heartbeater=Mock()) + + +@pytest.fixture +def ssh_tunnel_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = MagicMock() + monkeypatch.setattr("dstack._internal.server.services.runner.ssh.SSHTunnel", mock) + return mock + + +@pytest.fixture +def shim_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock() + mock.healthcheck.return_value = HealthcheckResponse(service="dstack-shim", version="latest") + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.ShimClient", Mock(return_value=mock) + ) + return mock + + +@pytest.fixture +def runner_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: + mock = Mock() + mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-runner", version="0.0.1.dev2" + ) + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.RunnerClient", Mock(return_value=mock) + ) + return mock + + +def _lock_job_foreign(job_model) -> None: job_model.lock_expires_at = get_current_datetime() + timedelta(minutes=1) job_model.lock_token = uuid.uuid4() job_model.lock_owner = "OtherPipeline" -def _lock_job_expired_same_owner(job_model): +def _lock_job_expired_same_owner(job_model) -> None: job_model.lock_expires_at = get_current_datetime() - timedelta(minutes=1) job_model.lock_token = uuid.uuid4() job_model.lock_owner = JobRunningPipeline.__name__ +def _lock_job(job_model) -> None: + job_model.lock_expires_at = get_current_datetime() + timedelta(seconds=30) + job_model.lock_token = uuid.uuid4() + job_model.lock_owner = JobRunningPipeline.__name__ + + +def _job_to_pipeline_item(job_model) -> JobRunningPipelineItem: + assert job_model.lock_token is not None + assert job_model.lock_expires_at is not None + return JobRunningPipelineItem( + __tablename__=job_model.__tablename__, + id=job_model.id, + lock_token=job_model.lock_token, + lock_expires_at=job_model.lock_expires_at, + prev_lock_expired=False, + status=job_model.status, + ) + + +async def _process_job( + session: AsyncSession, + worker: JobRunningWorker, + job_model, +) -> None: + _lock_job(job_model) + await session.commit() + await worker.process(_job_to_pipeline_item(job_model)) + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestJobRunningFetcher: @@ -222,3 +314,690 @@ async def test_fetch_returns_oldest_jobs_first_up_to_limit( assert oldest.lock_owner == JobRunningPipeline.__name__ assert middle.lock_owner == JobRunningPipeline.__name__ assert newest.lock_owner is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestJobRunningWorker: + async def test_process_skips_when_lock_token_changes( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=False), + instance=instance, + instance_assigned=True, + ) + _lock_job(job) + await session.commit() + + item = _job_to_pipeline_item(job) + new_lock_token = uuid.uuid4() + job.lock_token = new_lock_token + await session.commit() + + await worker.process(item) + await session.refresh(job) + + assert job.lock_token == new_lock_token + assert job.status == JobStatus.PROVISIONING + assert job.lock_owner == JobRunningPipeline.__name__ + + async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=False), + instance=instance, + instance_assigned=True, + ) + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.healthcheck.return_value = None + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + runner_client_mock.healthcheck.assert_called_once() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() + + await session.refresh(job) + assert job.status == JobStatus.PROVISIONING + assert job.lock_token is None + assert job.lock_owner is None + + async def test_runs_provisioning_job( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=False), + job_runtime_data=get_job_runtime_data(), + instance=instance, + instance_assigned=True, + ) + before_processed_at = job.last_processed_at + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-runner", version="0.0.1.dev2" + ) + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 2 + assert runner_client_mock.healthcheck.call_count == 2 + runner_client_mock.submit_job.assert_called_once() + runner_client_mock.upload_code.assert_called_once() + runner_client_mock.run_job.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner is None + assert job.last_processed_at > before_processed_at + job_runtime_data = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert job_runtime_data.working_dir == "/dstack/run" + assert job_runtime_data.username == "dstack" + + @pytest.mark.parametrize("privileged", [False, True]) + async def test_provisioning_shim_with_volumes( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + privileged: bool, + ): + project_ssh_pub_key = "__project_ssh_pub_key__" + project = await create_project(session=session, ssh_public_key=project_ssh_pub_key) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + configuration=get_volume_configuration( + name="my-vol", backend=BackendType.AWS, region="us-east-1" + ), + backend=BackendType.AWS, + region="us-east-1", + ) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run_spec.configuration.privileged = privileged + run_spec.configuration.volumes = [ + VolumeMountPoint(name="my-vol", path="/volume"), + InstanceMountPoint(instance_path="/root/.cache", path="/cache"), + ] + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job_provisioning_data = get_job_provisioning_data(dockerized=True) + + with patch( + "dstack._internal.server.services.jobs.configurators.base.get_default_python_verison" + ) as py_version: + py_version.return_value = "3.13" + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, + ) + + await _process_job(session, worker, job) + + ssh_tunnel_mock.assert_called_once() + shim_client_mock.healthcheck.assert_called_once() + shim_client_mock.submit_task.assert_called_once_with( + task_id=job.id, + name="test-run-0-0", + registry_username="", + registry_password="", + image_name=( + f"dstackai/base:{settings.DSTACK_BASE_IMAGE_VERSION}-" + f"base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ), + container_user="root", + privileged=privileged, + gpu=None, + cpu=None, + memory=None, + shm_size=None, + network_mode=NetworkMode.HOST, + volumes=[volume_model_to_volume(volume)], + volume_mounts=[VolumeMountPoint(name="my-vol", path="/volume")], + instance_mounts=[InstanceMountPoint(instance_path="/root/.cache", path="/cache")], + gpu_devices=[], + host_ssh_user="ubuntu", + host_ssh_keys=["user_ssh_key"], + container_ssh_keys=[project_ssh_pub_key, "user_ssh_key"], + instance_id=job_provisioning_data.instance_id, + ) + await session.refresh(job) + assert job.status == JobStatus.PULLING + + async def test_pulling_shim( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) + + await _process_job(session, worker, job) + + assert ssh_tunnel_mock.call_count == 3 + shim_client_mock.get_task.assert_called_once() + assert runner_client_mock.healthcheck.call_count == 2 + runner_client_mock.submit_job.assert_called_once() + runner_client_mock.upload_code.assert_called_once() + runner_client_mock.run_job.assert_called_once() + await session.refresh(job) + assert job.status == JobStatus.RUNNING + job_runtime_data = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert job_runtime_data.ports == {10022: 32771, 10999: 32772} + assert job_runtime_data.working_dir == "/dstack/run" + assert job_runtime_data.username == "dstack" + + async def test_pulling_shim_port_mapping_not_ready( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = None + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + await _process_job(session, worker, job) + ssh_tunnel_mock.assert_called_once() + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_not_called() + runner_client_mock.submit_job.assert_not_called() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() + + await session.refresh(job) + assert job.status == JobStatus.PULLING + + async def test_pulling_shim_runner_not_ready( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + runner_client_mock.healthcheck.return_value = None + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + ) as get_job_file_archives_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + ) as get_job_code_mock, + ): + await _process_job(session, worker, job) + assert ssh_tunnel_mock.call_count == 2 + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_called_once() + runner_client_mock.submit_job.assert_not_called() + get_job_file_archives_mock.assert_not_awaited() + get_job_code_mock.assert_not_awaited() + + await session.refresh(job) + assert job.status == JobStatus.PULLING + + async def test_pulling_shim_uses_runtime_port_mapping_for_runner_calls( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = [ + PortMapping(container=10022, host=32771), + PortMapping(container=10999, host=32772), + ] + expected_ports = {10022: 32771, 10999: 32772} + + def assert_runner_availability(_, __, job_runtime_data): + assert job_runtime_data is not None + assert job_runtime_data.ports == expected_ports + return _RunnerAvailability.AVAILABLE + + def assert_submit_job_to_runner(_, __, job_runtime_data, **kwargs): + assert job_runtime_data is not None + assert job_runtime_data.ports == expected_ports + return True + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_runner_availability", + side_effect=assert_runner_availability, + ) as get_runner_availability_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._submit_job_to_runner", + side_effect=assert_submit_job_to_runner, + ) as submit_job_to_runner_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + return_value=b"", + ), + ): + await _process_job(session, worker, job) + ssh_tunnel_mock.assert_called_once() + get_runner_availability_mock.assert_called_once() + submit_job_to_runner_mock.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.PULLING + job_runtime_data = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) + assert job_runtime_data.ports == expected_ports + + async def test_pulling_shim_failed( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.IDLE + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + ) + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.ssh.time.sleep"), + ): + from dstack._internal.core.errors import SSHError + + ssh_tunnel_cls.side_effect = SSHError + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 3 + + await session.refresh(job) + events = await list_events(session) + assert job.disconnected_at is not None + assert job.status == JobStatus.PULLING + assert len(events) == 1 + assert events[0].message == "Job became unreachable" + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.ssh.time.sleep"), + freeze_time(job.disconnected_at + timedelta(minutes=5)), + ): + from dstack._internal.core.errors import SSHError + + ssh_tunnel_cls.side_effect = SSHError + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 3 + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.INSTANCE_UNREACHABLE + assert job.remove_at is None + + async def test_provisioning_shim_force_stop_if_already_running_api_v1( + self, + monkeypatch: pytest.MonkeyPatch, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run_spec.configuration.image = "debian" + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + monkeypatch.setattr( + "dstack._internal.server.services.runner.ssh.SSHTunnel", Mock(return_value=MagicMock()) + ) + shim_client_mock = Mock() + monkeypatch.setattr( + "dstack._internal.server.services.runner.client.ShimClient", + Mock(return_value=shim_client_mock), + ) + shim_client_mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-shim", version="0.0.1.dev2" + ) + shim_client_mock.is_api_v2_supported.return_value = False + shim_client_mock.submit.return_value = False + + await _process_job(session, worker, job) + + shim_client_mock.healthcheck.assert_called_once() + shim_client_mock.submit.assert_called_once() + shim_client_mock.stop.assert_called_once_with(force=True) + await session.refresh(job) + assert job.status == JobStatus.PROVISIONING + + async def test_master_job_waits_for_workers( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run_spec.configuration.startup_order = StartupOrder.WORKERS_FIRST + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + ) + instance1 = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + instance2 = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + master_job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=job_provisioning_data, + instance_assigned=True, + instance=instance1, + job_num=0, + last_processed_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ) + worker_job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=job_provisioning_data, + instance_assigned=True, + instance=instance2, + job_num=1, + last_processed_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + ) + + await _process_job(session, worker, master_job) + await session.refresh(master_job) + assert master_job.status == JobStatus.PROVISIONING + + worker_job.status = JobStatus.RUNNING + master_job.last_processed_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel"), + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.healthcheck.return_value = HealthcheckResponse( + service="dstack-runner", version="0.0.1.dev2" + ) + await _process_job(session, worker, master_job) + + await session.refresh(master_job) + assert master_job.status == JobStatus.RUNNING + + async def test_apply_skips_when_lock_token_changes_after_processing( + self, test_db, session: AsyncSession, worker: JobRunningWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=get_job_provisioning_data(dockerized=False), + job_runtime_data=get_job_runtime_data(), + instance=instance, + instance_assigned=True, + ) + _lock_job(job) + await session.commit() + original_lock_token = job.lock_token + replacement_lock_token = uuid.uuid4() + + async def invalidate_lock(*args, **kwargs): + job.lock_token = replacement_lock_token + await session.commit() + return b"" + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_runner_availability", + return_value=_RunnerAvailability.AVAILABLE, + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + side_effect=invalidate_lock, + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._submit_job_to_runner", + return_value=True, + ), + ): + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.PROVISIONING + assert job.lock_token == replacement_lock_token + assert job.lock_token != original_lock_token From 6ae8dfcfd1482fcf09c3ebff976e9d45f01cac10 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 11 Mar 2026 17:45:39 +0500 Subject: [PATCH 03/21] Treat job_model as read-on;y --- .../background/pipeline_tasks/jobs_running.py | 433 +++++++++++------- .../pipeline_tasks/test_running_jobs.py | 47 +- 2 files changed, 311 insertions(+), 169 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 35643a861..927b41e31 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -1,7 +1,7 @@ import asyncio import enum import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Dict, Iterable, Literal, Optional, Sequence, Union @@ -103,41 +103,6 @@ class JobRunningPipelineItem(PipelineItem): status: JobStatus -@dataclass -class _RunningJobContext: - job_model: JobModel - run_model: RunModel - repo_model: RepoModel - project: ProjectModel - run: Run - job: Job - job_submission: JobSubmission - job_provisioning_data: Optional[JobProvisioningData] - initial_status: JobStatus - initial_disconnected_at: Optional[datetime] - server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None - - -@dataclass -class _RunningJobStartupContext: - cluster_info: ClusterInfo - volumes: list[Volume] - secrets: dict[str, str] - repo_creds: Optional[RemoteRepoCreds] - - -class _JobUpdateMap(ItemUpdateMap, total=False): - status: JobStatus - termination_reason: Optional[JobTerminationReason] - termination_reason_message: Optional[str] - job_provisioning_data: Optional[str] - job_runtime_data: Optional[str] - runner_timestamp: Optional[int] - disconnected_at: Optional[datetime] - inactivity_secs: Optional[int] - exit_status: Optional[int] - - class JobRunningPipeline(Pipeline[JobRunningPipelineItem]): def __init__( self, @@ -289,20 +254,29 @@ async def process(self, item: JobRunningPipelineItem): log_lock_token_mismatch(logger, item) return - await _process_running_job(context=context) - - job_update_map = _build_job_update_map(context.job_model) - set_processed_update_map_fields(job_update_map) - set_unlock_update_map_fields(job_update_map) + result = await _process_running_job(context=context) + set_processed_update_map_fields(result.job_update_map) + set_unlock_update_map_fields(result.job_update_map) await _apply_process_result( item=item, job_model=context.job_model, - initial_status=context.initial_status, - initial_disconnected_at=context.initial_disconnected_at, - job_update_map=job_update_map, + result=result, ) +@dataclass +class _RunningJobContext: + job_model: JobModel + run_model: RunModel + repo_model: RepoModel + project: ProjectModel + run: Run + job: Job + job_submission: JobSubmission + job_provisioning_data: Optional[JobProvisioningData] + server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None + + async def _load_running_job_context(item: JobRunningPipelineItem) -> Optional[_RunningJobContext]: async with get_session_ctx() as session: job_model = await _refetch_locked_job_model(session=session, item=item) @@ -321,44 +295,76 @@ async def _load_running_job_context(item: JobRunningPipelineItem) -> Optional[_R job=find_job(run.jobs, job_model.replica_num, job_model.job_num), job_submission=job_submission, job_provisioning_data=job_submission.job_provisioning_data, - initial_status=job_model.status, - initial_disconnected_at=job_model.disconnected_at, server_ssh_private_keys=server_ssh_private_keys, ) -async def _process_running_job(context: _RunningJobContext) -> None: +class _JobUpdateMap(ItemUpdateMap, total=False): + status: JobStatus + termination_reason: Optional[JobTerminationReason] + termination_reason_message: Optional[str] + job_provisioning_data: Optional[str] + job_runtime_data: Optional[str] + runner_timestamp: Optional[int] + disconnected_at: Optional[datetime] + inactivity_secs: Optional[int] + exit_status: Optional[int] + + +@dataclass +class _ProcessResult: + job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) + + +async def _process_running_job(context: _RunningJobContext) -> _ProcessResult: + result = _ProcessResult() if context.job_provisioning_data is None: logger.error("%s: job_provisioning_data of an active job is None", fmt(context.job_model)) _terminate_running_job( job_model=context.job_model, + result=result, termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, termination_reason_message=( "Unexpected server error: job_provisioning_data of an active job is None" ), ) - return + return result startup_context = None - if context.initial_status in [JobStatus.PROVISIONING, JobStatus.PULLING]: - startup_context = await _prepare_running_job_startup_context(context=context) + if context.job_model.status in [JobStatus.PROVISIONING, JobStatus.PULLING]: + startup_context = await _prepare_running_job_startup_context( + context=context, + result=result, + ) if startup_context is None: - return + return result - if context.initial_status == JobStatus.PROVISIONING: + if context.job_model.status == JobStatus.PROVISIONING: await _process_running_job_provisioning_state( context=context, startup_context=get_or_error(startup_context), + result=result, ) - elif context.initial_status == JobStatus.PULLING: + elif context.job_model.status == JobStatus.PULLING: await _process_running_job_pulling_state( context=context, startup_context=get_or_error(startup_context), + result=result, ) + return result + + +@dataclass +class _RunningJobStartupContext: + cluster_info: ClusterInfo + volumes: list[Volume] + secrets: dict[str, str] + repo_creds: Optional[RemoteRepoCreds] async def _prepare_running_job_startup_context( context: _RunningJobContext, + result: _ProcessResult, ) -> Optional[_RunningJobStartupContext]: job_provisioning_data = get_or_error(context.job_provisioning_data) @@ -404,6 +410,7 @@ async def _prepare_running_job_startup_context( except InterpolatorError as e: _terminate_running_job( job_model=context.job_model, + result=result, termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, termination_reason_message=f"Secrets interpolation error: {e.args[0]}", ) @@ -471,17 +478,17 @@ async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel async def _process_running_job_provisioning_state( context: _RunningJobContext, startup_context: _RunningJobStartupContext, + result: _ProcessResult, ) -> None: job_provisioning_data = get_or_error(context.job_provisioning_data) server_ssh_private_keys = get_or_error(context.server_ssh_private_keys) if job_provisioning_data.hostname is None: - _wait_for_instance_provisioning_data(context.job_model) + _wait_for_instance_provisioning_data(context.job_model, result) return if _should_wait_for_other_nodes(context.run, context.job, context.job_model): return - success = False if job_provisioning_data.dockerized: logger.debug( "%s: process provisioning job with shim, age=%s", @@ -501,6 +508,7 @@ async def _process_running_job_provisioning_state( None, run=context.run, job_model=context.job_model, + jrd=get_job_runtime_data(context.job_model), jpd=job_provisioning_data, volumes=startup_context.volumes, registry_auth=context.job.job_spec.registry_auth, @@ -508,6 +516,9 @@ async def _process_running_job_provisioning_state( ssh_user=ssh_user, ssh_key=user_ssh_key, ) + if success: + _set_job_status(context.job_model, result, JobStatus.PULLING) + return else: logger.debug( "%s: process provisioning job without shim, age=%s", @@ -530,7 +541,7 @@ async def _process_running_job_provisioning_state( repo=context.repo_model, code_hash=_get_repo_code_hash(context.run, context.job), ) - success = await run_async( + submit_result = await run_async( _submit_job_to_runner, server_ssh_private_keys, job_provisioning_data, @@ -538,6 +549,7 @@ async def _process_running_job_provisioning_state( run=context.run, job_model=context.job_model, job=context.job, + jrd=get_job_runtime_data(context.job_model), cluster_info=startup_context.cluster_info, code=code, file_archives=file_archives, @@ -545,26 +557,35 @@ async def _process_running_job_provisioning_state( repo_credentials=startup_context.repo_creds, success_if_not_available=False, ) - - if success: - return + if submit_result is not False: + _apply_submit_job_to_runner_result( + job_model=context.job_model, + result=result, + submit_result=submit_result, + ) + if submit_result is not False and submit_result.success: + return provisioning_timeout = get_provisioning_timeout( backend_type=job_provisioning_data.get_base_backend(), instance_type_name=job_provisioning_data.instance_type.name, ) if context.job_submission.age > provisioning_timeout: - context.job_model.termination_reason = JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED - context.job_model.termination_reason_message = ( - f"Runner did not become available within {provisioning_timeout.total_seconds()}s." - f" Job submission age: {context.job_submission.age.total_seconds()}s)" + _terminate_running_job( + job_model=context.job_model, + result=result, + termination_reason=JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED, + termination_reason_message=( + f"Runner did not become available within {provisioning_timeout.total_seconds()}s." + f" Job submission age: {context.job_submission.age.total_seconds()}s)" + ), ) - _switch_job_status(context.job_model, JobStatus.TERMINATING) async def _process_running_job_pulling_state( context: _RunningJobContext, startup_context: _RunningJobStartupContext, + result: _ProcessResult, ) -> None: job_provisioning_data = get_or_error(context.job_provisioning_data) server_ssh_private_keys = get_or_error(context.server_ssh_private_keys) @@ -580,13 +601,34 @@ async def _process_running_job_pulling_state( job_provisioning_data, None, job_model=context.job_model, + jrd=_get_result_job_runtime_data(context.job_model, result), ) - if shim_state == _ShimPullingState.WAITING: - _reset_disconnected_at(context.job_model) + if shim_state is False: + shim_state = None + elif shim_state.job_runtime_data is not None: + _set_job_runtime_data(result, shim_state.job_runtime_data) + + if shim_state is not None and shim_state.state == _ShimPullingState.WAITING: + _reset_disconnected_at(context.job_model, result) return - if shim_state == _ShimPullingState.READY: - job_runtime_data = get_job_runtime_data(context.job_model) + if shim_state is not None and shim_state.state == _ShimPullingState.FAILED: + logger.warning( + "%s: failed due to %s, age=%s", + fmt(context.job_model), + get_or_error(shim_state.termination_reason).value, + context.job_submission.age, + ) + _terminate_running_job( + job_model=context.job_model, + result=result, + termination_reason=get_or_error(shim_state.termination_reason), + termination_reason_message=get_or_error(shim_state.termination_reason_message), + ) + return + + if shim_state is not None and shim_state.state == _ShimPullingState.READY: + job_runtime_data = _get_result_job_runtime_data(context.job_model, result) runner_availability = await run_async( _get_runner_availability, server_ssh_private_keys, @@ -594,7 +636,7 @@ async def _process_running_job_pulling_state( job_runtime_data, ) if runner_availability == _RunnerAvailability.UNAVAILABLE: - _reset_disconnected_at(context.job_model) + _reset_disconnected_at(context.job_model, result) return if runner_availability == _RunnerAvailability.AVAILABLE: @@ -607,7 +649,7 @@ async def _process_running_job_pulling_state( repo=context.repo_model, code_hash=_get_repo_code_hash(context.run, context.job), ) - success = await run_async( + submit_result = await run_async( _submit_job_to_runner, server_ssh_private_keys, job_provisioning_data, @@ -615,6 +657,7 @@ async def _process_running_job_pulling_state( run=context.run, job_model=context.job_model, job=context.job, + jrd=job_runtime_data, cluster_info=startup_context.cluster_info, code=code, file_archives=file_archives, @@ -622,22 +665,20 @@ async def _process_running_job_pulling_state( repo_credentials=startup_context.repo_creds, success_if_not_available=True, ) - if success: - _reset_disconnected_at(context.job_model) + if submit_result is not False: + _apply_submit_job_to_runner_result( + job_model=context.job_model, + result=result, + submit_result=submit_result, + ) + if submit_result is not False and submit_result.success: + _reset_disconnected_at(context.job_model, result) return - if context.job_model.termination_reason: - logger.warning( - "%s: failed due to %s, age=%s", - fmt(context.job_model), - context.job_model.termination_reason.value, - context.job_submission.age, - ) - _switch_job_status(context.job_model, JobStatus.TERMINATING) - return - - _set_disconnected_at_now(context.job_model) - if not _should_terminate_job_due_to_disconnect(context.job_model): + _set_disconnected_at_now(context.job_model, result) + if not _should_terminate_job_due_to_disconnect( + _get_result_disconnected_at(context.job_model, result) + ): logger.warning( "%s: is unreachable, waiting for the instance to become reachable again, age=%s", fmt(context.job_model), @@ -646,44 +687,32 @@ async def _process_running_job_pulling_state( return if job_provisioning_data.instance_type.resources.spot: - context.job_model.termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY else: - context.job_model.termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE - context.job_model.termination_reason_message = "Instance is unreachable" - _switch_job_status(context.job_model, JobStatus.TERMINATING) - - -def _build_job_update_map(job_model: JobModel) -> _JobUpdateMap: - return _JobUpdateMap( - status=job_model.status, - termination_reason=job_model.termination_reason, - termination_reason_message=job_model.termination_reason_message, - job_provisioning_data=job_model.job_provisioning_data, - job_runtime_data=job_model.job_runtime_data, - runner_timestamp=job_model.runner_timestamp, - disconnected_at=job_model.disconnected_at, - inactivity_secs=job_model.inactivity_secs, - exit_status=job_model.exit_status, + termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + _terminate_running_job( + job_model=context.job_model, + result=result, + termination_reason=termination_reason, + termination_reason_message="Instance is unreachable", ) async def _apply_process_result( item: JobRunningPipelineItem, job_model: JobModel, - initial_status: JobStatus, - initial_disconnected_at: Optional[datetime], - job_update_map: _JobUpdateMap, + result: _ProcessResult, ) -> None: async with get_session_ctx() as session: now = get_current_datetime() - resolve_now_placeholders(job_update_map, now=now) + resolve_now_placeholders(result.job_update_map, now=now) res = await session.execute( update(JobModel) .where( JobModel.id == item.id, JobModel.lock_token == item.lock_token, ) - .values(**job_update_map) + .values(**result.job_update_map) .returning(JobModel.id) ) updated_ids = list(res.scalars().all()) @@ -694,12 +723,12 @@ async def _apply_process_result( emit_job_status_change_event( session=session, job_model=job_model, - old_status=initial_status, - new_status=job_update_map.get("status", initial_status), - termination_reason=job_update_map.get( + old_status=job_model.status, + new_status=result.job_update_map.get("status", job_model.status), + termination_reason=result.job_update_map.get( "termination_reason", job_model.termination_reason ), - termination_reason_message=job_update_map.get( + termination_reason_message=result.job_update_map.get( "termination_reason_message", job_model.termination_reason_message, ), @@ -707,25 +736,28 @@ async def _apply_process_result( _emit_reachability_change_event( session=session, job_model=job_model, - initial_disconnected_at=initial_disconnected_at, - new_disconnected_at=job_update_map.get("disconnected_at", initial_disconnected_at), + old_disconnected_at=job_model.disconnected_at, + new_disconnected_at=result.job_update_map.get( + "disconnected_at", + job_model.disconnected_at, + ), ) def _emit_reachability_change_event( session: AsyncSession, job_model: JobModel, - initial_disconnected_at: Optional[datetime], + old_disconnected_at: Optional[datetime], new_disconnected_at: Optional[datetime], ) -> None: - if initial_disconnected_at is None and new_disconnected_at is not None: + if old_disconnected_at is None and new_disconnected_at is not None: events.emit( session, "Job became unreachable", actor=events.SystemActor(), targets=[events.Target.from_model(job_model)], ) - elif initial_disconnected_at is not None and new_disconnected_at is None: + elif old_disconnected_at is not None and new_disconnected_at is None: events.emit( session, "Job became reachable", @@ -736,15 +768,19 @@ def _emit_reachability_change_event( def _terminate_running_job( job_model: JobModel, + result: _ProcessResult, termination_reason: JobTerminationReason, termination_reason_message: str, ) -> None: - job_model.termination_reason = termination_reason - job_model.termination_reason_message = termination_reason_message - _switch_job_status(job_model, JobStatus.TERMINATING) + result.job_update_map["termination_reason"] = termination_reason + result.job_update_map["termination_reason_message"] = termination_reason_message + _set_job_status(job_model, result, JobStatus.TERMINATING) -def _wait_for_instance_provisioning_data(job_model: JobModel) -> None: +def _wait_for_instance_provisioning_data( + job_model: JobModel, + result: _ProcessResult, +) -> None: if job_model.instance is None: logger.error( "%s: cannot update job_provisioning_data. job_model.instance is None.", @@ -759,12 +795,15 @@ def _wait_for_instance_provisioning_data(job_model: JobModel) -> None: return if job_model.instance.status == InstanceStatus.TERMINATED: - job_model.termination_reason = JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED - job_model.termination_reason_message = "Instance is terminated" - _switch_job_status(job_model, JobStatus.TERMINATING) + _terminate_running_job( + job_model=job_model, + result=result, + termination_reason=JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED, + termination_reason_message="Instance is terminated", + ) return - job_model.job_provisioning_data = job_model.instance.job_provisioning_data + result.job_update_map["job_provisioning_data"] = job_model.instance.job_provisioning_data def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool: @@ -805,6 +844,7 @@ def _process_provisioning_with_shim( ports: Dict[int, int], run: Run, job_model: JobModel, + jrd: Optional[JobRuntimeData], jpd: JobProvisioningData, volumes: list[Volume], registry_auth: Optional[RegistryAuth], @@ -843,12 +883,11 @@ def _process_provisioning_with_shim( gpu_devices = get_instance_specific_gpu_devices(jpd.backend, jpd.instance_type.name) container_user = "root" - job_runtime_data = get_job_runtime_data(job_model) - if job_runtime_data is not None: - gpu = job_runtime_data.gpu - cpu = job_runtime_data.cpu - memory = job_runtime_data.memory - network_mode = job_runtime_data.network_mode + if jrd is not None: + gpu = jrd.gpu + cpu = jrd.cpu + memory = jrd.memory + network_mode = jrd.network_mode else: gpu = None cpu = None @@ -903,7 +942,6 @@ def _process_provisioning_with_shim( shim_client.stop(force=True) return False - _switch_job_status(job_model, JobStatus.PULLING) return True @@ -915,6 +953,15 @@ class _RunnerAvailability(enum.Enum): class _ShimPullingState(enum.Enum): WAITING = "waiting" READY = "ready" + FAILED = "failed" + + +@dataclass +class _SyncShimPullingStateResult: + state: _ShimPullingState + termination_reason: Optional[JobTerminationReason] = None + termination_reason_message: Optional[str] = None + job_runtime_data: Optional[JobRuntimeData] = None @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) @@ -929,7 +976,8 @@ def _get_runner_availability(ports: Dict[int, int]) -> _RunnerAvailability: def _sync_shim_pulling_state( ports: Dict[int, int], job_model: JobModel, -) -> Union[Literal[False], _ShimPullingState]: + jrd: Optional[JobRuntimeData] = None, +) -> Union[_SyncShimPullingStateResult, Literal[False]]: shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) if shim_client.is_api_v2_supported(): task = shim_client.get_task(job_model.id) @@ -941,19 +989,19 @@ def _sync_shim_pulling_state( task.termination_message, ) logger.debug("task status: %s", task.dict()) - job_model.termination_reason = JobTerminationReason(task.termination_reason.lower()) - job_model.termination_reason_message = task.termination_message - return False + return _SyncShimPullingStateResult( + state=_ShimPullingState.FAILED, + termination_reason=JobTerminationReason(task.termination_reason.lower()), + termination_reason_message=task.termination_message, + ) if task.status != TaskStatus.RUNNING: - return _ShimPullingState.WAITING + return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) - job_runtime_data = get_job_runtime_data(job_model) - if job_runtime_data is not None: + if jrd is not None: if task.ports is None: - return _ShimPullingState.WAITING - job_runtime_data.ports = {pm.container: pm.host for pm in task.ports} - job_model.job_runtime_data = job_runtime_data.json() + return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) + jrd.ports = {pm.container: pm.host for pm in task.ports} else: shim_status = shim_client.pull() if ( @@ -968,30 +1016,35 @@ def _sync_shim_pulling_state( shim_status.result.reason_message, ) logger.debug("shim status: %s", shim_status.dict()) - job_model.termination_reason = JobTerminationReason(shim_status.result.reason.lower()) - job_model.termination_reason_message = shim_status.result.reason_message - return False + return _SyncShimPullingStateResult( + state=_ShimPullingState.FAILED, + termination_reason=JobTerminationReason(shim_status.result.reason.lower()), + termination_reason_message=shim_status.result.reason_message, + ) if shim_status.state in ("pulling", "creating"): - return _ShimPullingState.WAITING + return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) - return _ShimPullingState.READY + return _SyncShimPullingStateResult( + state=_ShimPullingState.READY, + job_runtime_data=jrd, + ) -def _should_terminate_job_due_to_disconnect(job_model: JobModel) -> bool: - if job_model.disconnected_at is None: +def _should_terminate_job_due_to_disconnect(disconnected_at: Optional[datetime]) -> bool: + if disconnected_at is None: return False - return get_current_datetime() > job_model.disconnected_at + JOB_DISCONNECTED_RETRY_TIMEOUT + return get_current_datetime() > disconnected_at + JOB_DISCONNECTED_RETRY_TIMEOUT -def _set_disconnected_at_now(job_model: JobModel) -> None: - if job_model.disconnected_at is None: - job_model.disconnected_at = get_current_datetime() +def _set_disconnected_at_now(job_model: JobModel, result: _ProcessResult) -> None: + if _get_result_disconnected_at(job_model, result) is None: + result.job_update_map["disconnected_at"] = get_current_datetime() -def _reset_disconnected_at(job_model: JobModel) -> None: - if job_model.disconnected_at is not None: - job_model.disconnected_at = None +def _reset_disconnected_at(job_model: JobModel, result: _ProcessResult) -> None: + if _get_result_disconnected_at(job_model, result) is not None: + result.job_update_map["disconnected_at"] = None def _get_cluster_info( @@ -1084,19 +1137,27 @@ async def _get_job_file_archive(archive_id: uuid.UUID, user: UserModel) -> bytes return blob +@dataclass +class _SubmitJobToRunnerResult: + success: bool + set_running_status: bool = False + job_runtime_data: Optional[JobRuntimeData] = None + + @runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) def _submit_job_to_runner( ports: Dict[int, int], run: Run, job_model: JobModel, job: Job, + jrd: Optional[JobRuntimeData], cluster_info: ClusterInfo, code: bytes, file_archives: Iterable[tuple[uuid.UUID, bytes]], secrets: Dict[str, str], repo_credentials: Optional[RemoteRepoCreds], success_if_not_available: bool, -) -> bool: +) -> Union[_SubmitJobToRunnerResult, Literal[False]]: logger.debug("%s: submitting job spec", fmt(job_model)) logger.debug( "%s: repo clone URL is %s", @@ -1111,7 +1172,7 @@ def _submit_job_to_runner( runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) if runner_client.healthcheck() is None: - return success_if_not_available + return _SubmitJobToRunnerResult(success=success_if_not_available) runner_client.submit_job( run=run, @@ -1129,14 +1190,14 @@ def _submit_job_to_runner( logger.debug("%s: starting job", fmt(job_model)) job_info = runner_client.run_job() if job_info is not None: - job_runtime_data = get_job_runtime_data(job_model) - if job_runtime_data is not None: - job_runtime_data.working_dir = job_info.working_dir - job_runtime_data.username = job_info.username - job_model.job_runtime_data = job_runtime_data.json() - - _switch_job_status(job_model, JobStatus.RUNNING) - return True + if jrd is not None: + jrd.working_dir = job_info.working_dir + jrd.username = job_info.username + return _SubmitJobToRunnerResult( + success=True, + set_running_status=True, + job_runtime_data=jrd, + ) def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: @@ -1149,6 +1210,44 @@ def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: ) -def _switch_job_status(job_model: JobModel, new_status: JobStatus) -> None: - if job_model.status != new_status: - job_model.status = new_status +def _set_job_status(job_model: JobModel, result: _ProcessResult, new_status: JobStatus) -> None: + if _get_result_status(job_model, result) != new_status: + result.job_update_map["status"] = new_status + + +def _get_result_status(job_model: JobModel, result: _ProcessResult) -> JobStatus: + return result.job_update_map.get("status", job_model.status) + + +def _get_result_disconnected_at(job_model: JobModel, result: _ProcessResult) -> Optional[datetime]: + return result.job_update_map.get("disconnected_at", job_model.disconnected_at) + + +def _get_result_job_runtime_data( + job_model: JobModel, result: _ProcessResult +) -> Optional[JobRuntimeData]: + raw_job_runtime_data = result.job_update_map.get( + "job_runtime_data", job_model.job_runtime_data + ) + if raw_job_runtime_data is None: + return None + return JobRuntimeData.__response__.parse_raw(raw_job_runtime_data) + + +def _set_job_runtime_data( + result: _ProcessResult, job_runtime_data: Optional[JobRuntimeData] +) -> None: + result.job_update_map["job_runtime_data"] = ( + None if job_runtime_data is None else job_runtime_data.json() + ) + + +def _apply_submit_job_to_runner_result( + job_model: JobModel, + result: _ProcessResult, + submit_result: _SubmitJobToRunnerResult, +) -> None: + if submit_result.job_runtime_data is not None: + _set_job_runtime_data(result, submit_result.job_runtime_data) + if submit_result.set_running_status: + _set_job_status(job_model, result, JobStatus.RUNNING) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index f50c60e5f..5e5497e36 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -25,6 +25,7 @@ JobRunningPipelineItem, JobRunningWorker, _RunnerAvailability, + _SubmitJobToRunnerResult, ) from dstack._internal.server.schemas.runner import ( HealthcheckResponse, @@ -645,6 +646,48 @@ async def test_pulling_shim_port_mapping_not_ready( await session.refresh(job) assert job.status == JobStatus.PULLING + async def test_pulling_shim_waiting_resets_disconnect_and_emits_reachable_event( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + submitted_at=get_current_datetime(), + disconnected_at=get_current_datetime() - timedelta(minutes=1), + job_provisioning_data=get_job_provisioning_data(dockerized=True), + job_runtime_data=get_job_runtime_data(network_mode="bridge", ports=None), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + shim_client_mock.get_task.return_value.ports = None + + await _process_job(session, worker, job) + + ssh_tunnel_mock.assert_called_once() + shim_client_mock.get_task.assert_called_once() + runner_client_mock.healthcheck.assert_not_called() + await session.refresh(job) + events = await list_events(session) + assert job.status == JobStatus.PULLING + assert job.disconnected_at is None + assert len(events) == 1 + assert events[0].message == "Job became reachable" + async def test_pulling_shim_runner_not_ready( self, test_db, @@ -739,7 +782,7 @@ def assert_runner_availability(_, __, job_runtime_data): def assert_submit_job_to_runner(_, __, job_runtime_data, **kwargs): assert job_runtime_data is not None assert job_runtime_data.ports == expected_ports - return True + return _SubmitJobToRunnerResult(success=True) with ( patch( @@ -992,7 +1035,7 @@ async def invalidate_lock(*args, **kwargs): ), patch( "dstack._internal.server.background.pipeline_tasks.jobs_running._submit_job_to_runner", - return_value=True, + return_value=_SubmitJobToRunnerResult(success=True), ), ): await worker.process(_job_to_pipeline_item(job)) From 22e724da014fb6463e8391fa99d8932aff5dec4b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 10:46:14 +0500 Subject: [PATCH 04/21] Finish running jobs pipeline worker --- .../background/pipeline_tasks/jobs_running.py | 654 +++++++++++++---- .../pipeline_tasks/test_running_jobs.py | 675 +++++++++++++++++- 2 files changed, 1178 insertions(+), 151 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 927b41e31..922193bb6 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -5,15 +5,19 @@ from datetime import datetime, timedelta from typing import Dict, Iterable, Literal, Optional, Sequence, Union +import httpx from sqlalchemy import and_, func, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import GatewayError, SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import NetworkMode, RegistryAuth +from dstack._internal.core.models.configurations import DevEnvironmentConfiguration from dstack._internal.core.models.files import FileArchiveMapping -from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.instances import InstanceStatus, SSHConnectionParams +from dstack._internal.core.models.metrics import Metric from dstack._internal.core.models.profiles import StartupOrder from dstack._internal.core.models.repos import RemoteRepoCreds from dstack._internal.core.models.runs import ( @@ -26,6 +30,7 @@ JobSubmission, JobTerminationReason, Run, + RunSpec, RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint @@ -57,11 +62,13 @@ from dstack._internal.server.schemas.runner import TaskStatus from dstack._internal.server.services import events from dstack._internal.server.services import files as files_services +from dstack._internal.server.services import logs as logs_services from dstack._internal.server.services.backends.provisioning import ( get_instance_specific_gpu_devices, get_instance_specific_mounts, resolve_provisioning_image_name, ) +from dstack._internal.server.services.gateways import get_or_add_gateway_connection from dstack._internal.server.services.instances import ( get_instance_remote_connection_info, get_instance_ssh_private_keys, @@ -76,6 +83,7 @@ ) from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.metrics import get_job_metrics from dstack._internal.server.services.repos import ( get_code_model, get_repo_creds, @@ -83,7 +91,7 @@ ) from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel -from dstack._internal.server.services.runs import run_model_to_run +from dstack._internal.server.services.runs import is_job_ready, run_model_to_run from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.storage import get_default_storage from dstack._internal.server.utils import sentry_utils @@ -95,7 +103,7 @@ JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2) -"""`JOB_DISCONNECTED_RETRY_TIMEOUT` is the minimum time before terminating active job in case of connectivity issues.""" +"""`The minimum time before terminating active job in case of connectivity issues.""" @dataclass @@ -246,10 +254,7 @@ def __init__( @sentry_utils.instrument_named_task("pipeline_tasks.JobRunningWorker.process") async def process(self, item: JobRunningPipelineItem): - if item.status == JobStatus.RUNNING: - raise NotImplementedError("RUNNING-state migration is implemented in a later step") - - context = await _load_running_job_context(item=item) + context = await _load_process_context(item=item) if context is None: log_lock_token_mismatch(logger, item) return @@ -265,7 +270,7 @@ async def process(self, item: JobRunningPipelineItem): @dataclass -class _RunningJobContext: +class _ProcessContext: job_model: JobModel run_model: RunModel repo_model: RepoModel @@ -277,7 +282,36 @@ class _RunningJobContext: server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None -async def _load_running_job_context(item: JobRunningPipelineItem) -> Optional[_RunningJobContext]: +class _JobUpdateMap(ItemUpdateMap, total=False): + status: JobStatus + termination_reason: Optional[JobTerminationReason] + termination_reason_message: Optional[str] + job_provisioning_data: Optional[str] + job_runtime_data: Optional[str] + runner_timestamp: Optional[int] + disconnected_at: Optional[datetime] + inactivity_secs: Optional[int] + exit_status: Optional[int] + registered: bool + + +@dataclass +class _ProcessResult: + job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) + new_probe_models: list[ProbeModel] = field(default_factory=list) + emit_register_replica_event: bool = False + register_gateway_target: Optional[events.Target] = None + + +@dataclass +class _StartupContext: + cluster_info: ClusterInfo + volumes: list[Volume] + secrets: dict[str, str] + repo_creds: Optional[RemoteRepoCreds] + + +async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_ProcessContext]: async with get_session_ctx() as session: job_model = await _refetch_locked_job_model(session=session, item=item) if job_model is None: @@ -286,7 +320,7 @@ async def _load_running_job_context(item: JobRunningPipelineItem) -> Optional[_R run = run_model_to_run(run_model, include_sensitive=True) job_submission = job_model_to_job_submission(job_model) server_ssh_private_keys = get_instance_ssh_private_keys(get_or_error(job_model.instance)) - return _RunningJobContext( + return _ProcessContext( job_model=job_model, run_model=run_model, repo_model=run_model.repo, @@ -299,24 +333,7 @@ async def _load_running_job_context(item: JobRunningPipelineItem) -> Optional[_R ) -class _JobUpdateMap(ItemUpdateMap, total=False): - status: JobStatus - termination_reason: Optional[JobTerminationReason] - termination_reason_message: Optional[str] - job_provisioning_data: Optional[str] - job_runtime_data: Optional[str] - runner_timestamp: Optional[int] - disconnected_at: Optional[datetime] - inactivity_secs: Optional[int] - exit_status: Optional[int] - - -@dataclass -class _ProcessResult: - job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) - - -async def _process_running_job(context: _RunningJobContext) -> _ProcessResult: +async def _process_running_job(context: _ProcessContext) -> _ProcessResult: result = _ProcessResult() if context.job_provisioning_data is None: logger.error("%s: job_provisioning_data of an active job is None", fmt(context.job_model)) @@ -332,7 +349,7 @@ async def _process_running_job(context: _RunningJobContext) -> _ProcessResult: startup_context = None if context.job_model.status in [JobStatus.PROVISIONING, JobStatus.PULLING]: - startup_context = await _prepare_running_job_startup_context( + startup_context = await _prepare_startup_context( context=context, result=result, ) @@ -340,32 +357,39 @@ async def _process_running_job(context: _RunningJobContext) -> _ProcessResult: return result if context.job_model.status == JobStatus.PROVISIONING: - await _process_running_job_provisioning_state( + await _process_provisioning_status( context=context, startup_context=get_or_error(startup_context), result=result, ) elif context.job_model.status == JobStatus.PULLING: - await _process_running_job_pulling_state( + await _process_pulling_status( context=context, startup_context=get_or_error(startup_context), result=result, ) - return result - + else: + await _process_running_status( + context=context, + result=result, + ) -@dataclass -class _RunningJobStartupContext: - cluster_info: ClusterInfo - volumes: list[Volume] - secrets: dict[str, str] - repo_creds: Optional[RemoteRepoCreds] + if _get_result_status(context.job_model, result) == JobStatus.RUNNING: + if context.job_model.status != JobStatus.RUNNING: + _initialize_running_job_probes( + job_model=context.job_model, + job=context.job, + result=result, + ) + await _maybe_register_replica(context=context, result=result) + await _check_gpu_utilization(context=context, result=result) + return result -async def _prepare_running_job_startup_context( - context: _RunningJobContext, +async def _prepare_startup_context( + context: _ProcessContext, result: _ProcessResult, -) -> Optional[_RunningJobStartupContext]: +) -> Optional[_StartupContext]: job_provisioning_data = get_or_error(context.job_provisioning_data) for other_job in context.run.jobs: @@ -405,6 +429,7 @@ async def _prepare_running_job_startup_context( context.repo_model, repo_creds_model, ).repo_creds + try: _interpolate_secrets(secrets, context.job.job_spec) except InterpolatorError as e: @@ -416,7 +441,7 @@ async def _prepare_running_job_startup_context( ) return None - return _RunningJobStartupContext( + return _StartupContext( cluster_info=cluster_info, volumes=volumes, secrets=secrets, @@ -475,9 +500,9 @@ async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel return res.unique().scalar_one() -async def _process_running_job_provisioning_state( - context: _RunningJobContext, - startup_context: _RunningJobStartupContext, +async def _process_provisioning_status( + context: _ProcessContext, + startup_context: _StartupContext, result: _ProcessResult, ) -> None: job_provisioning_data = get_or_error(context.job_provisioning_data) @@ -582,9 +607,9 @@ async def _process_running_job_provisioning_state( ) -async def _process_running_job_pulling_state( - context: _RunningJobContext, - startup_context: _RunningJobStartupContext, +async def _process_pulling_status( + context: _ProcessContext, + startup_context: _StartupContext, result: _ProcessResult, ) -> None: job_provisioning_data = get_or_error(context.job_provisioning_data) @@ -698,6 +723,54 @@ async def _process_running_job_pulling_state( ) +async def _process_running_status( + context: _ProcessContext, + result: _ProcessResult, +) -> None: + job_provisioning_data = get_or_error(context.job_provisioning_data) + server_ssh_private_keys = get_or_error(context.server_ssh_private_keys) + + logger.debug( + "%s: process running job, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + process_running_result = await run_async( + _process_running, + server_ssh_private_keys, + job_provisioning_data, + context.job_submission.job_runtime_data, + run_model=context.run_model, + job_model=context.job_model, + ) + if process_running_result is not False: + result.job_update_map.update(process_running_result.job_update_map) + _reset_disconnected_at(context.job_model, result) + return + + _set_disconnected_at_now(context.job_model, result) + if not _should_terminate_job_due_to_disconnect( + _get_result_disconnected_at(context.job_model, result) + ): + logger.warning( + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + return + + if job_provisioning_data.instance_type.resources.spot: + termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + else: + termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + _terminate_running_job( + job_model=context.job_model, + result=result, + termination_reason=termination_reason, + termination_reason_message="Instance is unreachable", + ) + + async def _apply_process_result( item: JobRunningPipelineItem, job_model: JobModel, @@ -720,6 +793,9 @@ async def _apply_process_result( log_lock_token_changed_after_processing(logger, item) return + if result.new_probe_models: + session.add_all(result.new_probe_models) + emit_job_status_change_event( session=session, job_model=job_model, @@ -742,28 +818,16 @@ async def _apply_process_result( job_model.disconnected_at, ), ) - - -def _emit_reachability_change_event( - session: AsyncSession, - job_model: JobModel, - old_disconnected_at: Optional[datetime], - new_disconnected_at: Optional[datetime], -) -> None: - if old_disconnected_at is None and new_disconnected_at is not None: - events.emit( - session, - "Job became unreachable", - actor=events.SystemActor(), - targets=[events.Target.from_model(job_model)], - ) - elif old_disconnected_at is not None and new_disconnected_at is None: - events.emit( - session, - "Job became reachable", - actor=events.SystemActor(), - targets=[events.Target.from_model(job_model)], - ) + if result.emit_register_replica_event: + targets = [events.Target.from_model(job_model)] + if result.register_gateway_target is not None: + targets.append(result.register_gateway_target) + events.emit( + session, + "Service replica registered to receive requests", + actor=events.SystemActor(), + targets=targets, + ) def _terminate_running_job( @@ -772,9 +836,12 @@ def _terminate_running_job( termination_reason: JobTerminationReason, termination_reason_message: str, ) -> None: - result.job_update_map["termination_reason"] = termination_reason - result.job_update_map["termination_reason_message"] = termination_reason_message - _set_job_status(job_model, result, JobStatus.TERMINATING) + _terminate_job( + job_model=job_model, + job_update_map=result.job_update_map, + termination_reason=termination_reason, + termination_reason_message=termination_reason_message, + ) def _wait_for_instance_provisioning_data( @@ -806,6 +873,161 @@ def _wait_for_instance_provisioning_data( result.job_update_map["job_provisioning_data"] = job_model.instance.job_provisioning_data +def _initialize_running_job_probes( + job_model: JobModel, + job: Job, + result: _ProcessResult, +) -> None: + for probe_num in range(len(job.job_spec.probes)): + result.new_probe_models.append( + ProbeModel( + name=f"{job_model.job_name}-{probe_num}", + job_id=job_model.id, + probe_num=probe_num, + due=get_current_datetime(), + success_streak=0, + active=True, + ) + ) + + +async def _maybe_register_replica( + context: _ProcessContext, + result: _ProcessResult, +) -> None: + if ( + context.run.run_spec.configuration.type != "service" + or _get_result_registered(context.job_model, result) + or context.job_model.job_num != 0 + or result.new_probe_models + or not is_job_ready(context.job_model.probes, context.job.job_spec.probes) + ): + return + + ssh_head_proxy: Optional[SSHConnectionParams] = None + ssh_head_proxy_private_key: Optional[str] = None + instance = get_or_error(context.job_model.instance) + rci = get_instance_remote_connection_info(instance) + if rci is not None and rci.ssh_proxy is not None: + ssh_head_proxy = rci.ssh_proxy + ssh_head_proxy_keys = get_or_error(rci.ssh_proxy_keys) + ssh_head_proxy_private_key = ssh_head_proxy_keys[0].private + + try: + gateway_target = await _register_service_replica( + context=context, + result=result, + ssh_head_proxy=ssh_head_proxy, + ssh_head_proxy_private_key=ssh_head_proxy_private_key, + ) + except GatewayError as e: + logger.warning("%s: failed to register service replica: %s", fmt(context.job_model), e) + _terminate_running_job( + job_model=context.job_model, + result=result, + termination_reason=JobTerminationReason.GATEWAY_ERROR, + termination_reason_message="Failed to register service replica", + ) + return + + result.job_update_map["registered"] = True + result.emit_register_replica_event = True + result.register_gateway_target = gateway_target + + +async def _register_service_replica( + context: _ProcessContext, + result: _ProcessResult, + ssh_head_proxy: Optional[SSHConnectionParams], + ssh_head_proxy_private_key: Optional[str], +) -> Optional[events.Target]: + if context.run_model.gateway_id is None: + return None + + async with get_session_ctx() as session: + gateway_model, conn = await get_or_add_gateway_connection( + session, context.run_model.gateway_id + ) + gateway_target = events.Target.from_model(gateway_model) + try: + logger.debug( + "%s: registering replica for service %s", fmt(context.job_model), context.run.id.hex + ) + # JobRuntimeData might change on PULLING -> RUNNING path + # so we must update job_submission with the result value. + job_submission = context.job_submission.copy(deep=True) + job_submission.job_runtime_data = _get_result_job_runtime_data(context.job_model, result) + async with conn.client() as gateway_client: + await gateway_client.register_replica( + run=context.run, + job_spec=JobSpec.__response__.parse_raw(context.job_model.job_spec_data), + job_submission=job_submission, + ssh_head_proxy=ssh_head_proxy, + ssh_head_proxy_private_key=ssh_head_proxy_private_key, + ) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + except GatewayError as e: + if "already exists in service" in e.msg: + logger.warning( + ( + "%s: could not register replica in gateway: %s." + " NOTE: if you just updated dstack from pre-0.19.25 to 0.19.25+," + " expect to see this warning once for every running service replica" + ), + fmt(context.job_model), + e.msg, + ) + else: + raise + return gateway_target + + +async def _check_gpu_utilization( + context: _ProcessContext, + result: _ProcessResult, +) -> None: + policy = context.job.job_spec.utilization_policy + if policy is None: + return + + after = get_current_datetime() - timedelta(seconds=policy.time_window) + async with get_session_ctx() as session: + job_metrics = await get_job_metrics(session, context.job_model, after=after) + gpus_util_metrics: list[Metric] = [] + for metric in job_metrics.metrics: + if metric.name.startswith("gpu_util_percent_gpu"): + gpus_util_metrics.append(metric) + if not gpus_util_metrics or gpus_util_metrics[0].timestamps[-1] > after + timedelta(minutes=1): + logger.debug("%s: GPU utilization check: not enough samples", fmt(context.job_model)) + return + if _should_terminate_due_to_low_gpu_util( + policy.min_gpu_utilization, [metric.values for metric in gpus_util_metrics] + ): + logger.debug("%s: GPU utilization check: terminating", fmt(context.job_model)) + _terminate_running_job( + job_model=context.job_model, + result=result, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message=( + f"The job GPU utilization below {policy.min_gpu_utilization}%" + f" for {policy.time_window} seconds" + ), + ) + else: + logger.debug("%s: GPU utilization check: OK", fmt(context.job_model)) + + +def _should_terminate_due_to_low_gpu_util( + min_util: int, gpus_util: Iterable[Iterable[int]] +) -> bool: + for gpu_util in gpus_util: + if all(util < min_util for util in gpu_util): + return True + return False + + def _should_wait_for_other_nodes(run: Run, job: Job, job_model: JobModel) -> bool: for other_job in run.jobs: if ( @@ -1031,6 +1253,168 @@ def _sync_shim_pulling_state( ) +@dataclass +class _SubmitJobToRunnerResult: + success: bool + set_running_status: bool = False + job_runtime_data: Optional[JobRuntimeData] = None + + +@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) +def _submit_job_to_runner( + ports: Dict[int, int], + run: Run, + job_model: JobModel, + job: Job, + jrd: Optional[JobRuntimeData], + cluster_info: ClusterInfo, + code: bytes, + file_archives: Iterable[tuple[uuid.UUID, bytes]], + secrets: Dict[str, str], + repo_credentials: Optional[RemoteRepoCreds], + success_if_not_available: bool, +) -> Union[_SubmitJobToRunnerResult, Literal[False]]: + logger.debug("%s: submitting job spec", fmt(job_model)) + logger.debug( + "%s: repo clone URL is %s", + fmt(job_model), + None if repo_credentials is None else repo_credentials.clone_url, + ) + instance = job_model.instance + if instance is not None and (rci := get_instance_remote_connection_info(instance)) is not None: + instance_env = rci.env + else: + instance_env = None + + runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + if runner_client.healthcheck() is None: + return _SubmitJobToRunnerResult(success=success_if_not_available) + + runner_client.submit_job( + run=run, + job=job, + cluster_info=cluster_info, + secrets={}, + repo_credentials=repo_credentials, + instance_env=instance_env, + ) + logger.debug("%s: uploading file archive(s)", fmt(job_model)) + for archive_id, archive in file_archives: + runner_client.upload_archive(archive_id, archive) + logger.debug("%s: uploading code", fmt(job_model)) + runner_client.upload_code(code) + logger.debug("%s: starting job", fmt(job_model)) + job_info = runner_client.run_job() + if job_info is not None: + if jrd is not None: + jrd.working_dir = job_info.working_dir + jrd.username = job_info.username + return _SubmitJobToRunnerResult( + success=True, + set_running_status=True, + job_runtime_data=jrd, + ) + + +@dataclass +class _ProcessRunningResult: + job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) + + +@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT]) +def _process_running( + ports: Dict[int, int], + run_model: RunModel, + job_model: JobModel, +) -> Union[_ProcessRunningResult, Literal[False]]: + runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) + timestamp = job_model.runner_timestamp or 0 + resp = runner_client.pull(timestamp) + logs_services.write_logs( + project=run_model.project, + run_name=run_model.run_name, + job_submission_id=job_model.id, + runner_logs=resp.runner_logs, + job_logs=resp.job_logs, + ) + result = _ProcessRunningResult( + job_update_map=_JobUpdateMap(runner_timestamp=resp.last_updated) + ) + if len(resp.job_states) > 0: + latest_state_event = resp.job_states[-1] + latest_status = latest_state_event.state + if latest_status == JobStatus.DONE: + _terminate_job( + job_model=job_model, + job_update_map=result.job_update_map, + termination_reason=JobTerminationReason.DONE_BY_RUNNER, + termination_reason_message=None, + ) + elif latest_status in {JobStatus.FAILED, JobStatus.TERMINATED}: + termination_reason = JobTerminationReason.CONTAINER_EXITED_WITH_ERROR + if latest_state_event.termination_reason: + termination_reason = JobTerminationReason( + latest_state_event.termination_reason.lower() + ) + _terminate_job( + job_model=job_model, + job_update_map=result.job_update_map, + termination_reason=termination_reason, + termination_reason_message=latest_state_event.termination_message, + ) + if latest_state_event.exit_status is not None: + result.job_update_map["exit_status"] = latest_state_event.exit_status + if latest_state_event.exit_status != 0: + logger.info( + "%s: non-zero exit status %s", fmt(job_model), latest_state_event.exit_status + ) + else: + _terminate_if_inactivity_duration_exceeded( + run_model=run_model, + job_model=job_model, + job_update_map=result.job_update_map, + no_connections_secs=resp.no_connections_secs, + ) + return result + + +def _terminate_if_inactivity_duration_exceeded( + run_model: RunModel, + job_model: JobModel, + job_update_map: _JobUpdateMap, + no_connections_secs: Optional[int], +) -> None: + conf = RunSpec.__response__.parse_raw(run_model.run_spec).configuration + if not isinstance(conf, DevEnvironmentConfiguration) or not isinstance( + conf.inactivity_duration, int + ): + job_update_map["inactivity_secs"] = None + return + + logger.debug("%s: no SSH connections for %s seconds", fmt(job_model), no_connections_secs) + job_update_map["inactivity_secs"] = no_connections_secs + if no_connections_secs is None: + _terminate_job( + job_model=job_model, + job_update_map=job_update_map, + termination_reason=JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY, + termination_reason_message=( + "The selected instance was created before dstack 0.18.41" + " and does not support inactivity_duration" + ), + ) + elif no_connections_secs >= conf.inactivity_duration: + _terminate_job( + job_model=job_model, + job_update_map=job_update_map, + termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason_message=( + f"The job was inactive for {no_connections_secs} seconds," + f" exceeding the inactivity_duration of {conf.inactivity_duration} seconds" + ), + ) + + def _should_terminate_job_due_to_disconnect(disconnected_at: Optional[datetime]) -> bool: if disconnected_at is None: return False @@ -1137,67 +1521,26 @@ async def _get_job_file_archive(archive_id: uuid.UUID, user: UserModel) -> bytes return blob -@dataclass -class _SubmitJobToRunnerResult: - success: bool - set_running_status: bool = False - job_runtime_data: Optional[JobRuntimeData] = None - - -@runner_ssh_tunnel(ports=[DSTACK_RUNNER_HTTP_PORT], retries=1) -def _submit_job_to_runner( - ports: Dict[int, int], - run: Run, +def _emit_reachability_change_event( + session: AsyncSession, job_model: JobModel, - job: Job, - jrd: Optional[JobRuntimeData], - cluster_info: ClusterInfo, - code: bytes, - file_archives: Iterable[tuple[uuid.UUID, bytes]], - secrets: Dict[str, str], - repo_credentials: Optional[RemoteRepoCreds], - success_if_not_available: bool, -) -> Union[_SubmitJobToRunnerResult, Literal[False]]: - logger.debug("%s: submitting job spec", fmt(job_model)) - logger.debug( - "%s: repo clone URL is %s", - fmt(job_model), - None if repo_credentials is None else repo_credentials.clone_url, - ) - instance = job_model.instance - if instance is not None and (rci := get_instance_remote_connection_info(instance)) is not None: - instance_env = rci.env - else: - instance_env = None - - runner_client = client.RunnerClient(port=ports[DSTACK_RUNNER_HTTP_PORT]) - if runner_client.healthcheck() is None: - return _SubmitJobToRunnerResult(success=success_if_not_available) - - runner_client.submit_job( - run=run, - job=job, - cluster_info=cluster_info, - secrets={}, - repo_credentials=repo_credentials, - instance_env=instance_env, - ) - logger.debug("%s: uploading file archive(s)", fmt(job_model)) - for archive_id, archive in file_archives: - runner_client.upload_archive(archive_id, archive) - logger.debug("%s: uploading code", fmt(job_model)) - runner_client.upload_code(code) - logger.debug("%s: starting job", fmt(job_model)) - job_info = runner_client.run_job() - if job_info is not None: - if jrd is not None: - jrd.working_dir = job_info.working_dir - jrd.username = job_info.username - return _SubmitJobToRunnerResult( - success=True, - set_running_status=True, - job_runtime_data=jrd, - ) + old_disconnected_at: Optional[datetime], + new_disconnected_at: Optional[datetime], +) -> None: + if old_disconnected_at is None and new_disconnected_at is not None: + events.emit( + session, + "Job became unreachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(job_model)], + ) + elif old_disconnected_at is not None and new_disconnected_at is None: + events.emit( + session, + "Job became reachable", + actor=events.SystemActor(), + targets=[events.Target.from_model(job_model)], + ) def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: @@ -1210,9 +1553,28 @@ def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: ) +def _set_job_update_status( + job_model: JobModel, + job_update_map: _JobUpdateMap, + new_status: JobStatus, +) -> None: + if job_update_map.get("status", job_model.status) != new_status: + job_update_map["status"] = new_status + + +def _terminate_job( + job_model: JobModel, + job_update_map: _JobUpdateMap, + termination_reason: JobTerminationReason, + termination_reason_message: Optional[str], +) -> None: + job_update_map["termination_reason"] = termination_reason + job_update_map["termination_reason_message"] = termination_reason_message + _set_job_update_status(job_model, job_update_map, JobStatus.TERMINATING) + + def _set_job_status(job_model: JobModel, result: _ProcessResult, new_status: JobStatus) -> None: - if _get_result_status(job_model, result) != new_status: - result.job_update_map["status"] = new_status + _set_job_update_status(job_model, result.job_update_map, new_status) def _get_result_status(job_model: JobModel, result: _ProcessResult) -> JobStatus: @@ -1226,20 +1588,18 @@ def _get_result_disconnected_at(job_model: JobModel, result: _ProcessResult) -> def _get_result_job_runtime_data( job_model: JobModel, result: _ProcessResult ) -> Optional[JobRuntimeData]: - raw_job_runtime_data = result.job_update_map.get( - "job_runtime_data", job_model.job_runtime_data - ) - if raw_job_runtime_data is None: + jrd = result.job_update_map.get("job_runtime_data", job_model.job_runtime_data) + if jrd is None: return None - return JobRuntimeData.__response__.parse_raw(raw_job_runtime_data) + return JobRuntimeData.__response__.parse_raw(jrd) -def _set_job_runtime_data( - result: _ProcessResult, job_runtime_data: Optional[JobRuntimeData] -) -> None: - result.job_update_map["job_runtime_data"] = ( - None if job_runtime_data is None else job_runtime_data.json() - ) +def _set_job_runtime_data(result: _ProcessResult, jrd: Optional[JobRuntimeData]) -> None: + result.job_update_map["job_runtime_data"] = None if jrd is None else jrd.json() + + +def _get_result_registered(job_model: JobModel, result: _ProcessResult) -> bool: + return result.job_update_map.get("registered", job_model.registered) def _apply_submit_job_to_runner_result( diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index 5e5497e36..750a5522b 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -1,17 +1,28 @@ import asyncio import uuid +from dataclasses import dataclass from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Optional from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from freezegun import freeze_time +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from dstack._internal import settings +from dstack._internal.core.errors import SSHError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import NetworkMode +from dstack._internal.core.models.configurations import ( + DevEnvironmentConfiguration, + ProbeConfig, + ServiceConfiguration, +) from dstack._internal.core.models.instances import InstanceStatus -from dstack._internal.core.models.profiles import StartupOrder +from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy from dstack._internal.core.models.runs import ( JobRuntimeData, JobStatus, @@ -19,6 +30,7 @@ RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, VolumeMountPoint, VolumeStatus +from dstack._internal.server import settings as server_settings from dstack._internal.server.background.pipeline_tasks.jobs_running import ( JobRunningFetcher, JobRunningPipeline, @@ -27,16 +39,23 @@ _RunnerAvailability, _SubmitJobToRunnerResult, ) +from dstack._internal.server.models import JobModel, ProbeModel from dstack._internal.server.schemas.runner import ( HealthcheckResponse, JobInfoResponse, + JobStateEvent, PortMapping, + PullResponse, TaskStatus, ) +from dstack._internal.server.services.runner.client import RunnerClient, ShimClient +from dstack._internal.server.services.runner.ssh import SSHTunnel from dstack._internal.server.services.volumes import volume_model_to_volume from dstack._internal.server.testing.common import ( create_instance, create_job, + create_job_metrics_point, + create_probe, create_project, create_repo, create_run, @@ -53,6 +72,12 @@ pytestmark = pytest.mark.usefixtures("image_config_mock") +@dataclass +class _ProbeSetup: + success_streak: int + ready_after: int + + @pytest.fixture def fetcher() -> JobRunningFetcher: return JobRunningFetcher( @@ -71,14 +96,14 @@ def worker() -> JobRunningWorker: @pytest.fixture def ssh_tunnel_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = MagicMock() + mock = MagicMock(spec_set=SSHTunnel) monkeypatch.setattr("dstack._internal.server.services.runner.ssh.SSHTunnel", mock) return mock @pytest.fixture def shim_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = Mock() + mock = Mock(spec_set=ShimClient) mock.healthcheck.return_value = HealthcheckResponse(service="dstack-shim", version="latest") monkeypatch.setattr( "dstack._internal.server.services.runner.client.ShimClient", Mock(return_value=mock) @@ -88,7 +113,7 @@ def shim_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: @pytest.fixture def runner_client_mock(monkeypatch: pytest.MonkeyPatch) -> Mock: - mock = Mock() + mock = Mock(spec_set=RunnerClient) mock.healthcheck.return_value = HealthcheckResponse( service="dstack-runner", version="0.0.1.dev2" ) @@ -1044,3 +1069,645 @@ async def invalidate_lock(*args, **kwargs): assert job.status == JobStatus.PROVISIONING assert job.lock_token == replacement_lock_token assert job.lock_token != original_lock_token + + async def test_updates_running_job( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + tmp_path: Path, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(dockerized=False), + instance=instance, + instance_assigned=True, + ) + last_processed_at = job.last_processed_at + + with ( + patch.object(server_settings, "SERVER_DIR_PATH", tmp_path), + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.pull.return_value = PullResponse( + job_states=[JobStateEvent(timestamp=1, state=JobStatus.RUNNING)], + job_logs=[], + runner_logs=[], + last_updated=1, + ) + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.runner_timestamp == 1 + + job.last_processed_at = last_processed_at + await session.commit() + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.pull.return_value = PullResponse( + job_states=[JobStateEvent(timestamp=1, state=JobStatus.DONE, exit_status=0)], + job_logs=[], + runner_logs=[], + last_updated=2, + ) + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.DONE_BY_RUNNER + assert job.exit_status == 0 + assert job.runner_timestamp == 2 + + async def test_running_job_disconnect_retries_then_terminates( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(dockerized=False), + instance=instance, + instance_assigned=True, + ) + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.ssh.time.sleep"), + ): + ssh_tunnel_cls.side_effect = SSHError + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 3 + + await session.refresh(job) + events = await list_events(session) + assert job.status == JobStatus.RUNNING + assert job.disconnected_at is not None + assert len(events) == 1 + assert events[0].message == "Job became unreachable" + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch("dstack._internal.server.services.runner.ssh.time.sleep"), + freeze_time(job.disconnected_at + timedelta(minutes=5)), + ): + ssh_tunnel_cls.side_effect = SSHError + await _process_job(session, worker, job) + assert ssh_tunnel_cls.call_count == 3 + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.INSTANCE_UNREACHABLE + + @pytest.mark.parametrize( + ( + "inactivity_duration", + "no_connections_secs", + "expected_status", + "expected_termination_reason", + "expected_inactivity_secs", + ), + [ + pytest.param( + "1h", + 60 * 60 - 1, + JobStatus.RUNNING, + None, + 60 * 60 - 1, + id="duration-not-exceeded", + ), + pytest.param( + "1h", + 60 * 60, + JobStatus.TERMINATING, + JobTerminationReason.TERMINATED_BY_SERVER, + 60 * 60, + id="duration-exceeded-exactly", + ), + pytest.param( + "1h", + 60 * 60 + 1, + JobStatus.TERMINATING, + JobTerminationReason.TERMINATED_BY_SERVER, + 60 * 60 + 1, + id="duration-exceeded", + ), + pytest.param("off", 60 * 60, JobStatus.RUNNING, None, None, id="duration-off"), + pytest.param(False, 60 * 60, JobStatus.RUNNING, None, None, id="duration-false"), + pytest.param(None, 60 * 60, JobStatus.RUNNING, None, None, id="duration-none"), + pytest.param( + "1h", + None, + JobStatus.TERMINATING, + JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY, + None, + id="legacy-runner", + ), + pytest.param( + None, + None, + JobStatus.RUNNING, + None, + None, + id="legacy-runner-without-duration", + ), + ], + ) + async def test_inactivity_duration( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + inactivity_duration, + no_connections_secs: Optional[int], + expected_status: JobStatus, + expected_termination_reason: Optional[JobTerminationReason], + expected_inactivity_secs: Optional[int], + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + run_name="test-run", + run_spec=get_run_spec( + run_name="test-run", + repo_id=repo.name, + configuration=DevEnvironmentConfiguration( + name="test-run", + ide="vscode", + inactivity_duration=inactivity_duration, + ), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(), + instance=instance, + instance_assigned=True, + ) + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + no_connections_secs=no_connections_secs, + ) + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + runner_client_mock.pull.assert_called_once() + + await session.refresh(job) + assert job.status == expected_status + assert job.termination_reason == expected_termination_reason + assert job.inactivity_secs == expected_inactivity_secs + + @pytest.mark.parametrize( + ["samples", "expected_status"], + [ + pytest.param( + [ + (datetime(2023, 1, 1, 12, 25, 20, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 25, 30, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 29, 50, tzinfo=timezone.utc), 40), + ], + JobStatus.RUNNING, + id="not-enough-points", + ), + pytest.param( + [ + (datetime(2023, 1, 1, 12, 20, 10, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 20, 20, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 29, 50, tzinfo=timezone.utc), 80), + ], + JobStatus.RUNNING, + id="any-above-min", + ), + pytest.param( + [ + (datetime(2023, 1, 1, 12, 10, 10, tzinfo=timezone.utc), 80), + (datetime(2023, 1, 1, 12, 10, 20, tzinfo=timezone.utc), 80), + (datetime(2023, 1, 1, 12, 20, 10, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 20, 20, tzinfo=timezone.utc), 30), + (datetime(2023, 1, 1, 12, 29, 50, tzinfo=timezone.utc), 40), + ], + JobStatus.TERMINATING, + id="all-below-min", + ), + ], + ) + @freeze_time(datetime(2023, 1, 1, 12, 30, tzinfo=timezone.utc)) + async def test_gpu_utilization( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + samples: list[tuple[datetime, int]], + expected_status: JobStatus, + ) -> None: + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + run_name="test-run", + run_spec=get_run_spec( + run_name="test-run", + repo_id=repo.name, + configuration=DevEnvironmentConfiguration( + name="test-run", + ide="vscode", + utilization_policy=UtilizationPolicy( + min_gpu_utilization=80, + time_window=600, + ), + ), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(), + instance=instance, + instance_assigned=True, + last_processed_at=datetime(2023, 1, 1, 11, 30, tzinfo=timezone.utc), + ) + for timestamp, gpu_util in samples: + await create_job_metrics_point( + session=session, + job_model=job, + timestamp=timestamp, + gpus_memory_usage_bytes=[1024, 1024], + gpus_util_percent=[gpu_util, 100], + ) + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as ssh_tunnel_cls, + patch( + "dstack._internal.server.services.runner.client.RunnerClient" + ) as runner_client_cls, + ): + runner_client_mock = runner_client_cls.return_value + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + no_connections_secs=0, + ) + await _process_job(session, worker, job) + ssh_tunnel_cls.assert_called_once() + runner_client_mock.pull.assert_called_once() + + await session.refresh(job) + assert job.status == expected_status + if expected_status == JobStatus.TERMINATING: + assert job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + assert job.termination_reason_message == ( + "The job GPU utilization below 80% for 600 seconds" + ) + else: + assert job.termination_reason is None + assert job.termination_reason_message is None + + @pytest.mark.parametrize("probe_count", [1, 2]) + async def test_creates_probe_models_and_not_registers_service_replica( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + probe_count: int, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, + image="ubuntu", + probes=[ProbeConfig(type="http", url=f"/{i}") for i in range(probe_count)], + ), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + assert len(job.probes) == 0 + await _process_job(session, worker, job) + + await session.refresh(job) + job = ( + await session.execute( + select(JobModel) + .where(JobModel.id == job.id) + .options(selectinload(JobModel.probes)) + ) + ).scalar_one() + assert job.status == JobStatus.RUNNING + assert [probe.probe_num for probe in job.probes] == list(range(probe_count)) + assert not job.registered + + async def test_registers_service_replica_immediately_if_no_probes( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration(port=80, image="ubuntu"), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + await _process_job(session, worker, job) + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.registered + events = await list_events(session) + assert {event.message for event in events} == { + "Job status changed PULLING -> RUNNING", + "Service replica registered to receive requests", + } + + @pytest.mark.parametrize( + ("probes", "expect_to_register"), + [ + ([_ProbeSetup(success_streak=0, ready_after=1)], False), + ([_ProbeSetup(success_streak=1, ready_after=1)], True), + ( + [ + _ProbeSetup(success_streak=1, ready_after=1), + _ProbeSetup(success_streak=1, ready_after=2), + ], + False, + ), + ( + [ + _ProbeSetup(success_streak=1, ready_after=1), + _ProbeSetup(success_streak=3, ready_after=2), + ], + True, + ), + ], + ) + async def test_registers_service_replica_only_after_probes_pass( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, + probes: list[_ProbeSetup], + expect_to_register: bool, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, + image="ubuntu", + probes=[ + ProbeConfig(type="http", url=f"/{i}", ready_after=probe.ready_after) + for i, probe in enumerate(probes) + ], + ), + ), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + registered=False, + ) + for i, probe in enumerate(probes): + await create_probe( + session=session, + job=job, + probe_num=i, + success_streak=probe.success_streak, + ) + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + ) + + await _process_job(session, worker, job) + + await session.refresh(job) + events = await list_events(session) + if expect_to_register: + assert job.registered + assert len(events) == 1 + assert events[0].message == "Service replica registered to receive requests" + else: + assert not job.registered + assert not events + + async def test_apply_skips_probe_insert_when_lock_token_changes_after_processing( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, + image="ubuntu", + probes=[ProbeConfig(type="http", url="/health")], + ), + ), + ) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + _lock_job(job) + await session.commit() + replacement_lock_token = uuid.uuid4() + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + async def invalidate_lock(*args, **kwargs): + job.lock_token = replacement_lock_token + await session.commit() + return b"" + + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_file_archives", + new_callable=AsyncMock, + return_value=[], + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._get_job_code", + new_callable=AsyncMock, + side_effect=invalidate_lock, + ), + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_running._submit_job_to_runner", + return_value=_SubmitJobToRunnerResult( + success=True, + set_running_status=True, + ), + ), + ): + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.PULLING + assert job.lock_token == replacement_lock_token + probes = ( + (await session.execute(select(ProbeModel).where(ProbeModel.job_id == job.id))) + .scalars() + .all() + ) + assert probes == [] From d5c35517a96d0d85328723e8321dd7833bc88356 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 10:53:05 +0500 Subject: [PATCH 05/21] Wire pipeline --- .../server/background/pipeline_tasks/__init__.py | 2 ++ .../server/background/scheduled_tasks/__init__.py | 12 ++++++------ .../server/background/scheduled_tasks/runs.py | 4 ++++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 7b2d79047..3d7efae47 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -5,6 +5,7 @@ from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline from dstack._internal.server.background.pipeline_tasks.instances import InstancePipeline +from dstack._internal.server.background.pipeline_tasks.jobs_running import JobRunningPipeline from dstack._internal.server.background.pipeline_tasks.jobs_terminating import ( JobTerminatingPipeline, ) @@ -23,6 +24,7 @@ def __init__(self) -> None: ComputeGroupPipeline(), FleetPipeline(), GatewayPipeline(), + JobRunningPipeline(), JobTerminatingPipeline(), InstancePipeline(), PlacementGroupPipeline(), diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index a0b97e7d5..c61d24953 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -128,12 +128,6 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 5}, max_instances=4 if replica == 0 else 1, ) - _scheduler.add_job( - process_running_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) _scheduler.add_job( process_runs, IntervalTrigger(seconds=2, jitter=1), @@ -153,6 +147,12 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 1}, max_instances=2 if replica == 0 else 1, ) + _scheduler.add_job( + process_running_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.add_job( process_terminating_jobs, IntervalTrigger(seconds=4, jitter=2), diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 56d9fea77..2a2405c4e 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -144,6 +144,10 @@ async def _process_next_run(): .where( JobModel.run_id == run_model.id, JobModel.id.not_in(job_lockset), + or_( + JobModel.lock_expires_at.is_(None), + JobModel.lock_expires_at < now, + ), ) .options( load_only(JobModel.id), From f392d53b4c86da5dc2172284c8a08ba9744fec0f Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 11:22:13 +0500 Subject: [PATCH 06/21] Restore TODOs and simplifify code --- .../background/pipeline_tasks/jobs_running.py | 188 +++++++++--------- 1 file changed, 97 insertions(+), 91 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 922193bb6..3453707cc 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -337,9 +337,9 @@ async def _process_running_job(context: _ProcessContext) -> _ProcessResult: result = _ProcessResult() if context.job_provisioning_data is None: logger.error("%s: job_provisioning_data of an active job is None", fmt(context.job_model)) - _terminate_running_job( + _terminate_job( job_model=context.job_model, - result=result, + job_update_map=result.job_update_map, termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, termination_reason_message=( "Unexpected server error: job_provisioning_data of an active job is None" @@ -433,9 +433,9 @@ async def _prepare_startup_context( try: _interpolate_secrets(secrets, context.job.job_spec) except InterpolatorError as e: - _terminate_running_job( + _terminate_job( job_model=context.job_model, - result=result, + job_update_map=result.job_update_map, termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, termination_reason_message=f"Secrets interpolation error: {e.args[0]}", ) @@ -596,9 +596,9 @@ async def _process_provisioning_status( instance_type_name=job_provisioning_data.instance_type.name, ) if context.job_submission.age > provisioning_timeout: - _terminate_running_job( + _terminate_job( job_model=context.job_model, - result=result, + job_update_map=result.job_update_map, termination_reason=JobTerminationReason.WAITING_RUNNER_LIMIT_EXCEEDED, termination_reason_message=( f"Runner did not become available within {provisioning_timeout.total_seconds()}s." @@ -628,31 +628,30 @@ async def _process_pulling_status( job_model=context.job_model, jrd=_get_result_job_runtime_data(context.job_model, result), ) - if shim_state is False: - shim_state = None - elif shim_state.job_runtime_data is not None: - _set_job_runtime_data(result, shim_state.job_runtime_data) + if shim_state is not False: + if shim_state.job_runtime_data is not None: + _set_job_runtime_data(result, shim_state.job_runtime_data) - if shim_state is not None and shim_state.state == _ShimPullingState.WAITING: - _reset_disconnected_at(context.job_model, result) - return + if shim_state.state == _ShimPullingState.WAITING: + _reset_disconnected_at(context.job_model, result) + return - if shim_state is not None and shim_state.state == _ShimPullingState.FAILED: - logger.warning( - "%s: failed due to %s, age=%s", - fmt(context.job_model), - get_or_error(shim_state.termination_reason).value, - context.job_submission.age, - ) - _terminate_running_job( - job_model=context.job_model, - result=result, - termination_reason=get_or_error(shim_state.termination_reason), - termination_reason_message=get_or_error(shim_state.termination_reason_message), - ) - return + if shim_state.state == _ShimPullingState.FAILED: + logger.warning( + "%s: failed due to %s, age=%s", + fmt(context.job_model), + get_or_error(shim_state.termination_reason).value, + context.job_submission.age, + ) + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=get_or_error(shim_state.termination_reason), + termination_reason_message=get_or_error(shim_state.termination_reason_message), + ) + return - if shim_state is not None and shim_state.state == _ShimPullingState.READY: + # _ShimPullingState.READY job_runtime_data = _get_result_job_runtime_data(context.job_model, result) runner_availability = await run_async( _get_runner_availability, @@ -700,6 +699,7 @@ async def _process_pulling_status( _reset_disconnected_at(context.job_model, result) return + # SSH tunnel failed or READY but runner submit failed — treat as disconnect _set_disconnected_at_now(context.job_model, result) if not _should_terminate_job_due_to_disconnect( _get_result_disconnected_at(context.job_model, result) @@ -715,9 +715,9 @@ async def _process_pulling_status( termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY else: termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE - _terminate_running_job( + _terminate_job( job_model=context.job_model, - result=result, + job_update_map=result.job_update_map, termination_reason=termination_reason, termination_reason_message="Instance is unreachable", ) @@ -763,9 +763,9 @@ async def _process_running_status( termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY else: termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE - _terminate_running_job( + _terminate_job( job_model=context.job_model, - result=result, + job_update_map=result.job_update_map, termination_reason=termination_reason, termination_reason_message="Instance is unreachable", ) @@ -796,52 +796,47 @@ async def _apply_process_result( if result.new_probe_models: session.add_all(result.new_probe_models) - emit_job_status_change_event( - session=session, - job_model=job_model, - old_status=job_model.status, - new_status=result.job_update_map.get("status", job_model.status), - termination_reason=result.job_update_map.get( - "termination_reason", job_model.termination_reason - ), - termination_reason_message=result.job_update_map.get( - "termination_reason_message", - job_model.termination_reason_message, - ), - ) - _emit_reachability_change_event( - session=session, - job_model=job_model, - old_disconnected_at=job_model.disconnected_at, - new_disconnected_at=result.job_update_map.get( - "disconnected_at", - job_model.disconnected_at, - ), - ) - if result.emit_register_replica_event: - targets = [events.Target.from_model(job_model)] - if result.register_gateway_target is not None: - targets.append(result.register_gateway_target) - events.emit( - session, - "Service replica registered to receive requests", - actor=events.SystemActor(), - targets=targets, - ) + _emit_result_events(session=session, job_model=job_model, result=result) -def _terminate_running_job( +def _emit_result_events( + session: AsyncSession, job_model: JobModel, result: _ProcessResult, - termination_reason: JobTerminationReason, - termination_reason_message: str, ) -> None: - _terminate_job( + """Emit audit events for changes recorded in result..""" + emit_job_status_change_event( + session=session, job_model=job_model, - job_update_map=result.job_update_map, - termination_reason=termination_reason, - termination_reason_message=termination_reason_message, + old_status=job_model.status, + new_status=result.job_update_map.get("status", job_model.status), + termination_reason=result.job_update_map.get( + "termination_reason", job_model.termination_reason + ), + termination_reason_message=result.job_update_map.get( + "termination_reason_message", + job_model.termination_reason_message, + ), + ) + _emit_reachability_change_event( + session=session, + job_model=job_model, + old_disconnected_at=job_model.disconnected_at, + new_disconnected_at=result.job_update_map.get( + "disconnected_at", + job_model.disconnected_at, + ), ) + if result.emit_register_replica_event: + targets = [events.Target.from_model(job_model)] + if result.register_gateway_target is not None: + targets.append(result.register_gateway_target) + events.emit( + session, + "Service replica registered to receive requests", + actor=events.SystemActor(), + targets=targets, + ) def _wait_for_instance_provisioning_data( @@ -862,9 +857,9 @@ def _wait_for_instance_provisioning_data( return if job_model.instance.status == InstanceStatus.TERMINATED: - _terminate_running_job( + _terminate_job( job_model=job_model, - result=result, + job_update_map=result.job_update_map, termination_reason=JobTerminationReason.WAITING_INSTANCE_LIMIT_EXCEEDED, termination_reason_message="Instance is terminated", ) @@ -922,9 +917,9 @@ async def _maybe_register_replica( ) except GatewayError as e: logger.warning("%s: failed to register service replica: %s", fmt(context.job_model), e) - _terminate_running_job( + _terminate_job( job_model=context.job_model, - result=result, + job_update_map=result.job_update_map, termination_reason=JobTerminationReason.GATEWAY_ERROR, termination_reason_message="Failed to register service replica", ) @@ -1006,9 +1001,10 @@ async def _check_gpu_utilization( policy.min_gpu_utilization, [metric.values for metric in gpus_util_metrics] ): logger.debug("%s: GPU utilization check: terminating", fmt(context.job_model)) - _terminate_running_job( + # TODO(0.19 or earlier): set JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY + _terminate_job( job_model=context.job_model, - result=result, + job_update_map=result.job_update_map, termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, termination_reason_message=( f"The job GPU utilization below {policy.min_gpu_utilization}%" @@ -1223,7 +1219,7 @@ def _sync_shim_pulling_state( if jrd is not None: if task.ports is None: return _SyncShimPullingStateResult(state=_ShimPullingState.WAITING) - jrd.ports = {pm.container: pm.host for pm in task.ports} + jrd = jrd.copy(update={"ports": {pm.container: pm.host for pm in task.ports}}) else: shim_status = shim_client.pull() if ( @@ -1294,6 +1290,8 @@ def _submit_job_to_runner( run=run, job=job, cluster_info=cluster_info, + # Do not send all the secrets since interpolation is already done by the server. + # TODO: Passing secrets may be necessary for filtering out secret values from logs. secrets={}, repo_credentials=repo_credentials, instance_env=instance_env, @@ -1394,6 +1392,7 @@ def _terminate_if_inactivity_duration_exceeded( logger.debug("%s: no SSH connections for %s seconds", fmt(job_model), no_connections_secs) job_update_map["inactivity_secs"] = no_connections_secs if no_connections_secs is None: + # TODO(0.19 or earlier): make no_connections_secs required _terminate_job( job_model=job_model, job_update_map=job_update_map, @@ -1404,6 +1403,7 @@ def _terminate_if_inactivity_duration_exceeded( ), ) elif no_connections_secs >= conf.inactivity_duration: + # TODO(0.19 or earlier): set JobTerminationReason.INACTIVITY_DURATION_EXCEEDED _terminate_job( job_model=job_model, job_update_map=job_update_map, @@ -1454,6 +1454,7 @@ def _get_cluster_info( def _get_repo_code_hash(run: Run, job: Job) -> Optional[str]: + # TODO: drop this function when supporting jobs submitted before 0.19.17 is no longer relevant. if ( job.job_spec.repo_code_hash is None and run.run_spec.repo_code_hash is not None @@ -1553,15 +1554,6 @@ def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: ) -def _set_job_update_status( - job_model: JobModel, - job_update_map: _JobUpdateMap, - new_status: JobStatus, -) -> None: - if job_update_map.get("status", job_model.status) != new_status: - job_update_map["status"] = new_status - - def _terminate_job( job_model: JobModel, job_update_map: _JobUpdateMap, @@ -1573,10 +1565,28 @@ def _terminate_job( _set_job_update_status(job_model, job_update_map, JobStatus.TERMINATING) +def _set_job_update_status( + job_model: JobModel, + job_update_map: _JobUpdateMap, + new_status: JobStatus, +) -> None: + if job_update_map.get("status", job_model.status) != new_status: + job_update_map["status"] = new_status + + def _set_job_status(job_model: JobModel, result: _ProcessResult, new_status: JobStatus) -> None: _set_job_update_status(job_model, result.job_update_map, new_status) +def _set_job_runtime_data(result: _ProcessResult, jrd: Optional[JobRuntimeData]) -> None: + result.job_update_map["job_runtime_data"] = None if jrd is None else jrd.json() + + +# Convention: _get_result_* helpers merge the loaded job_model state with any pending +# updates recorded in result.job_update_map. Always use these (not job_model.attr directly) +# when the field may have been updated earlier in the same processing cycle. + + def _get_result_status(job_model: JobModel, result: _ProcessResult) -> JobStatus: return result.job_update_map.get("status", job_model.status) @@ -1594,10 +1604,6 @@ def _get_result_job_runtime_data( return JobRuntimeData.__response__.parse_raw(jrd) -def _set_job_runtime_data(result: _ProcessResult, jrd: Optional[JobRuntimeData]) -> None: - result.job_update_map["job_runtime_data"] = None if jrd is None else jrd.json() - - def _get_result_registered(job_model: JobModel, result: _ProcessResult) -> bool: return result.job_update_map.get("registered", job_model.registered) From 0d1e4d949016694af3a6091c520f9249bcc0e7fc Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 11:22:41 +0500 Subject: [PATCH 07/21] Set TERMINATED_DUE_TO_UTILIZATION_POLICY --- .../_internal/server/background/pipeline_tasks/jobs_running.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 3453707cc..ece2c79e5 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -1001,11 +1001,10 @@ async def _check_gpu_utilization( policy.min_gpu_utilization, [metric.values for metric in gpus_util_metrics] ): logger.debug("%s: GPU utilization check: terminating", fmt(context.job_model)) - # TODO(0.19 or earlier): set JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY _terminate_job( job_model=context.job_model, job_update_map=result.job_update_map, - termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason=JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY, termination_reason_message=( f"The job GPU utilization below {policy.min_gpu_utilization}%" f" for {policy.time_window} seconds" From 8053e592de23753883a9e1fabf35966de1e2433e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 11:23:33 +0500 Subject: [PATCH 08/21] Set INACTIVITY_DURATION_EXCEEDED --- .../_internal/server/background/pipeline_tasks/jobs_running.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index ece2c79e5..37957b523 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -1402,11 +1402,10 @@ def _terminate_if_inactivity_duration_exceeded( ), ) elif no_connections_secs >= conf.inactivity_duration: - # TODO(0.19 or earlier): set JobTerminationReason.INACTIVITY_DURATION_EXCEEDED _terminate_job( job_model=job_model, job_update_map=job_update_map, - termination_reason=JobTerminationReason.TERMINATED_BY_SERVER, + termination_reason=JobTerminationReason.INACTIVITY_DURATION_EXCEEDED, termination_reason_message=( f"The job was inactive for {no_connections_secs} seconds," f" exceeding the inactivity_duration of {conf.inactivity_duration} seconds" From 0a0f09e5bec0aae740a725aa2aef1b7d428bbb4d Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 11:53:03 +0500 Subject: [PATCH 09/21] Extract _handle_instance_unreachable --- .../background/pipeline_tasks/jobs_running.py | 175 ++++++++---------- .../pipeline_tasks/test_running_jobs.py | 8 +- 2 files changed, 85 insertions(+), 98 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 37957b523..224df44f9 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -273,14 +273,20 @@ async def process(self, item: JobRunningPipelineItem): class _ProcessContext: job_model: JobModel run_model: RunModel - repo_model: RepoModel - project: ProjectModel run: Run job: Job job_submission: JobSubmission job_provisioning_data: Optional[JobProvisioningData] server_ssh_private_keys: Optional[tuple[str, Optional[str]]] = None + @property + def repo_model(self) -> RepoModel: + return self.run_model.repo + + @property + def project(self) -> ProjectModel: + return self.run_model.project + class _JobUpdateMap(ItemUpdateMap, total=False): status: JobStatus @@ -295,12 +301,16 @@ class _JobUpdateMap(ItemUpdateMap, total=False): registered: bool +@dataclass +class _RegisterReplicaResult: + gateway_target: Optional[events.Target] # None = no gateway + + @dataclass class _ProcessResult: job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) new_probe_models: list[ProbeModel] = field(default_factory=list) - emit_register_replica_event: bool = False - register_gateway_target: Optional[events.Target] = None + replica_registration: Optional[_RegisterReplicaResult] = None # None = not registered yet @dataclass @@ -323,8 +333,6 @@ async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_Proce return _ProcessContext( job_model=job_model, run_model=run_model, - repo_model=run_model.repo, - project=run_model.project, run=run, job=find_job(run.jobs, job_model.replica_num, job_model.job_num), job_submission=job_submission, @@ -347,32 +355,22 @@ async def _process_running_job(context: _ProcessContext) -> _ProcessResult: ) return result - startup_context = None - if context.job_model.status in [JobStatus.PROVISIONING, JobStatus.PULLING]: - startup_context = await _prepare_startup_context( - context=context, - result=result, - ) + if context.job_model.status == JobStatus.PROVISIONING: + startup_context = await _prepare_startup_context(context=context, result=result) if startup_context is None: return result - - if context.job_model.status == JobStatus.PROVISIONING: await _process_provisioning_status( - context=context, - startup_context=get_or_error(startup_context), - result=result, + context=context, startup_context=startup_context, result=result ) elif context.job_model.status == JobStatus.PULLING: + startup_context = await _prepare_startup_context(context=context, result=result) + if startup_context is None: + return result await _process_pulling_status( - context=context, - startup_context=get_or_error(startup_context), - result=result, - ) - else: - await _process_running_status( - context=context, - result=result, + context=context, startup_context=startup_context, result=result ) + elif context.job_model.status == JobStatus.RUNNING: + await _process_running_status(context=context, result=result) if _get_result_status(context.job_model, result) == JobStatus.RUNNING: if context.job_model.status != JobStatus.RUNNING: @@ -700,27 +698,7 @@ async def _process_pulling_status( return # SSH tunnel failed or READY but runner submit failed — treat as disconnect - _set_disconnected_at_now(context.job_model, result) - if not _should_terminate_job_due_to_disconnect( - _get_result_disconnected_at(context.job_model, result) - ): - logger.warning( - "%s: is unreachable, waiting for the instance to become reachable again, age=%s", - fmt(context.job_model), - context.job_submission.age, - ) - return - - if job_provisioning_data.instance_type.resources.spot: - termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY - else: - termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE - _terminate_job( - job_model=context.job_model, - job_update_map=result.job_update_map, - termination_reason=termination_reason, - termination_reason_message="Instance is unreachable", - ) + _handle_instance_unreachable(context, result, job_provisioning_data) async def _process_running_status( @@ -748,27 +726,7 @@ async def _process_running_status( _reset_disconnected_at(context.job_model, result) return - _set_disconnected_at_now(context.job_model, result) - if not _should_terminate_job_due_to_disconnect( - _get_result_disconnected_at(context.job_model, result) - ): - logger.warning( - "%s: is unreachable, waiting for the instance to become reachable again, age=%s", - fmt(context.job_model), - context.job_submission.age, - ) - return - - if job_provisioning_data.instance_type.resources.spot: - termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY - else: - termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE - _terminate_job( - job_model=context.job_model, - job_update_map=result.job_update_map, - termination_reason=termination_reason, - termination_reason_message="Instance is unreachable", - ) + _handle_instance_unreachable(context, result, job_provisioning_data) async def _apply_process_result( @@ -827,10 +785,10 @@ def _emit_result_events( job_model.disconnected_at, ), ) - if result.emit_register_replica_event: + if result.replica_registration is not None: targets = [events.Target.from_model(job_model)] - if result.register_gateway_target is not None: - targets.append(result.register_gateway_target) + if result.replica_registration.gateway_target is not None: + targets.append(result.replica_registration.gateway_target) events.emit( session, "Service replica registered to receive requests", @@ -868,6 +826,33 @@ def _wait_for_instance_provisioning_data( result.job_update_map["job_provisioning_data"] = job_model.instance.job_provisioning_data +def _handle_instance_unreachable( + context: _ProcessContext, + result: _ProcessResult, + job_provisioning_data: JobProvisioningData, +) -> None: + _set_disconnected_at_now(context.job_model, result) + if not _should_terminate_job_due_to_disconnect( + _get_result_disconnected_at(context.job_model, result) + ): + logger.warning( + "%s: is unreachable, waiting for the instance to become reachable again, age=%s", + fmt(context.job_model), + context.job_submission.age, + ) + return + if job_provisioning_data.instance_type.resources.spot: + termination_reason = JobTerminationReason.INTERRUPTED_BY_NO_CAPACITY + else: + termination_reason = JobTerminationReason.INSTANCE_UNREACHABLE + _terminate_job( + job_model=context.job_model, + job_update_map=result.job_update_map, + termination_reason=termination_reason, + termination_reason_message="Instance is unreachable", + ) + + def _initialize_running_job_probes( job_model: JobModel, job: Job, @@ -926,8 +911,7 @@ async def _maybe_register_replica( return result.job_update_map["registered"] = True - result.emit_register_replica_event = True - result.register_gateway_target = gateway_target + result.replica_registration = _RegisterReplicaResult(gateway_target=gateway_target) async def _register_service_replica( @@ -1304,8 +1288,9 @@ def _submit_job_to_runner( job_info = runner_client.run_job() if job_info is not None: if jrd is not None: - jrd.working_dir = job_info.working_dir - jrd.username = job_info.username + jrd = jrd.copy( + update={"working_dir": job_info.working_dir, "username": job_info.username} + ) return _SubmitJobToRunnerResult( success=True, set_running_status=True, @@ -1520,6 +1505,16 @@ async def _get_job_file_archive(archive_id: uuid.UUID, user: UserModel) -> bytes return blob +def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: + interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error + job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()} + if job_spec.registry_auth is not None: + job_spec.registry_auth = RegistryAuth( + username=interpolate(job_spec.registry_auth.username), + password=interpolate(job_spec.registry_auth.password), + ) + + def _emit_reachability_change_event( session: AsyncSession, job_model: JobModel, @@ -1542,16 +1537,6 @@ def _emit_reachability_change_event( ) -def _interpolate_secrets(secrets: Dict[str, str], job_spec: JobSpec) -> None: - interpolate = VariablesInterpolator({"secrets": secrets}).interpolate_or_error - job_spec.env = {k: interpolate(v) for k, v in job_spec.env.items()} - if job_spec.registry_auth is not None: - job_spec.registry_auth = RegistryAuth( - username=interpolate(job_spec.registry_auth.username), - password=interpolate(job_spec.registry_auth.password), - ) - - def _terminate_job( job_model: JobModel, job_update_map: _JobUpdateMap, @@ -1580,6 +1565,17 @@ def _set_job_runtime_data(result: _ProcessResult, jrd: Optional[JobRuntimeData]) result.job_update_map["job_runtime_data"] = None if jrd is None else jrd.json() +def _apply_submit_job_to_runner_result( + job_model: JobModel, + result: _ProcessResult, + submit_result: _SubmitJobToRunnerResult, +) -> None: + if submit_result.job_runtime_data is not None: + _set_job_runtime_data(result, submit_result.job_runtime_data) + if submit_result.set_running_status: + _set_job_status(job_model, result, JobStatus.RUNNING) + + # Convention: _get_result_* helpers merge the loaded job_model state with any pending # updates recorded in result.job_update_map. Always use these (not job_model.attr directly) # when the field may have been updated earlier in the same processing cycle. @@ -1604,14 +1600,3 @@ def _get_result_job_runtime_data( def _get_result_registered(job_model: JobModel, result: _ProcessResult) -> bool: return result.job_update_map.get("registered", job_model.registered) - - -def _apply_submit_job_to_runner_result( - job_model: JobModel, - result: _ProcessResult, - submit_result: _SubmitJobToRunnerResult, -) -> None: - if submit_result.job_runtime_data is not None: - _set_job_runtime_data(result, submit_result.job_runtime_data) - if submit_result.set_running_status: - _set_job_status(job_model, result, JobStatus.RUNNING) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index 750a5522b..8cee7f90b 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -1215,7 +1215,7 @@ async def test_running_job_disconnect_retries_then_terminates( "1h", 60 * 60, JobStatus.TERMINATING, - JobTerminationReason.TERMINATED_BY_SERVER, + JobTerminationReason.INACTIVITY_DURATION_EXCEEDED, 60 * 60, id="duration-exceeded-exactly", ), @@ -1223,7 +1223,7 @@ async def test_running_job_disconnect_retries_then_terminates( "1h", 60 * 60 + 1, JobStatus.TERMINATING, - JobTerminationReason.TERMINATED_BY_SERVER, + JobTerminationReason.INACTIVITY_DURATION_EXCEEDED, 60 * 60 + 1, id="duration-exceeded", ), @@ -1425,7 +1425,9 @@ async def test_gpu_utilization( await session.refresh(job) assert job.status == expected_status if expected_status == JobStatus.TERMINATING: - assert job.termination_reason == JobTerminationReason.TERMINATED_BY_SERVER + assert ( + job.termination_reason == JobTerminationReason.TERMINATED_DUE_TO_UTILIZATION_POLICY + ) assert job.termination_reason_message == ( "The job GPU utilization below 80% for 600 seconds" ) From 0e6c12385df09523235813f3164d049ce8b119c0 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 12:19:40 +0500 Subject: [PATCH 10/21] Unify jobs pipelines patterns --- .../background/pipeline_tasks/jobs_running.py | 2 +- .../pipeline_tasks/jobs_terminating.py | 20 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 224df44f9..f5819691a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -196,7 +196,7 @@ async def fetch(self, limit: int) -> list[JobRunningPipelineItem]: [JobStatus.PROVISIONING, JobStatus.PULLING, JobStatus.RUNNING] ), RunModel.status.not_in([RunStatus.TERMINATING]), - JobModel.last_processed_at < now - self._min_processing_interval, + JobModel.last_processed_at <= now - self._min_processing_interval, or_( JobModel.lock_expires_at.is_(None), JobModel.lock_expires_at < now, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 61c8adaee..5360a9d0a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -289,6 +289,11 @@ class _VolumeUpdateRow(TypedDict): last_job_processed_at: UpdateMapDateTime +@dataclass +class _UnregisterReplicaResult: + gateway_target: Optional[events.Target] # None = no gateway + + @dataclass class _ProcessResult: job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) @@ -296,8 +301,9 @@ class _ProcessResult: volume_update_rows: list[_VolumeUpdateRow] = field(default_factory=list) detached_volume_ids: set[uuid.UUID] = field(default_factory=set) unassign_event_message: Optional[str] = None - emit_unregister_replica_event: bool = False - unregister_gateway_target: Optional[events.Target] = None + replica_unregistration: Optional[_UnregisterReplicaResult] = ( + None # None = not unregistered yet + ) @dataclass @@ -536,10 +542,10 @@ async def _apply_process_result( ], ) - if result.emit_unregister_replica_event: + if result.replica_unregistration is not None: targets = [events.Target.from_model(job_model)] - if result.unregister_gateway_target is not None: - targets.append(result.unregister_gateway_target) + if result.replica_unregistration.gateway_target is not None: + targets.append(result.replica_unregistration.gateway_target) events.emit( session, "Service replica unregistered from receiving requests", @@ -689,10 +695,10 @@ async def _detach_job_volumes( async def _unregister_replica_and_update_result( result: _ProcessResult, job_model: JobModel ) -> None: - result.unregister_gateway_target = await _unregister_replica(job_model=job_model) + gateway_target = await _unregister_replica(job_model=job_model) if job_model.registered: result.job_update_map["registered"] = False - result.emit_unregister_replica_event = True + result.replica_unregistration = _UnregisterReplicaResult(gateway_target=gateway_target) async def _unregister_replica( From 748077c1e3f51e34a767ebd096475d12587ee622 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 12:42:54 +0500 Subject: [PATCH 11/21] Add context and apply to fleet pipeline --- .../background/pipeline_tasks/fleets.py | 272 ++++++++++-------- 1 file changed, 155 insertions(+), 117 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 3065c1e09..51f4230a8 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -196,106 +196,23 @@ def __init__( @sentry_utils.instrument_named_task("pipeline_tasks.FleetWorker.process") async def process(self, item: PipelineItem): - async with get_session_ctx() as session: - res = await session.execute( - select(FleetModel) - .where( - FleetModel.id == item.id, - FleetModel.lock_token == item.lock_token, - ) - .options(joinedload(FleetModel.project)) - .options( - selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)) - .joinedload(InstanceModel.jobs) - .load_only(JobModel.id), - ) - .options( - selectinload( - FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) - ).load_only(RunModel.status) - ) - ) - fleet_model = res.unique().scalar_one_or_none() - if fleet_model is None: - log_lock_token_mismatch(logger, item) - return - - # Lock instance only if consolidation is needed. - locked_instance_ids: set[uuid.UUID] = set() - consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model) - consolidation_instances = None - if consolidation_fleet_spec is not None: - consolidation_instances = await _lock_fleet_instances_for_consolidation( - session=session, - item=item, - ) - if consolidation_instances is None: - return - locked_instance_ids = {instance.id for instance in consolidation_instances} - + process_context = await _load_process_context(item) + if process_context is None: + return result = await _process_fleet( - fleet_model, - consolidation_fleet_spec=consolidation_fleet_spec, - consolidation_instances=consolidation_instances, - ) - fleet_update_map = _FleetUpdateMap() - fleet_update_map.update(result.fleet_update_map) - set_processed_update_map_fields(fleet_update_map) - set_unlock_update_map_fields(fleet_update_map) - instance_update_rows = _build_instance_update_rows( - result.instance_id_to_update_map, - unlock_instance_ids=locked_instance_ids, + process_context.fleet_model, + consolidation_fleet_spec=process_context.consolidation_fleet_spec, + consolidation_instances=process_context.consolidation_instances, ) + await _apply_process_result(item, process_context, result) - async with get_session_ctx() as session: - now = get_current_datetime() - resolve_now_placeholders(fleet_update_map, now=now) - resolve_now_placeholders(instance_update_rows, now=now) - res = await session.execute( - update(FleetModel) - .where( - FleetModel.id == fleet_model.id, - FleetModel.lock_token == fleet_model.lock_token, - ) - .values(**fleet_update_map) - .returning(FleetModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: - log_lock_token_changed_after_processing(logger, item) - if locked_instance_ids: - await _unlock_fleet_locked_instances( - session=session, - item=item, - locked_instance_ids=locked_instance_ids, - ) - # TODO: Clean up fleet. - return - - if fleet_update_map.get("deleted"): - await session.execute( - update(PlacementGroupModel) - .where(PlacementGroupModel.fleet_id == item.id) - .values(fleet_deleted=True) - ) - if instance_update_rows: - await session.execute( - update(InstanceModel), - instance_update_rows, - ) - if len(result.new_instance_creates) > 0: - await _create_missing_fleet_instances( - session=session, - fleet_model=fleet_model, - new_instance_creates=result.new_instance_creates, - ) - emit_fleet_status_change_event( - session=session, - fleet_model=fleet_model, - old_status=fleet_model.status, - new_status=fleet_update_map.get("status", fleet_model.status), - status_message=fleet_update_map.get("status_message", fleet_model.status_message), - ) + +@dataclass +class _ProcessContext: + fleet_model: FleetModel + consolidation_fleet_spec: Optional[FleetSpec] + consolidation_instances: Optional[list[InstanceModel]] + locked_instance_ids: set[uuid.UUID] = field(default_factory=set) class _FleetUpdateMap(ItemUpdateMap, total=False): @@ -318,6 +235,83 @@ class _InstanceUpdateMap(ItemUpdateMap, total=False): id: uuid.UUID +@dataclass +class _ProcessResult: + fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) + new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list) + + +class _NewInstanceCreate(TypedDict): + id: uuid.UUID + instance_num: int + + +@dataclass +class _MaintainNodesResult: + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) + new_instance_creates: list[_NewInstanceCreate] = field(default_factory=list) + changes_required: bool = False + + @property + def has_changes(self) -> bool: + return len(self.instance_id_to_update_map) > 0 or len(self.new_instance_creates) > 0 + + +async def _load_process_context(item: PipelineItem) -> Optional[_ProcessContext]: + async with get_session_ctx() as session: + fleet_model = await _refetch_locked_fleet(session=session, item=item) + if fleet_model is None: + log_lock_token_mismatch(logger, item) + return None + + consolidation_fleet_spec = _get_fleet_spec_if_ready_for_consolidation(fleet_model) + consolidation_instances = None + if consolidation_fleet_spec is not None: + consolidation_instances = await _lock_fleet_instances_for_consolidation( + session=session, + item=item, + ) + if consolidation_instances is None: + return None + + return _ProcessContext( + fleet_model=fleet_model, + consolidation_fleet_spec=consolidation_fleet_spec, + consolidation_instances=consolidation_instances, + locked_instance_ids=( + set() + if consolidation_instances is None + else {i.id for i in consolidation_instances} + ), + ) + + +async def _refetch_locked_fleet( + session: AsyncSession, + item: PipelineItem, +) -> Optional[FleetModel]: + res = await session.execute( + select(FleetModel) + .where( + FleetModel.id == item.id, + FleetModel.lock_token == item.lock_token, + ) + .options(joinedload(FleetModel.project)) + .options( + selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)) + .joinedload(InstanceModel.jobs) + .load_only(JobModel.id), + ) + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.status) + ) + ) + return res.unique().scalar_one_or_none() + + def _get_fleet_spec_if_ready_for_consolidation(fleet_model: FleetModel) -> Optional[FleetSpec]: if fleet_model.status == FleetStatus.TERMINATING: return None @@ -398,27 +392,71 @@ async def _lock_fleet_instances_for_consolidation( return locked_instance_models -@dataclass -class _ProcessResult: - fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) - instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) - new_instance_creates: list["_NewInstanceCreate"] = field(default_factory=list) - - -class _NewInstanceCreate(TypedDict): - id: uuid.UUID - instance_num: int - - -@dataclass -class _MaintainNodesResult: - instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) - new_instance_creates: list[_NewInstanceCreate] = field(default_factory=list) - changes_required: bool = False +async def _apply_process_result( + item: PipelineItem, + context: _ProcessContext, + result: "_ProcessResult", +) -> None: + fleet_update_map = _FleetUpdateMap() + fleet_update_map.update(result.fleet_update_map) + set_processed_update_map_fields(fleet_update_map) + set_unlock_update_map_fields(fleet_update_map) + instance_update_rows = _build_instance_update_rows( + result.instance_id_to_update_map, + unlock_instance_ids=context.locked_instance_ids, + ) - @property - def has_changes(self) -> bool: - return len(self.instance_id_to_update_map) > 0 or len(self.new_instance_creates) > 0 + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(fleet_update_map, now=now) + resolve_now_placeholders(instance_update_rows, now=now) + res = await session.execute( + update(FleetModel) + .where( + FleetModel.id == context.fleet_model.id, + FleetModel.lock_token == context.fleet_model.lock_token, + ) + .values(**fleet_update_map) + .returning(FleetModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + if context.locked_instance_ids: + await _unlock_fleet_locked_instances( + session=session, + item=item, + locked_instance_ids=context.locked_instance_ids, + ) + # TODO: Clean up fleet. + return + + if fleet_update_map.get("deleted"): + await session.execute( + update(PlacementGroupModel) + .where(PlacementGroupModel.fleet_id == context.fleet_model.id) + .values(fleet_deleted=True) + ) + if instance_update_rows: + await session.execute( + update(InstanceModel), + instance_update_rows, + ) + if len(result.new_instance_creates) > 0: + await _create_missing_fleet_instances( + session=session, + fleet_model=context.fleet_model, + new_instance_creates=result.new_instance_creates, + ) + emit_fleet_status_change_event( + session=session, + fleet_model=context.fleet_model, + old_status=context.fleet_model.status, + new_status=fleet_update_map.get("status", context.fleet_model.status), + status_message=fleet_update_map.get( + "status_message", context.fleet_model.status_message + ), + ) async def _process_fleet( From 0fdafde8588b235cca7c2c5b4db854ee19e615b5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 12:53:11 +0500 Subject: [PATCH 12/21] Describe Typical worker structure --- contributing/PIPELINES.md | 49 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/contributing/PIPELINES.md b/contributing/PIPELINES.md index 89037bb6b..e41949b4f 100644 --- a/contributing/PIPELINES.md +++ b/contributing/PIPELINES.md @@ -37,6 +37,55 @@ Brief checklist for implementing a new pipeline: 8. Register the pipeline in `PipelineManager` and hint fetch from services after commit via `pipeline_hinter.hint_fetch(Model.__name__)`. 9. Add minimum tests: fetch eligibility/order, successful unlock path, stale lock token path, and related lock contention retry path. +## Typical worker structure + +Most workers are easiest to reason about when `process()` is split into three phases: + +1. Load/refetch: open a short DB session, refetch the locked main row by `id + lock_token`, lock any required related rows, and gather any extra data needed for processing. +2. Process: do the heavy work outside DB sessions and build result objects or update maps instead of mutating detached ORM models. +3. Apply: open a short DB session, guard the main update by `id + lock_token`, resolve time placeholders, apply related updates, emit events, and unlock rows. + +A dedicated context object is often useful for the load step when the worker needs multiple loaded models, related lock metadata, or derived values that should be passed cleanly into processing and apply. For very small pipelines, a direct load -> process -> apply flow may still be clearer. + +Workers can share one context type and one apply function across all states even if the processing logic differs by state: + +```python +async def process(item): + context = await _load_process_context(item) + if context is None: + return + result = await _process_item(context) + await _apply_process_result(item, context, result) +``` + +Sometimes state-specific helpers are still the cleanest option, but they can still share a common apply phase if all states write results in the same general shape: + +```python +async def process(item): + if item.status == Status.PENDING: + context = await _load_pending_context(item) + elif item.status == Status.RUNNING: + context = await _load_running_context(item) + else: + return + if context is None: + return + result = await _process_item(context) + await _apply_process_result(item, context, result) +``` + +If different states have materially different write-side behavior, different apply paths are fine as well. This commonly happens when one state does a normal guarded update while another does delete-or-cleanup work with different related updates: + +```python +async def process(item): + if item.to_be_deleted: + await _process_to_be_deleted_item(item) + elif item.status == Status.SUBMITTED: + await _process_submitted_item(item) +``` + +It's ok not to force all pipelines into one exact shape. + ## Implementation patterns **Guarded apply by lock token** From c54c002a0ecf0d9fbb549744706c6dba5a945221 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 14:11:54 +0500 Subject: [PATCH 13/21] Move unlock/processed updates inside _apply_process_result --- .../pipeline_tasks/instances/__init__.py | 6 +- .../background/pipeline_tasks/jobs_running.py | 5 +- .../pipeline_tasks/jobs_terminating.py | 16 +-- .../background/pipeline_tasks/volumes.py | 134 +++++++----------- .../background/pipeline_tasks/test_volumes.py | 67 +++++++++ 5 files changed, 132 insertions(+), 96 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py index b5289e05e..67f017619 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/instances/__init__.py @@ -238,8 +238,7 @@ async def process(self, item: InstancePipelineItem): process_context = await _process_terminating_item(item) if process_context is None: return - set_processed_update_map_fields(process_context.result.instance_update_map) - set_unlock_update_map_fields(process_context.result.instance_update_map) + await _apply_process_result( item=item, instance_model=process_context.instance_model, @@ -376,6 +375,9 @@ async def _apply_process_result( instance_model: InstanceModel, result: ProcessResult, ) -> None: + set_processed_update_map_fields(result.instance_update_map) + set_unlock_update_map_fields(result.instance_update_map) + async with get_session_ctx() as session: if result.health_check_create is not None: session.add(InstanceHealthCheckModel(**result.health_check_create)) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index f5819691a..1c61458dd 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -260,8 +260,6 @@ async def process(self, item: JobRunningPipelineItem): return result = await _process_running_job(context=context) - set_processed_update_map_fields(result.job_update_map) - set_unlock_update_map_fields(result.job_update_map) await _apply_process_result( item=item, job_model=context.job_model, @@ -734,6 +732,9 @@ async def _apply_process_result( job_model: JobModel, result: _ProcessResult, ) -> None: + set_processed_update_map_fields(result.job_update_map) + set_unlock_update_map_fields(result.job_update_map) + async with get_session_ctx() as session: now = get_current_datetime() resolve_now_placeholders(result.job_update_map, now=now) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 5360a9d0a..727120841 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -251,14 +251,6 @@ async def process(self, item: JobTerminatingPipelineItem): instance_model=get_or_error(instance_model), ) - set_processed_update_map_fields(result.job_update_map) - set_unlock_update_map_fields(result.job_update_map) - if instance_model is not None: - if result.instance_update_map is None: - result.instance_update_map = _InstanceUpdateMap() - instance_update_map = result.instance_update_map - set_processed_update_map_fields(instance_update_map) - set_unlock_update_map_fields(instance_update_map) await _apply_process_result( item=item, job_model=job_model, @@ -439,6 +431,14 @@ async def _apply_process_result( instance_model: Optional[InstanceModel], result: _ProcessResult, ) -> None: + set_processed_update_map_fields(result.job_update_map) + set_unlock_update_map_fields(result.job_update_map) + if instance_model is not None and result.instance_update_map is None: + result.instance_update_map = _InstanceUpdateMap() + if result.instance_update_map is not None: + set_processed_update_map_fields(result.instance_update_map) + set_unlock_update_map_fields(result.instance_update_map) + async with get_session_ctx() as session: now = get_current_datetime() related_instance_lock_owner = _get_related_instance_lock_owner(item.id) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index 89eea7d92..406c216c7 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass, field from datetime import timedelta -from typing import Sequence +from typing import Optional, Sequence from sqlalchemy import or_, select, update from sqlalchemy.orm import joinedload, load_only @@ -202,13 +202,22 @@ def __init__( @sentry_utils.instrument_named_task("pipeline_tasks.VolumeWorker.process") async def process(self, item: VolumePipelineItem): + volume_model = await _refetch_locked_volume(item) + if volume_model is None: + log_lock_token_mismatch(logger, item) + return + if item.to_be_deleted: - await _process_to_be_deleted_item(item) + result = await _process_to_be_deleted_volume(volume_model) elif item.status == VolumeStatus.SUBMITTED: - await _process_submitted_item(item) + result = await _process_submitted_volume(volume_model) + else: + return + await _apply_process_result(item=item, volume_model=volume_model, result=result) -async def _process_submitted_item(item: VolumePipelineItem): + +async def _refetch_locked_volume(item: VolumePipelineItem) -> Optional[VolumeModel]: async with get_session_ctx() as session: res = await session.execute( select(VolumeModel) @@ -217,7 +226,7 @@ async def _process_submitted_item(item: VolumePipelineItem): VolumeModel.lock_token == item.lock_token, ) .options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends)) - .options(joinedload(VolumeModel.user)) + .options(joinedload(VolumeModel.user).load_only(UserModel.name)) .options( joinedload(VolumeModel.attachments) .joinedload(VolumeAttachmentModel.instance) @@ -225,13 +234,16 @@ async def _process_submitted_item(item: VolumePipelineItem): .load_only(FleetModel.name) ) ) - volume_model = res.unique().scalar_one_or_none() - if volume_model is None: - log_lock_token_mismatch(logger, item) - return + return res.unique().scalar_one_or_none() + - result = await _process_submitted_volume(volume_model) - update_map = result.update_map +async def _apply_process_result( + item: VolumePipelineItem, + volume_model: VolumeModel, + result: "_ProcessResult", +): + update_map = _VolumeUpdateMap() + update_map.update(result.update_map) set_processed_update_map_fields(update_map) set_unlock_update_map_fields(update_map) @@ -249,15 +261,25 @@ async def _process_submitted_item(item: VolumePipelineItem): updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: log_lock_token_changed_after_processing(logger, item) - # TODO: Clean up volume. + if item.status == VolumeStatus.SUBMITTED: + # TODO: Clean up volume. + pass return - emit_volume_status_change_event( - session=session, - volume_model=volume_model, - old_status=volume_model.status, - new_status=update_map.get("status", volume_model.status), - status_message=update_map.get("status_message", volume_model.status_message), - ) + if item.to_be_deleted: + events.emit( + session, + "Volume deleted", + actor=events.SystemActor(), + targets=[events.Target.from_model(volume_model)], + ) + else: + emit_volume_status_change_event( + session=session, + volume_model=volume_model, + old_status=volume_model.status, + new_status=update_map.get("status", volume_model.status), + status_message=update_map.get("status_message", volume_model.status_message), + ) class _VolumeUpdateMap(ItemUpdateMap, total=False): @@ -269,11 +291,11 @@ class _VolumeUpdateMap(ItemUpdateMap, total=False): @dataclass -class _SubmittedResult: +class _ProcessResult: update_map: _VolumeUpdateMap = field(default_factory=_VolumeUpdateMap) -async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResult: +async def _process_submitted_volume(volume_model: VolumeModel) -> _ProcessResult: volume = volume_model_to_volume(volume_model) try: backend = await backends_services.get_project_backend_by_type_or_error( @@ -287,7 +309,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu volume.name, volume.configuration.backend.value, ) - return _SubmittedResult( + return _ProcessResult( update_map={ "status": VolumeStatus.FAILED, "status_message": "Backend not available", @@ -314,7 +336,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu status_message = f"Backend error: {repr(e)}" if len(e.args) > 0: status_message = str(e.args[0]) - return _SubmittedResult( + return _ProcessResult( update_map={ "status": VolumeStatus.FAILED, "status_message": status_message, @@ -322,7 +344,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu ) except Exception as e: logger.exception("Got exception when creating volume %s", volume_model.name) - return _SubmittedResult( + return _ProcessResult( update_map={ "status": VolumeStatus.FAILED, "status_message": f"Unexpected error: {repr(e)}", @@ -332,7 +354,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu logger.info("Added new volume %s", volume_model.name) # Provisioned volumes marked as active since they become available almost immediately in AWS # TODO: Consider checking volume state - return _SubmittedResult( + return _ProcessResult( update_map={ "status": VolumeStatus.ACTIVE, "volume_provisioning_data": vpd.json(), @@ -340,63 +362,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResu ) -async def _process_to_be_deleted_item(item: VolumePipelineItem): - async with get_session_ctx() as session: - res = await session.execute( - select(VolumeModel) - .where( - VolumeModel.id == item.id, - VolumeModel.lock_token == item.lock_token, - ) - .options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends)) - .options(joinedload(VolumeModel.user).load_only(UserModel.name)) - .options( - joinedload(VolumeModel.attachments) - .joinedload(VolumeAttachmentModel.instance) - .joinedload(InstanceModel.fleet) - .load_only(FleetModel.name) - ) - ) - volume_model = res.unique().scalar_one_or_none() - if volume_model is None: - log_lock_token_mismatch(logger, item) - return - - result = await _process_to_be_deleted_volume(volume_model) - update_map = _VolumeUpdateMap() - update_map.update(result.update_map) - set_processed_update_map_fields(update_map) - set_unlock_update_map_fields(update_map) - async with get_session_ctx() as session: - now = get_current_datetime() - resolve_now_placeholders(update_map, now=now) - res = await session.execute( - update(VolumeModel) - .where( - VolumeModel.id == volume_model.id, - VolumeModel.lock_token == volume_model.lock_token, - ) - .values(**update_map) - .returning(VolumeModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: - log_lock_token_changed_after_processing(logger, item) - return - events.emit( - session, - "Volume deleted", - actor=events.SystemActor(), - targets=[events.Target.from_model(volume_model)], - ) - - -@dataclass -class _ProcessToBeDeletedResult: - update_map: _VolumeUpdateMap = field(default_factory=_VolumeUpdateMap) - - -async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _ProcessToBeDeletedResult: +async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _ProcessResult: volume = volume_model_to_volume(volume_model) if volume.external: return _get_deleted_result() @@ -437,8 +403,8 @@ async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _ProcessTo return _get_deleted_result() -def _get_deleted_result() -> _ProcessToBeDeletedResult: - return _ProcessToBeDeletedResult( +def _get_deleted_result() -> _ProcessResult: + return _ProcessResult( update_map={ "deleted": True, "deleted_at": NOW_PLACEHOLDER, diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py index 63dfaaa45..ad1cbb7cf 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_volumes.py @@ -9,6 +9,7 @@ from dstack._internal.core.errors import BackendError, BackendNotAvailable from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.volumes import VolumeProvisioningData, VolumeStatus +from dstack._internal.server.background.pipeline_tasks import volumes as volumes_pipeline from dstack._internal.server.background.pipeline_tasks.volumes import ( VolumeFetcher, VolumePipeline, @@ -349,6 +350,72 @@ async def test_marks_volume_failed_if_backend_returns_error( assert len(events) == 1 assert events[0].message == "Volume status changed SUBMITTED -> FAILED (Some error)" + async def test_skips_processing_if_lock_token_changed_before_refetch( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + item = _volume_to_pipeline_item(volume) + + volume.lock_token = uuid.uuid4() + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes._process_submitted_volume" + ) as process_volume_mock: + await worker.process(item) + process_volume_mock.assert_not_awaited() + + await session.refresh(volume) + assert volume.status == VolumeStatus.SUBMITTED + events = await list_events(session) + assert len(events) == 0 + + async def test_skips_apply_if_lock_token_changed_after_processing( + self, test_db, session: AsyncSession, worker: VolumeWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.SUBMITTED, + ) + volume.lock_token = uuid.uuid4() + volume.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + async def _change_lock_token_and_return_result(_volume_model: VolumeModel): + volume.lock_token = uuid.uuid4() + await session.commit() + return volumes_pipeline._ProcessResult( + update_map={ + "status": VolumeStatus.ACTIVE, + } + ) + + with patch( + "dstack._internal.server.background.pipeline_tasks.volumes._process_submitted_volume", + side_effect=_change_lock_token_and_return_result, + ) as process_volume_mock: + await worker.process(_volume_to_pipeline_item(volume)) + process_volume_mock.assert_awaited_once() + + await session.refresh(volume) + assert volume.status == VolumeStatus.SUBMITTED + events = await list_events(session) + assert len(events) == 0 + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) From f6987ece3905246b6433281f886815cbbb711985 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 15:12:15 +0500 Subject: [PATCH 14/21] Add FIXME: Race condition when checking len(fleet_model.instances) == 0 --- .../server/background/scheduled_tasks/submitted_jobs.py | 3 +++ src/dstack/_internal/server/services/runs/__init__.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index 729ded205..df30bbd9c 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -848,6 +848,9 @@ async def _run_jobs_on_new_instances( ) continue finally: + # FIXME: Race condition when checking len(fleet_model.instances) == 0 + # if provisioning independent jobs in a cluster fleet. + # Leads to placement groups being marked for deletion while still in use. if fleet_model is not None and len(fleet_model.instances) == 0: # Clean up placement groups that did not end up being used. for pg in placement_group_models: diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index f8aa3f288..4655f600f 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -987,9 +987,8 @@ def _get_job_submission_cost(job_submission: JobSubmission) -> float: async def process_terminating_run(session: AsyncSession, run_model: RunModel): """ - Used by both `process_runs` and `stop_run` to process a TERMINATING run. Stops the jobs gracefully and marks them as TERMINATING. - Jobs should be terminated by `process_terminating_jobs`. + Jobs then should be terminated by `process_terminating_jobs`. When all jobs are terminated, assigns a finished status to the run. Caller must acquire the lock on run. """ From 9b0b298a40dca7fa1bf2dbc24bcd57703b8197a5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 15:44:46 +0500 Subject: [PATCH 15/21] Fix stale fleet_model read --- .../scheduled_tasks/submitted_jobs.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index df30bbd9c..f5c28b9ab 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -409,13 +409,14 @@ async def _process_submitted_job( await session.commit() return - master_instance_provisioning_data = ( - await _fetch_fleet_with_master_instance_provisioning_data( - exit_stack=exit_stack, - session=session, - fleet_model=fleet_model, - job=job, - ) + ( + fleet_model, + master_instance_provisioning_data, + ) = await _fetch_fleet_with_master_instance_provisioning_data( + exit_stack=exit_stack, + session=session, + fleet_model=fleet_model, + job=job, ) master_provisioning_data = ( master_job_provisioning_data or master_instance_provisioning_data @@ -573,7 +574,7 @@ async def _fetch_fleet_with_master_instance_provisioning_data( session: AsyncSession, fleet_model: Optional[FleetModel], job: Job, -) -> Optional[JobProvisioningData]: +) -> tuple[Optional[FleetModel], Optional[JobProvisioningData]]: # TODO: When submitted-jobs provisioning moves to pipelines, stop inferring the # cluster master from loaded fleet instances here. Resolve the current master via # FleetModel.current_master_instance_id so jobs follow the same master election @@ -626,7 +627,7 @@ async def _fetch_fleet_with_master_instance_provisioning_data( fleet_model=fleet_model, fleet_spec=fleet.spec, ) - return master_instance_provisioning_data + return fleet_model, master_instance_provisioning_data async def _assign_job_to_fleet_instance( @@ -848,9 +849,6 @@ async def _run_jobs_on_new_instances( ) continue finally: - # FIXME: Race condition when checking len(fleet_model.instances) == 0 - # if provisioning independent jobs in a cluster fleet. - # Leads to placement groups being marked for deletion while still in use. if fleet_model is not None and len(fleet_model.instances) == 0: # Clean up placement groups that did not end up being used. for pg in placement_group_models: From 8bdba43fb78c143ce03da06601de6e3e8b9dbb83 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 16:46:39 +0500 Subject: [PATCH 16/21] Clean up pipeline tests --- .../scheduled_tasks/test_running_jobs.py | 247 ++++++++---------- .../scheduled_tasks/test_submitted_jobs.py | 8 +- .../scheduled_tasks/test_terminating_jobs.py | 7 +- 3 files changed, 113 insertions(+), 149 deletions(-) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index 80a18dc11..f28b10976 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -112,7 +112,7 @@ class TestProcessRunningJobs: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( - self, test_db, session: AsyncSession + self, test_db, session: AsyncSession, ssh_tunnel_mock: Mock, runner_client_mock: Mock ): project = await create_project(session=session) user = await create_user(session=session) @@ -141,37 +141,20 @@ async def test_leaves_provisioning_job_unchanged_if_runner_not_alive( instance=instance, instance_assigned=True, ) - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - patch( - "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_file_archives", - new_callable=AsyncMock, - ) as get_job_file_archives_mock, - patch( - "dstack._internal.server.background.scheduled_tasks.running_jobs._get_job_code", - new_callable=AsyncMock, - ) as get_job_code_mock, - patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock, - ): + runner_client_mock.healthcheck.return_value = None + with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock: datetime_mock.return_value = datetime(2023, 1, 2, 5, 12, 30, 10, tzinfo=timezone.utc) - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.healthcheck = Mock() - runner_client_mock.healthcheck.return_value = None await process_running_jobs() - SSHTunnelMock.assert_called_once() - runner_client_mock.healthcheck.assert_called_once() - get_job_file_archives_mock.assert_not_awaited() - get_job_code_mock.assert_not_awaited() + ssh_tunnel_mock.assert_called() + runner_client_mock.healthcheck.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PROVISIONING @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_runs_provisioning_job(self, test_db, session: AsyncSession): + async def test_runs_provisioning_job( + self, test_db, session: AsyncSession, ssh_tunnel_mock: Mock, runner_client_mock: Mock + ): project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo( @@ -199,27 +182,16 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): instance=instance, instance_assigned=True, ) - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.healthcheck.return_value = HealthcheckResponse( - service="dstack-runner", version="0.0.1.dev2" - ) - runner_client_mock.run_job.return_value = JobInfoResponse( - working_dir="/dstack/run", username="dstack" - ) - await process_running_jobs() - assert SSHTunnelMock.call_count == 2 - assert runner_client_mock.healthcheck.call_count == 2 - runner_client_mock.submit_job.assert_called_once() - runner_client_mock.upload_code.assert_called_once() - runner_client_mock.run_job.assert_called_once() + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) + await process_running_jobs() + ssh_tunnel_mock.assert_called() + assert runner_client_mock.healthcheck.call_count == 2 + runner_client_mock.submit_job.assert_called_once() + runner_client_mock.upload_code.assert_called_once() + runner_client_mock.run_job.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.RUNNING jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) assert jrd.working_dir == "/dstack/run" @@ -227,7 +199,14 @@ async def test_runs_provisioning_job(self, test_db, session: AsyncSession): @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_updates_running_job(self, test_db, session: AsyncSession, tmp_path: Path): + async def test_running_job_updates_runner_timestamp( + self, + test_db, + session: AsyncSession, + tmp_path: Path, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, + ): project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo( @@ -254,46 +233,63 @@ async def test_updates_running_job(self, test_db, session: AsyncSession, tmp_pat instance=instance, instance_assigned=True, ) - last_processed_at = job.last_processed_at - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - patch.object(server_settings, "SERVER_DIR_PATH", tmp_path), - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.pull.return_value = PullResponse( - job_states=[JobStateEvent(timestamp=1, state=JobStatus.RUNNING)], - job_logs=[], - runner_logs=[], - last_updated=1, - ) + runner_client_mock.pull.return_value = PullResponse( + job_states=[JobStateEvent(timestamp=1, state=JobStatus.RUNNING)], + job_logs=[], + runner_logs=[], + last_updated=1, + ) + with patch.object(server_settings, "SERVER_DIR_PATH", tmp_path): await process_running_jobs() - SSHTunnelMock.assert_called_once() + ssh_tunnel_mock.assert_called() await session.refresh(job) - assert job is not None assert job.status == JobStatus.RUNNING assert job.runner_timestamp == 1 - job.last_processed_at = last_processed_at - await session.commit() - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.pull.return_value = PullResponse( - job_states=[JobStateEvent(timestamp=1, state=JobStatus.DONE, exit_status=0)], - job_logs=[], - runner_logs=[], - last_updated=2, - ) - await process_running_jobs() - SSHTunnelMock.assert_called_once() + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_running_job_terminates_when_done_by_runner( + self, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, + ) + runner_client_mock.pull.return_value = PullResponse( + job_states=[JobStateEvent(timestamp=1, state=JobStatus.DONE, exit_status=0)], + job_logs=[], + runner_logs=[], + last_updated=2, + ) + await process_running_jobs() + ssh_tunnel_mock.assert_called() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.DONE_BY_RUNNER assert job.exit_status == 0 @@ -389,7 +385,6 @@ async def test_provisioning_shim_with_volumes( instance_id=job_provisioning_data.instance_id, ) await session.refresh(job) - assert job is not None assert job.status == JobStatus.PULLING @pytest.mark.asyncio @@ -436,14 +431,13 @@ async def test_pulling_shim( await process_running_jobs() - assert ssh_tunnel_mock.call_count == 3 + ssh_tunnel_mock.assert_called() shim_client_mock.get_task.assert_called_once() assert runner_client_mock.healthcheck.call_count == 2 runner_client_mock.submit_job.assert_called_once() runner_client_mock.upload_code.assert_called_once() runner_client_mock.run_job.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.RUNNING jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) assert jrd.ports == { @@ -509,7 +503,6 @@ async def test_pulling_shim_port_mapping_not_ready( get_job_file_archives_mock.assert_not_awaited() get_job_code_mock.assert_not_awaited() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PULLING @pytest.mark.asyncio @@ -572,7 +565,6 @@ async def test_pulling_shim_runner_not_ready( get_job_code_mock.assert_not_awaited() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PULLING @pytest.mark.asyncio @@ -655,7 +647,6 @@ def assert_submit_job_to_runner(_, __, job_runtime_data, **kwargs): submit_job_to_runner_mock.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PULLING jrd = JobRuntimeData.__response__.parse_raw(job.job_runtime_data) assert jrd.ports == expected_ports @@ -694,10 +685,9 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession): ): SSHTunnelMock.side_effect = SSHError await process_running_jobs() - assert SSHTunnelMock.call_count == 3 + SSHTunnelMock.assert_called() await session.refresh(job) events = await list_events(session) - assert job is not None assert job.disconnected_at is not None assert job.status == JobStatus.PULLING assert len(events) == 1 @@ -709,7 +699,7 @@ async def test_pulling_shim_failed(self, test_db, session: AsyncSession): ): SSHTunnelMock.side_effect = SSHError await process_running_jobs() - assert SSHTunnelMock.call_count == 3 + SSHTunnelMock.assert_called() await session.refresh(job) assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.INSTANCE_UNREACHABLE @@ -832,6 +822,8 @@ async def test_inactivity_duration( self, test_db, session: AsyncSession, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, inactivity_duration, no_connections_secs: Optional[int], expected_status: JobStatus, @@ -874,23 +866,16 @@ async def test_inactivity_duration( instance=instance, instance_assigned=True, ) - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.pull.return_value = PullResponse( - job_states=[], - job_logs=[], - runner_logs=[], - last_updated=0, - no_connections_secs=no_connections_secs, - ) - await process_running_jobs() - SSHTunnelMock.assert_called_once() - runner_client_mock.pull.assert_called_once() + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + no_connections_secs=no_connections_secs, + ) + await process_running_jobs() + ssh_tunnel_mock.assert_called() + runner_client_mock.pull.assert_called_once() await session.refresh(job) assert job.status == expected_status assert job.termination_reason == expected_termination_reason @@ -937,6 +922,8 @@ async def test_gpu_utilization( self, test_db, session: AsyncSession, + ssh_tunnel_mock: Mock, + runner_client_mock: Mock, samples: list[tuple[datetime, int]], expected_status: JobStatus, ) -> None: @@ -989,23 +976,16 @@ async def test_gpu_utilization( gpus_memory_usage_bytes=[1024, 1024], gpus_util_percent=[gpu_util, 100], ) - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.pull.return_value = PullResponse( - job_states=[], - job_logs=[], - runner_logs=[], - last_updated=0, - no_connections_secs=0, - ) - await process_running_jobs() - SSHTunnelMock.assert_called_once() - runner_client_mock.pull.assert_called_once() + runner_client_mock.pull.return_value = PullResponse( + job_states=[], + job_logs=[], + runner_logs=[], + last_updated=0, + no_connections_secs=0, + ) + await process_running_jobs() + ssh_tunnel_mock.assert_called() + runner_client_mock.pull.assert_called_once() await session.refresh(job) assert job.status == expected_status if expected_status == JobStatus.TERMINATING: @@ -1019,7 +999,9 @@ async def test_gpu_utilization( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession): + async def test_master_job_waits_for_workers( + self, test_db, session: AsyncSession, ssh_tunnel_mock: Mock, runner_client_mock: Mock + ): project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo( @@ -1069,6 +1051,9 @@ async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession job_num=1, last_processed_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), ) + runner_client_mock.run_job.return_value = JobInfoResponse( + working_dir="/dstack/run", username="dstack" + ) await process_running_jobs() await session.refresh(master_job) assert master_job.status == JobStatus.PROVISIONING @@ -1076,17 +1061,7 @@ async def test_master_job_waits_for_workers(self, test_db, session: AsyncSession # To guarantee master_job is processed next master_job.last_processed_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() - with ( - patch("dstack._internal.server.services.runner.ssh.SSHTunnel"), - patch( - "dstack._internal.server.services.runner.client.RunnerClient" - ) as RunnerClientMock, - ): - runner_client_mock = RunnerClientMock.return_value - runner_client_mock.healthcheck.return_value = HealthcheckResponse( - service="dstack-runner", version="0.0.1.dev2" - ) - await process_running_jobs() + await process_running_jobs() await session.refresh(master_job) assert master_job.status == JobStatus.RUNNING diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py index db75bbf53..775569d99 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_submitted_jobs.py @@ -87,7 +87,6 @@ async def test_fails_job_when_no_backends(self, test_db, session: AsyncSession): ) await process_submitted_jobs() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY @@ -145,7 +144,6 @@ async def test_provisions_job( backend_mock.compute.return_value.run_job.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PROVISIONING @pytest.mark.asyncio @@ -194,7 +192,6 @@ async def test_fails_job_when_privileged_true_and_no_offers_with_create_instance backend_mock.compute.return_value.run_job.assert_not_called() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY @@ -244,7 +241,6 @@ async def test_fails_job_when_instance_mounts_and_no_offers_with_create_instance backend_mock.compute.return_value.run_job.assert_not_called() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY @@ -291,7 +287,6 @@ async def test_provisions_job_with_optional_instance_volume_not_attached( await process_submitted_jobs() await session.refresh(job) - assert job is not None assert job.status == JobStatus.PROVISIONING @pytest.mark.asyncio @@ -328,13 +323,12 @@ async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession): await process_submitted_jobs() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_assignes_job_to_instance(self, test_db, session: AsyncSession): + async def test_assigns_job_to_instance(self, test_db, session: AsyncSession): project = await create_project(session) user = await create_user(session) repo = await create_repo( diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py index d2b4d2d31..da58e9708 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_terminating_jobs.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import timedelta, timezone from unittest.mock import Mock, patch import pytest @@ -57,7 +57,6 @@ async def test_terminates_job(self, session: AsyncSession): run=run, status=JobStatus.TERMINATING, termination_reason=JobTerminationReason.TERMINATED_BY_USER, - submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), job_provisioning_data=job_provisioning_data, instance=instance, ) @@ -70,7 +69,6 @@ async def test_terminates_job(self, session: AsyncSession): SSHTunnelMock.assert_called_once() shim_client_mock.healthcheck.assert_called_once() await session.refresh(job) - assert job is not None assert job.status == JobStatus.TERMINATED async def test_detaches_job_volumes(self, session: AsyncSession): @@ -103,7 +101,6 @@ async def test_detaches_job_volumes(self, session: AsyncSession): run=run, status=JobStatus.TERMINATING, termination_reason=JobTerminationReason.TERMINATED_BY_USER, - submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), job_provisioning_data=job_provisioning_data, instance=instance, ) @@ -149,7 +146,6 @@ async def test_force_detaches_job_volumes(self, session: AsyncSession): run=run, status=JobStatus.TERMINATING, termination_reason=JobTerminationReason.TERMINATED_BY_USER, - submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), job_provisioning_data=job_provisioning_data, instance=instance, ) @@ -304,7 +300,6 @@ async def test_detaches_job_volumes_on_shared_instance(self, session: AsyncSessi run=run, status=JobStatus.TERMINATING, termination_reason=JobTerminationReason.TERMINATED_BY_USER, - submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), job_provisioning_data=job_provisioning_data, job_runtime_data=get_job_runtime_data(volume_names=["vol-1"]), instance=instance, From b29dda0227e89bc2bccc82d961d02775e99c6215 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 16:58:41 +0500 Subject: [PATCH 17/21] Fix empty fleet select --- .../server/background/scheduled_tasks/submitted_jobs.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index f5c28b9ab..a82614d74 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta from typing import List, Optional, Union -from sqlalchemy import func, or_, select +from sqlalchemy import exists, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import ( contains_eager, @@ -596,12 +596,11 @@ async def _fetch_fleet_with_master_instance_provisioning_data( await sqlite_commit(session) res = await session.execute( select(FleetModel) - .outerjoin(FleetModel.instances) .where( FleetModel.id == fleet_model.id, - or_( - InstanceModel.id.is_(None), - InstanceModel.deleted == True, + ~exists().where( + InstanceModel.fleet_id == fleet_model.id, + InstanceModel.deleted == False, ), ) .with_for_update(key_share=True, of=FleetModel) From 444c5faba517c5682c3230cf18bdbf14f9357a57 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 17:16:29 +0500 Subject: [PATCH 18/21] Fix missing az restriction for clusters in submitted_jobs --- .../background/scheduled_tasks/submitted_jobs.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index a82614d74..bf94bf6a9 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -105,7 +105,10 @@ ) from dstack._internal.server.services.locking import get_locker, string_to_lock_id from dstack._internal.server.services.logging import fmt -from dstack._internal.server.services.offers import get_offers_by_requirements +from dstack._internal.server.services.offers import ( + get_instance_offer_with_restricted_az, + get_offers_by_requirements, +) from dstack._internal.server.services.placement import ( find_or_create_suitable_placement_group, get_fleet_placement_group_models, @@ -782,6 +785,13 @@ async def _run_jobs_on_new_instances( offer_volumes = _get_offer_volumes(volumes, offer) job_configurations = [JobConfiguration(job=j, volumes=offer_volumes) for j in jobs] compute = backend.compute() + if master_job_provisioning_data is not None: + # `get_offers_by_requirements()` already restricts backend and region from the master. + # Availability zone still has to be narrowed per offer. + offer = get_instance_offer_with_restricted_az( + instance_offer=offer, + master_job_provisioning_data=master_job_provisioning_data, + ) if ( fleet_model is not None and len(fleet_model.instances) == 0 From 0923340b45582ba423e45044d12838f7032274f6 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 17:22:16 +0500 Subject: [PATCH 19/21] Add deprecated note --- .../_internal/server/background/scheduled_tasks/running_jobs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index ea3d53973..6ceebf9e3 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -102,6 +102,8 @@ JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2) +# NOTE: This scheduled task is going to be deprecated in favor of `JobRunningPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/jobs_running.py` in sync. async def process_running_jobs(batch_size: int = 1): tasks = [] for _ in range(batch_size): From 9d8c6f276b333939b6dddf35e0a62ad723221d01 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 17:42:36 +0500 Subject: [PATCH 20/21] Pass instance_project_ssh_private_key --- .../background/pipeline_tasks/jobs_running.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 1c61458dd..eda36d811 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -929,19 +929,24 @@ async def _register_service_replica( session, context.run_model.gateway_id ) gateway_target = events.Target.from_model(gateway_model) + assert context.job_model.instance is not None + instance_project_ssh_private_key = None + if context.job_model.project_id != context.job_model.instance.project_id: + instance_project_ssh_private_key = context.job_model.instance.project.ssh_private_key + # JobRuntimeData might change on PULLING -> RUNNING path + # so we must update job_submission with the result value. + job_submission = context.job_submission.copy(deep=True) + job_submission.job_runtime_data = _get_result_job_runtime_data(context.job_model, result) try: logger.debug( "%s: registering replica for service %s", fmt(context.job_model), context.run.id.hex ) - # JobRuntimeData might change on PULLING -> RUNNING path - # so we must update job_submission with the result value. - job_submission = context.job_submission.copy(deep=True) - job_submission.job_runtime_data = _get_result_job_runtime_data(context.job_model, result) async with conn.client() as gateway_client: await gateway_client.register_replica( run=context.run, job_spec=JobSpec.__response__.parse_raw(context.job_model.job_spec_data), job_submission=job_submission, + instance_project_ssh_private_key=instance_project_ssh_private_key, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, ) From fdd38ddadbb1c49cb25a21e4b534d3cfb1b58e0e Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Thu, 12 Mar 2026 17:49:36 +0500 Subject: [PATCH 21/21] Fix missing pipeline tests --- .../pipeline_tasks/test_running_jobs.py | 171 +++++++++++++++++- 1 file changed, 170 insertions(+), 1 deletion(-) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index 8cee7f90b..a52924a55 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Optional -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch import pytest from freezegun import freeze_time @@ -21,6 +21,7 @@ ProbeConfig, ServiceConfiguration, ) +from dstack._internal.core.models.gateways import GatewayStatus from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy from dstack._internal.core.models.runs import ( @@ -52,6 +53,11 @@ from dstack._internal.server.services.runner.ssh import SSHTunnel from dstack._internal.server.services.volumes import volume_model_to_volume from dstack._internal.server.testing.common import ( + create_backend, + create_export, + create_fleet, + create_gateway, + create_gateway_compute, create_instance, create_job, create_job_metrics_point, @@ -1635,6 +1641,169 @@ async def test_registers_service_replica_only_after_probes_pass( assert not job.registered assert not events + async def test_registers_service_replica_in_gateway( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + mock_gateway_connection: AsyncMock, + ): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + name="test-gateway", + wildcard_domain="example.com", + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, image="ubuntu", gateway="test-gateway" + ), + ), + gateway=gateway, + ) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + fleet=fleet, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + await _process_job(session, worker, job) + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.registered + events = await list_events(session) + assert {event.message for event in events} == { + "Job status changed PULLING -> RUNNING", + "Service replica registered to receive requests", + } + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value.register_replica.assert_called_once_with( + run=ANY, + job_spec=ANY, + job_submission=ANY, + instance_project_ssh_private_key=None, + ssh_head_proxy=None, + ssh_head_proxy_private_key=None, + ) + + async def test_registers_service_replica_in_gateway_when_running_on_imported_instance( + self, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + mock_gateway_connection: AsyncMock, + ): + user = await create_user(session=session) + exporter_project = await create_project( + session=session, name="exporter", owner=user, ssh_private_key="exporter-private-key" + ) + importer_project = await create_project(session=session, name="importer", owner=user) + fleet = await create_fleet(session=session, project=exporter_project) + instance = await create_instance( + session=session, + project=exporter_project, + status=InstanceStatus.BUSY, + fleet=fleet, + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + repo = await create_repo(session=session, project_id=importer_project.id) + backend = await create_backend(session=session, project_id=importer_project.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + gateway = await create_gateway( + session=session, + project_id=importer_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + name="test-gateway", + wildcard_domain="example.com", + ) + run = await create_run( + session=session, + project=importer_project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, image="ubuntu", gateway="test-gateway" + ), + ), + gateway=gateway, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + await _process_job(session, worker, job) + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.registered + events = await list_events(session) + assert {event.message for event in events} == { + "Job status changed PULLING -> RUNNING", + "Service replica registered to receive requests", + } + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value.register_replica.assert_called_once_with( + run=ANY, + job_spec=ANY, + job_submission=ANY, + instance_project_ssh_private_key="exporter-private-key", + ssh_head_proxy=None, + ssh_head_proxy_private_key=None, + ) + async def test_apply_skips_probe_insert_when_lock_token_changes_after_processing( self, test_db,