|
import os |
|
import json |
|
import yaml |
|
from typing import Any |
|
from pathlib import Path |
|
from dotenv import load_dotenv |
|
from collections import namedtuple |
|
from importlib import resources as pkg_resources |
|
|
|
|
|
from openfactcheck.utils.logging import get_logger, set_verbosity |
|
from openfactcheck.errors import ConfigValidationError |
|
from openfactcheck import templates as solver_config_templates_dir |
|
from openfactcheck import solvers as solver_templates_dir |
|
|
|
|
|
solver_config_templates_path = str(pkg_resources.files(solver_config_templates_dir) / "solver_configs") |
|
solver_config_template_files = [str(f) for f in Path(solver_config_templates_path).iterdir()] |
|
|
|
|
|
|
|
|
|
solver_templates_paths = [ |
|
str(pkg_resources.files(solver_templates_dir) / "webservice"), |
|
str(pkg_resources.files(solver_templates_dir) / "factool"), |
|
] |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
class OpenFactCheckConfig: |
|
""" |
|
Class to load the OpenFactCheck configuration from a JSON or YAML file. |
|
|
|
Parameters |
|
---------- |
|
filename_or_path: str or path object |
|
The path to the configuration file. |
|
|
|
Attributes |
|
---------- |
|
retries: int |
|
Number of retries for the pipeline components. |
|
pipeline: namedtuple |
|
Namedtuple containing the pipeline components. |
|
solver_configs: dict |
|
Dictionary containing the solver configurations. |
|
solver_paths: dict |
|
Dictionary containing the paths to the solver models. |
|
output_path: str |
|
Path to the output directory. |
|
secrets: namedtuple |
|
Namedtuple containing the API keys. |
|
verbose: str |
|
The verbosity level for the logger. |
|
|
|
Methods |
|
------- |
|
solver_configuration(solver: str = None) -> dict: |
|
Get the solver configuration for a specific solver or all solvers. |
|
validate(): |
|
Validate the configuration file |
|
|
|
Examples |
|
-------- |
|
For loading the default configuration file 'config.json': |
|
>>> config = OpenFactCheckConfig() |
|
|
|
For loading the configuration file from a specific path or filename: |
|
>>> config = OpenFactCheckConfig("config.json") |
|
|
|
For loading the configuration file and validating it: |
|
>>> config = OpenFactCheckConfig("config.json") |
|
>>> config.validate() |
|
""" |
|
|
|
def __init__(self, filename_or_path: str | Path = "config.json"): |
|
|
|
self.logger = get_logger() |
|
|
|
|
|
self.filename_or_path = filename_or_path |
|
|
|
|
|
Secrets = namedtuple("Secrets", ["openai_api_key", "serper_api_key", "scraper_api_key"]) |
|
|
|
|
|
self.config: dict = {} |
|
self.retries: int = 3 |
|
self.pipeline: list = [] |
|
self.solver_configs: dict[Any, Any] = SolversConfig(solver_config_template_files)() |
|
self.solver_paths: dict[str, list[str]] = {"default": solver_templates_paths, "user_defined": []} |
|
self.output_path: str = "tmp/output" |
|
self.secrets: Secrets = Secrets(openai_api_key=None, serper_api_key=None, scraper_api_key=None) |
|
self.verbose = "WARNING" |
|
|
|
try: |
|
|
|
if Path(self.filename_or_path).exists(): |
|
|
|
with open(self.filename_or_path, encoding="utf-8") as file: |
|
self.config = json.load(file) |
|
self.logger.info(f"Config file loaded successfully from {self.filename_or_path}") |
|
else: |
|
|
|
self.logger.warning(f"Config file not found: {self.filename_or_path}") |
|
|
|
|
|
if "retries" in self.config: |
|
self.retries = self.config["retries"] |
|
else: |
|
self.logger.warning("No retries found in the configuration file. Using default value of 3.") |
|
|
|
|
|
|
|
if "solver_configs" in self.config: |
|
self.solver_configs = SolversConfig(solver_config_template_files + self.config["solver_configs"])() |
|
else: |
|
self.logger.warning( |
|
"No solver configurations found in the configuration file. Using default templates only." |
|
) |
|
|
|
|
|
if "solver_paths" in self.config: |
|
self.solver_paths = { |
|
"default": solver_templates_paths, |
|
"user_defined": self.config["solver_paths"], |
|
} |
|
else: |
|
self.logger.warning( |
|
"No solver paths found in the configuration file. Using default solver paths only." |
|
) |
|
|
|
|
|
if "output_path" in self.config: |
|
self.output_path = self.config["output_path"] |
|
os.makedirs(self.output_path, exist_ok=True) |
|
else: |
|
self.logger.warning( |
|
"No output path found in the configuration file. Using default output path 'tmp/output'." |
|
) |
|
self.output_path = "tmp/output" |
|
os.makedirs(self.output_path, exist_ok=True) |
|
|
|
|
|
if "pipeline" in self.config: |
|
self.pipeline = self.config["pipeline"] |
|
else: |
|
if self.solver_configs: |
|
solvers = list(self.solver_configs.keys()) |
|
claimprocessor = "factool_claimprocessor" if "factool_claimprocessor" in solvers else None |
|
retriever = "factool_retriever" if "factool_retriever" in solvers else None |
|
verifier = "factcheckgpt_verifier" if "factcheckgpt_verifier" in solvers else None |
|
for solver in solvers: |
|
if claimprocessor and retriever and verifier: |
|
break |
|
if "claimprocessor" in solver: |
|
claimprocessor = solver |
|
if "retriever" in solver: |
|
retriever = solver |
|
if "verifier" in solver: |
|
verifier = solver |
|
self.pipeline = [claimprocessor, retriever, verifier] |
|
self.logger.warning( |
|
f"No pipeline found in the configuration file. Using first solver as default pipeline. ClaimProcessor: {claimprocessor}, Retriever: {retriever}, Verifier: {verifier}" |
|
) |
|
|
|
|
|
if "secrets" in self.config: |
|
self.secrets = Secrets( |
|
openai_api_key=self.config["secrets"]["openai_api_key"], |
|
serper_api_key=self.config["secrets"]["serper_api_key"], |
|
scraper_api_key=self.config["secrets"]["scraper_api_key"], |
|
) |
|
else: |
|
self.logger.warning( |
|
"No secrets found in the configuration file. Make sure to set the environment variables." |
|
) |
|
|
|
|
|
if self.secrets.openai_api_key: |
|
os.environ["OPENAI_API_KEY"] = self.secrets.openai_api_key |
|
if self.secrets.serper_api_key: |
|
os.environ["SERPER_API_KEY"] = self.secrets.serper_api_key |
|
if self.secrets.scraper_api_key: |
|
os.environ["SCRAPER_API_KEY"] = self.secrets.scraper_api_key |
|
|
|
|
|
if "verbose" in self.config: |
|
self.verbose = self.config["verbose"] |
|
set_verbosity(self.verbose) |
|
else: |
|
self.logger.warning("No verbose level found in the configuration file. Using default level 'WARNING'.") |
|
|
|
|
|
self.validate() |
|
|
|
except FileNotFoundError: |
|
self.logger.error(f"Config file not found: {self.filename_or_path}") |
|
raise FileNotFoundError(f"Config file not found: {self.filename_or_path}") |
|
|
|
except json.JSONDecodeError: |
|
self.logger.error(f"Invalid JSON in config file: {self.filename_or_path}") |
|
raise ValueError(f"Invalid JSON in config file: {self.filename_or_path}") |
|
|
|
except ConfigValidationError as e: |
|
self.logger.error(f"Configuration validation failed: {e}") |
|
raise ConfigValidationError(f"Configuration validation failed: {e}") |
|
|
|
except Exception as e: |
|
self.logger.error(f"Unexpected error loading config file: {e}") |
|
raise Exception(f"Unexpected error loading config file: {e}") |
|
|
|
def validate(self): |
|
""" |
|
Validate the configuration file. |
|
|
|
Raises |
|
------ |
|
ValueError |
|
If the configuration file is invalid. |
|
|
|
Examples |
|
-------- |
|
>>> config = OpenFactCheckConfig("config.json") |
|
>>> config.validate() |
|
""" |
|
|
|
if "OPENAI_API_KEY" not in os.environ: |
|
self.logger.warning("OPENAI_API_KEY environment variable not found.") |
|
raise ConfigValidationError("OPENAI_API_KEY environment variable not found.") |
|
if "SERPER_API_KEY" not in os.environ: |
|
self.logger.warning("SERPER_API_KEY environment variable not found.") |
|
raise ConfigValidationError("SERPER_API_KEY environment variable not found.") |
|
if "SCRAPER_API_KEY" not in os.environ: |
|
self.logger.warning("SCRAPER_API_KEY environment variable not found.") |
|
raise ConfigValidationError("SCRAPER_API_KEY environment variable not found.") |
|
|
|
def solver_configuration(self, solver: str | None = None) -> dict: |
|
""" |
|
Get the solver configuration for a specific solver or all solvers. |
|
|
|
Parameters |
|
---------- |
|
solver: str |
|
The name of the solver to get the configuration for. |
|
If not provided, returns the configuration for all solvers. |
|
|
|
Returns |
|
------- |
|
dict |
|
The configuration for the specified solver or all solvers. |
|
|
|
Raises |
|
------ |
|
ConfigValidationError |
|
If the configuration validation fails. |
|
|
|
Examples |
|
-------- |
|
>>> config = OpenFactCheckConfig("config.json") |
|
>>> config.solver_configuration() |
|
>>> config.solver_configuration("factcheckgpt_claimprocessor") |
|
""" |
|
if solver: |
|
if solver in self.solver_configs: |
|
return self.solver_configs[solver] |
|
else: |
|
self.logger.error(f"Solver not found: {solver}") |
|
raise ValueError(f"Solver not found: {solver}") |
|
else: |
|
return self.solver_configs |
|
|
|
|
|
class SolversConfig: |
|
""" |
|
A class to load solver configurations from one or more JSON or YAML files. |
|
|
|
This class reads solver configurations from specified files, merges them, |
|
and provides access to the combined configuration as a dictionary. |
|
|
|
Parameters |
|
---------- |
|
filename_or_paths : str | Path | list[str | Path] |
|
The path or list of paths to the solver configuration files. |
|
|
|
Attributes |
|
---------- |
|
solvers : dict[Any, Any] |
|
Dictionary containing the merged solver configurations. |
|
|
|
Examples |
|
-------- |
|
Load solver configurations from a single file: |
|
|
|
>>> solvers = SolversConfig("solvers.yaml") |
|
>>> config = solvers() |
|
|
|
Load solver configurations from multiple files: |
|
|
|
>>> solvers = SolversConfig(["solvers1.json", "solvers2.yaml"]) |
|
>>> config = solvers() |
|
|
|
Access the solvers dictionary: |
|
|
|
>>> config = solvers() |
|
""" |
|
|
|
def __init__(self, filename_or_path_s: str | Path | list[str] | list[Path]) -> None: |
|
""" |
|
Initialize the SolversConfig class. |
|
|
|
Parameters |
|
---------- |
|
filename_or_path_s: str or path object or list of str or path objects |
|
The path to the solvers configuration or a list of paths to multiple solvers configurations. |
|
""" |
|
|
|
self.logger = get_logger() |
|
|
|
|
|
self.filename_or_path_or_path_s = filename_or_path_s |
|
|
|
|
|
self.solvers: dict[Any, Any] = {} |
|
|
|
try: |
|
if isinstance(self.filename_or_path_or_path_s, (str, Path)): |
|
self.__load_config(self.filename_or_path_or_path_s) |
|
elif isinstance(self.filename_or_path_or_path_s, list): |
|
self.__load_configs(self.filename_or_path_or_path_s) |
|
else: |
|
self.logger.error(f"Invalid filename type: {type(self.filename_or_path_or_path_s)}") |
|
raise ValueError(f"Invalid filename type: {type(self.filename_or_path_or_path_s)}") |
|
|
|
except FileNotFoundError: |
|
self.logger.error(f"Solvers file not found: {self.filename_or_path_or_path_s}") |
|
raise FileNotFoundError(f"Solvers file not found: {self.filename_or_path_or_path_s}") |
|
except json.JSONDecodeError: |
|
self.logger.error(f"Invalid JSON in solvers file: {self.filename_or_path_or_path_s}") |
|
raise ValueError(f"Invalid JSON in solvers file: {self.filename_or_path_or_path_s}") |
|
except Exception as e: |
|
self.logger.error(f"Unexpected error loading solvers file: {e}") |
|
raise Exception(f"Unexpected error loading solvers file: {e}") |
|
|
|
def __load_config(self, filename_or_path: str | Path) -> None: |
|
|
|
filename = str(filename_or_path) |
|
|
|
with open(filename, encoding="utf-8") as file: |
|
if filename.endswith(".yaml"): |
|
file_data = yaml.load(file, Loader=yaml.FullLoader) |
|
elif filename.endswith(".json"): |
|
file_data = json.load(file) |
|
else: |
|
self.logger.error(f"Invalid file format: {filename}") |
|
raise ValueError(f"Invalid file format: {filename}") |
|
|
|
|
|
self.solvers.update(file_data) |
|
|
|
|
|
if "template" in filename: |
|
self.logger.info(f"Template solver configuration loaded: {filename.split('/')[-1]}") |
|
else: |
|
self.logger.info(f"User-defined solver configuration loaded from: {filename}") |
|
|
|
def __load_configs(self, filenames: list[str] | list[Path]) -> None: |
|
for filename in filenames: |
|
self.__load_config(filename) |
|
|
|
def __call__(self) -> dict[Any, Any]: |
|
return self.solvers |
|
|