Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ markers = [
]
env = [
"DSTACK_CLI_RICH_FORCE_TERMINAL=0",
"DSTACK_SSHPROXY_API_TOKEN=test-token",
]
filterwarnings = [
# testcontainers modules use deprecated decorators – nothing we can do:
Expand All @@ -142,6 +143,7 @@ dev = [
"pytest-httpbin>=2.1.0",
"pytest-socket>=0.7.0",
"pytest-env>=1.1.0",
"pytest-unordered>=0.7.0",
"httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3
"requests-mock>=1.12.1",
"openai>=1.68.2",
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def run_job(
commands = get_docker_commands(
[run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()]
)
# There is a one jump pod per Kubernetes backend that is used
# as an ssh proxy jump to connect to all other services in Kubernetes.
# There is one jump pod per project that is used as an ssh proxy jump to connect
# to all job pods of the same project.
# The service is created here and configured later in update_provisioning_data()
jump_pod_name = f"dstack-{run.project_name}-ssh-jump-pod"
jump_pod_service_name = _get_pod_service_name(jump_pod_name)
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
runs,
secrets,
server,
sshproxy,
templates,
users,
volumes,
Expand Down Expand Up @@ -255,6 +256,7 @@ def register_routes(app: FastAPI, ui: bool = True):
app.include_router(events.root_router)
app.include_router(templates.router)
app.include_router(exports.project_router)
app.include_router(sshproxy.router)

@app.exception_handler(ForbiddenError)
async def forbidden_error_handler(request: Request, exc: ForbiddenError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
emit_job_status_change_event,
get_job_provisioning_data,
get_job_runtime_data,
get_job_spec,
)
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.logging import fmt
Expand Down Expand Up @@ -797,7 +798,7 @@ async def _detach_volumes_from_job_instance(
jpd: JobProvisioningData,
run_termination_reason: Optional[RunTerminationReason],
) -> _VolumeDetachResult:
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
job_spec = get_job_spec(job_model)
backend = await backends_services.get_project_backend_by_type(
project=instance_model.project,
backend_type=jpd.backend,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy.orm import joinedload

from dstack._internal.core.errors import SSHError
from dstack._internal.core.models.runs import JobSpec, JobStatus, ProbeSpec
from dstack._internal.core.models.runs import JobStatus, ProbeSpec
from dstack._internal.core.services.ssh.tunnel import (
SSH_DEFAULT_OPTIONS,
IPSocket,
Expand All @@ -21,6 +21,7 @@
)
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import InstanceModel, JobModel, ProbeModel
from dstack._internal.server.services.jobs import get_job_spec
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.logging import fmt
from dstack._internal.server.services.ssh import container_ssh_tunnel
Expand Down Expand Up @@ -71,7 +72,7 @@ async def process_probes():
if probe.job.status != JobStatus.RUNNING:
probe.active = False
else:
job_spec: JobSpec = JobSpec.__response__.parse_raw(probe.job.job_spec_data)
job_spec = get_job_spec(probe.job)
probe_spec = job_spec.probes[probe.probe_num]
if probe_spec.until_ready and probe.success_streak >= probe_spec.ready_after:
probe.active = False
Expand Down Expand Up @@ -148,7 +149,7 @@ async def _get_service_replica_client(job: JobModel) -> AsyncGenerator[AsyncClie
**SSH_DEFAULT_OPTIONS,
"ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
}
job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
job_spec = get_job_spec(job)
with TemporaryDirectory() as temp_dir:
app_socket_path = (Path(temp_dir) / "replica.sock").absolute()
async with container_ssh_tunnel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
JobTerminationReason,
ProbeSpec,
Run,
RunSpec,
RunStatus,
)
from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint
Expand Down Expand Up @@ -67,6 +66,7 @@
find_job,
get_job_attached_volumes,
get_job_runtime_data,
get_job_spec,
is_master_job,
job_model_to_job_submission,
switch_job_status,
Expand All @@ -82,6 +82,7 @@
from dstack._internal.server.services.runner import client
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
from dstack._internal.server.services.runs import (
get_run_spec,
is_job_ready,
run_model_to_run,
)
Expand Down Expand Up @@ -732,7 +733,7 @@ def _process_provisioning_with_shim(
Returns:
is successful
"""
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
job_spec = get_job_spec(job_model)

shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])

Expand Down Expand Up @@ -980,7 +981,7 @@ def _terminate_if_inactivity_duration_exceeded(
job_model: JobModel,
no_connections_secs: Optional[int],
) -> None:
conf = RunSpec.__response__.parse_raw(run_model.run_spec).configuration
conf = get_run_spec(run_model).configuration
if not isinstance(conf, DevEnvironmentConfiguration) or not isinstance(
conf.inactivity_duration, int
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
from dstack._internal.core.models.runs import (
Job,
JobSpec,
JobStatus,
JobTerminationReason,
Run,
Expand All @@ -33,6 +32,7 @@
from dstack._internal.server.services import events
from dstack._internal.server.services.jobs import (
find_job,
get_job_spec,
get_job_specs_from_run_spec,
group_jobs_by_replica_latest,
is_master_job,
Expand Down Expand Up @@ -527,7 +527,7 @@ async def _handle_run_replicas(
if job.status.is_finished():
continue
try:
job_spec = JobSpec.__response__.parse_raw(job.job_spec_data)
job_spec = get_job_spec(job)
existing_group_names.add(job_spec.replica_group)
except Exception:
continue
Expand Down Expand Up @@ -643,7 +643,7 @@ async def _update_jobs_to_new_deployment_in_place(
replica_group_name = None

if replicas:
job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data)
job_spec = get_job_spec(job_models[0])
replica_group_name = job_spec.replica_group

# FIXME: Handle getting image configuration errors or skip it.
Expand All @@ -658,7 +658,7 @@ async def _update_jobs_to_new_deployment_in_place(
)
can_update_all_jobs = True
for old_job_model, new_job_spec in zip(job_models, new_job_specs):
old_job_spec = JobSpec.__response__.parse_raw(old_job_model.job_spec_data)
old_job_spec = get_job_spec(old_job_model)
if new_job_spec != old_job_spec:
can_update_all_jobs = False
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from dstack._internal.server.services.jobs import (
get_job_provisioning_data,
get_job_runtime_data,
get_job_spec,
switch_job_status,
)
from dstack._internal.server.services.locking import get_locker
Expand Down Expand Up @@ -356,7 +357,7 @@ async def _detach_volumes_from_job_instance(
instance_model: InstanceModel,
volume_models: list[VolumeModel],
) -> bool:
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
job_spec = get_job_spec(job_model)
backend = await backends_services.get_project_backend_by_type(
project=project,
backend_type=jpd.backend,
Expand Down
39 changes: 39 additions & 0 deletions src/dstack/_internal/server/routers/sshproxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
from typing import Annotated

from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.server.db import get_session
from dstack._internal.server.schemas.sshproxy import GetUpstreamRequest, GetUpstreamResponse
from dstack._internal.server.security.permissions import AlwaysForbidden, ServiceAccount
from dstack._internal.server.services.sshproxy import get_upstream_response
from dstack._internal.server.utils.routers import (
CustomORJSONResponse,
get_base_api_additional_responses,
)

if _token := os.getenv("DSTACK_SSHPROXY_API_TOKEN"):
_auth = ServiceAccount(_token)
else:
_auth = AlwaysForbidden()


router = APIRouter(
prefix="/api/sshproxy",
tags=["sshproxy"],
responses=get_base_api_additional_responses(),
dependencies=[Depends(_auth)],
)


@router.post("/get_upstream", response_model=GetUpstreamResponse)
async def get_upstream(
body: GetUpstreamRequest,
session: Annotated[AsyncSession, Depends(get_session)],
):
response = await get_upstream_response(session=session, upstream_id=body.id)
if response is None:
raise ResourceNotExistsError()
return CustomORJSONResponse(response)
27 changes: 27 additions & 0 deletions src/dstack/_internal/server/schemas/sshproxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Annotated

from pydantic import Field

from dstack._internal.core.models.common import CoreModel


class GetUpstreamRequest(CoreModel):
# The format of id is intentionally not limited to UUID to allow further extensions
id: str


class UpstreamHost(CoreModel):
host: Annotated[str, Field(description="The hostname or IP address")]
port: Annotated[int, Field(description="The SSH port")]
user: Annotated[str, Field(description="The user to log in")]
private_key: Annotated[str, Field(description="The private key in OpenSSH file format")]


class GetUpstreamResponse(CoreModel):
hosts: Annotated[
list[UpstreamHost],
Field(description="The chain of SSH hosts, the jump host(s) first, the target host last"),
]
authorized_keys: Annotated[
list[str], Field(description="The list of authorized public keys in OpenSSH file format")
]
27 changes: 23 additions & 4 deletions src/dstack/_internal/server/security/permissions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from secrets import compare_digest
from typing import Annotated, Optional, Tuple
from uuid import UUID

Expand Down Expand Up @@ -219,9 +220,23 @@ async def __call__(
raise error_forbidden()


class OptionalServiceAccount:
class ServiceAccount:
def __init__(self, token: str) -> None:
self._token = token.encode()

async def __call__(
self, token: Annotated[HTTPAuthorizationCredentials, Security(HTTPBearer())]
) -> None:
if not compare_digest(token.credentials.encode(), self._token):
raise error_invalid_token()


class OptionalServiceAccount(ServiceAccount):
_token: Optional[bytes] = None

def __init__(self, token: Optional[str]) -> None:
self._token = token
if token is not None:
super().__init__(token)

async def __call__(
self,
Expand All @@ -233,8 +248,12 @@ async def __call__(
return
if token is None:
raise error_forbidden()
if token.credentials != self._token:
raise error_invalid_token()
await super().__call__(token)


class AlwaysForbidden:
async def __call__(self) -> None:
raise error_forbidden()


async def get_project_member(
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/services/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def get_job_runtime_data(job_model: JobModel) -> Optional[JobRuntimeData]:
return JobRuntimeData.__response__.parse_raw(job_model.job_runtime_data)


def get_job_spec(job_model: JobModel) -> JobSpec:
return JobSpec.__response__.parse_raw(job_model.job_spec_data)


def delay_job_instance_termination(job_model: JobModel):
job_model.remove_at = common.get_current_datetime() + timedelta(seconds=15)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sqlalchemy.orm import aliased, joinedload

from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.core.models.runs import JobStatus, RunSpec, RunStatus
from dstack._internal.core.models.runs import JobStatus, RunStatus
from dstack._internal.server.models import (
InstanceModel,
JobMetricsPoint,
Expand All @@ -25,6 +25,7 @@
)
from dstack._internal.server.services.instances import get_instance_offer
from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
from dstack._internal.server.services.runs import get_run_spec
from dstack._internal.utils.common import get_current_datetime


Expand Down Expand Up @@ -152,7 +153,7 @@ async def get_job_metrics(session: AsyncSession) -> Iterable[Metric]:
price = jrd.offer.price
gpus = resources.gpus
cpus = resources.cpus
run_spec = RunSpec.__response__.parse_raw(job.run.run_spec)
run_spec = get_run_spec(job.run)
labels = {
"dstack_project_name": job.project.name,
"dstack_user_name": job.run.user.name,
Expand Down Expand Up @@ -186,7 +187,7 @@ async def get_job_metrics(session: AsyncSession) -> Iterable[Metric]:
)
):
gpu_labels = labels.copy()
gpu_labels["dstack_gpu_num"] = gpu_num
gpu_labels["dstack_gpu_num"] = str(gpu_num)
metrics.add_sample(_JOB_GPU_USAGE_RATIO, gpu_labels, gpu_util / 100)
metrics.add_sample(_JOB_GPU_MEMORY_TOTAL, gpu_labels, gpu_memory_total)
metrics.add_sample(_JOB_GPU_MEMORY_USAGE, gpu_labels, gpu_memory_usage)
Expand Down
8 changes: 4 additions & 4 deletions src/dstack/_internal/server/services/proxy/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import (
JobProvisioningData,
JobSpec,
JobStatus,
RunSpec,
RunStatus,
ServiceSpec,
get_service_port,
Expand All @@ -32,6 +30,8 @@
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel
from dstack._internal.server.services.instances import get_instance_remote_connection_info
from dstack._internal.server.services.jobs import get_job_spec
from dstack._internal.server.services.runs import get_run_spec
from dstack._internal.server.settings import DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE
from dstack._internal.utils.common import get_or_error

Expand Down Expand Up @@ -68,7 +68,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
if not len(jobs):
return None
run = jobs[0].run
run_spec = RunSpec.__response__.parse_raw(run.run_spec)
run_spec = get_run_spec(run)
if not isinstance(run_spec.configuration, ServiceConfiguration):
return None
replicas = []
Expand Down Expand Up @@ -108,7 +108,7 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
if rci is not None and rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
job_spec = get_job_spec(job)
replica = Replica(
id=job.id.hex,
app_port=get_service_port(job_spec, run_spec.configuration),
Expand Down
Loading
Loading