maahi2412's picture
Update app.py
58013a0 verified
raw
history blame
13.1 kB
import os
import pdfplumber
from PIL import Image
import pytesseract
import numpy as np
from flask import Flask, request, jsonify
from flask_cors import CORS
import transformers
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset, concatenate_datasets
import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
app = Flask(__name__)
CORS(app)
UPLOAD_FOLDER = os.path.join(os.getcwd(), 'uploads')
PEGASUS_MODEL_DIR = 'fine_tuned_pegasus'
BERT_MODEL_DIR = 'fine_tuned_bert'
LEGALBERT_MODEL_DIR = 'fine_tuned_legalbert'
MAX_FILE_SIZE = 100 * 1024 * 1024
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
transformers.logging.set_verbosity_error()
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
# Pegasus Fine-Tuning
def load_or_finetune_pegasus():
if os.path.exists(PEGASUS_MODEL_DIR):
print("Loading fine-tuned Pegasus model...")
tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL_DIR)
model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL_DIR)
else:
print("Fine-tuning Pegasus on CNN/Daily Mail and XSUM...")
tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
# Load and normalize datasets
cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]").rename_column("article", "text").rename_column("highlights", "summary")
xsum = load_dataset("xsum", split="train[:5000]", trust_remote_code=True).rename_column("document", "text")
combined_dataset = concatenate_datasets([cnn_dm, xsum])
def preprocess_function(examples):
# Directly use normalized 'text' and 'summary' fields
inputs = tokenizer(examples["text"], max_length=512, truncation=True, padding="max_length", return_tensors="pt")
targets = tokenizer(examples["summary"], max_length=400, truncation=True, padding="max_length", return_tensors="pt")
inputs["labels"] = targets["input_ids"]
return inputs
tokenized_dataset = combined_dataset.map(preprocess_function, batched=True)
train_dataset = tokenized_dataset.select(range(8000))
eval_dataset = tokenized_dataset.select(range(8000, 10000))
training_args = TrainingArguments(
output_dir="./pegasus_finetune",
num_train_epochs=3,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(PEGASUS_MODEL_DIR)
tokenizer.save_pretrained(PEGASUS_MODEL_DIR)
print(f"Fine-tuned Pegasus saved to {PEGASUS_MODEL_DIR}")
return tokenizer, model
# BERT Fine-Tuning
def load_or_finetune_bert():
if os.path.exists(BERT_MODEL_DIR):
print("Loading fine-tuned BERT model...")
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_DIR)
model = BertForSequenceClassification.from_pretrained(BERT_MODEL_DIR, num_labels=2)
else:
print("Fine-tuning BERT on CNN/Daily Mail for extractive summarization...")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]")
def preprocess_for_extractive(examples):
sentences = []
labels = []
for article, highlights in zip(examples["article"], examples["highlights"]):
article_sents = article.split(". ")
highlight_sents = highlights.split(". ")
for sent in article_sents:
if sent.strip():
is_summary = any(sent.strip() in h for h in highlight_sents)
sentences.append(sent)
labels.append(1 if is_summary else 0)
return {"sentence": sentences, "label": labels}
dataset = cnn_dm.map(preprocess_for_extractive, batched=True, remove_columns=["article", "highlights", "id"])
tokenized_dataset = dataset.map(
lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
batched=True
)
tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
training_args = TrainingArguments(
output_dir="./bert_finetune",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(BERT_MODEL_DIR)
tokenizer.save_pretrained(BERT_MODEL_DIR)
print(f"Fine-tuned BERT saved to {BERT_MODEL_DIR}")
return tokenizer, model
# LegalBERT Fine-Tuning
def load_or_finetune_legalbert():
if os.path.exists(LEGALBERT_MODEL_DIR):
print("Loading fine-tuned LegalBERT model...")
tokenizer = BertTokenizer.from_pretrained(LEGALBERT_MODEL_DIR)
model = BertForSequenceClassification.from_pretrained(LEGALBERT_MODEL_DIR, num_labels=2)
else:
print("Fine-tuning LegalBERT on Billsum for extractive summarization...")
tokenizer = BertTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=2)
billsum = load_dataset("billsum", split="train[:5000]")
def preprocess_for_extractive(examples):
sentences = []
labels = []
for text, summary in zip(examples["text"], examples["summary"]):
text_sents = text.split(". ")
summary_sents = summary.split(". ")
for sent in text_sents:
if sent.strip():
is_summary = any(sent.strip() in s for s in summary_sents)
sentences.append(sent)
labels.append(1 if is_summary else 0)
return {"sentence": sentences, "label": labels}
dataset = billsum.map(preprocess_for_extractive, batched=True, remove_columns=["text", "summary", "title"])
tokenized_dataset = dataset.map(
lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
batched=True
)
tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
training_args = TrainingArguments(
output_dir="./legalbert_finetune",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=500,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(LEGALBERT_MODEL_DIR)
tokenizer.save_pretrained(LEGALBERT_MODEL_DIR)
print(f"Fine-tuned LegalBERT saved to {LEGALBERT_MODEL_DIR}")
return tokenizer, model
# Load models
pegasus_tokenizer, pegasus_model = load_or_finetune_pegasus()
bert_tokenizer, bert_model = load_or_finetune_bert()
legalbert_tokenizer, legalbert_model = load_or_finetune_legalbert()
def extract_text_from_pdf(file_path):
text = ""
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
text += page.extract_text() or ""
return text
def extract_text_from_image(file_path):
image = Image.open(file_path)
text = pytesseract.image_to_string(image)
return text
def choose_model(text):
legal_keywords = ["court", "legal", "law", "judgment", "contract", "statute", "case"]
tfidf = TfidfVectorizer(vocabulary=legal_keywords)
tfidf_matrix = tfidf.fit_transform([text.lower()])
score = np.sum(tfidf_matrix.toarray())
if score > 0.1:
return "legalbert"
elif len(text.split()) > 50:
return "pegasus"
else:
return "bert"
def summarize_with_pegasus(text):
inputs = pegasus_tokenizer(text, truncation=True, padding="longest", return_tensors="pt", max_length=512)
summary_ids = pegasus_model.generate(
inputs["input_ids"],
max_length=400, min_length=80, length_penalty=1.5, num_beams=4
)
return pegasus_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def summarize_with_bert(text):
sentences = text.split(". ")
if len(sentences) < 6:
return text
inputs = bert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = bert_model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)[:, 1]
key_sentence_idx = probs.argsort(descending=True)[:5]
return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
def summarize_with_legalbert(text):
sentences = text.split(". ")
if len(sentences) < 6:
return text
inputs = legalbert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = legalbert_model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)[:, 1]
key_sentence_idx = probs.argsort(descending=True)[:5]
return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
@app.route('/summarize', methods=['POST'])
def summarize_document():
if 'file' not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files['file']
filename = file.filename
file.seek(0, os.SEEK_END)
file_size = file.tell()
if file_size > MAX_FILE_SIZE:
return jsonify({"error": f"File size exceeds {MAX_FILE_SIZE // (1024 * 1024)} MB"}), 413
file.seek(0)
file_path = os.path.join(UPLOAD_FOLDER, filename)
try:
file.save(file_path)
except Exception as e:
return jsonify({"error": f"Failed to save file: {str(e)}"}), 500
try:
if filename.endswith('.pdf'):
text = extract_text_from_pdf(file_path)
elif filename.endswith(('.png', '.jpeg', '.jpg')):
text = extract_text_from_image(file_path)
else:
os.remove(file_path)
return jsonify({"error": "Unsupported file format."}), 400
except Exception as e:
os.remove(file_path)
return jsonify({"error": f"Text extraction failed: {str(e)}"}), 500
if not text.strip():
os.remove(file_path)
return jsonify({"error": "No text extracted"}), 400
try:
model = choose_model(text)
if model == "pegasus":
summary = summarize_with_pegasus(text)
elif model == "bert":
summary = summarize_with_bert(text)
elif model == "legalbert":
summary = summarize_with_legalbert(text)
except Exception as e:
os.remove(file_path)
return jsonify({"error": f"Summarization failed: {str(e)}"}), 500
os.remove(file_path)
return jsonify({"model_used": model, "summary": summary})
if __name__ == '__main__':
port = int(os.environ.get("PORT", 5000))
app.run(debug=False, host='0.0.0.0', port=port)