Rudrameher45 commited on
Commit
f520122
·
verified ·
1 Parent(s): 6fa95ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -43
app.py CHANGED
@@ -1,47 +1,28 @@
1
- from flask import Flask, request, render_template
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import torch
4
 
5
  # Initialize Flask app
6
  app = Flask(__name__)
7
 
8
- # Load the fine-tuned model and tokenizer
9
- model_dir = "./finetune_model"
10
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
-
12
-
13
-
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
15
- model.eval() # Set model to evaluation mode
16
-
17
- # Generate headline from article
18
- def generate_headline(article, max_length=128, num_beams=5):
19
- # Tokenize the input article
20
- inputs = tokenizer(article, max_length=256, truncation=True, return_tensors="pt", padding="max_length")
21
-
22
- # Move inputs to the same device as the model
23
- input_ids = inputs['input_ids'].to(model.device)
24
- attention_mask = inputs['attention_mask'].to(model.device)
25
-
26
- # Generate headline
27
- outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=max_length, num_beams=num_beams, early_stopping=True)
28
- headline = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
-
30
- return headline
31
-
32
- # Home route to render the form and handle POST requests
33
- @app.route('/', methods=['GET', 'POST'])
34
- def home():
35
- headline = None
36
- if request.method == 'POST':
37
- article = request.form.get('article')
38
- if article:
39
- headline = generate_headline(article)
40
- return render_template('index.html', headline=headline)
41
-
42
- # Run the app
43
- if __name__ == '__main__':
44
- app.run(app.run(port=5001, debug=False))
45
-
46
-
47
-
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from transformers import pipeline
 
3
 
4
  # Initialize Flask app
5
  app = Flask(__name__)
6
 
7
+ # Load the fine-tuned model
8
+ model_path = "./finetune_model"
9
+ headline_generator = pipeline("text2text-generation", model=model_path)
10
+
11
+ @app.route('/')
12
+ def index():
13
+ return render_template('index.html')
14
+
15
+ @app.route('/predict', methods=['POST'])
16
+ def predict():
17
+ article = request.form.get('article')
18
+ if not article:
19
+ return jsonify({"error": "No article provided"}), 400
20
+ try:
21
+ prediction = headline_generator(article, max_length=50, num_return_sequences=1)
22
+ headline = prediction[0]['generated_text']
23
+ return jsonify({"headline": headline})
24
+ except Exception as e:
25
+ return jsonify({"error": str(e)}), 500
26
+
27
+ if __name__ == "__main__":
28
+ app.run(debug=True)