import os from fastapi import FastAPI, Query from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import torch from retriever import retrieve_documents # Set writable cache location #os.environ["HF_HOME"] = "/tmp/huggingface" #os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" # Load Mistral 7B model #MODEL_NAME = "mistralai/Mistral-7B-v0.1" MODEL_NAME = "microsoft/Phi-4-mini-instruct" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"), cache_dir="/tmp/huggingface") model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"), cache_dir="/tmp/huggingface", device_map="auto", torch_dtype=torch.float16 ) # Create inference pipeline generator = pipeline("text-generation", model=model, tokenizer=tokenizer) # FastAPI server app = FastAPI() @app.get("/") def read_root(): return {"message": "Phi3 Mini RAG API is running!"} @app.get("/generate/") def generate_response(query: str = Query(..., title="User Query")): # Retrieve relevant documents retrieved_docs = retrieve_documents(query) # Format prompt for RAG prompt = f"Use the following information to answer:\n{retrieved_docs}\n\nUser: {query}\nAI:" # Generate response output = generator(prompt, max_length=256, do_sample=True, temperature=0.7)[0]["generated_text"] return {"query": query, "response": output}