trial2 / app.py
alaa-ahmed14's picture
Update app.py
19cd085 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
# Set cache directory for Hugging Face Transformers
os.environ["TRANSFORMERS_CACHE"] = "/home/user/.cache"
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("matsant01/STEMerald-2b")
model = AutoModelForCausalLM.from_pretrained("matsant01/STEMerald-2b")
# Initialize FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve the HTML file
@app.get("/", response_class=HTMLResponse)
async def read_root():
with open("index.html", "r") as f:
return f.read()
@app.post("/generate/")
async def generate_text(prompt: str):
if not prompt:
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs["input_ids"], max_length=50)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": generated_text}