import os
import pdfplumber
import re
import nltk
import torch
import uvicorn
import time
from nltk.tokenize import sent_tokenize
from transformers import pipeline
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware

# ✅ Set cache directories to writable locations for Hugging Face
TMP_DIR = "/tmp/huggingface_cache"
os.environ["TRANSFORMERS_CACHE"] = TMP_DIR
os.environ["HF_HOME"] = TMP_DIR  
os.environ["HUGGINGFACE_HUB_CACHE"] = TMP_DIR

os.makedirs(TMP_DIR, exist_ok=True)

# ✅ Ensure NLTK Dependencies are Stored in a Writable Directory
NLTK_DATA_DIR = "/tmp/nltk_data"
os.makedirs(NLTK_DATA_DIR, exist_ok=True)
nltk.data.path.append(NLTK_DATA_DIR)

# ✅ Fix: Download only 'punkt' (NOT 'punkt_tab')
try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt", download_dir=NLTK_DATA_DIR)

# ✅ Initialize FastAPI App
app = FastAPI()

# ✅ Enable CORS for API Accessibility
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ✅ Force GPU Usage if Available
device = 0 if torch.cuda.is_available() else -1
if device == 0:
    print("🚀 Running on GPU!")
else:
    print("⚠️ GPU Not Available! Running on CPU.")

# ✅ Load Summarization Model (Force Cache in /tmp/)
summarizer = pipeline(
    "summarization",
    model="google/pegasus-xsum",
    device=device,  # ✅ Force GPU if available
    cache_dir=TMP_DIR
)

# --- **Generalized Cleaning** ---
def clean_text(text):
    text = re.sub(r"\[\d+\]|\(\d+\)|\(\d{4}\)", "", text)
    text = re.sub(r"(References:.*$)", "", text, flags=re.IGNORECASE)
    text = re.sub(r"https?://\S+|www\.\S+", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

# --- **PDF Text Extraction** ---
def extract_text_from_pdf(pdf_path):
    with pdfplumber.open(pdf_path) as pdf:
        extracted_text = [page.extract_text() for page in pdf.pages if page.extract_text()]
    return "\n".join(extracted_text)

# --- **Chunking for Summarization** ---
def split_text(text, chunk_size=2048):
    sentences = sent_tokenize(text)
    chunks, current_chunk = [], ""
    for sentence in sentences:
        if len(current_chunk) + len(sentence) + 1 <= chunk_size:
            current_chunk += sentence + " "
        else:
            chunks.append(current_chunk.strip())
            current_chunk = sentence + " "
    if current_chunk:
        chunks.append(current_chunk.strip())
    return chunks

# ✅ **Summarization API**
@app.post("/summarize-pdf/")
async def summarize_pdf(file: UploadFile = File(...)):
    try:
        start_time = time.time()
        pdf_content = await file.read()
        pdf_path = "/tmp/temp.pdf"  # ✅ Store in /tmp/
        with open(pdf_path, "wb") as f:
            f.write(pdf_content)

        full_text = extract_text_from_pdf(pdf_path)
        if not full_text.strip():
            return {"error": "No text extracted from the PDF."}

        cleaned_text = clean_text(full_text)
        text_chunks = split_text(cleaned_text, chunk_size=2048)
        summaries = [summarizer(chunk, max_new_tokens=250, num_beams=5, truncation=True)[0]['summary_text'] for chunk in text_chunks]

        final_summary = " ".join(summaries)
        return {"summary": final_summary}

    except Exception as e:
        return {"error": str(e)}

# ✅ Start Uvicorn for Hugging Face Spaces
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)