flan-t5-custom-handler / app.py.backup
MjolnirThor's picture
Rename app.py to app.py.backup
a5a7bfd verified
raw
history blame
1.65 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import gradio as gr
app = FastAPI()
# Initialize model and tokenizer
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
class Query(BaseModel):
inputs: str
@app.post("/")
async def generate(query: Query):
try:
# Tokenize input
inputs = tokenizer(query.inputs, return_tensors="pt", max_length=512, truncation=True)
# Generate response
outputs = model.generate(
inputs.input_ids,
max_length=512,
num_beams=4,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
early_stopping=True
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Gradio interface
def generate_text(prompt):
query = Query(inputs=prompt)
response = generate(query)
return response["generated_text"]
iface = gr.Interface(
fn=generate_text,
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
outputs="text",
title="Medical Assistant",
description="Ask me anything about medical topics!"
)
# Mount the Gradio app
app = gr.mount_gradio_app(app, iface, path="/")
if __name__ == "__main__":
import train # This will start the training process