Spaces:
Running
Running
import os | |
from fastapi import FastAPI, Query | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import torch | |
from retriever import retrieve_documents | |
# Set writable cache location | |
#os.environ["HF_HOME"] = "/tmp/huggingface" | |
#os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" | |
# Load Mistral 7B model | |
#MODEL_NAME = "mistralai/Mistral-7B-v0.1" | |
MODEL_NAME = "microsoft/Phi-4-mini-instruct" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"), cache_dir="/tmp/huggingface") | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, | |
use_auth_token=os.getenv("HUGGING_FACE_HUB_TOKEN"), | |
cache_dir="/tmp/huggingface", | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
# Create inference pipeline | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
# FastAPI server | |
app = FastAPI() | |
def read_root(): | |
return {"message": "Phi3 Mini RAG API is running!"} | |
def generate_response(query: str = Query(..., title="User Query")): | |
# Retrieve relevant documents | |
retrieved_docs = retrieve_documents(query) | |
# Format prompt for RAG | |
prompt = f"Use the following information to answer:\n{retrieved_docs}\n\nUser: {query}\nAI:" | |
# Generate response | |
output = generator(prompt, max_length=256, do_sample=True, temperature=0.7)[0]["generated_text"] | |
return {"query": query, "response": output} | |