import gradio as gr import logging from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch from fastapi.middleware.cors import CORSMiddleware # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Load FLAN-T5 model model_name = "google/flan-t5-base" logger.info(f"Loading {model_name} model...") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) logger.info("Model loaded successfully!") # ----------------------------- # REST API SECTION # ----------------------------- api = FastAPI() api.add_middleware( CORSMiddleware, allow_origins=["*"], # Change to your domain later allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class SummarizeRequest(BaseModel): text: str max_length: int = 150 min_length: int = 30 @api.post("/summarize") def summarize_endpoint(request: SummarizeRequest): text = request.text.strip() if not text or len(text) < 50: return {"summary": text} logger.info(f"Summarizing via API. Length: {len(text)}") input_text = f"summarize: {text}" inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024) # Safe dynamic length handling max_tokens = min(request.max_length, 512) min_tokens = min(request.min_length, max_tokens - 1) outputs = model.generate( **inputs, max_new_tokens=max_tokens, min_length=min_tokens ) summary = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"summary": summary} # ----------------------------- # GRADIO UI SECTION # ----------------------------- def summarize_text(text, max_length=150, min_length=30): return summarize_endpoint(SummarizeRequest(text=text, max_length=max_length, min_length=min_length))["summary"] demo = gr.Interface( fn=summarize_text, inputs=[ gr.Textbox(lines=10, label="Text to Summarize"), gr.Slider(50, 512, value=150, label="Max Length"), gr.Slider(10, 300, value=30, label="Min Length") ], outputs=gr.Textbox(label="Summary"), title="StudAI Text Summarization", description="Powered by google/flan-t5-base model" ) # Mount Gradio + API app = FastAPI() app.mount("/", api) demo.launch(server_name="0.0.0.0", server_port=7860, root_path="/", app=app)