import os import torch import streamlit as st from dotenv import load_dotenv from peft import PeftModel, PeftConfig from chromadb import HttpClient from utils.embedding_utils import CustomEmbeddingFunction from transformers import AutoModelForCausalLM, AutoTokenizer st.title("FormulAI") # Device and model configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = "unsloth/Llama-3.2-1B" # Load pretrained model and tokenizer model = AutoModelForCausalLM.from_pretrained(model_name).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) # Load PEFT configuration and apply to model on device adapter_name = "FormulAI/FormuLLaMa-3.2-1B-LoRA" peft_config = PeftConfig.from_pretrained(adapter_name) model = PeftModel(model, peft_config).to(device) template = """Answer the following QUESTION based on the CONTEXT given. If you do not know the answer and the CONTEXT doesn't contain the answer truthfully say "I don't know". CONTEXT: {context} QUESTION: {question} ANSWER: """ if 'generated' not in st.session_state: st.session_state['generated'] = [] if 'past' not in st.session_state: st.session_state['past'] = [] def get_text(): input_text = st.text_input("Ask a question regarding Formula 1: ", "", key="input") return input_text load_dotenv("chroma.env") chroma_host = os.getenv("CHROMA_HOST", "localhost") chroma_port = os.getenv("CHROMA_PORT", 8000) chroma_collection = os.getenv("CHROMA_COLLECTION", "F1-wiki") chroma_client = HttpClient(host=chroma_host, port=chroma_port) collection = chroma_client.get_collection(name="F1-wiki", embedding_function=CustomEmbeddingFunction()) question = get_text() if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id if question: with st.spinner("Generating answer... "): response = collection.query(query_texts=question, include=['documents'], n_results=5) context = " ".join(response['documents'][0]) input_text = template.replace("{context}", context).replace("{question}", question) input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) attention_mask = (input_ids != tokenizer.pad_token_id).to(device) output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=200, early_stopping=True) answer = tokenizer.decode(output[0], skip_special_tokens=True).split("ANSWER:")[1].strip() st.session_state.past.append(question) st.session_state.generated.append(answer) st.write(answer)