med-summary-v2 / main.py
hariharan220's picture
Update main.py
9e8f213 verified
import os
import pdfplumber
import re
import nltk
import torch
import uvicorn
import time
from nltk.tokenize import sent_tokenize
from transformers import pipeline
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
# βœ… Set cache directories to writable locations for Hugging Face
TMP_DIR = "/tmp/huggingface_cache"
os.environ["TRANSFORMERS_CACHE"] = TMP_DIR
os.environ["HF_HOME"] = TMP_DIR
os.environ["HUGGINGFACE_HUB_CACHE"] = TMP_DIR
os.makedirs(TMP_DIR, exist_ok=True)
# βœ… Ensure NLTK Dependencies are Stored in a Writable Directory
NLTK_DATA_DIR = "/tmp/nltk_data"
os.makedirs(NLTK_DATA_DIR, exist_ok=True)
nltk.data.path.append(NLTK_DATA_DIR)
# βœ… Fix: Download only 'punkt' (NOT 'punkt_tab')
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt", download_dir=NLTK_DATA_DIR)
# βœ… Initialize FastAPI App
app = FastAPI()
# βœ… Enable CORS for API Accessibility
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# βœ… Force GPU Usage if Available
device = 0 if torch.cuda.is_available() else -1
if device == 0:
print("πŸš€ Running on GPU!")
else:
print("⚠️ GPU Not Available! Running on CPU.")
# βœ… Load Summarization Model (Force Cache in /tmp/)
summarizer = pipeline(
"summarization",
model="google/pegasus-xsum",
device=device, # βœ… Force GPU if available
cache_dir=TMP_DIR
)
# --- **Generalized Cleaning** ---
def clean_text(text):
text = re.sub(r"\[\d+\]|\(\d+\)|\(\d{4}\)", "", text)
text = re.sub(r"(References:.*$)", "", text, flags=re.IGNORECASE)
text = re.sub(r"https?://\S+|www\.\S+", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
# --- **PDF Text Extraction** ---
def extract_text_from_pdf(pdf_path):
with pdfplumber.open(pdf_path) as pdf:
extracted_text = [page.extract_text() for page in pdf.pages if page.extract_text()]
return "\n".join(extracted_text)
# --- **Chunking for Summarization** ---
def split_text(text, chunk_size=2048):
sentences = sent_tokenize(text)
chunks, current_chunk = [], ""
for sentence in sentences:
if len(current_chunk) + len(sentence) + 1 <= chunk_size:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# βœ… **Summarization API**
@app.post("/summarize-pdf/")
async def summarize_pdf(file: UploadFile = File(...)):
try:
start_time = time.time()
pdf_content = await file.read()
pdf_path = "/tmp/temp.pdf" # βœ… Store in /tmp/
with open(pdf_path, "wb") as f:
f.write(pdf_content)
full_text = extract_text_from_pdf(pdf_path)
if not full_text.strip():
return {"error": "No text extracted from the PDF."}
cleaned_text = clean_text(full_text)
text_chunks = split_text(cleaned_text, chunk_size=2048)
summaries = [summarizer(chunk, max_new_tokens=250, num_beams=5, truncation=True)[0]['summary_text'] for chunk in text_chunks]
final_summary = " ".join(summaries)
return {"summary": final_summary}
except Exception as e:
return {"error": str(e)}
# βœ… Start Uvicorn for Hugging Face Spaces
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)