Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ async def register_replica(
ssh_destination=body.ssh_host,
ssh_port=body.ssh_port,
ssh_proxy=body.ssh_proxy,
ssh_proxy_private_key=body.ssh_proxy_private_key,
ssh_head_proxy=body.ssh_head_proxy,
ssh_head_proxy_private_key=body.ssh_head_proxy_private_key,
internal_ip=body.internal_ip,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/proxy/gateway/schemas/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class RegisterReplicaRequest(BaseModel):
ssh_host: str
ssh_port: int
ssh_proxy: Optional[SSHConnectionParams]
ssh_proxy_private_key: Optional[str]
ssh_head_proxy: Optional[SSHConnectionParams]
ssh_head_proxy_private_key: Optional[str]
internal_ip: Optional[str] = None
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ async def register_replica(
ssh_destination: str,
ssh_port: int,
ssh_proxy: Optional[SSHConnectionParams],
ssh_proxy_private_key: Optional[str],
ssh_head_proxy: Optional[SSHConnectionParams],
ssh_head_proxy_private_key: Optional[str],
repo: GatewayProxyRepo,
Expand All @@ -147,6 +148,7 @@ async def register_replica(
ssh_destination=ssh_destination,
ssh_port=ssh_port,
ssh_proxy=ssh_proxy,
ssh_proxy_private_key=ssh_proxy_private_key,
ssh_head_proxy=ssh_head_proxy,
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
internal_ip=internal_ip,
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/proxy/lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Replica(ImmutableModel):
ssh_destination: str
ssh_port: int
ssh_proxy: Optional[SSHConnectionParams]
ssh_proxy_private_key: Optional[str] = None
"`None` means same as service project's key"
# Optional outer proxy, a head node/bastion
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def __init__(self, project: Project, service: Service, replica: Replica) -> None
ssh_head_proxy_private_key = get_or_error(replica.ssh_head_proxy_private_key)
ssh_proxies.append((replica.ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
if replica.ssh_proxy is not None:
ssh_proxies.append((replica.ssh_proxy, None))
if replica.ssh_proxy_private_key is not None:
ssh_proxies.append((replica.ssh_proxy, FileContent(replica.ssh_proxy_private_key)))
else:
ssh_proxies.append((replica.ssh_proxy, None))
self._tunnel = SSHTunnel(
destination=replica.ssh_destination,
port=replica.ssh_port,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def process_probes():
.joinedload(JobModel.instance)
.joinedload(InstanceModel.project)
)
.options(joinedload(ProbeModel.job))
.options(joinedload(ProbeModel.job).joinedload(JobModel.project))
.execution_options(populate_existing=True)
)
probes = res.unique().scalars().all()
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/services/gateways/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def register_replica(
run: Run,
job_spec: JobSpec,
job_submission: JobSubmission,
instance_project_ssh_private_key: Optional[str],
ssh_head_proxy: Optional[SSHConnectionParams],
ssh_head_proxy_private_key: Optional[str],
):
Expand Down Expand Up @@ -122,6 +123,7 @@ async def register_replica(
username=jpd.username,
port=jpd.ssh_port,
).dict(),
"ssh_proxy_private_key": instance_project_ssh_private_key,
}
)
resp = await self._client.post(
Expand Down
17 changes: 12 additions & 5 deletions src/dstack/_internal/server/services/proxy/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pydantic
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import contains_eager, joinedload

import dstack._internal.server.services.jobs as jobs_services
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
Expand All @@ -30,7 +30,7 @@
TGIChatModelFormat,
)
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.server.models import JobModel, ProjectModel, RunModel
from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel
from dstack._internal.server.services.instances import get_instance_remote_connection_info
from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
from dstack._internal.utils.common import get_or_error
Expand Down Expand Up @@ -59,8 +59,9 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
JobModel.job_num == 0,
)
.options(
joinedload(JobModel.run),
joinedload(JobModel.instance),
contains_eager(JobModel.run),
contains_eager(JobModel.project),
joinedload(JobModel.instance).joinedload(InstanceModel.project),
)
)
jobs = res.unique().scalars().all()
Expand All @@ -77,10 +78,12 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
)
assert jpd.hostname is not None
assert jpd.ssh_port is not None
instance = get_or_error(job.instance)
if not jpd.dockerized:
ssh_destination = f"{jpd.username}@{jpd.hostname}"
ssh_port = jpd.ssh_port
ssh_proxy = jpd.ssh_proxy
ssh_proxy_private_key = None
else:
ssh_destination = "root@localhost"
ssh_port = DSTACK_RUNNER_SSH_PORT
Expand All @@ -93,11 +96,14 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
username=jpd.username,
port=jpd.ssh_port,
)
ssh_proxy_private_key = None
if job.project_id != instance.project_id:
ssh_proxy_private_key = instance.project.ssh_private_key
if jpd.backend == BackendType.LOCAL:
ssh_proxy = None
ssh_proxy_private_key = None
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
instance = get_or_error(job.instance)
rci = get_instance_remote_connection_info(instance)
if rci is not None and rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
Expand All @@ -109,6 +115,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
ssh_destination=ssh_destination,
ssh_port=ssh_port,
ssh_proxy=ssh_proxy,
ssh_proxy_private_key=ssh_proxy_private_key,
ssh_head_proxy=ssh_head_proxy,
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
internal_ip=jpd.internal_ip,
Expand Down
5 changes: 5 additions & 0 deletions src/dstack/_internal/server/services/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,18 @@ async def register_replica(
if gateway_id is not None:
gateway, conn = await get_or_add_gateway_connection(session, gateway_id)
job_submission = jobs_services.job_model_to_job_submission(job_model)
assert job_model.instance is not None
instance_project_ssh_private_key = None
if job_model.project_id != job_model.instance.project_id:
instance_project_ssh_private_key = job_model.instance.project.ssh_private_key
try:
logger.debug("%s: registering replica for service %s", fmt(job_model), run.id.hex)
async with conn.client() as client:
await client.register_replica(
run=run,
job_spec=JobSpec.__response__.parse_raw(job_model.job_spec_data),
job_submission=job_submission,
instance_project_ssh_private_key=instance_project_ssh_private_key,
ssh_head_proxy=ssh_head_proxy,
ssh_head_proxy_private_key=ssh_head_proxy_private_key,
)
Expand Down
11 changes: 8 additions & 3 deletions src/dstack/_internal/server/services/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def container_ssh_tunnel(
)
assert jpd.hostname is not None
assert jpd.ssh_port is not None
instance = get_or_error(job.instance)
if not jpd.dockerized:
ssh_destination = f"{jpd.username}@{jpd.hostname}"
ssh_port = jpd.ssh_port
ssh_proxy = jpd.ssh_proxy
ssh_proxy_private_key = None
else:
ssh_destination = "root@localhost"
ssh_port = DSTACK_RUNNER_SSH_PORT
Expand All @@ -42,11 +44,14 @@ def container_ssh_tunnel(
username=jpd.username,
port=jpd.ssh_port,
)
ssh_proxy_private_key = None
if job.project_id != instance.project_id:
ssh_proxy_private_key = FileContent(instance.project.ssh_private_key)
if jpd.backend == BackendType.LOCAL:
ssh_proxy = None
ssh_proxy_private_key = None
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
instance = get_or_error(job.instance)
rci = get_instance_remote_connection_info(instance)
if rci is not None and rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
Expand All @@ -56,12 +61,12 @@ def container_ssh_tunnel(
ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key)
ssh_proxies.append((ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
if ssh_proxy is not None:
ssh_proxies.append((ssh_proxy, None))
ssh_proxies.append((ssh_proxy, ssh_proxy_private_key))
return SSHTunnel(
destination=ssh_destination,
port=ssh_port,
ssh_proxies=ssh_proxies,
identity=FileContent(instance.project.ssh_private_key),
identity=FileContent(job.project.ssh_private_key),
forwarded_sockets=forwarded_sockets,
options=options,
)
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ async def create_run(
repo: RepoModel,
user: UserModel,
fleet: Optional[FleetModel] = None,
gateway: Optional[GatewayModel] = None,
run_name: Optional[str] = None,
status: RunStatus = RunStatus.SUBMITTED,
termination_reason: Optional[RunTerminationReason] = None,
Expand Down Expand Up @@ -349,6 +350,7 @@ async def create_run(
desired_replica_count=1,
resubmission_attempt=resubmission_attempt,
next_triggered_at=next_triggered_at,
gateway=gateway,
)
session.add(run)
await session.commit()
Expand Down
Loading
Loading