from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os # Set cache directory for Hugging Face Transformers os.environ["TRANSFORMERS_CACHE"] = "/home/user/.cache" # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained("matsant01/STEMerald-2b") model = AutoModelForCausalLM.from_pretrained("matsant01/STEMerald-2b") # Initialize FastAPI app app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Serve the HTML file @app.get("/", response_class=HTMLResponse) async def read_root(): with open("index.html", "r") as f: return f.read() @app.post("/generate/") async def generate_text(prompt: str): if not prompt: raise HTTPException(status_code=400, detail="Prompt cannot be empty") inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate(inputs["input_ids"], max_length=50) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": generated_text}