From c1ace05e4435bd4fe143aa47f200c064403b3a18 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Thu, 12 Mar 2026 10:25:56 +0100 Subject: [PATCH] Fix services on imported fleets For all server-to-container and gateway-to-container SSH connections: - Use the instance project's key to connect to the instance - Use the job project's key to connect to the container --- .../proxy/gateway/routers/registry.py | 1 + .../proxy/gateway/schemas/registry.py | 1 + .../proxy/gateway/services/registry.py | 2 + src/dstack/_internal/proxy/lib/models.py | 2 + .../proxy/lib/services/service_connection.py | 5 +- .../background/scheduled_tasks/probes.py | 2 +- .../server/services/gateways/client.py | 2 + .../_internal/server/services/proxy/repo.py | 17 +- .../server/services/services/__init__.py | 5 + src/dstack/_internal/server/services/ssh.py | 11 +- src/dstack/_internal/server/testing/common.py | 2 + .../scheduled_tasks/test_running_jobs.py | 173 +++++++++++++++++- src/tests/_internal/server/conftest.py | 13 +- .../_internal/server/routers/test_runs.py | 11 +- 14 files changed, 226 insertions(+), 21 deletions(-) diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index c5f4cf8a1a..61283e9082 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -78,6 +78,7 @@ async def register_replica( ssh_destination=body.ssh_host, ssh_port=body.ssh_port, ssh_proxy=body.ssh_proxy, + ssh_proxy_private_key=body.ssh_proxy_private_key, ssh_head_proxy=body.ssh_head_proxy, ssh_head_proxy_private_key=body.ssh_head_proxy_private_key, internal_ip=body.internal_ip, diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 802d23a700..33001cf25f 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -54,6 +54,7 @@ class RegisterReplicaRequest(BaseModel): ssh_host: str ssh_port: int ssh_proxy: Optional[SSHConnectionParams] + ssh_proxy_private_key: Optional[str] ssh_head_proxy: Optional[SSHConnectionParams] ssh_head_proxy_private_key: Optional[str] internal_ip: Optional[str] = None diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index adebe6f41d..919c05c0f2 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -134,6 +134,7 @@ async def register_replica( ssh_destination: str, ssh_port: int, ssh_proxy: Optional[SSHConnectionParams], + ssh_proxy_private_key: Optional[str], ssh_head_proxy: Optional[SSHConnectionParams], ssh_head_proxy_private_key: Optional[str], repo: GatewayProxyRepo, @@ -147,6 +148,7 @@ async def register_replica( ssh_destination=ssh_destination, ssh_port=ssh_port, ssh_proxy=ssh_proxy, + ssh_proxy_private_key=ssh_proxy_private_key, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, internal_ip=internal_ip, diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index a0a724dbea..d15e4b7ef2 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -24,6 +24,8 @@ class Replica(ImmutableModel): ssh_destination: str ssh_port: int ssh_proxy: Optional[SSHConnectionParams] + ssh_proxy_private_key: Optional[str] = None + "`None` means same as service project's key" # Optional outer proxy, a head node/bastion ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None diff --git a/src/dstack/_internal/proxy/lib/services/service_connection.py b/src/dstack/_internal/proxy/lib/services/service_connection.py index dc94cb27f3..37bdc5083a 100644 --- a/src/dstack/_internal/proxy/lib/services/service_connection.py +++ b/src/dstack/_internal/proxy/lib/services/service_connection.py @@ -53,7 +53,10 @@ def __init__(self, project: Project, service: Service, replica: Replica) -> None ssh_head_proxy_private_key = get_or_error(replica.ssh_head_proxy_private_key) ssh_proxies.append((replica.ssh_head_proxy, FileContent(ssh_head_proxy_private_key))) if replica.ssh_proxy is not None: - ssh_proxies.append((replica.ssh_proxy, None)) + if replica.ssh_proxy_private_key is not None: + ssh_proxies.append((replica.ssh_proxy, FileContent(replica.ssh_proxy_private_key))) + else: + ssh_proxies.append((replica.ssh_proxy, None)) self._tunnel = SSHTunnel( destination=replica.ssh_destination, port=replica.ssh_port, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/probes.py b/src/dstack/_internal/server/background/scheduled_tasks/probes.py index 4f712ff4cb..9b36bd09fe 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/probes.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/probes.py @@ -63,7 +63,7 @@ async def process_probes(): .joinedload(JobModel.instance) .joinedload(InstanceModel.project) ) - .options(joinedload(ProbeModel.job)) + .options(joinedload(ProbeModel.job).joinedload(JobModel.project)) .execution_options(populate_existing=True) ) probes = res.unique().scalars().all() diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index 9bc7a1f903..d83891c0b7 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -85,6 +85,7 @@ async def register_replica( run: Run, job_spec: JobSpec, job_submission: JobSubmission, + instance_project_ssh_private_key: Optional[str], ssh_head_proxy: Optional[SSHConnectionParams], ssh_head_proxy_private_key: Optional[str], ): @@ -122,6 +123,7 @@ async def register_replica( username=jpd.username, port=jpd.ssh_port, ).dict(), + "ssh_proxy_private_key": instance_project_ssh_private_key, } ) resp = await self._client.post( diff --git a/src/dstack/_internal/server/services/proxy/repo.py b/src/dstack/_internal/server/services/proxy/repo.py index 385c9e654f..ab34bf278d 100644 --- a/src/dstack/_internal/server/services/proxy/repo.py +++ b/src/dstack/_internal/server/services/proxy/repo.py @@ -3,7 +3,7 @@ import pydantic from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import contains_eager, joinedload import dstack._internal.server.services.jobs as jobs_services from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT @@ -30,7 +30,7 @@ TGIChatModelFormat, ) from dstack._internal.proxy.lib.repo import BaseProxyRepo -from dstack._internal.server.models import JobModel, ProjectModel, RunModel +from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel from dstack._internal.server.services.instances import get_instance_remote_connection_info from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE from dstack._internal.utils.common import get_or_error @@ -59,8 +59,9 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic JobModel.job_num == 0, ) .options( - joinedload(JobModel.run), - joinedload(JobModel.instance), + contains_eager(JobModel.run), + contains_eager(JobModel.project), + joinedload(JobModel.instance).joinedload(InstanceModel.project), ) ) jobs = res.unique().scalars().all() @@ -77,10 +78,12 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic ) assert jpd.hostname is not None assert jpd.ssh_port is not None + instance = get_or_error(job.instance) if not jpd.dockerized: ssh_destination = f"{jpd.username}@{jpd.hostname}" ssh_port = jpd.ssh_port ssh_proxy = jpd.ssh_proxy + ssh_proxy_private_key = None else: ssh_destination = "root@localhost" ssh_port = DSTACK_RUNNER_SSH_PORT @@ -93,11 +96,14 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic username=jpd.username, port=jpd.ssh_port, ) + ssh_proxy_private_key = None + if job.project_id != instance.project_id: + ssh_proxy_private_key = instance.project.ssh_private_key if jpd.backend == BackendType.LOCAL: ssh_proxy = None + ssh_proxy_private_key = None ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None - instance = get_or_error(job.instance) rci = get_instance_remote_connection_info(instance) if rci is not None and rci.ssh_proxy is not None: ssh_head_proxy = rci.ssh_proxy @@ -109,6 +115,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic ssh_destination=ssh_destination, ssh_port=ssh_port, ssh_proxy=ssh_proxy, + ssh_proxy_private_key=ssh_proxy_private_key, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, internal_ip=jpd.internal_ip, diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 2a730b5695..9a0bc03369 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -313,6 +313,10 @@ async def register_replica( if gateway_id is not None: gateway, conn = await get_or_add_gateway_connection(session, gateway_id) job_submission = jobs_services.job_model_to_job_submission(job_model) + assert job_model.instance is not None + instance_project_ssh_private_key = None + if job_model.project_id != job_model.instance.project_id: + instance_project_ssh_private_key = job_model.instance.project.ssh_private_key try: logger.debug("%s: registering replica for service %s", fmt(job_model), run.id.hex) async with conn.client() as client: @@ -320,6 +324,7 @@ async def register_replica( run=run, job_spec=JobSpec.__response__.parse_raw(job_model.job_spec_data), job_submission=job_submission, + instance_project_ssh_private_key=instance_project_ssh_private_key, ssh_head_proxy=ssh_head_proxy, ssh_head_proxy_private_key=ssh_head_proxy_private_key, ) diff --git a/src/dstack/_internal/server/services/ssh.py b/src/dstack/_internal/server/services/ssh.py index 0fa7c189e2..cb5d46c8c3 100644 --- a/src/dstack/_internal/server/services/ssh.py +++ b/src/dstack/_internal/server/services/ssh.py @@ -26,10 +26,12 @@ def container_ssh_tunnel( ) assert jpd.hostname is not None assert jpd.ssh_port is not None + instance = get_or_error(job.instance) if not jpd.dockerized: ssh_destination = f"{jpd.username}@{jpd.hostname}" ssh_port = jpd.ssh_port ssh_proxy = jpd.ssh_proxy + ssh_proxy_private_key = None else: ssh_destination = "root@localhost" ssh_port = DSTACK_RUNNER_SSH_PORT @@ -42,11 +44,14 @@ def container_ssh_tunnel( username=jpd.username, port=jpd.ssh_port, ) + ssh_proxy_private_key = None + if job.project_id != instance.project_id: + ssh_proxy_private_key = FileContent(instance.project.ssh_private_key) if jpd.backend == BackendType.LOCAL: ssh_proxy = None + ssh_proxy_private_key = None ssh_head_proxy: Optional[SSHConnectionParams] = None ssh_head_proxy_private_key: Optional[str] = None - instance = get_or_error(job.instance) rci = get_instance_remote_connection_info(instance) if rci is not None and rci.ssh_proxy is not None: ssh_head_proxy = rci.ssh_proxy @@ -56,12 +61,12 @@ def container_ssh_tunnel( ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key) ssh_proxies.append((ssh_head_proxy, FileContent(ssh_head_proxy_private_key))) if ssh_proxy is not None: - ssh_proxies.append((ssh_proxy, None)) + ssh_proxies.append((ssh_proxy, ssh_proxy_private_key)) return SSHTunnel( destination=ssh_destination, port=ssh_port, ssh_proxies=ssh_proxies, - identity=FileContent(instance.project.ssh_private_key), + identity=FileContent(job.project.ssh_private_key), forwarded_sockets=forwarded_sockets, options=options, ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 2b4c347125..4f527ef490 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -309,6 +309,7 @@ async def create_run( repo: RepoModel, user: UserModel, fleet: Optional[FleetModel] = None, + gateway: Optional[GatewayModel] = None, run_name: Optional[str] = None, status: RunStatus = RunStatus.SUBMITTED, termination_reason: Optional[RunTerminationReason] = None, @@ -349,6 +350,7 @@ async def create_run( desired_replica_count=1, resubmission_attempt=resubmission_attempt, next_triggered_at=next_triggered_at, + gateway=gateway, ) session.add(run) await session.commit() diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index 80a18dc11d..239cc265b3 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -2,7 +2,7 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Optional -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch import pytest from freezegun import freeze_time @@ -19,6 +19,7 @@ ProbeConfig, ServiceConfiguration, ) +from dstack._internal.core.models.gateways import GatewayStatus from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.profiles import StartupOrder, UtilizationPolicy from dstack._internal.core.models.runs import ( @@ -52,6 +53,11 @@ volume_model_to_volume, ) from dstack._internal.server.testing.common import ( + create_backend, + create_export, + create_fleet, + create_gateway, + create_gateway_compute, create_instance, create_job, create_job_metrics_point, @@ -1297,3 +1303,168 @@ async def test_registers_service_replica_only_after_probes_pass( else: assert not job.registered assert not events + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_registers_service_replica_in_gateway( + self, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + mock_gateway_connection: AsyncMock, + ): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + backend = await create_backend(session=session, project_id=project.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + name="test-gateway", + wildcard_domain="example.com", + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, image="ubuntu", gateway="test-gateway" + ), + ), + gateway=gateway, + ) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + fleet=fleet, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + await process_running_jobs() + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.registered + events = await list_events(session) + assert {e.message for e in events} == { + "Job status changed PULLING -> RUNNING", + "Service replica registered to receive requests", + } + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value.register_replica.assert_called_once_with( + run=ANY, + job_spec=ANY, + job_submission=ANY, + instance_project_ssh_private_key=None, + ssh_head_proxy=None, + ssh_head_proxy_private_key=None, + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_registers_service_replica_in_gateway_when_running_on_imported_instance( + self, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + runner_client_mock: Mock, + mock_gateway_connection: AsyncMock, + ): + user = await create_user(session=session) + exporter_project = await create_project( + session=session, name="exporter", owner=user, ssh_private_key="exporter-private-key" + ) + importer_project = await create_project(session=session, name="importer", owner=user) + fleet = await create_fleet(session=session, project=exporter_project) + instance = await create_instance( + session=session, + project=exporter_project, + status=InstanceStatus.BUSY, + fleet=fleet, + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + repo = await create_repo(session=session, project_id=importer_project.id) + backend = await create_backend(session=session, project_id=importer_project.id) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + ) + gateway = await create_gateway( + session=session, + project_id=importer_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + name="test-gateway", + wildcard_domain="example.com", + ) + run = await create_run( + session=session, + project=importer_project, + repo=repo, + user=user, + run_spec=get_run_spec( + run_name="test", + repo_id=repo.name, + configuration=ServiceConfiguration( + port=80, image="ubuntu", gateway="test-gateway" + ), + ), + gateway=gateway, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.PULLING, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + instance_assigned=True, + ) + shim_client_mock.get_task.return_value.status = TaskStatus.RUNNING + + await process_running_jobs() + + await session.refresh(job) + assert job.status == JobStatus.RUNNING + assert job.registered + events = await list_events(session) + assert {e.message for e in events} == { + "Job status changed PULLING -> RUNNING", + "Service replica registered to receive requests", + } + mock_gateway_connection.return_value.client.return_value.__aenter__.return_value.register_replica.assert_called_once_with( + run=ANY, + job_spec=ANY, + job_submission=ANY, + instance_project_ssh_private_key="exporter-private-key", + ssh_head_proxy=None, + ssh_head_proxy_private_key=None, + ) diff --git a/src/tests/_internal/server/conftest.py b/src/tests/_internal/server/conftest.py index 9bb508c5d6..dcd9291c59 100644 --- a/src/tests/_internal/server/conftest.py +++ b/src/tests/_internal/server/conftest.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from pathlib import Path -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock, patch import httpx import pytest @@ -39,3 +40,13 @@ def image_config_mock(monkeypatch: pytest.MonkeyPatch) -> ImageConfig: Mock(return_value=ImageConfigObject(config=image_config)), ) return image_config + + +@pytest.fixture() +def mock_gateway_connection() -> Generator[AsyncMock, None, None]: + with patch( + "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" + ) as get_conn_mock: + get_conn_mock.return_value.client = Mock() + get_conn_mock.return_value.client.return_value = AsyncMock() + yield get_conn_mock diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 4d6e7aa95d..ea829d569d 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -1,7 +1,7 @@ import copy import json from datetime import datetime, timezone -from typing import Dict, Generator, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from unittest.mock import AsyncMock, Mock, patch from uuid import UUID @@ -2351,14 +2351,7 @@ async def test_returns_400_if_runs_active( @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestSubmitService: - @pytest.fixture(autouse=True) - def mock_gateway_connection(self) -> Generator[AsyncMock, None, None]: - with patch( - "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" - ) as get_conn_mock: - get_conn_mock.return_value.client = Mock() - get_conn_mock.return_value.client.return_value = AsyncMock() - yield get_conn_mock + pytestmark = pytest.mark.usefixtures("mock_gateway_connection") @pytest.mark.asyncio @pytest.mark.parametrize(