I am trying to deploy a Llama 2 model for text generation inference using Sagemaker and LangChain. I am writing code in Notebook instances and deploying SageMaker instances from the code. I followed the documentation from https://python.langchain.com/docs/integrations/llms/sagemaker. I used the following code to create a chain for question answering:
from langchain.docstore.document import Document
example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""
docs = [
Document(
page_content=example_doc_1,
)
]
from typing import Dict
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains.question_answering import load_qa_chain
import json
query = """How long was Elizabeth hospitalized?
"""
prompt_template = """Use the following pieces of context to answer the question at the end.
{context}
Question: {question}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompt: prompt, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
content_handler = ContentHandler()
chain = load_qa_chain(
llm=SagemakerEndpoint(
endpoint_name="XYZ",
credentials_profile_name="XYZ",
region_name="XYZ",
model_kwargs={"temperature": 1e-10},
content_handler=content_handler,
),
prompt=PROMPT,
)
chain({"input_documents": docs, "question": query}, return_only_outputs=True)
But I got an error
ValueError: Error raised by inference endpoint:
An error occurred (ModelError) when calling the InvokeEndpoint operation:
Received client error (422) from primary with message
"Failed to deserialize the JSON body into the target type: missing field `inputs` at line 1 column 966".
In multiple tutorials there isn't any inputs field. I have no idea if they updated the documentation, which I have been referring to but can't resolve this problem.
My question is:
Looks like it is a known issues with langchain documentation, @sigvamo mentioned this error can be workaround by updating ContentHandler
to include inputs
in its transform_input
method
from typing import Dict, List
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
import json
class ContentHandler(EmbeddingsContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
input_str = json.dumps({"text_inputs": inputs, **model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> List[List[float]]:
response_json = json.loads(output.read().decode("utf-8"))
return response_json["embedding"]