Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/sagemaker/serve/model_server/triton/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ def auto_complete_config(auto_complete_model_config):
def initialize(self, args: dict) -> None:
"""Placeholder docstring"""
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")

with open(str(serve_path), mode="rb") as f:
inference_spec, schema_builder = cloudpickle.load(f)
buffer = f.read()
perform_integrity_check(buffer=buffer, metadata_path=str(metadata_path))

# TODO: HMAC signing for integrity check
with open(str(serve_path), mode="rb") as f:
inference_spec, schema_builder = cloudpickle.load(f)

self.inference_spec = inference_spec
self.schema_builder = schema_builder
Expand Down
14 changes: 12 additions & 2 deletions src/sagemaker/serve/model_server/triton/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,16 @@ def _start_triton_server(
env_vars.update(
{
"TRITON_MODEL_DIR": "/models/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}
)

# Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where
# pickle integrity verification is needed. The ONNX path does not
# use pickles, so no secret key is required.
if secret_key and isinstance(secret_key, str) and secret_key.strip():
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key

if "cpu" not in image_uri:
self.container = docker_client.containers.run(
image=image_uri,
Expand Down Expand Up @@ -146,7 +151,12 @@ def _upload_triton_artifacts(
env_vars = {
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
"TRITON_MODEL_DIR": "/opt/ml/model/model",
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
"LOCAL_PYTHON": platform.python_version(),
}

# Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where
# pickle integrity verification is needed.
if secret_key and isinstance(secret_key, str) and secret_key.strip():
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key

return s3_upload_path, env_vars
8 changes: 3 additions & 5 deletions src/sagemaker/serve/model_server/triton/triton_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def _prepare_for_triton(self):
export_path.mkdir(parents=True)

if self.model:
self.secret_key = "dummy secret key for onnx backend"

# ONNX path: export model to ONNX format for Triton's native ONNX backend.
# No pickle is created or loaded at runtime, so no HMAC signing is needed.
if self._framework == "pytorch":
self._export_pytorch_to_onnx(
export_path=export_path, model=self.model, schema_builder=self.schema_builder
Expand Down Expand Up @@ -457,13 +457,11 @@ def _get_triton_predictor(self, endpoint_name: str, sagemaker_session: Session)
)

def _save_inference_spec(self) -> None:
"""Placeholder docstring"""
"""Save inference specification to pickle file."""
if self.inference_spec:
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
save_pkl(pkl_path, (self.inference_spec, self.schema_builder))

return

def _build_for_triton(self):
"""Placeholder docstring"""
self._validate_for_triton()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def prepare_triton_builder_for_model(self, triton_builder: Triton) -> Triton:
mock_export = Mock()
triton_builder._export_pytorch_to_onnx = mock_export
triton_builder._export_tf_to_onnx = mock_export
triton_builder._hmac_signing = Mock()

return triton_builder

Expand Down
Loading