Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/pr-vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -939,4 +939,5 @@ jobs:
- name: Run sagemaker endpoint test
run: |
source .venv/bin/activate
python test/vllm/sagemaker/test_sm_endpoint.py --image-uri ${{ needs.set-sagemaker-test-environment.outputs.image-uri }} --endpoint-name test-sm-vllm-endpoint-${{ github.sha }}
cd test/
python3 -m pytest -vs -rA --image-uri ${{ needs.set-sagemaker-test-environment.outputs.image-uri }} vllm/sagemaker
28 changes: 5 additions & 23 deletions test/sglang/sagemaker/test_sm_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from pprint import pformat

import pytest
from botocore.exceptions import ClientError
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from test_utils import clean_string, random_suffix_name, wait_for_status
from test_utils import clean_string, get_hf_token, random_suffix_name, wait_for_status
from test_utils.constants import INFERENCE_AMI_VERSION, SAGEMAKER_ROLE

# To enable debugging, change logging.INFO to logging.DEBUG
LOGGER = logging.getLogger(__name__)
Expand All @@ -38,23 +38,6 @@ def get_endpoint_status(sagemaker_client, endpoint_name):
return response["EndpointStatus"]


def get_hf_token(aws_session):
LOGGER.info("Retrieving HuggingFace token from AWS Secrets Manager...")
token_path = "test/hf_token"

try:
get_secret_value_response = aws_session.secretsmanager.get_secret_value(SecretId=token_path)
LOGGER.info("Successfully retrieved HuggingFace token")
except ClientError as e:
LOGGER.error(f"Failed to retrieve HuggingFace token: {e}")
raise e

# Do not print secrets token in logs
response = json.loads(get_secret_value_response["SecretString"])
token = response.get("HF_TOKEN")
return token


@pytest.fixture(scope="function")
def model_id(request):
# Return the model_id given by the test parameter
Expand All @@ -63,14 +46,13 @@ def model_id(request):

@pytest.fixture(scope="function")
def instance_type(request):
# Return the model_id given by the test parameter
# Return the instance_type given by the test parameter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch thanks!

return request.param


@pytest.fixture(scope="function")
def model_package(aws_session, image_uri, model_id):
sagemaker_client = aws_session.sagemaker
sagemaker_role = aws_session.iam_resource.Role("SageMakerRole").arn
cleaned_id = clean_string(model_id.split("/")[1], "_./")
model_name = random_suffix_name(f"sglang-{cleaned_id}-model-package", 50)

Expand All @@ -82,7 +64,7 @@ def model_package(aws_session, image_uri, model_id):
model = Model(
name=model_name,
image_uri=image_uri,
role=sagemaker_role,
role=SAGEMAKER_ROLE,
predictor_cls=Predictor,
env={
"SM_SGLANG_MODEL_PATH": model_id,
Expand Down Expand Up @@ -111,7 +93,7 @@ def model_endpoint(aws_session, model_package, instance_type):
instance_type=instance_type,
initial_instance_count=1,
endpoint_name=endpoint_name,
inference_ami_version="al2-ami-sagemaker-inference-gpu-3-1",
inference_ami_version=INFERENCE_AMI_VERSION,
serializer=JSONSerializer(),
wait=True,
)
Expand Down
22 changes: 22 additions & 0 deletions test/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
When necessary, use docstrings to explain the functions' mechanisms.
"""

import json
import logging
import random
import string
import time
from collections.abc import Callable
from typing import Any

from botocore.exceptions import ClientError

from .aws import AWSSessionManager

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)

Expand Down Expand Up @@ -58,3 +63,20 @@ def wait_for_status(

LOGGER.error(f"Wait for status: {expected_status} timed out. Actual status: {actual_status}")
return False


def get_hf_token(aws_session: AWSSessionManager) -> str:
LOGGER.info("Retrieving HuggingFace token from AWS Secrets Manager...")
token_path = "test/hf_token"

try:
get_secret_value_response = aws_session.secretsmanager.get_secret_value(SecretId=token_path)
LOGGER.info("Successfully retrieved HuggingFace token")
except ClientError as e:
LOGGER.error(f"Failed to retrieve HuggingFace token: {e}")
raise e

# Do not print secrets token in logs
response = json.loads(get_secret_value_response["SecretString"])
token = response.get("HF_TOKEN")
return token
2 changes: 2 additions & 0 deletions test/test_utils/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
DEFAULT_REGION = "us-west-2"
SAGEMAKER_ROLE = "SageMakerRole"
INFERENCE_AMI_VERSION = "al2-ami-sagemaker-inference-gpu-3-1"
Loading
Loading