Studai / app.py
WolfeLeo2's picture
change to fastAPI
af53a88
raw
history blame
2.52 kB
import gradio as gr
import logging
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from fastapi.middleware.cors import CORSMiddleware
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Load FLAN-T5 model
model_name = "google/flan-t5-base"
logger.info(f"Loading {model_name} model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
logger.info("Model loaded successfully!")
# -----------------------------
# REST API SECTION
# -----------------------------
api = FastAPI()
api.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change to your domain later
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class SummarizeRequest(BaseModel):
text: str
max_length: int = 150
min_length: int = 30
@api.post("/summarize")
def summarize_endpoint(request: SummarizeRequest):
text = request.text.strip()
if not text or len(text) < 50:
return {"summary": text}
logger.info(f"Summarizing via API. Length: {len(text)}")
input_text = f"summarize: {text}"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
# Safe dynamic length handling
max_tokens = min(request.max_length, 512)
min_tokens = min(request.min_length, max_tokens - 1)
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
min_length=min_tokens
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"summary": summary}
# -----------------------------
# GRADIO UI SECTION
# -----------------------------
def summarize_text(text, max_length=150, min_length=30):
return summarize_endpoint(SummarizeRequest(text=text, max_length=max_length, min_length=min_length))["summary"]
demo = gr.Interface(
fn=summarize_text,
inputs=[
gr.Textbox(lines=10, label="Text to Summarize"),
gr.Slider(50, 512, value=150, label="Max Length"),
gr.Slider(10, 300, value=30, label="Min Length")
],
outputs=gr.Textbox(label="Summary"),
title="StudAI Text Summarization",
description="Powered by google/flan-t5-base model"
)
# Mount Gradio + API
app = FastAPI()
app.mount("/", api)
demo.launch(server_name="0.0.0.0", server_port=7860, root_path="/", app=app)