From b4c7ba3cef80cfdc4eb6cadf1b309f17b653ce10 Mon Sep 17 00:00:00 2001 From: Pravali Uppugunduri Date: Wed, 18 Mar 2026 19:56:02 +0000 Subject: [PATCH] Fix Triton HMAC security vulnerabilities (v2) - Bug 1: Add HMAC integrity check before pickle deserialization in Triton handler initialize() method (model.py) - Bug 2: Replace hardcoded secret key with generate_secret_key() and add _hmac_signing() after ONNX exports (triton_builder.py) - Bug 3: Add secret key validation in _start_triton_server() to reject empty/None keys before passing to container (server.py) Aligns Triton code path with existing HMAC verification patterns used by TorchServe, MMS, TF Serving, and SMD handlers. Ticket: P400136088 --- src/sagemaker/serve/model_server/triton/model.py | 8 ++++++-- src/sagemaker/serve/model_server/triton/server.py | 14 ++++++++++++-- .../serve/model_server/triton/triton_builder.py | 8 +++----- .../model_server/triton/test_triton_builder.py | 1 + 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/serve/model_server/triton/model.py b/src/sagemaker/serve/model_server/triton/model.py index a1c731b0d6..35bb84d97b 100644 --- a/src/sagemaker/serve/model_server/triton/model.py +++ b/src/sagemaker/serve/model_server/triton/model.py @@ -26,10 +26,14 @@ def auto_complete_config(auto_complete_model_config): def initialize(self, args: dict) -> None: """Placeholder docstring""" serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl") + metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json") + with open(str(serve_path), mode="rb") as f: - inference_spec, schema_builder = cloudpickle.load(f) + buffer = f.read() + perform_integrity_check(buffer=buffer, metadata_path=str(metadata_path)) - # TODO: HMAC signing for integrity check + with open(str(serve_path), mode="rb") as f: + inference_spec, schema_builder = cloudpickle.load(f) self.inference_spec = inference_spec self.schema_builder = schema_builder diff --git a/src/sagemaker/serve/model_server/triton/server.py b/src/sagemaker/serve/model_server/triton/server.py index e2f3c20d7a..c1bebbbddf 100644 --- a/src/sagemaker/serve/model_server/triton/server.py +++ b/src/sagemaker/serve/model_server/triton/server.py @@ -43,11 +43,16 @@ def _start_triton_server( env_vars.update( { "TRITON_MODEL_DIR": "/models/model", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } ) + # Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where + # pickle integrity verification is needed. The ONNX path does not + # use pickles, so no secret key is required. + if secret_key and isinstance(secret_key, str) and secret_key.strip(): + env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key + if "cpu" not in image_uri: self.container = docker_client.containers.run( image=image_uri, @@ -146,7 +151,12 @@ def _upload_triton_artifacts( env_vars = { "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model", "TRITON_MODEL_DIR": "/opt/ml/model/model", - "SAGEMAKER_SERVE_SECRET_KEY": secret_key, "LOCAL_PYTHON": platform.python_version(), } + + # Only set SAGEMAKER_SERVE_SECRET_KEY for inference_spec path where + # pickle integrity verification is needed. + if secret_key and isinstance(secret_key, str) and secret_key.strip(): + env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key + return s3_upload_path, env_vars diff --git a/src/sagemaker/serve/model_server/triton/triton_builder.py b/src/sagemaker/serve/model_server/triton/triton_builder.py index c47991fa09..3dda0ac28a 100644 --- a/src/sagemaker/serve/model_server/triton/triton_builder.py +++ b/src/sagemaker/serve/model_server/triton/triton_builder.py @@ -213,8 +213,8 @@ def _prepare_for_triton(self): export_path.mkdir(parents=True) if self.model: - self.secret_key = "dummy secret key for onnx backend" - + # ONNX path: export model to ONNX format for Triton's native ONNX backend. + # No pickle is created or loaded at runtime, so no HMAC signing is needed. if self._framework == "pytorch": self._export_pytorch_to_onnx( export_path=export_path, model=self.model, schema_builder=self.schema_builder @@ -457,13 +457,11 @@ def _get_triton_predictor(self, endpoint_name: str, sagemaker_session: Session) ) def _save_inference_spec(self) -> None: - """Placeholder docstring""" + """Save inference specification to pickle file.""" if self.inference_spec: pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model") save_pkl(pkl_path, (self.inference_spec, self.schema_builder)) - return - def _build_for_triton(self): """Placeholder docstring""" self._validate_for_triton() diff --git a/tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py b/tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py index ae5b9001c7..b2d490b45c 100644 --- a/tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py +++ b/tests/unit/sagemaker/serve/model_server/triton/test_triton_builder.py @@ -67,6 +67,7 @@ def prepare_triton_builder_for_model(self, triton_builder: Triton) -> Triton: mock_export = Mock() triton_builder._export_pytorch_to_onnx = mock_export triton_builder._export_tf_to_onnx = mock_export + triton_builder._hmac_signing = Mock() return triton_builder