Spaces:
Paused
Paused
import os | |
import torch | |
import streamlit as st | |
from dotenv import load_dotenv | |
from peft import PeftModel, PeftConfig | |
from chromadb import HttpClient | |
from utils.embedding_utils import CustomEmbeddingFunction | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
st.title("FormulAI") | |
# Device and model configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_name = "unsloth/Llama-3.2-1B" | |
# Load pretrained model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load PEFT configuration and apply to model on device | |
adapter_name = "FormulAI/FormuLLaMa-3.2-1B-LoRA" | |
peft_config = PeftConfig.from_pretrained(adapter_name) | |
model = PeftModel(model, peft_config).to(device) | |
template = """Answer the following QUESTION based on the CONTEXT given. | |
If you do not know the answer and the CONTEXT doesn't contain the answer truthfully say "I don't know". | |
CONTEXT: | |
{context} | |
QUESTION: | |
{question} | |
ANSWER: | |
""" | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = [] | |
if 'past' not in st.session_state: | |
st.session_state['past'] = [] | |
def get_text(): | |
input_text = st.text_input("Ask a question regarding Formula 1: ", "", key="input") | |
return input_text | |
load_dotenv("chroma.env") | |
chroma_host = os.getenv("CHROMA_HOST", "localhost") | |
chroma_port = os.getenv("CHROMA_PORT", 8000) | |
chroma_collection = os.getenv("CHROMA_COLLECTION", "F1-wiki") | |
chroma_client = HttpClient(host=chroma_host, port=chroma_port) | |
collection = chroma_client.get_collection(name="F1-wiki", embedding_function=CustomEmbeddingFunction()) | |
question = get_text() | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
if question: | |
with st.spinner("Generating answer... "): | |
response = collection.query(query_texts=question, include=['documents'], n_results=5) | |
context = " ".join(response['documents'][0]) | |
input_text = template.replace("{context}", context).replace("{question}", question) | |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
attention_mask = (input_ids != tokenizer.pad_token_id).to(device) | |
output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=200, early_stopping=True) | |
answer = tokenizer.decode(output[0], skip_special_tokens=True).split("ANSWER:")[1].strip() | |
st.session_state.past.append(question) | |
st.session_state.generated.append(answer) | |
st.write(answer) | |