maahi2412 commited on
Commit
ed7a266
·
verified ·
1 Parent(s): 09a3a4e

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +323 -0
  2. dockerfile +43 -0
  3. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdfplumber
3
+ from PIL import Image
4
+ import pytesseract
5
+ import numpy as np
6
+ from flask import Flask, request, jsonify
7
+ from flask_cors import CORS
8
+ import transformers # Full import for logging
9
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer, BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
10
+ from datasets import load_dataset, concatenate_datasets
11
+ import torch
12
+ from sklearn.feature_extraction.text import TfidfVectorizer
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+
15
+ app = Flask(__name__)
16
+ CORS(app) # Enable CORS for frontend compatibility
17
+ UPLOAD_FOLDER = os.path.join(os.getcwd(), 'uploads')
18
+ PEGASUS_MODEL_DIR = 'fine_tuned_pegasus'
19
+ BERT_MODEL_DIR = 'fine_tuned_bert'
20
+ LEGALBERT_MODEL_DIR = 'fine_tuned_legalbert'
21
+ MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB limit
22
+
23
+ # Ensure upload folder exists
24
+ if not os.path.exists(UPLOAD_FOLDER):
25
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
26
+
27
+ transformers.logging.set_verbosity_error() # Suppress transformers warnings
28
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
29
+
30
+ # Pegasus Fine-Tuning
31
+ def load_or_finetune_pegasus():
32
+ if os.path.exists(PEGASUS_MODEL_DIR):
33
+ print("Loading fine-tuned Pegasus model...")
34
+ tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL_DIR)
35
+ model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL_DIR)
36
+ else:
37
+ print("Fine-tuning Pegasus on CNN/Daily Mail and XSUM...")
38
+ tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")
39
+ model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
40
+
41
+ cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]")
42
+ xsum = load_dataset("xsum", split="train[:5000]")
43
+ combined_dataset = concatenate_datasets([cnn_dm, xsum])
44
+
45
+ def preprocess_function(examples):
46
+ inputs = tokenizer(examples["article"] if "article" in examples else examples["document"],
47
+ max_length=512, truncation=True, padding="max_length")
48
+ targets = tokenizer(examples["highlights"] if "highlights" in examples else examples["summary"],
49
+ max_length=400, truncation=True, padding="max_length")
50
+ inputs["labels"] = targets["input_ids"]
51
+ return inputs
52
+
53
+ tokenized_dataset = combined_dataset.map(preprocess_function, batched=True)
54
+ train_dataset = tokenized_dataset.select(range(8000))
55
+ eval_dataset = tokenized_dataset.select(range(8000, 10000))
56
+
57
+ training_args = TrainingArguments(
58
+ output_dir="./pegasus_finetune",
59
+ num_train_epochs=3,
60
+ per_device_train_batch_size=1,
61
+ per_device_eval_batch_size=1,
62
+ warmup_steps=500,
63
+ weight_decay=0.01,
64
+ logging_dir="./logs",
65
+ logging_steps=10,
66
+ eval_strategy="epoch",
67
+ save_strategy="epoch",
68
+ load_best_model_at_end=True,
69
+ )
70
+
71
+ trainer = Trainer(
72
+ model=model,
73
+ args=training_args,
74
+ train_dataset=train_dataset,
75
+ eval_dataset=eval_dataset,
76
+ )
77
+
78
+ trainer.train()
79
+ trainer.save_model(PEGASUS_MODEL_DIR)
80
+ tokenizer.save_pretrained(PEGASUS_MODEL_DIR)
81
+ print(f"Fine-tuned Pegasus saved to {PEGASUS_MODEL_DIR}")
82
+
83
+ return tokenizer, model
84
+
85
+ # BERT Fine-Tuning
86
+ def load_or_finetune_bert():
87
+ if os.path.exists(BERT_MODEL_DIR):
88
+ print("Loading fine-tuned BERT model...")
89
+ tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_DIR)
90
+ model = BertForSequenceClassification.from_pretrained(BERT_MODEL_DIR, num_labels=2)
91
+ else:
92
+ print("Fine-tuning BERT on CNN/Daily Mail for extractive summarization...")
93
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
94
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
95
+
96
+ cnn_dm = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]")
97
+
98
+ def preprocess_for_extractive(examples):
99
+ sentences = []
100
+ labels = []
101
+ for article, highlights in zip(examples["article"], examples["highlights"]):
102
+ article_sents = article.split(". ")
103
+ highlight_sents = highlights.split(". ")
104
+ for sent in article_sents:
105
+ if sent.strip():
106
+ is_summary = any(sent.strip() in h for h in highlight_sents)
107
+ sentences.append(sent)
108
+ labels.append(1 if is_summary else 0)
109
+ return {"sentence": sentences, "label": labels}
110
+
111
+ dataset = cnn_dm.map(preprocess_for_extractive, batched=True, remove_columns=["article", "highlights", "id"])
112
+ tokenized_dataset = dataset.map(
113
+ lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
114
+ batched=True
115
+ )
116
+ tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
117
+ train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
118
+ eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
119
+
120
+ training_args = TrainingArguments(
121
+ output_dir="./bert_finetune",
122
+ num_train_epochs=3,
123
+ per_device_train_batch_size=8,
124
+ per_device_eval_batch_size=8,
125
+ warmup_steps=500,
126
+ weight_decay=0.01,
127
+ logging_dir="./logs",
128
+ logging_steps=10,
129
+ eval_strategy="epoch",
130
+ save_strategy="epoch",
131
+ load_best_model_at_end=True,
132
+ )
133
+
134
+ trainer = Trainer(
135
+ model=model,
136
+ args=training_args,
137
+ train_dataset=train_dataset,
138
+ eval_dataset=eval_dataset,
139
+ )
140
+
141
+ trainer.train()
142
+ trainer.save_model(BERT_MODEL_DIR)
143
+ tokenizer.save_pretrained(BERT_MODEL_DIR)
144
+ print(f"Fine-tuned BERT saved to {BERT_MODEL_DIR}")
145
+
146
+ return tokenizer, model
147
+
148
+ # LegalBERT Fine-Tuning
149
+ def load_or_finetune_legalbert():
150
+ if os.path.exists(LEGALBERT_MODEL_DIR):
151
+ print("Loading fine-tuned LegalBERT model...")
152
+ tokenizer = BertTokenizer.from_pretrained(LEGALBERT_MODEL_DIR)
153
+ model = BertForSequenceClassification.from_pretrained(LEGALBERT_MODEL_DIR, num_labels=2)
154
+ else:
155
+ print("Fine-tuning LegalBERT on Billsum for extractive summarization...")
156
+ tokenizer = BertTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
157
+ model = BertForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=2)
158
+
159
+ billsum = load_dataset("billsum", split="train[:5000]")
160
+
161
+ def preprocess_for_extractive(examples):
162
+ sentences = []
163
+ labels = []
164
+ for text, summary in zip(examples["text"], examples["summary"]):
165
+ text_sents = text.split(". ")
166
+ summary_sents = summary.split(". ")
167
+ for sent in text_sents:
168
+ if sent.strip():
169
+ is_summary = any(sent.strip() in s for s in summary_sents)
170
+ sentences.append(sent)
171
+ labels.append(1 if is_summary else 0)
172
+ return {"sentence": sentences, "label": labels}
173
+
174
+ dataset = billsum.map(preprocess_for_extractive, batched=True, remove_columns=["text", "summary", "title"])
175
+ tokenized_dataset = dataset.map(
176
+ lambda x: tokenizer(x["sentence"], max_length=512, truncation=True, padding="max_length"),
177
+ batched=True
178
+ )
179
+ tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
180
+ train_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset))))
181
+ eval_dataset = tokenized_dataset.select(range(int(0.8 * len(tokenized_dataset)), len(tokenized_dataset)))
182
+
183
+ training_args = TrainingArguments(
184
+ output_dir="./legalbert_finetune",
185
+ num_train_epochs=3,
186
+ per_device_train_batch_size=8,
187
+ per_device_eval_batch_size=8,
188
+ warmup_steps=500,
189
+ weight_decay=0.01,
190
+ logging_dir="./logs",
191
+ logging_steps=10,
192
+ eval_strategy="epoch",
193
+ save_strategy="epoch",
194
+ load_best_model_at_end=True,
195
+ )
196
+
197
+ trainer = Trainer(
198
+ model=model,
199
+ args=training_args,
200
+ train_dataset=train_dataset,
201
+ eval_dataset=eval_dataset,
202
+ )
203
+
204
+ trainer.train()
205
+ trainer.save_model(LEGALBERT_MODEL_DIR)
206
+ tokenizer.save_pretrained(LEGALBERT_MODEL_DIR)
207
+ print(f"Fine-tuned LegalBERT saved to {LEGALBERT_MODEL_DIR}")
208
+
209
+ return tokenizer, model
210
+
211
+ # Load models
212
+ pegasus_tokenizer, pegasus_model = load_or_finetune_pegasus()
213
+ bert_tokenizer, bert_model = load_or_finetune_bert()
214
+ legalbert_tokenizer, legalbert_model = load_or_finetune_legalbert()
215
+
216
+ def extract_text_from_pdf(file_path):
217
+ text = ""
218
+ with pdfplumber.open(file_path) as pdf:
219
+ for page in pdf.pages:
220
+ text += page.extract_text() or ""
221
+ return text
222
+
223
+ def extract_text_from_image(file_path):
224
+ image = Image.open(file_path)
225
+ text = pytesseract.image_to_string(image)
226
+ return text
227
+
228
+ def choose_model(text):
229
+ legal_keywords = ["court", "legal", "law", "judgment", "contract", "statute", "case"]
230
+ tfidf = TfidfVectorizer(vocabulary=legal_keywords)
231
+ tfidf_matrix = tfidf.fit_transform([text.lower()])
232
+ score = np.sum(tfidf_matrix.toarray())
233
+ if score > 0.1:
234
+ return "legalbert"
235
+ elif len(text.split()) > 50:
236
+ return "pegasus"
237
+ else:
238
+ return "bert"
239
+
240
+ def summarize_with_pegasus(text):
241
+ inputs = pegasus_tokenizer(text, truncation=True, padding="longest", return_tensors="pt", max_length=512)
242
+ summary_ids = pegasus_model.generate(
243
+ inputs["input_ids"],
244
+ max_length=400, min_length=80, length_penalty=1.5, num_beams=4
245
+ )
246
+ return pegasus_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
247
+
248
+ def summarize_with_bert(text):
249
+ sentences = text.split(". ")
250
+ if len(sentences) < 6:
251
+ return text
252
+ inputs = bert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
253
+ with torch.no_grad():
254
+ outputs = bert_model(**inputs)
255
+ logits = outputs.logits
256
+ probs = torch.softmax(logits, dim=1)[:, 1]
257
+ key_sentence_idx = probs.argsort(descending=True)[:5]
258
+ return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
259
+
260
+ def summarize_with_legalbert(text):
261
+ sentences = text.split(". ")
262
+ if len(sentences) < 6:
263
+ return text
264
+ inputs = legalbert_tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
265
+ with torch.no_grad():
266
+ outputs = legalbert_model(**inputs)
267
+ logits = outputs.logits
268
+ probs = torch.softmax(logits, dim=1)[:, 1]
269
+ key_sentence_idx = probs.argsort(descending=True)[:5]
270
+ return ". ".join([sentences[idx] for idx in key_sentence_idx if sentences[idx].strip()])
271
+
272
+ @app.route('/summarize', methods=['POST'])
273
+ def summarize_document():
274
+ if 'file' not in request.files:
275
+ return jsonify({"error": "No file uploaded"}), 400
276
+
277
+ file = request.files['file']
278
+ filename = file.filename
279
+ file.seek(0, os.SEEK_END)
280
+ file_size = file.tell()
281
+ if file_size > MAX_FILE_SIZE:
282
+ return jsonify({"error": f"File size exceeds {MAX_FILE_SIZE // (1024 * 1024)} MB"}), 413
283
+ file.seek(0)
284
+ file_path = os.path.join(UPLOAD_FOLDER, filename)
285
+ try:
286
+ file.save(file_path)
287
+ except Exception as e:
288
+ return jsonify({"error": f"Failed to save file: {str(e)}"}), 500
289
+
290
+ try:
291
+ if filename.endswith('.pdf'):
292
+ text = extract_text_from_pdf(file_path)
293
+ elif filename.endswith(('.png', '.jpeg', '.jpg')):
294
+ text = extract_text_from_image(file_path)
295
+ else:
296
+ os.remove(file_path)
297
+ return jsonify({"error": "Unsupported file format."}), 400
298
+ except Exception as e:
299
+ os.remove(file_path)
300
+ return jsonify({"error": f"Text extraction failed: {str(e)}"}), 500
301
+
302
+ if not text.strip():
303
+ os.remove(file_path)
304
+ return jsonify({"error": "No text extracted"}), 400
305
+
306
+ try:
307
+ model = choose_model(text)
308
+ if model == "pegasus":
309
+ summary = summarize_with_pegasus(text)
310
+ elif model == "bert":
311
+ summary = summarize_with_bert(text)
312
+ elif model == "legalbert":
313
+ summary = summarize_with_legalbert(text)
314
+ except Exception as e:
315
+ os.remove(file_path)
316
+ return jsonify({"error": f"Summarization failed: {str(e)}"}), 500
317
+
318
+ os.remove(file_path)
319
+ return jsonify({"model_used": model, "summary": summary})
320
+
321
+ if __name__ == '__main__':
322
+ port = int(os.environ.get("PORT", 5000)) # Use PORT env var if set by Hugging Face
323
+ app.run(debug=False, host='0.0.0.0', port=port)
dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as the base image
2
+ FROM python:3.8-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies for pdfplumber, pytesseract, and general compatibility
8
+ RUN apt-get update && apt-get install -y \
9
+ tesseract-ocr \
10
+ libtesseract-dev \
11
+ poppler-utils \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy application code
15
+ COPY . /app
16
+
17
+ # Install Python dependencies, including sentencepiece for Pegasus
18
+ RUN pip install --no-cache-dir \
19
+ flask \
20
+ flask-cors \
21
+ pdfplumber \
22
+ pillow \
23
+ pytesseract \
24
+ numpy \
25
+ torch \
26
+ transformers \
27
+ datasets \
28
+ scikit-learn \
29
+ gunicorn \
30
+ sentencepiece
31
+
32
+ # Create uploads and cache directories with proper permissions
33
+ RUN mkdir -p /app/uploads /app/cache && \
34
+ chmod -R 777 /app/uploads /app/cache
35
+
36
+ # Set environment variable for Hugging Face cache (using HF_HOME as per latest transformers recommendation)
37
+ ENV HF_HOME=/app/cache
38
+
39
+ # Expose port (Hugging Face Spaces typically uses 7860, but we'll stick to 5000 and adjust in app.py if needed)
40
+ EXPOSE 5000
41
+
42
+ # Run with Gunicorn
43
+ CMD ["gunicorn", "--bind", "0.0.0.0:5000", "app:app"]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ pdfplumber
4
+ pillow
5
+ pytesseract
6
+ numpy
7
+ torch
8
+ transformers
9
+ datasets
10
+ scikit-learn
11
+ gunicorn
12
+ sentencepiece