Spaces:
Sleeping
Sleeping
| import argparse | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import shutil | |
| import torch | |
| from groq import Groq | |
| from nougat import NougatModel | |
| from nougat.utils.device import move_to_device | |
| from nougat.postprocessing import markdown_compatible | |
| from pypdf import PdfReader | |
| from tqdm import tqdm | |
| from dotenv import load_dotenv | |
| import pypdfium2 as pdfium | |
| from torchvision.transforms.functional import to_tensor | |
| # Configure basic logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[logging.StreamHandler(), logging.FileHandler("pdf_processing.log")], | |
| ) | |
| class NougatPDFProcessor: | |
| """ | |
| Processes PDFs using the Nougat model to generate high-quality Markdown, | |
| and prepends an AMA citation generated by a Groq LLM. | |
| """ | |
| def __init__(self, input_dir: str, output_dir: str): | |
| self.input_dir = Path(input_dir) | |
| self.output_dir = Path(output_dir) | |
| self.temp_dir = self.output_dir / "temp_nougat_output" | |
| self.output_dir.mkdir(exist_ok=True) | |
| self.temp_dir.mkdir(exist_ok=True) | |
| load_dotenv() | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| raise ValueError("GROQ_API_KEY not found in .env file") | |
| self.groq_client = Groq(api_key=groq_api_key) | |
| # Initialize Nougat model | |
| self.model = NougatModel.from_pretrained("facebook/nougat-small").to(torch.bfloat16) | |
| self.model = move_to_device(self.model) | |
| self.model.eval() | |
| def _get_first_page_text(self, pdf_path: Path) -> str: | |
| """Extracts text from the first page of a PDF.""" | |
| try: | |
| reader = PdfReader(pdf_path) | |
| first_page = reader.pages[0] | |
| return first_page.extract_text() or "" | |
| except Exception as e: | |
| logging.error(f"Could not extract text from first page of '{pdf_path.name}': {e}") | |
| return "" | |
| def _generate_ama_citation(self, text: str) -> str: | |
| """Generates an AMA citation using the Groq API.""" | |
| if not text: | |
| return "Citation could not be generated: No text found on the first page." | |
| prompt = ( | |
| "Based on the following text from the first page of a medical document, " | |
| "please generate a concise AMA (American Medical Association) style citation. " | |
| "Include authors, title, journal/source, year, and volume/page numbers if available. " | |
| "If some information is missing, create the best citation possible with the available data. " | |
| "Output only the citation itself, with no additional text or labels.\n\n" | |
| f"--- DOCUMENT TEXT ---\n{text[:4000]}\n\n--- END DOCUMENT TEXT ---\n\nAMA Citation:" | |
| ) | |
| try: | |
| chat_completion = self.groq_client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model="llama3-8b-8192", | |
| temperature=0, | |
| max_tokens=200, | |
| ) | |
| citation = chat_completion.choices[0].message.content.strip() | |
| return f"## Citation\n\n{citation}\n\n---\n\n" | |
| except Exception as e: | |
| logging.error(f"Groq API call failed for citation generation: {e}") | |
| return "## Citation\n\nCitation could not be generated due to an error.\n\n---\n\n" | |
| def process_single_pdf(self, pdf_path: Path): | |
| """Processes a single PDF with Nougat and adds a citation.""" | |
| logging.info(f"Processing '{pdf_path.name}'...") | |
| final_md_path = self.output_dir / f"{pdf_path.stem}.md" | |
| # 1. Generate Citation | |
| logging.info(f"Generating AMA citation for '{pdf_path.name}'...") | |
| first_page_text = self._get_first_page_text(pdf_path) | |
| citation_md = self._generate_ama_citation(first_page_text) | |
| logging.info(f"Citation generated for '{pdf_path.name}'.") | |
| # 2. Process with Nougat | |
| logging.info(f"Processing PDF '{pdf_path.name}' with Nougat...") | |
| try: | |
| # Open the PDF with pypdfium2 and render the first page | |
| pdf = pdfium.PdfDocument(pdf_path) | |
| page = pdf[0] # Get the first page | |
| bitmap = page.render(scale=1) # Render at 72 DPI | |
| image = bitmap.to_pil() # Convert to a PIL Image | |
| # Resize the image to the required dimensions for the Nougat model | |
| image = image.resize((672, 896)) | |
| # Convert PIL image to a bfloat16 tensor | |
| tensor = to_tensor(image).to(torch.bfloat16) | |
| # Pass the tensor to the model | |
| predictions = self.model.inference(image_tensors=tensor.unsqueeze(0)) | |
| # The output for a single file is in predictions['predictions'][0] | |
| nougat_markdown = predictions['predictions'][0] | |
| # Post-processing to fix common markdown issues | |
| nougat_markdown = markdown_compatible(nougat_markdown) | |
| logging.info(f"Successfully processed '{pdf_path.name}' with Nougat.") | |
| # 3. Combine and Save | |
| final_content = citation_md + nougat_markdown | |
| final_md_path.write_text(final_content, encoding="utf-8") | |
| logging.info(f"Successfully saved final markdown to '{final_md_path}'.") | |
| except Exception as e: | |
| logging.error(f"Failed to process '{pdf_path.name}' with Nougat: {e}") | |
| # Create an error file to avoid reprocessing | |
| final_md_path.write_text(f"Failed to process this document with Nougat.\n\nError: {e}", encoding="utf-8") | |
| def process_all_pdfs(self): | |
| """Processes all PDF files in the input directory.""" | |
| pdf_files = sorted(list(self.input_dir.glob("*.pdf"))) | |
| if not pdf_files: | |
| logging.warning(f"No PDF files found in {self.input_dir}") | |
| return | |
| logging.info(f"Found {len(pdf_files)} PDF(s) to process.") | |
| for pdf_path in tqdm(pdf_files, desc="Processing PDFs with Nougat"): | |
| final_md_path = self.output_dir / f"{pdf_path.stem}.md" | |
| if final_md_path.exists(): | |
| logging.info(f"Skipping '{pdf_path.name}' as it has already been processed.") | |
| continue | |
| self.process_single_pdf(pdf_path) | |
| def main(): | |
| """Main function to run the PDF processing script.""" | |
| parser = argparse.ArgumentParser(description="PDF to Markdown Converter using Nougat with AMA Citations.") | |
| parser.add_argument("--input-dir", type=str, default="Obs", help="Directory containing source PDF files.") | |
| parser.add_argument("--output-dir", type=str, default="src/processed_markdown", help="Directory to save final Markdown files.") | |
| parser.add_argument("--file", type=str, help="Process a single PDF file by name (e.g., 'my_doc.pdf').") | |
| args = parser.parse_args() | |
| processor = NougatPDFProcessor(input_dir=args.input_dir, output_dir=args.output_dir) | |
| if args.file: | |
| pdf_to_process = Path(args.input_dir) / args.file | |
| if pdf_to_process.exists(): | |
| processor.process_single_pdf(pdf_to_process) | |
| else: | |
| logging.error(f"Specified file not found: {pdf_to_process}") | |
| else: | |
| processor.process_all_pdfs() | |
| if __name__ == "__main__": | |
| main() |