Spaces:
Paused
Paused
import os | |
import typing | |
import json | |
from pydantic import root_validator | |
from langchain.llms import SagemakerEndpoint | |
from langchain.llms.sagemaker_endpoint import LLMContentHandler | |
from src.utils import FakeTokenizer | |
class ChatContentHandler(LLMContentHandler): | |
content_type = "application/json" | |
accepts = "application/json" | |
def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes: | |
messages0 = [] | |
openai_system_prompt = "You are a helpful assistant." | |
if openai_system_prompt: | |
messages0.append({"role": "system", "content": openai_system_prompt}) | |
messages0.append({'role': 'user', 'content': prompt}) | |
input_dict = {'inputs': [messages0], "parameters": model_kwargs} | |
return json.dumps(input_dict).encode("utf-8") | |
def transform_output(self, output: bytes) -> str: | |
response_json = json.loads(output.read().decode("utf-8")) | |
return response_json[0]["generation"]['content'] | |
class BaseContentHandler(LLMContentHandler): | |
content_type = "application/json" | |
accepts = "application/json" | |
def transform_input(self, prompt: str, model_kwargs: typing.Dict) -> bytes: | |
input_dict = {'inputs': prompt, "parameters": model_kwargs} | |
return json.dumps(input_dict).encode("utf-8") | |
def transform_output(self, output: bytes) -> str: | |
response_json = json.loads(output.read().decode("utf-8")) | |
return response_json[0]["generation"] | |
class H2OSagemakerEndpoint(SagemakerEndpoint): | |
aws_access_key_id: str = "" | |
aws_secret_access_key: str = "" | |
tokenizer: typing.Any = None | |
def validate_environment(cls, values: typing.Dict) -> typing.Dict: | |
"""Validate that AWS credentials to and python package exists in environment.""" | |
try: | |
import boto3 | |
try: | |
if values["credentials_profile_name"] is not None: | |
session = boto3.Session( | |
profile_name=values["credentials_profile_name"] | |
) | |
else: | |
# use default credentials | |
session = boto3.Session() | |
values["client"] = session.client( | |
"sagemaker-runtime", | |
region_name=values['region_name'], | |
aws_access_key_id=values['aws_access_key_id'], | |
aws_secret_access_key=values['aws_secret_access_key'], | |
) | |
except Exception as e: | |
raise ValueError( | |
"Could not load credentials to authenticate with AWS client. " | |
"Please check that credentials in the specified " | |
"profile name are valid." | |
) from e | |
except ImportError: | |
raise ImportError( | |
"Could not import boto3 python package. " | |
"Please install it with `pip install boto3`." | |
) | |
return values | |
def get_token_ids(self, text: str) -> typing.List[int]: | |
tokenizer = self.tokenizer | |
if tokenizer is not None: | |
return tokenizer.encode(text) | |
else: | |
return FakeTokenizer().encode(text)['input_ids'] | |