ChatCSV / models /llm_setup.py
Chamin09's picture
Update models/llm_setup.py
89efbe0 verified
raw
history blame
2.07 kB
from typing import Optional
#from llama_index.llms import HuggingFaceLLM
from llama_index.llms.huggingface import HuggingFaceLLM
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
def setup_llm(model_name: str = "microsoft/phi-3-mini-4k-instruct",
device: str = None,
context_window: int = 4096,
max_new_tokens: int = 512) -> HuggingFaceLLM:
"""
Set up the language model for the CSV chatbot.
Args:
model_name: Name of the Hugging Face model to use
device: Device to run the model on ('cuda', 'cpu', etc.)
context_window: Maximum context window size
max_new_tokens: Maximum number of new tokens to generate
Returns:
Configured LLM instance
"""
# Determine device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Configure quantization for memory efficiency
if device == "cuda":
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
else:
quantization_config = None
# Configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
# Configure model with appropriate parameters for HF Spaces
model_kwargs = {
"trust_remote_code": True,
"torch_dtype": torch.float16,
}
if quantization_config:
model_kwargs["quantization_config"] = quantization_config
# Initialize LLM
llm = HuggingFaceLLM(
model_name=model_name,
tokenizer_name=model_name,
context_window=context_window,
max_new_tokens=max_new_tokens,
generate_kwargs={"temperature": 0.7, "top_p": 0.95},
device_map=device,
tokenizer_kwargs={"trust_remote_code": True},
model_kwargs=model_kwargs,
# Cache the model to avoid reloading
cache_folder="./model_cache"
)
return llm