Spaces:
Sleeping
Sleeping
import contextlib | |
import json | |
import logging | |
import os | |
import re | |
from abc import ABC, abstractmethod | |
from typing import Any, AsyncGenerator | |
import backoff | |
import httpx | |
import tqdm | |
from article_embedding.utils import env_str | |
log = logging.getLogger(__name__) | |
class Checkpoint(ABC): | |
def get(self) -> str | None: ... | |
def set(self, value: str) -> None: ... | |
def reset(self) -> None: ... | |
class NullCheckpoint(Checkpoint): | |
def get(self) -> str | None: | |
return None | |
def set(self, value: str) -> None: | |
pass | |
def reset(self) -> None: | |
pass | |
_NULL_CHECKPOINT = NullCheckpoint() | |
class FileCheckpoint(Checkpoint): | |
def __init__(self, path: str) -> None: | |
self.path = path | |
def get(self) -> str | None: | |
try: | |
with open(self.path) as file: | |
return file.read().strip() | |
except FileNotFoundError: | |
return None | |
def set(self, value: str) -> None: | |
with open(self.path, "w") as file: | |
file.write(value) | |
def reset(self) -> None: | |
with contextlib.suppress(FileNotFoundError): | |
os.remove(self.path) | |
class CouchDB: | |
def __init__(self) -> None: | |
self.client = self.make_client() | |
self.database = env_str("COUCHDB_DB") | |
self.path_view = f"/{self.database}/{env_str("DOCS_PATH_VIEW")}" | |
def __new__(cls) -> "CouchDB": | |
if not hasattr(cls, "_instance"): | |
cls._instance = super().__new__(cls) | |
return cls._instance | |
def make_client(self) -> httpx.AsyncClient: | |
url = os.environ["COUCHDB_URL"] | |
user = os.environ["COUCHDB_USER"] | |
password = os.environ["COUCHDB_PASSWORD"] | |
auth = {"name": user, "password": password} | |
async def on_backoff(details: Any) -> None: | |
response = await self.client.post("/_session", json=auth) | |
response.raise_for_status() | |
client = httpx.AsyncClient(base_url=url) | |
decorator = backoff.on_predicate( | |
backoff.expo, | |
predicate=lambda r: r.status_code == 401, | |
on_backoff=on_backoff, | |
max_tries=2, | |
factor=0, | |
) | |
client.get = decorator(client.get) # type: ignore[method-assign] | |
return client | |
async def changes(self, *, batch_size: int, checkpoint: Checkpoint = _NULL_CHECKPOINT) -> AsyncGenerator[list[Any], None]: | |
since = checkpoint.get() or 0 | |
params = {"since": since, "limit": batch_size, "include_docs": True} | |
while True: | |
response = await self.client.get(f"/{self.database}/_changes", params=params) | |
response.raise_for_status() | |
data = response.json() | |
yield [change["doc"] for change in data["results"]] | |
since = data["last_seq"] | |
assert isinstance(since, str) | |
params["since"] = since | |
checkpoint.set(since) | |
if data["pending"] == 0: | |
break | |
async def estimate_total_changes(self, *, checkpoint: Checkpoint = _NULL_CHECKPOINT) -> int: | |
since = checkpoint.get() or 0 | |
params = {"since": since, "limit": 0} | |
response = await self.client.get(f"/{self.database}/_changes", params=params) | |
response.raise_for_status() | |
data = response.json() | |
return int(data["pending"]) + 1 | |
async def get_doc_by_id(self, doc_id: str) -> Any: | |
try: | |
response = await self.client.get(f"/{self.database}/{doc_id}") | |
if response.status_code == 404: | |
return None | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
log.error("Error fetching document by ID", exc_info=e) | |
return None | |
async def get_doc_by_path(self, path: str) -> Any: | |
try: | |
params = { | |
"limit": "1", | |
"key": json.dumps(path), | |
"include_docs": "true", | |
} | |
response = await self.client.get(self.path_view, params=params) | |
response.raise_for_status() | |
data = response.json() | |
rows = data["rows"] | |
if not rows: | |
return None | |
return rows[0]["doc"] | |
except Exception as e: | |
logging.error("Error fetching document by path", exc_info=e) | |
return None | |
async def get_doc(self, id_or_path: str) -> Any: | |
uuids = extract_doc_ids(id_or_path) | |
for uuid in uuids: | |
doc = await self.get_doc_by_id(uuid) | |
if doc: | |
return doc | |
path = extract_doc_path(id_or_path) | |
if path: | |
return await self.get_doc_by_path(path) | |
return None | |
UUID_PATTERN = re.compile(r"[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}") | |
def extract_doc_ids(s: str) -> list[str]: | |
return UUID_PATTERN.findall(s) | |
def extract_doc_path(s: str) -> str | None: | |
if not s.endswith(".html"): | |
return None | |
if s.startswith("/"): | |
return s | |
if "://" in s: | |
s = s.split("://", 1)[1] | |
if "/" in s: | |
return "/" + s.split("/", 1)[1] | |
return None | |
if __name__ == "__main__": | |
async def main() -> None: | |
db = CouchDB() | |
checkpoint = FileCheckpoint(".checkpoint") | |
total = await db.estimate_total_changes(checkpoint=checkpoint) | |
with tqdm.tqdm(total=total) as pbar: | |
async for docs in db.changes(batch_size=40, checkpoint=checkpoint): | |
for doc in docs: | |
kind = doc.get("type") | |
if kind == "article": | |
_id = doc["_id"] | |
language = doc["language"] | |
path = doc["path"] | |
path = os.path.basename(path) | |
pbar.desc = f"{_id}: {kind} {language} {path}" | |
else: | |
pbar.desc = f"{kind}" | |
pbar.update(1) | |
import asyncio | |
from dotenv import load_dotenv | |
load_dotenv() | |
asyncio.run(main()) | |