Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
from collections import Counter | |
from collections.abc import Mapping | |
from typing import Any | |
import httpx | |
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed | |
logger = logging.getLogger(__name__) | |
BASE_URL = os.environ.get("ETHER0_REMOTES_API_BASE_URL") | |
HEADERS = { | |
"Authorization": f"Bearer {os.environ.get('ETHER0_REMOTES_API_TOKEN')}", | |
"Content-Type": "application/json", | |
} | |
SERVER_ERRORS_COUNTER = Counter({ | |
"fetch_solubility": 0, | |
"fetch_purchasable": 0, | |
"fetch_forward_rxn": 0, | |
"fetch_rxn_info": 0, | |
}) | |
THROW_500_ERROR_THRESHOLD = int( | |
os.environ.get("ETHER0_REMOTES_THROW_500_ERROR_THRESHOLD", "100") | |
) | |
# If our server throws a 501, we don't retry | |
OUR_SERVER_DONT_RETRY_CODE = httpx.codes.NOT_IMPLEMENTED.value | |
REMOTE_WORKER_COLD_START_TIME = 180 # sec | |
class RetryableServerError(Exception): | |
"""Retryable server error.""" | |
def check_raise( | |
cls, response: httpx.Response, kwargs: Mapping[str, Any] | None = None | |
) -> None: | |
if ( | |
response.is_server_error | |
and response.status_code != OUR_SERVER_DONT_RETRY_CODE | |
): | |
raise cls( | |
f"Retryable server error with status code {response.status_code}" | |
f" and inputs {kwargs or {}} and response {response=}." | |
) | |
def fetch_solubility(query_smiles: str) -> dict: | |
response = httpx.post( | |
f"{BASE_URL}/compute_solubility", | |
json={"smiles": query_smiles}, | |
headers=HEADERS, | |
timeout=REMOTE_WORKER_COLD_START_TIME, | |
) | |
error_message = "" | |
if response.is_success: | |
result = response.json() | |
if "error" in result: | |
error_message = result["error"] | |
else: | |
solubility = result["mean"] | |
return {"smiles": query_smiles, "solubility": solubility} | |
if response.is_redirect or response.is_server_error: | |
# We should not have redirect responses or server errors, so let's retry these | |
error_message = response.text | |
SERVER_ERRORS_COUNTER["fetch_solubility"] += 1 | |
if SERVER_ERRORS_COUNTER["fetch_solubility"] >= THROW_500_ERROR_THRESHOLD: | |
response.raise_for_status() | |
RetryableServerError.check_raise( | |
response, kwargs={"query_smiles": query_smiles} | |
) | |
if error_message: | |
logger.warning( | |
f"fetch_solubility did not succeed on {query_smiles=} with" | |
f" {response=} and {error_message=}." | |
) | |
return { | |
"smiles": query_smiles, | |
"error": f"API error: {response} - {error_message}", | |
} | |
def fetch_purchasable(query_smiles_list: list[str] | str) -> dict[str, bool]: | |
response = httpx.post( | |
f"{BASE_URL}/is_purchasable", | |
json={"smiles": query_smiles_list}, | |
headers=HEADERS, | |
timeout=REMOTE_WORKER_COLD_START_TIME, | |
) | |
if response.is_success: | |
return response.json() | |
logger.warning( | |
f"fetch_purchasable did not succeed on {query_smiles_list=} with" | |
f" {response=} and {response.text=}." | |
) | |
if response.is_redirect or response.is_server_error: | |
# We should not have redirect responses or server errors, so let's retry these | |
SERVER_ERRORS_COUNTER["fetch_purchasable"] += 1 | |
if SERVER_ERRORS_COUNTER["fetch_purchasable"] >= THROW_500_ERROR_THRESHOLD: | |
response.raise_for_status() | |
RetryableServerError.check_raise( | |
response, kwargs={"query_smiles_list": query_smiles_list} | |
) | |
return {} | |
def fetch_forward_rxn(query_rxn_smiles: str) -> dict[str, str]: | |
response = httpx.post( | |
f"{BASE_URL}/translate", | |
json={"reaction": query_rxn_smiles}, | |
headers=HEADERS, | |
timeout=REMOTE_WORKER_COLD_START_TIME, | |
) | |
if response.is_success: | |
result = response.json() | |
product = result["product"] | |
return {"smiles": query_rxn_smiles, "product": product} | |
logger.warning( | |
f"fetch_forward_rxn did not succeed on {query_rxn_smiles=} with" | |
f" {response=} and {response.text=}." | |
) | |
if response.is_redirect or response.is_server_error: | |
# We should not have redirect responses or server errors, so let's retry these | |
SERVER_ERRORS_COUNTER["fetch_forward_rxn"] += 1 | |
if SERVER_ERRORS_COUNTER["fetch_forward_rxn"] >= THROW_500_ERROR_THRESHOLD: | |
response.raise_for_status() | |
RetryableServerError.check_raise( | |
response, kwargs={"query_rxn_smiles": query_rxn_smiles} | |
) | |
return { | |
"smiles": query_rxn_smiles, | |
"error": f"API error: {response} - {response.text}", | |
} | |