Spaces:
Build error
Build error
from flask import Flask, request, render_template | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Load the fine-tuned model and tokenizer | |
model_dir = "./finetune_model" | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) | |
model.eval() # Set model to evaluation mode | |
# Generate headline from article | |
def generate_headline(article, max_length=128, num_beams=5): | |
# Tokenize the input article | |
inputs = tokenizer(article, max_length=256, truncation=True, return_tensors="pt", padding="max_length") | |
# Move inputs to the same device as the model | |
input_ids = inputs['input_ids'].to(model.device) | |
attention_mask = inputs['attention_mask'].to(model.device) | |
# Generate headline | |
outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=max_length, num_beams=num_beams, early_stopping=True) | |
headline = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return headline | |
# Home route to render the form and handle POST requests | |
def home(): | |
headline = None | |
if request.method == 'POST': | |
article = request.form.get('article') | |
if article: | |
headline = generate_headline(article) | |
return render_template('index.html', headline=headline) | |
# Run the app | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=5000, debug=True) | |