from fastapi import FastAPI, Query from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import torch from retriever import retrieve_documents # Load Mistral 7B model MODEL_NAME = "mistralai/Mistral-7B-v0.1" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", torch_dtype=torch.float16) # Create inference pipeline generator = pipeline("text-generation", model=model, tokenizer=tokenizer) # FastAPI server app = FastAPI() @app.get("/") def read_root(): return {"message": "Mistral 7B RAG API is running!"} @app.get("/generate/") def generate_response(query: str = Query(..., title="User Query")): # Retrieve relevant documents retrieved_docs = retrieve_documents(query) # Format prompt for RAG prompt = f"Use the following information to answer:\n{retrieved_docs}\n\nUser: {query}\nAI:" # Generate response output = generator(prompt, max_length=256, do_sample=True, temperature=0.7)[0]["generated_text"] return {"query": query, "response": output}