amazon-web-servicesamazon-sagemakerlarge-language-model

Deploying LLM from S3 on Amazon SageMaker


I trained Llama 2 7B and was trying to deploy the model on SageMaker.


from sagemaker.huggingface import HuggingFaceModel

model_s3_path = 's3://bucket/model/model.tar.gz'


# sagemaker config
instance_type = "ml.g4dn.2xlarge"
number_of_gpu = 1
health_check_timeout = 300
image='763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04'

# Define Model and Endpoint configuration parameter
config = {
  'HF_MODEL_ID': "/opt/ml/model", # path to where sagemaker stores the model
  'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica
  'MAX_INPUT_LENGTH': json.dumps(1024), # Max length of input text
  'MAX_TOTAL_TOKENS': json.dumps(2048), # Max length of the generation (including input text)
}

# create HuggingFaceModel with the image uri
llm_model = HuggingFaceModel(
  image_uri=image, 
  role=sagemaker.get_execution_role(),  
  model_data=model_s3_path,
  entry_point="deploy.py",
  source_dir="src",
  env=config,
)

and to deploy I have

llm = llm_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  container_startup_health_check_timeout=health_check_timeout, # 10 minutes to give SageMaker the time to download the model
)

In my Sagemaker workspace I have src directory that contains the deploy.py where I load the model.

The problem is the control doesn't come till the deploy.py, when the llm_model.deploy cell executes I get the following error

Traceback (most recent call last):
  File "/usr/local/bin/dockerd-entrypoint.py", line 23, in <module>
    serving.main()
  File "/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/serving.py", line 34, in main
    _start_mms()
  File "/opt/conda/lib/python3.10/site-packages/retrying.py", line 56, in wrapped_f
    return Retrying(*dargs, **dkw).call(f, *args, **kw)
  File "/opt/conda/lib/python3.10/site-packages/retrying.py", line 257, in call
    return attempt.get(self._wrap_exception)
  File "/opt/conda/lib/python3.10/site-packages/retrying.py", line 301, in get
    six.reraise(self.value[0], self.value[1], self.value[2])
  File "/opt/conda/lib/python3.10/site-packages/six.py", line 719, in reraise
    raise value
  File "/opt/conda/lib/python3.10/site-packages/retrying.py", line 251, in call
    attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
  File "/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/serving.py", line 30, in _start_mms
    mms_model_server.start_model_server(handler_service=HANDLER_SERVICE)
  File "/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/mms_model_server.py", line 81, in start_model_server
    storage_dir = _load_model_from_hub(
  File "/opt/conda/lib/python3.10/site-packages/sagemaker_huggingface_inference_toolkit/transformers_utils.py", line 204, in _load_model_from_hub
    files = HfApi().model_info(model_id).siblings
  File "/opt/conda/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 110, in _inner_fn
    validate_repo_id(arg_value)
  File "/opt/conda/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 158, in validate_repo_id
    raise HFValidationError(huggingface_hub.utils._validators.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/opt/ml/model'. Use `repo_type` argument if needed.

The container is trying to connect to Huggingface hub, instead of loading the model from S3. How can I fix this?


Solution

  • sagemaker.huggingface.HuggingFaceModel can handle S3 path for the model_data argument, as explained in this sample.

    As you are using custom image with image_uri, it is likely that the image is not compatible with the SageMaker, and it is not trying to handle entry point script you specified.

    To isolate the problem, please try to change your code to use SageMaker's official image. Then investigate why your custom image is not loading the entry point script.

    See also: