|
|
import gradio as gr |
|
|
import logging |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import torch |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
model_name = "google/flan-t5-base" |
|
|
logger.info(f"Loading {model_name} model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
logger.info("Model loaded successfully!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api = FastAPI() |
|
|
|
|
|
api.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
class SummarizeRequest(BaseModel): |
|
|
text: str |
|
|
max_length: int = 150 |
|
|
min_length: int = 30 |
|
|
|
|
|
@api.post("/summarize") |
|
|
def summarize_endpoint(request: SummarizeRequest): |
|
|
text = request.text.strip() |
|
|
if not text or len(text) < 50: |
|
|
return {"summary": text} |
|
|
|
|
|
logger.info(f"Summarizing via API. Length: {len(text)}") |
|
|
|
|
|
input_text = f"summarize: {text}" |
|
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024) |
|
|
|
|
|
|
|
|
max_tokens = min(request.max_length, 512) |
|
|
min_tokens = min(request.min_length, max_tokens - 1) |
|
|
|
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
min_length=min_tokens |
|
|
) |
|
|
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return {"summary": summary} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def summarize_text(text, max_length=150, min_length=30): |
|
|
return summarize_endpoint(SummarizeRequest(text=text, max_length=max_length, min_length=min_length))["summary"] |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=summarize_text, |
|
|
inputs=[ |
|
|
gr.Textbox(lines=10, label="Text to Summarize"), |
|
|
gr.Slider(50, 512, value=150, label="Max Length"), |
|
|
gr.Slider(10, 300, value=30, label="Min Length") |
|
|
], |
|
|
outputs=gr.Textbox(label="Summary"), |
|
|
title="StudAI Text Summarization", |
|
|
description="Powered by google/flan-t5-base model" |
|
|
) |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
app.mount("/", api) |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, root_path="/", app=app) |