mtyrrell's picture
refactor
72cb6c4
from huggingface_hub import InferenceClient
from auditqa.process_chunks import getconfig
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.chat_models.huggingface import ChatHuggingFace
import os
from dotenv import load_dotenv
load_dotenv()
model_config = getconfig("model_params.cfg")
# NVIDIA_SERVER = os.environ["NVIDIA_SERVERLESS"] #TESTING
HF_token = os.environ["LLAMA_3_1"]
def nvidia_client():
""" returns the nvidia server client """
client = InferenceClient(
base_url=model_config.get('reader','NVIDIA_ENDPOINT'),
api_key=NVIDIA_SERVER)
print("getting nvidia client")
return client
# TESTING VERSION
def dedicated_endpoint():
try:
HF_token = os.environ["LLAMA_3_1"]
if not HF_token:
raise ValueError("LLAMA_3_1 environment variable is empty")
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
client = InferenceClient(
model=model_id,
api_key=HF_token,
)
return client
except Exception as e:
raise
# def dedicated_endpoint():
# """ returns the dedicated server endpoint"""
# # Set up the streaming callback handler
# callback = StreamingStdOutCallbackHandler()
# # Initialize the HuggingFaceEndpoint with streaming enabled
# llm_qa = HuggingFaceEndpoint(
# endpoint_url=model_config.get('reader', 'DEDICATED_ENDPOINT'),
# max_new_tokens=int(model_config.get('reader','MAX_TOKENS')),
# repetition_penalty=1.03,
# timeout=70,
# huggingfacehub_api_token=HF_token,
# streaming=True, # Enable streaming for real-time token generation
# callbacks=[callback] # Add the streaming callback handler
# )
# # Create a ChatHuggingFace instance with the streaming-enabled endpoint
# chat_model = ChatHuggingFace(llm=llm_qa)
# print("getting dedicated endpoint wrapped in ChathuggingFace ")
# return chat_model