mtyrrell's picture
refactor
b1ab347
raw
history blame
1.84 kB
from huggingface_hub import InferenceClient
from auditqa.process_chunks import getconfig
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.chat_models.huggingface import ChatHuggingFace
import os
from dotenv import load_dotenv
load_dotenv()
# TESTING DEBUG LOG
from auditqa.logging_config import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
model_config = getconfig("model_params.cfg")
# NVIDIA_SERVER = os.environ["NVIDIA_SERVERLESS"] #TESTING
HF_token = os.environ["LLAMA_3_1"]
def nvidia_client():
logger.info("NVIDIA client activated")
""" returns the nvidia server client """
try:
NVIDIA_SERVER = os.environ["NVIDIA_SERVERLESS"]
client = InferenceClient(
base_url=model_config.get('reader','NVIDIA_ENDPOINT'),
api_key=NVIDIA_SERVER)
print("getting nvidia client")
return client
except KeyError:
raise KeyError("NVIDIA_SERVERLESS environment variable not set. Required for NVIDIA endpoint.")
# TESTING VERSION
def dedicated_endpoint():
logger.info("Serverless endpoint activated")
try:
HF_token = os.environ["LLAMA_3_1"]
if not HF_token:
raise ValueError("LLAMA_3_1 environment variable is empty")
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
logger.info(f"Initializing InferenceClient with model: {model_id}")
client = InferenceClient(
model=model_id,
api_key=HF_token,
)
logger.info("Serverless InferenceClient initialization successful")
return client
except Exception as e:
logger.error(f"Error initializing dedicated endpoint: {str(e)}")
raise