diff --git a/pyproject.toml b/pyproject.toml index d8acee223c..b601fe6466 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ markers = [ ] env = [ "DSTACK_CLI_RICH_FORCE_TERMINAL=0", + "DSTACK_SSHPROXY_API_TOKEN=test-token", ] filterwarnings = [ # testcontainers modules use deprecated decorators – nothing we can do: @@ -142,6 +143,7 @@ dev = [ "pytest-httpbin>=2.1.0", "pytest-socket>=0.7.0", "pytest-env>=1.1.0", + "pytest-unordered>=0.7.0", "httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3 "requests-mock>=1.12.1", "openai>=1.68.2", diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index 21ac9228c9..962cf4b5c0 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -130,8 +130,8 @@ def run_job( commands = get_docker_commands( [run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()] ) - # There is a one jump pod per Kubernetes backend that is used - # as an ssh proxy jump to connect to all other services in Kubernetes. + # There is one jump pod per project that is used as an ssh proxy jump to connect + # to all job pods of the same project. # The service is created here and configured later in update_provisioning_data() jump_pod_name = f"dstack-{run.project_name}-ssh-jump-pod" jump_pod_service_name = _get_pod_service_name(jump_pod_name) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index a0e22b8473..0a39d5e41e 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -45,6 +45,7 @@ runs, secrets, server, + sshproxy, templates, users, volumes, @@ -255,6 +256,7 @@ def register_routes(app: FastAPI, ui: bool = True): app.include_router(events.root_router) app.include_router(templates.router) app.include_router(exports.project_router) + app.include_router(sshproxy.router) @app.exception_handler(ForbiddenError) async def forbidden_error_handler(request: Request, exc: ForbiddenError): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 61c8adaeec..accbe9875e 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -59,6 +59,7 @@ emit_job_status_change_event, get_job_provisioning_data, get_job_runtime_data, + get_job_spec, ) from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt @@ -797,7 +798,7 @@ async def _detach_volumes_from_job_instance( jpd: JobProvisioningData, run_termination_reason: Optional[RunTerminationReason], ) -> _VolumeDetachResult: - job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + job_spec = get_job_spec(job_model) backend = await backends_services.get_project_backend_by_type( project=instance_model.project, backend_type=jpd.backend, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/probes.py b/src/dstack/_internal/server/background/scheduled_tasks/probes.py index 9b36bd09fe..ee5b5c9b4d 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/probes.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/probes.py @@ -12,7 +12,7 @@ from sqlalchemy.orm import joinedload from dstack._internal.core.errors import SSHError -from dstack._internal.core.models.runs import JobSpec, JobStatus, ProbeSpec +from dstack._internal.core.models.runs import JobStatus, ProbeSpec from dstack._internal.core.services.ssh.tunnel import ( SSH_DEFAULT_OPTIONS, IPSocket, @@ -21,6 +21,7 @@ ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import InstanceModel, JobModel, ProbeModel +from dstack._internal.server.services.jobs import get_job_spec from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.ssh import container_ssh_tunnel @@ -71,7 +72,7 @@ async def process_probes(): if probe.job.status != JobStatus.RUNNING: probe.active = False else: - job_spec: JobSpec = JobSpec.__response__.parse_raw(probe.job.job_spec_data) + job_spec = get_job_spec(probe.job) probe_spec = job_spec.probes[probe.probe_num] if probe_spec.until_ready and probe.success_streak >= probe_spec.ready_after: probe.active = False @@ -148,7 +149,7 @@ async def _get_service_replica_client(job: JobModel) -> AsyncGenerator[AsyncClie **SSH_DEFAULT_OPTIONS, "ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())), } - job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data) + job_spec = get_job_spec(job) with TemporaryDirectory() as temp_dir: app_socket_path = (Path(temp_dir) / "replica.sock").absolute() async with container_ssh_tunnel( diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index ea3d539734..37b427dbd0 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -34,7 +34,6 @@ JobTerminationReason, ProbeSpec, Run, - RunSpec, RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint @@ -67,6 +66,7 @@ find_job, get_job_attached_volumes, get_job_runtime_data, + get_job_spec, is_master_job, job_model_to_job_submission, switch_job_status, @@ -82,6 +82,7 @@ from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel from dstack._internal.server.services.runs import ( + get_run_spec, is_job_ready, run_model_to_run, ) @@ -732,7 +733,7 @@ def _process_provisioning_with_shim( Returns: is successful """ - job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + job_spec = get_job_spec(job_model) shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) @@ -980,7 +981,7 @@ def _terminate_if_inactivity_duration_exceeded( job_model: JobModel, no_connections_secs: Optional[int], ) -> None: - conf = RunSpec.__response__.parse_raw(run_model.run_spec).configuration + conf = get_run_spec(run_model).configuration if not isinstance(conf, DevEnvironmentConfiguration) or not isinstance( conf.inactivity_duration, int ): diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index 56d9fea77a..fdffa2dd80 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -13,7 +13,6 @@ from dstack._internal.core.models.profiles import RetryEvent, StopCriteria from dstack._internal.core.models.runs import ( Job, - JobSpec, JobStatus, JobTerminationReason, Run, @@ -33,6 +32,7 @@ from dstack._internal.server.services import events from dstack._internal.server.services.jobs import ( find_job, + get_job_spec, get_job_specs_from_run_spec, group_jobs_by_replica_latest, is_master_job, @@ -527,7 +527,7 @@ async def _handle_run_replicas( if job.status.is_finished(): continue try: - job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + job_spec = get_job_spec(job) existing_group_names.add(job_spec.replica_group) except Exception: continue @@ -643,7 +643,7 @@ async def _update_jobs_to_new_deployment_in_place( replica_group_name = None if replicas: - job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data) + job_spec = get_job_spec(job_models[0]) replica_group_name = job_spec.replica_group # FIXME: Handle getting image configuration errors or skip it. @@ -658,7 +658,7 @@ async def _update_jobs_to_new_deployment_in_place( ) can_update_all_jobs = True for old_job_model, new_job_spec in zip(job_models, new_job_specs): - old_job_spec = JobSpec.__response__.parse_raw(old_job_model.job_spec_data) + old_job_spec = get_job_spec(old_job_model) if new_job_spec != old_job_spec: can_update_all_jobs = False break diff --git a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py index 4cf63f2b7f..ab25c2c7d5 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py @@ -39,6 +39,7 @@ from dstack._internal.server.services.jobs import ( get_job_provisioning_data, get_job_runtime_data, + get_job_spec, switch_job_status, ) from dstack._internal.server.services.locking import get_locker @@ -356,7 +357,7 @@ async def _detach_volumes_from_job_instance( instance_model: InstanceModel, volume_models: list[VolumeModel], ) -> bool: - job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + job_spec = get_job_spec(job_model) backend = await backends_services.get_project_backend_by_type( project=project, backend_type=jpd.backend, diff --git a/src/dstack/_internal/server/routers/sshproxy.py b/src/dstack/_internal/server/routers/sshproxy.py new file mode 100644 index 0000000000..3edc927e96 --- /dev/null +++ b/src/dstack/_internal/server/routers/sshproxy.py @@ -0,0 +1,39 @@ +import os +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.server.db import get_session +from dstack._internal.server.schemas.sshproxy import GetUpstreamRequest, GetUpstreamResponse +from dstack._internal.server.security.permissions import AlwaysForbidden, ServiceAccount +from dstack._internal.server.services.sshproxy import get_upstream_response +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) + +if _token := os.getenv("DSTACK_SSHPROXY_API_TOKEN"): + _auth = ServiceAccount(_token) +else: + _auth = AlwaysForbidden() + + +router = APIRouter( + prefix="/api/sshproxy", + tags=["sshproxy"], + responses=get_base_api_additional_responses(), + dependencies=[Depends(_auth)], +) + + +@router.post("/get_upstream", response_model=GetUpstreamResponse) +async def get_upstream( + body: GetUpstreamRequest, + session: Annotated[AsyncSession, Depends(get_session)], +): + response = await get_upstream_response(session=session, upstream_id=body.id) + if response is None: + raise ResourceNotExistsError() + return CustomORJSONResponse(response) diff --git a/src/dstack/_internal/server/schemas/sshproxy.py b/src/dstack/_internal/server/schemas/sshproxy.py new file mode 100644 index 0000000000..10c9297d88 --- /dev/null +++ b/src/dstack/_internal/server/schemas/sshproxy.py @@ -0,0 +1,27 @@ +from typing import Annotated + +from pydantic import Field + +from dstack._internal.core.models.common import CoreModel + + +class GetUpstreamRequest(CoreModel): + # The format of id is intentionally not limited to UUID to allow further extensions + id: str + + +class UpstreamHost(CoreModel): + host: Annotated[str, Field(description="The hostname or IP address")] + port: Annotated[int, Field(description="The SSH port")] + user: Annotated[str, Field(description="The user to log in")] + private_key: Annotated[str, Field(description="The private key in OpenSSH file format")] + + +class GetUpstreamResponse(CoreModel): + hosts: Annotated[ + list[UpstreamHost], + Field(description="The chain of SSH hosts, the jump host(s) first, the target host last"), + ] + authorized_keys: Annotated[ + list[str], Field(description="The list of authorized public keys in OpenSSH file format") + ] diff --git a/src/dstack/_internal/server/security/permissions.py b/src/dstack/_internal/server/security/permissions.py index 107e526d30..a343152e6e 100644 --- a/src/dstack/_internal/server/security/permissions.py +++ b/src/dstack/_internal/server/security/permissions.py @@ -1,3 +1,4 @@ +from secrets import compare_digest from typing import Annotated, Optional, Tuple from uuid import UUID @@ -219,9 +220,23 @@ async def __call__( raise error_forbidden() -class OptionalServiceAccount: +class ServiceAccount: + def __init__(self, token: str) -> None: + self._token = token.encode() + + async def __call__( + self, token: Annotated[HTTPAuthorizationCredentials, Security(HTTPBearer())] + ) -> None: + if not compare_digest(token.credentials.encode(), self._token): + raise error_invalid_token() + + +class OptionalServiceAccount(ServiceAccount): + _token: Optional[bytes] = None + def __init__(self, token: Optional[str]) -> None: - self._token = token + if token is not None: + super().__init__(token) async def __call__( self, @@ -233,8 +248,12 @@ async def __call__( return if token is None: raise error_forbidden() - if token.credentials != self._token: - raise error_invalid_token() + await super().__call__(token) + + +class AlwaysForbidden: + async def __call__(self) -> None: + raise error_forbidden() async def get_project_member( diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index f718a80ce6..62254d7659 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -267,6 +267,10 @@ def get_job_runtime_data(job_model: JobModel) -> Optional[JobRuntimeData]: return JobRuntimeData.__response__.parse_raw(job_model.job_runtime_data) +def get_job_spec(job_model: JobModel) -> JobSpec: + return JobSpec.__response__.parse_raw(job_model.job_spec_data) + + def delay_job_instance_termination(job_model: JobModel): job_model.remove_at = common.get_current_datetime() + timedelta(seconds=15) diff --git a/src/dstack/_internal/server/services/prometheus/custom_metrics.py b/src/dstack/_internal/server/services/prometheus/custom_metrics.py index e9105b525c..6880f5c3ed 100644 --- a/src/dstack/_internal/server/services/prometheus/custom_metrics.py +++ b/src/dstack/_internal/server/services/prometheus/custom_metrics.py @@ -13,7 +13,7 @@ from sqlalchemy.orm import aliased, joinedload from dstack._internal.core.models.instances import InstanceStatus -from dstack._internal.core.models.runs import JobStatus, RunSpec, RunStatus +from dstack._internal.core.models.runs import JobStatus, RunStatus from dstack._internal.server.models import ( InstanceModel, JobMetricsPoint, @@ -25,6 +25,7 @@ ) from dstack._internal.server.services.instances import get_instance_offer from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data +from dstack._internal.server.services.runs import get_run_spec from dstack._internal.utils.common import get_current_datetime @@ -152,7 +153,7 @@ async def get_job_metrics(session: AsyncSession) -> Iterable[Metric]: price = jrd.offer.price gpus = resources.gpus cpus = resources.cpus - run_spec = RunSpec.__response__.parse_raw(job.run.run_spec) + run_spec = get_run_spec(job.run) labels = { "dstack_project_name": job.project.name, "dstack_user_name": job.run.user.name, @@ -186,7 +187,7 @@ async def get_job_metrics(session: AsyncSession) -> Iterable[Metric]: ) ): gpu_labels = labels.copy() - gpu_labels["dstack_gpu_num"] = gpu_num + gpu_labels["dstack_gpu_num"] = str(gpu_num) metrics.add_sample(_JOB_GPU_USAGE_RATIO, gpu_labels, gpu_util / 100) metrics.add_sample(_JOB_GPU_MEMORY_TOTAL, gpu_labels, gpu_memory_total) metrics.add_sample(_JOB_GPU_MEMORY_USAGE, gpu_labels, gpu_memory_usage) diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index ab34bf278d..a454b74ba8 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -12,9 +12,7 @@ from dstack._internal.core.models.instances import SSHConnectionParams from dstack._internal.core.models.runs import ( JobProvisioningData, - JobSpec, JobStatus, - RunSpec, RunStatus, ServiceSpec, get_service_port, @@ -32,6 +30,8 @@ from dstack._internal.proxy.lib.repo import BaseProxyRepo from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel from dstack._internal.server.services.instances import get_instance_remote_connection_info +from dstack._internal.server.services.jobs import get_job_spec +from dstack._internal.server.services.runs import get_run_spec from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE from dstack._internal.utils.common import get_or_error @@ -68,7 +68,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic if not len(jobs): return None run = jobs[0].run - run_spec = RunSpec.__response__.parse_raw(run.run_spec) + run_spec = get_run_spec(run) if not isinstance(run_spec.configuration, ServiceConfiguration): return None replicas = [] @@ -108,7 +108,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic if rci is not None and rci.ssh_proxy is not None: ssh_head_proxy = rci.ssh_proxy ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private - job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data) + job_spec = get_job_spec(job) replica = Replica( id=job.id.hex, app_port=get_service_port(job_spec, run_spec.configuration), diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index f8aa3f288b..958306bfaa 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -24,7 +24,6 @@ from dstack._internal.core.models.runs import ( ApplyRunPlanInput, Job, - JobSpec, JobStatus, JobSubmission, JobTerminationReason, @@ -54,6 +53,7 @@ check_can_attach_job_volumes, delay_job_instance_termination, get_job_configured_volumes, + get_job_spec, get_jobs_from_run_spec, job_model_to_job_submission, remove_job_spec_sensitive_info, @@ -112,6 +112,10 @@ def switch_run_status( events.emit(session, msg, actor=actor, targets=[events.Target.from_model(run_model)]) +def get_run_spec(run_model: RunModel) -> RunSpec: + return RunSpec.__response__.parse_raw(run_model.run_spec) + + async def list_user_runs( session: AsyncSession, user: UserModel, @@ -743,7 +747,7 @@ def run_model_to_run( include_sensitive=include_sensitive, ) - run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + run_spec = get_run_spec(run_model) latest_job_submission = None if len(jobs) > 0 and len(jobs[0].job_submissions) > 0: @@ -831,7 +835,7 @@ def _get_run_jobs_with_submissions( submissions.append(job_submission) if job_model is not None: # Use the spec from the latest submission. Submissions can have different specs - job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + job_spec = get_job_spec(job_model) if not include_sensitive: remove_job_spec_sensitive_info(job_spec) jobs.append(Job(job_spec=job_spec, job_submissions=submissions)) @@ -857,7 +861,7 @@ def _get_run_status_message(run_model: RunModel) -> str: if run_model.status in [RunStatus.SUBMITTED, RunStatus.PENDING]: # Show `retrying` if any job caused the run to retry for job_models in job_models_grouped_by_job: - last_job_spec = JobSpec.__response__.parse_raw(job_models[-1].job_spec_data) + last_job_spec = get_job_spec(job_models[-1]) retry_on_events = last_job_spec.retry.on_events if last_job_spec.retry else [] last_job_termination_reason = _get_last_job_termination_reason(job_models) if ( diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index 4f6c7ee19d..ffb7bd2169 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -4,10 +4,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.models.configurations import ReplicaGroup -from dstack._internal.core.models.runs import JobSpec, JobStatus, JobTerminationReason, RunSpec +from dstack._internal.core.models.runs import JobStatus, JobTerminationReason, RunSpec from dstack._internal.server.models import JobModel, RunModel from dstack._internal.server.services import events from dstack._internal.server.services.jobs import ( + get_job_spec, get_jobs_from_run_spec, group_jobs_by_replica_latest, switch_job_status, @@ -15,6 +16,7 @@ from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.runs import ( create_job_model_for_new_submission, + get_run_spec, logger, ) from dstack._internal.server.services.secrets import get_project_secrets_mapping @@ -30,8 +32,8 @@ async def retry_run_replica_jobs( ) # Determine replica group from existing job - run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) - job_spec = JobSpec.__response__.parse_raw(latest_jobs[0].job_spec_data) + run_spec = get_run_spec(run_model) + job_spec = get_job_spec(latest_jobs[0]) replica_group_name = job_spec.replica_group new_jobs = await get_jobs_from_run_spec( @@ -86,7 +88,7 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica ) active_replicas, inactive_replicas = build_replica_lists(run_model) - run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + run_spec = get_run_spec(run_model) if replicas_diff < 0: scale_down_replicas(session, active_replicas, abs(replicas_diff)) @@ -259,7 +261,7 @@ async def scale_run_replicas_per_group( run_model=run_model, group=group, replicas_diff=group_diff, - run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + run_spec=get_run_spec(run_model), active_replicas=active_replicas, inactive_replicas=inactive_replicas, ) @@ -300,7 +302,7 @@ async def scale_run_replicas_for_group( def job_belongs_to_group(job: JobModel, group_name: str) -> bool: - job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + job_spec = get_job_spec(job) return job_spec.replica_group == group_name diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 9a0bc03369..82d25a94b9 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -32,7 +32,7 @@ RouterType, SGLangServiceRouterConfig, ) -from dstack._internal.core.models.runs import JobSpec, Run, RunSpec, ServiceModelSpec, ServiceSpec +from dstack._internal.core.models.runs import Run, RunSpec, ServiceModelSpec, ServiceSpec from dstack._internal.core.models.services import OpenAIChatModel from dstack._internal.proxy.gateway.const import SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE from dstack._internal.server import settings @@ -322,7 +322,7 @@ async def register_replica( async with conn.client() as client: await client.register_replica( run=run, - job_spec=JobSpec.__response__.parse_raw(job_model.job_spec_data), + job_spec=jobs_services.get_job_spec(job_model), job_submission=job_submission, instance_project_ssh_private_key=instance_project_ssh_private_key, ssh_head_proxy=ssh_head_proxy, diff --git a/src/dstack/_internal/server/services/ssh.py b/src/dstack/_internal/server/services/ssh.py index cb5d46c8c3..72a0f65905 100644 --- a/src/dstack/_internal/server/services/ssh.py +++ b/src/dstack/_internal/server/services/ssh.py @@ -1,72 +1,100 @@ from collections.abc import Iterable -from typing import Optional -import dstack._internal.server.services.jobs as jobs_services from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.core.services.ssh.tunnel import SSH_DEFAULT_OPTIONS, SocketPair, SSHTunnel from dstack._internal.server.models import JobModel from dstack._internal.server.services.instances import get_instance_remote_connection_info +from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data from dstack._internal.utils.common import get_or_error from dstack._internal.utils.path import FileContent -def container_ssh_tunnel( - job: JobModel, - forwarded_sockets: Iterable[SocketPair] = (), - options: dict[str, str] = SSH_DEFAULT_OPTIONS, -) -> SSHTunnel: +def get_container_ssh_credentials(job: JobModel) -> list[tuple[SSHConnectionParams, FileContent]]: """ - Build SSHTunnel for connecting to the container running the specified job. + Returns the information needed to connect to the SSH server inside the job container. + + The user of the target host (container) is set to: + * VM-based backends and SSH instances: "root" + * container-based backends: `JobProvisioningData.username`, which is, as of 2026-03-10, + is always "root" on all supported backends (Runpod, Vast.ai, Kubernetes) + + Args: + job: `JobModel` with `project`, `instance` and `instance.project` fields loaded. + + Returns: + A list of hosts credentials as (host's `SSHConnectionParams`, private key's `FileContent`) + pairs ordered from the first proxy jump (if any) to the target host (container). """ - jpd: JobProvisioningData = JobProvisioningData.__response__.parse_raw( - job.job_provisioning_data - ) + hosts: list[tuple[SSHConnectionParams, FileContent]] = [] + + instance = get_or_error(job.instance) + + rci = get_instance_remote_connection_info(instance) + if rci is not None and (head_proxy := rci.ssh_proxy) is not None: + head_key = FileContent(get_or_error(get_or_error(rci.ssh_proxy_keys)[0].private)) + hosts.append((head_proxy, head_key)) + + jpd = get_job_provisioning_data(job) + assert jpd is not None assert jpd.hostname is not None assert jpd.ssh_port is not None - instance = get_or_error(job.instance) - if not jpd.dockerized: - ssh_destination = f"{jpd.username}@{jpd.hostname}" - ssh_port = jpd.ssh_port - ssh_proxy = jpd.ssh_proxy - ssh_proxy_private_key = None - else: - ssh_destination = "root@localhost" + + job_project_key = FileContent(job.project.ssh_private_key) + + if jpd.dockerized: + if jpd.backend != BackendType.LOCAL: + instance_proxy = SSHConnectionParams( + hostname=jpd.hostname, + username=jpd.username, + port=jpd.ssh_port, + ) + instance_project_key = FileContent(instance.project.ssh_private_key) + hosts.append((instance_proxy, instance_project_key)) ssh_port = DSTACK_RUNNER_SSH_PORT - job_submission = jobs_services.job_model_to_job_submission(job) - jrd = job_submission.job_runtime_data + jrd = get_job_runtime_data(job) if jrd is not None and jrd.ports is not None: ssh_port = jrd.ports.get(ssh_port, ssh_port) - ssh_proxy = SSHConnectionParams( + target_host = SSHConnectionParams( + hostname="localhost", + username="root", + port=ssh_port, + ) + hosts.append((target_host, job_project_key)) + else: + if jpd.ssh_proxy is not None: + # As of 2026-03-13, the only container-based backend with SSH proxy is Kubernetes, + # which is implemented as follows: the jump pod (JobProvisioningData.ssh_proxy) + # is created once per project via Compute.run_job() with a public key submitted as + # a method argument, that is, with the public key of the project of the first (within + # that project) job submitted to the cluster. + hosts.append((jpd.ssh_proxy, job_project_key)) + target_host = SSHConnectionParams( hostname=jpd.hostname, username=jpd.username, port=jpd.ssh_port, ) - ssh_proxy_private_key = None - if job.project_id != instance.project_id: - ssh_proxy_private_key = FileContent(instance.project.ssh_private_key) - if jpd.backend == BackendType.LOCAL: - ssh_proxy = None - ssh_proxy_private_key = None - ssh_head_proxy: Optional[SSHConnectionParams] = None - ssh_head_proxy_private_key: Optional[str] = None - rci = get_instance_remote_connection_info(instance) - if rci is not None and rci.ssh_proxy is not None: - ssh_head_proxy = rci.ssh_proxy - ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private - ssh_proxies = [] - if ssh_head_proxy is not None: - ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key) - ssh_proxies.append((ssh_head_proxy, FileContent(ssh_head_proxy_private_key))) - if ssh_proxy is not None: - ssh_proxies.append((ssh_proxy, ssh_proxy_private_key)) + hosts.append((target_host, job_project_key)) + + return hosts + + +def container_ssh_tunnel( + job: JobModel, + forwarded_sockets: Iterable[SocketPair] = (), + options: dict[str, str] = SSH_DEFAULT_OPTIONS, +) -> SSHTunnel: + """ + Build SSHTunnel for connecting to the container running the specified job. + """ + hosts = get_container_ssh_credentials(job) + target, identity = hosts[-1] return SSHTunnel( - destination=ssh_destination, - port=ssh_port, - ssh_proxies=ssh_proxies, - identity=FileContent(job.project.ssh_private_key), + destination=f"{target.username}@{target.hostname}", + port=target.port, + ssh_proxies=hosts[:-1], + identity=identity, forwarded_sockets=forwarded_sockets, options=options, ) diff --git a/src/dstack/_internal/server/services/sshproxy.py b/src/dstack/_internal/server/services/sshproxy.py new file mode 100644 index 0000000000..0877436d0e --- /dev/null +++ b/src/dstack/_internal/server/services/sshproxy.py @@ -0,0 +1,87 @@ +from typing import Optional +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from dstack._internal.core.models.runs import JobStatus +from dstack._internal.server.models import ( + InstanceModel, + JobModel, + ProjectModel, + RunModel, + UserModel, +) +from dstack._internal.server.schemas.sshproxy import GetUpstreamResponse, UpstreamHost +from dstack._internal.server.services.jobs import get_job_runtime_data, get_job_spec +from dstack._internal.server.services.runs import get_run_spec +from dstack._internal.server.services.ssh import get_container_ssh_credentials + + +async def get_upstream_response( + session: AsyncSession, + upstream_id: str, +) -> Optional[GetUpstreamResponse]: + # The format of upstream_id is intentionally not limited to UUID in the API schema to allow + # further extensions. Currently, it's just a JobModel.id + try: + job_id = UUID(upstream_id) + except ValueError: + return None + + res = await session.execute( + select(JobModel) + .where( + JobModel.id == job_id, + JobModel.status == JobStatus.RUNNING, + ) + .options( + (joinedload(JobModel.project, innerjoin=True).load_only(ProjectModel.ssh_private_key)), + ( + joinedload(JobModel.instance, innerjoin=True) + .load_only(InstanceModel.remote_connection_info) + .joinedload(InstanceModel.project, innerjoin=True) + .load_only(ProjectModel.ssh_private_key) + ), + ( + joinedload(JobModel.run, innerjoin=True) + .load_only(RunModel.run_spec) + .joinedload(RunModel.user, innerjoin=True) + .load_only(UserModel.ssh_public_key) + ), + ) + ) + job = res.scalar_one_or_none() + if job is None: + return None + + hosts: list[UpstreamHost] = [] + for ssh_params, private_key in get_container_ssh_credentials(job): + hosts.append( + UpstreamHost( + host=ssh_params.hostname, + port=ssh_params.port, + user=ssh_params.username, + private_key=private_key.content, + ) + ) + + username: Optional[str] = None + if (jrd := get_job_runtime_data(job)) is not None: + username = jrd.username + if username is None and (job_spec_user := get_job_spec(job).user) is not None: + username = job_spec_user.username + if username is not None: + hosts[-1].user = username + + authorized_keys: set[str] = set() + if (run_spec_key := get_run_spec(job.run).ssh_key_pub) is not None: + authorized_keys.add(run_spec_key) + if (user_key := job.run.user.ssh_public_key) is not None: + authorized_keys.add(user_key) + + return GetUpstreamResponse( + hosts=hosts, + authorized_keys=list(authorized_keys), + ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 4f527ef490..437da47617 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -55,6 +55,7 @@ InstanceType, RemoteConnectionInfo, Resources, + SSHConnectionParams, SSHKey, ) from dstack._internal.core.models.placement import ( @@ -428,6 +429,9 @@ def get_job_provisioning_data( internal_ip: Optional[str] = "127.0.0.4", price: float = 10.5, instance_type: Optional[InstanceType] = None, + username: str = "ubuntu", + ssh_port: int = 22, + ssh_proxy: Optional[SSHConnectionParams] = None, ) -> JobProvisioningData: gpus = [ Gpu( @@ -451,11 +455,11 @@ def get_job_provisioning_data( internal_ip=internal_ip, region=region, price=price, - username="ubuntu", - ssh_port=22, + username=username, + ssh_port=ssh_port, dockerized=dockerized, backend_data=None, - ssh_proxy=None, + ssh_proxy=ssh_proxy, ) @@ -869,6 +873,8 @@ def get_remote_connection_info( port: int = 22, ssh_user: str = "ubuntu", ssh_keys: Optional[list[SSHKey]] = None, + ssh_proxy: Optional[SSHConnectionParams] = None, + ssh_proxy_keys: Optional[list[SSHKey]] = None, env: Optional[Union[Env, dict]] = None, ): if ssh_keys is None: @@ -882,6 +888,8 @@ def get_remote_connection_info( port=port, ssh_user=ssh_user, ssh_keys=ssh_keys, + ssh_proxy=ssh_proxy, + ssh_proxy_keys=ssh_proxy_keys, env=env, ) diff --git a/src/tests/_internal/server/routers/test_prometheus.py b/src/tests/_internal/server/routers/test_prometheus.py index ab9549965d..f87f43a80f 100644 --- a/src/tests/_internal/server/routers/test_prometheus.py +++ b/src/tests/_internal/server/routers/test_prometheus.py @@ -369,7 +369,7 @@ async def test_returns_404_if_not_enabled( async def test_returns_403_if_not_authenticated( self, monkeypatch: pytest.MonkeyPatch, client: AsyncClient, token: Optional[str] ): - monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", "secret") + monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", b"secret") if token is not None: headers = get_auth_headers(token) else: @@ -380,7 +380,7 @@ async def test_returns_403_if_not_authenticated( async def test_returns_200_if_token_is_valid( self, monkeypatch: pytest.MonkeyPatch, client: AsyncClient ): - monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", "secret") + monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", b"secret") response = await client.get("/metrics", headers=get_auth_headers("secret")) assert response.status_code == 200 diff --git a/src/tests/_internal/server/routers/test_sshproxy.py b/src/tests/_internal/server/routers/test_sshproxy.py new file mode 100644 index 0000000000..2b546d7d66 --- /dev/null +++ b/src/tests/_internal/server/routers/test_sshproxy.py @@ -0,0 +1,189 @@ +import os +from typing import Optional + +import pytest +from httpx import AsyncClient +from pytest_unordered import unordered +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ServerClientErrorCode +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import DevEnvironmentConfiguration +from dstack._internal.core.models.runs import ( + JobStatus, +) +from dstack._internal.server.testing.common import ( + create_instance, + create_job, + create_project, + create_repo, + create_run, + create_user, + get_auth_headers, + get_job_provisioning_data, + get_job_runtime_data, + get_run_spec, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("image_config_mock", "test_db") +class TestGetUpstream: + @pytest.fixture + def token(self) -> str: + token_var = "DSTACK_SSHPROXY_API_TOKEN" + token = os.getenv(token_var) + assert token is not None, f"{token_var} must be set via pytest-env" + return token + + async def test_returns_40x_if_no_api_token_provided(self, client: AsyncClient): + response = await client.post("/api/sshproxy/get_upstream") + + assert response.status_code in [401, 403] + + async def test_returns_40x_if_api_token_is_not_valid(self, client: AsyncClient): + response = await client.post( + "/api/sshproxy/get_upstream", headers=get_auth_headers("invalid-token") + ) + + assert response.status_code in [401, 403] + + async def test_returns_resource_not_exists_if_upstream_id_is_not_uuid( + self, client: AsyncClient, token: str + ): + response = await client.post( + "/api/sshproxy/get_upstream", + headers=get_auth_headers(token), + json={"id": "some-string"}, + ) + + assert response.json()["detail"][0]["code"] == ServerClientErrorCode.RESOURCE_NOT_EXISTS + + async def test_returns_resource_not_exists_if_job_is_not_running( + self, + session: AsyncSession, + client: AsyncClient, + token: str, + ): + project = await create_project(session=session) + instance = await create_instance(session=session, project=project) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, user=user, repo=repo) + job = await create_job( + session=session, + run=run, + instance=instance, + status=JobStatus.TERMINATING, + ) + + response = await client.post( + "/api/sshproxy/get_upstream", + headers=get_auth_headers(token), + json={"id": str(job.id)}, + ) + + assert response.json()["detail"][0]["code"] == ServerClientErrorCode.RESOURCE_NOT_EXISTS + + async def test_response( + self, + session: AsyncSession, + client: AsyncClient, + token: str, + ): + project = await create_project(session=session, ssh_private_key="project-key") + instance = await create_instance( + session=session, project=project, backend=BackendType.RUNPOD + ) + user = await create_user(session=session, ssh_public_key="user-key") + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec(repo_id=repo.name, ssh_key_pub="run-spec-key") + run = await create_run( + session=session, project=project, user=user, repo=repo, run_spec=run_spec + ) + jpd = get_job_provisioning_data( + dockerized=False, + backend=BackendType.RUNPOD, + hostname="100.100.100.100", + username="root", + ssh_port=32768, + ssh_proxy=None, + ) + jrd = get_job_runtime_data(username="test-user") + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + job_runtime_data=jrd, + status=JobStatus.RUNNING, + ) + + response = await client.post( + "/api/sshproxy/get_upstream", + headers=get_auth_headers(token), + json={"id": str(job.id)}, + ) + + assert response.json() == { + "hosts": [ + { + "host": "100.100.100.100", + "port": 32768, + "private_key": "project-key", + "user": "test-user", + }, + ], + "authorized_keys": unordered( + [ + "user-key", + "run-spec-key", + ] + ), + } + + @pytest.mark.parametrize( + ["jrd_user", "conf_user", "expected_user"], + [ + pytest.param("jrd", "conf", "jrd", id="from-runner"), + pytest.param(None, "conf", "conf", id="from-configuration"), + pytest.param(None, None, "root", id="default"), + ], + ) + async def test_username_fallbacks( + self, + session: AsyncSession, + client: AsyncClient, + token: str, + jrd_user: Optional[str], + conf_user: Optional[str], + expected_user: str, + ): + project = await create_project(session=session, ssh_private_key="project-key") + instance = await create_instance(session=session, project=project, backend=BackendType.AWS) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + configuration = DevEnvironmentConfiguration(ide="vscode", user=conf_user) + run_spec = get_run_spec(repo_id=repo.name, configuration=configuration) + run = await create_run( + session=session, project=project, user=user, repo=repo, run_spec=run_spec + ) + jpd = get_job_provisioning_data(dockerized=True, backend=BackendType.AWS, username="root") + jrd = get_job_runtime_data(username=jrd_user) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + job_runtime_data=jrd, + status=JobStatus.RUNNING, + ) + + response = await client.post( + "/api/sshproxy/get_upstream", + headers=get_auth_headers(token), + json={"id": str(job.id)}, + ) + + assert response.json()["hosts"][-1]["user"] == expected_user diff --git a/src/tests/_internal/server/services/test_ssh.py b/src/tests/_internal/server/services/test_ssh.py new file mode 100644 index 0000000000..d9c492e225 --- /dev/null +++ b/src/tests/_internal/server/services/test_ssh.py @@ -0,0 +1,324 @@ +from typing import Optional + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import NetworkMode +from dstack._internal.core.models.instances import SSHConnectionParams, SSHKey +from dstack._internal.core.models.runs import ( + JobRuntimeData, +) +from dstack._internal.server.models import ProjectModel, RunModel +from dstack._internal.server.services.ssh import get_container_ssh_credentials +from dstack._internal.server.testing.common import ( + create_instance, + create_job, + create_project, + create_repo, + create_run, + create_user, + get_job_provisioning_data, + get_job_runtime_data, + get_remote_connection_info, +) +from dstack._internal.utils.path import FileContent + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db", "image_config_mock") +class TestGetContainerSSHCredentials: + instance_project_key = "instance-project-key" + run_project_key = "run-project-key" + + @pytest_asyncio.fixture + async def instance_project(self, session: AsyncSession) -> ProjectModel: + owner = await create_user(session=session, name="instance-project-owner") + return await create_project( + session=session, + name="instance-project", + owner=owner, + ssh_private_key=self.instance_project_key, + ) + + @pytest_asyncio.fixture + async def run(self, session: AsyncSession) -> RunModel: + run_project_owner = await create_user(session=session, name="run-project-owner") + run_project = await create_project( + session=session, name="run-project", ssh_private_key=self.run_project_key + ) + repo = await create_repo(session=session, project_id=run_project.id) + run = await create_run( + session=session, project=run_project, user=run_project_owner, repo=repo + ) + # Triggers session magic, attaches ProjectModel to JobModel somehow + assert run.project is not None + return run + + @pytest.mark.parametrize( + ["jrd", "expected_port"], + [ + pytest.param(None, DSTACK_RUNNER_SSH_PORT, id="no-jrd"), + pytest.param( + get_job_runtime_data(network_mode=NetworkMode.HOST, ports={}), + DSTACK_RUNNER_SSH_PORT, + id="host", + ), + pytest.param( + get_job_runtime_data( + network_mode=NetworkMode.HOST, ports={DSTACK_RUNNER_SSH_PORT: 32772} + ), + 32772, + id="bridge", + ), + ], + ) + async def test_vm_based_backend( + self, + session: AsyncSession, + instance_project: ProjectModel, + run: RunModel, + jrd: Optional[JobRuntimeData], + expected_port: int, + ): + instance = await create_instance( + session=session, project=instance_project, backend=BackendType.AWS + ) + jpd = get_job_provisioning_data( + backend=BackendType.AWS, + dockerized=True, + hostname="80.80.80.80", + username="ubuntu", + ssh_port=22, + ssh_proxy=None, + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + job_runtime_data=jrd, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="80.80.80.80", + username="ubuntu", + port=22, + ), + FileContent(self.instance_project_key), + ), + ( + SSHConnectionParams( + hostname="localhost", + username="root", + port=expected_port, + ), + FileContent(self.run_project_key), + ), + ] + + async def test_container_based_backend( + self, + session: AsyncSession, + instance_project: ProjectModel, + run: RunModel, + ): + instance = await create_instance( + session=session, project=instance_project, backend=BackendType.RUNPOD + ) + jpd = get_job_provisioning_data( + backend=BackendType.RUNPOD, + dockerized=False, + hostname="100.100.100.100", + username="root", + ssh_port=32768, + ssh_proxy=None, + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="100.100.100.100", + username="root", + port=32768, + ), + FileContent(self.run_project_key), + ), + ] + + async def test_container_based_backend_with_proxy( + self, + session: AsyncSession, + instance_project: ProjectModel, + run: RunModel, + ): + instance = await create_instance( + session=session, project=instance_project, backend=BackendType.KUBERNETES + ) + jpd = get_job_provisioning_data( + backend=BackendType.KUBERNETES, + dockerized=False, + hostname="10.105.30.22", + username="root", + ssh_port=DSTACK_RUNNER_SSH_PORT, + ssh_proxy=SSHConnectionParams( + hostname="120.120.120.120", + username="root", + port=30022, + ), + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="120.120.120.120", + username="root", + port=30022, + ), + FileContent(self.run_project_key), + ), + ( + SSHConnectionParams( + hostname="10.105.30.22", + username="root", + port=DSTACK_RUNNER_SSH_PORT, + ), + FileContent(self.run_project_key), + ), + ] + + async def test_ssh_instance_with_head_proxy( + self, + session: AsyncSession, + instance_project: ProjectModel, + run: RunModel, + ): + rci = get_remote_connection_info( + host="192.168.100.50", + port=22222, + ssh_user="ubuntu", + # User-provided key is only used for instance provisioning, then we always use + # the project key, which is added during provisioning + ssh_keys=[SSHKey(public="public", private="instance-key")], + ssh_proxy=SSHConnectionParams( + hostname="140.140.140.140", + username="bastion", + port=22, + ), + ssh_proxy_keys=[SSHKey(public="public", private="head-key")], + ) + instance = await create_instance( + session=session, + project=instance_project, + backend=BackendType.REMOTE, + remote_connection_info=rci, + ) + jpd = get_job_provisioning_data( + backend=BackendType.REMOTE, + dockerized=True, + hostname="192.168.100.50", + username="ubuntu", + ssh_port=22222, + # Actually, JobModel.job_provisioning_data.ssh_proxy is set to + # InstanceModel.remote_connection_info.ssh_proxy but not used in the function we test + ssh_proxy=None, + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + # jrd is tested in vm-based backend tests + job_runtime_data=None, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="140.140.140.140", + username="bastion", + port=22, + ), + FileContent("head-key"), + ), + ( + SSHConnectionParams( + hostname="192.168.100.50", + username="ubuntu", + port=22222, + ), + FileContent(self.instance_project_key), + ), + ( + SSHConnectionParams( + hostname="localhost", + username="root", + port=DSTACK_RUNNER_SSH_PORT, + ), + FileContent(self.run_project_key), + ), + ] + + async def test_local_backend( + self, + session: AsyncSession, + instance_project: ProjectModel, + run: RunModel, + ): + instance = await create_instance( + session=session, project=instance_project, backend=BackendType.LOCAL + ) + jpd = get_job_provisioning_data( + backend=BackendType.LOCAL, + dockerized=True, + hostname="127.0.0.1", + username="root", + ssh_port=DSTACK_RUNNER_SSH_PORT, + ssh_proxy=None, + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + # jrd is tested in vm-based backend tests + job_runtime_data=None, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="localhost", + username="root", + port=DSTACK_RUNNER_SSH_PORT, + ), + FileContent(self.run_project_key), + ), + ]