Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os | |
from pathlib import Path | |
from typing import Dict, List, Literal, Tuple | |
from dotenv import load_dotenv | |
from openai import OpenAI | |
import anthropic | |
import requests | |
import base64 | |
from pydantic import BaseModel | |
from .logger import get_review_logger | |
from .utils import extract_all_tags | |
load_dotenv() | |
# --------------------------------------------------------------------------- | |
# Pydantic models | |
# --------------------------------------------------------------------------- | |
class Point(BaseModel): | |
content: str | |
importance: Literal["critical", "minor"] | |
class Review(BaseModel): | |
contributions: str | |
strengths: List[Point] | |
weaknesses: List[Point] | |
requested_changes: List[Point] | |
impact_concerns: str | |
claims_and_evidence: str | |
audience_interest: str | |
IMPORTANCE_MAPPING = {"critical": 2, "minor": 1} | |
# --------------------------------------------------------------------------- | |
# Reviewer Class | |
# --------------------------------------------------------------------------- | |
class PDFReviewer: | |
"""Encapsulates the full PDF review life-cycle. | |
Parameters | |
---------- | |
openai_key: | |
OAuth key for the OpenAI client. Falls back to ``OPENAI_API_KEY`` env var. | |
anthropic_key: | |
Key for Anthropic Claude API. Falls back to ``ANTHROPIC_API_KEY`` env var. | |
cache_dir: | |
Where temporary PDFs are stored. | |
""" | |
def __init__( | |
self, | |
*, | |
openai_key: str | None = None, | |
anthropic_key: str | None = None, | |
cache_dir: str | Path | None = None, | |
debug: bool = False, | |
) -> None: | |
self.openai_key = openai_key or os.getenv("OPENAI_API_KEY") | |
self.anthropic_key = anthropic_key or os.getenv("ANTHROPIC_API_KEY") | |
if not self.openai_key: | |
raise EnvironmentError("Missing OPENAI_API_KEY env var or parameter") | |
if not self.anthropic_key: | |
raise EnvironmentError("Missing ANTHROPIC_API_KEY env var or parameter") | |
self.client = OpenAI(api_key=self.openai_key) | |
self.claude_client = anthropic.Anthropic(api_key=self.anthropic_key) | |
cache_dir = cache_dir or os.getenv("TMLR_CACHE_DIR", "/tmp/tmlr_cache") | |
self.cache_dir = Path(cache_dir) | |
self.cache_dir.mkdir(exist_ok=True) | |
self.debug = debug | |
self.logger = get_review_logger() | |
# Lazy import prompts to avoid circular dependency during tests | |
import importlib | |
self.PROMPTS = importlib.import_module("prompts") | |
# --------------------------------------------------------------------- | |
# Public high-level API | |
# --------------------------------------------------------------------- | |
def review_pdf(self, pdf_path: str | Path) -> Dict[str, str]: | |
"""Main entry-point: review *pdf_path* and return parsed results.""" | |
pdf_path = Path(pdf_path) | |
self.logger.info("Starting review for %s", pdf_path.name) | |
file_uploaded = self._step("upload_pdf", self._upload_pdf, pdf_path) | |
self.logger.info("PDF uploaded, id=%s", file_uploaded.id) | |
literature_report = self._step("literature_search", self._literature_search, file_uploaded) | |
self.logger.info("Literature search complete") | |
raw_review = self._step("generate_initial_review", self._generate_initial_review, file_uploaded, literature_report) | |
self.logger.info("Initial review generated") | |
# Optional defense / revision stage | |
defended_review = self._step("defend_review", self._defend_review, file_uploaded, raw_review) | |
parsed_review = self._step("parse_final", self._parse_final, defended_review) | |
self.logger.info("Review parsed") | |
return parsed_review | |
# ------------------------------------------------------------------ | |
# Internal helpers (prefixed with _) | |
# ------------------------------------------------------------------ | |
def _upload_pdf(self, pdf_path: Path): | |
"""Upload *pdf_path* to OpenAI and return the file object.""" | |
with open(pdf_path, "rb") as pdf_file: | |
return self.client.files.create(file=pdf_file, purpose="user_data") | |
def _literature_search(self, file): | |
"""Run literature search tool call.""" | |
model_name = "gpt-4o" if self.debug else "gpt-4.1" | |
resp = self.client.responses.create( | |
model=model_name, | |
input=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "input_file", "file_id": file.id}, | |
{"type": "input_text", "text": self.PROMPTS.literature_search}, | |
], | |
} | |
], | |
tools=[{"type": "web_search"}], | |
) | |
return resp.output_text | |
def _generate_initial_review(self, file, literature_report: str): | |
"""Query GPT model with combined prompts to get initial review.""" | |
prompt = self.PROMPTS.review_prompt.format( | |
literature_search_report=literature_report, | |
acceptance_criteria=self.PROMPTS.acceptance_criteria, | |
review_format=self.PROMPTS.review_format, | |
) | |
model_name = "gpt-4o" if self.debug else "o4-mini" | |
resp = self.client.responses.create( | |
model=model_name, | |
input=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "input_file", "file_id": file.id}, | |
{"type": "input_text", "text": prompt}, | |
], | |
} | |
], | |
) | |
return resp.output_text | |
# ------------------------------------------------------------------ | |
# Static/utility parsing helpers | |
# ------------------------------------------------------------------ | |
def _parse_final(self, parsed: Dict, *, max_strengths: int = 3, max_weaknesses: int = 5, max_requested_changes: int = 5) -> Dict[str, str]: | |
"""Convert model structured response into simplified text blobs.""" | |
self.logger.debug("Parsing final review json -> human readable") | |
if isinstance(parsed, str): | |
# attempt to parse via Pydantic | |
try: | |
parsed = Review.model_validate_json(parsed).model_dump() | |
except Exception: | |
self.logger.warning("parse_final received string that could not be parsed by Review model. Returning as-is text under 'contributions'.") | |
return {"contributions": parsed} | |
new_parsed: Dict[str, str] = {} | |
new_parsed["contributions"] = parsed["contributions"] | |
new_parsed["claims_and_evidence"] = parsed["claims_and_evidence"] | |
new_parsed["audience_interest"] = parsed["audience_interest"] | |
new_parsed["impact_concerns"] = parsed["impact_concerns"] | |
new_parsed["strengths"] = "\n".join( | |
[f"- {point['content']}" for point in parsed["strengths"][:max_strengths]] | |
) | |
new_parsed["weaknesses"] = "\n".join( | |
[f"- {point['content']}" for point in parsed["weaknesses"][:max_weaknesses]] | |
) | |
request_changes_sorted = sorted( | |
parsed["requested_changes"], | |
key=lambda x: IMPORTANCE_MAPPING[x["importance"]], | |
reverse=True, | |
) | |
new_parsed["requested_changes"] = "\n".join( | |
[f"- {point['content']}" for point in request_changes_sorted[:max_requested_changes]] | |
) | |
return new_parsed | |
# ------------------------------------------------------------------ | |
# Optional β could integrate unit tests style checks here | |
# ------------------------------------------------------------------ | |
def _run_unit_tests(self, pdf_path: Path, review: Dict[str, str]) -> Tuple[bool, str | None]: | |
"""Run post-hoc sanity tests powered by Claude prompts.""" | |
test_prompt = self.PROMPTS.unit_test_prompt.format(review=review) | |
response = self._ask_claude(test_prompt, pdf_path) | |
results = extract_all_tags(response) | |
for test_name in [ | |
"reviewing_process_references", | |
"inappropriate_language", | |
"llm_generated_review", | |
"hallucinations", | |
"formatting_and_style", | |
]: | |
self.logger.info("Unit test %s: %s", test_name, results.get(test_name)) | |
if results.get(test_name) == "FAIL": | |
return False, test_name | |
return True, None | |
# ------------------------------------------------------------------ | |
# Claude wrapper | |
# ------------------------------------------------------------------ | |
def _ask_claude( | |
self, | |
query: str, | |
pdf_path: str | Path | None = None, | |
*, | |
max_tokens: int = 8000, | |
model: str = "claude-3-5-sonnet-20241022", | |
) -> str: | |
content = query | |
betas: List[str] = [] | |
# Attach PDF for context if provided | |
if pdf_path is not None: | |
if str(pdf_path).startswith(("http://", "https://")): | |
binary_data = requests.get(str(pdf_path)).content | |
else: | |
with open(pdf_path, "rb") as fp: | |
binary_data = fp.read() | |
pdf_data = base64.standard_b64encode(binary_data).decode() | |
content = [ | |
{ | |
"type": "document", | |
"source": { | |
"type": "base64", | |
"media_type": "application/pdf", | |
"data": pdf_data, | |
}, | |
}, | |
{"type": "text", "text": query}, | |
] | |
betas.append("pdfs-2024-09-25") | |
kwargs = { | |
"model": model, | |
"max_tokens": max_tokens, | |
"messages": [{"role": "user", "content": content}], | |
} | |
if betas: | |
kwargs["betas"] = betas | |
message = self.claude_client.beta.messages.create(**kwargs) # type: ignore[arg-type] | |
return message.content[0].text | |
# ------------------------------------------------------------------ | |
# Public utility methods | |
# ------------------------------------------------------------------ | |
def get_prompts(self): | |
"""Return the prompts module for inspection.""" | |
return self.PROMPTS | |
def get_logger(self): | |
"""Return the logger for inspection.""" | |
return self.logger | |
# ------------------------------------------------------------------ | |
# _step helper (defined at end to avoid cluttering core logic) | |
# ------------------------------------------------------------------ | |
def _step(self, name: str, fn, *args, **kwargs): | |
"""Execute *fn* and, if an exception occurs, trigger pdb in debug mode.""" | |
try: | |
self.logger.info("Starting step: %s", name) | |
result = fn(*args, **kwargs) | |
self.logger.info("Completed step: %s", name) | |
return result | |
except Exception: | |
self.logger.exception("Step %s failed", name) | |
if self.debug: | |
import pdb, traceback | |
traceback.print_exc() | |
pdb.post_mortem() | |
raise | |
# ------------------------------------------------------------------ | |
# Defense / revision helpers | |
# ------------------------------------------------------------------ | |
def _run_query_on_file(self, file, prompt: str, *, model_name: str): | |
"""Thin wrapper around OpenAI responses.create used by several steps.""" | |
return self.client.responses.create( | |
model=model_name, | |
input=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "input_file", "file_id": file.id}, | |
{"type": "input_text", "text": prompt}, | |
], | |
} | |
], | |
).output_text | |
def _defend_review(self, file, review: str): | |
"""Run defense β revision β human-style polishing as in legacy workflow.""" | |
model_name = "gpt-4o" if self.debug else "o3" | |
defense = self._run_query_on_file( | |
file, | |
self.PROMPTS.defend_prompt.format(combined_review=review), | |
model_name=model_name, | |
) | |
revision_prompt = self.PROMPTS.revise_prompt.format( | |
review_format=self.PROMPTS.review_format.format( | |
acceptance_criteria=self.PROMPTS.acceptance_criteria, | |
review_format=self.PROMPTS.review_format, | |
), | |
combined_review=review, | |
defended_paper=defense, | |
) | |
revision = self._run_query_on_file(file, revision_prompt, model_name=model_name) | |
humanised = self._run_query_on_file( | |
file, | |
self.PROMPTS.human_style.format(review=revision), | |
model_name=model_name, | |
) | |
# Finally, convert to structured Review JSON using formatting prompt | |
formatted = self._format_review(humanised, model_name=model_name) | |
return formatted | |
def _format_review(self, review_text: str, *, model_name: str): | |
"""Use OpenAI function calling to map *review_text* β Review model dict.""" | |
chat_completion = self.client.beta.chat.completions.parse( | |
messages=[ | |
{ | |
"role": "user", | |
"content": self.PROMPTS.formatting_prompt.format(review=review_text), | |
} | |
], | |
model=model_name, | |
response_format=Review, | |
) | |
return chat_completion.choices[0].message.parsed.model_dump() |