Spaces:
Running
Running
import json | |
import os | |
import warnings | |
from typing import Dict, Any, List, Union, Type, get_origin, get_args | |
from .variables.default import DEFAULT_CONFIG | |
from .variables.base import BaseConfig | |
from ..retrievers.utils import get_all_retriever_names | |
class Config: | |
"""Config class for GPT Researcher.""" | |
CONFIG_DIR = os.path.join(os.path.dirname(__file__), "variables") | |
def __init__(self, config_path: str | None = None): | |
"""Initialize the config class.""" | |
self.config_path = config_path | |
self.llm_kwargs: Dict[str, Any] = {} | |
self.embedding_kwargs: Dict[str, Any] = {} | |
config_to_use = self.load_config(config_path) | |
self._set_attributes(config_to_use) | |
self._set_embedding_attributes() | |
self._set_llm_attributes() | |
self._handle_deprecated_attributes() | |
self._set_doc_path(config_to_use) | |
def _set_attributes(self, config: Dict[str, Any]) -> None: | |
for key, value in config.items(): | |
env_value = os.getenv(key) | |
if env_value is not None: | |
value = self.convert_env_value(key, env_value, BaseConfig.__annotations__[key]) | |
setattr(self, key.lower(), value) | |
# Handle RETRIEVER with default value | |
retriever_env = os.environ.get("RETRIEVER", config.get("RETRIEVER", "tavily")) | |
try: | |
self.retrievers = self.parse_retrievers(retriever_env) | |
except ValueError as e: | |
print(f"Warning: {str(e)}. Defaulting to 'tavily' retriever.") | |
self.retrievers = ["tavily"] | |
def _set_embedding_attributes(self) -> None: | |
self.embedding_provider, self.embedding_model = self.parse_embedding( | |
self.embedding | |
) | |
def _set_llm_attributes(self) -> None: | |
self.fast_llm_provider, self.fast_llm_model = self.parse_llm(self.fast_llm) | |
self.smart_llm_provider, self.smart_llm_model = self.parse_llm(self.smart_llm) | |
self.strategic_llm_provider, self.strategic_llm_model = self.parse_llm(self.strategic_llm) | |
def _handle_deprecated_attributes(self) -> None: | |
if os.getenv("EMBEDDING_PROVIDER") is not None: | |
warnings.warn( | |
"EMBEDDING_PROVIDER is deprecated and will be removed soon. Use EMBEDDING instead.", | |
FutureWarning, | |
stacklevel=2, | |
) | |
self.embedding_provider = ( | |
os.environ["EMBEDDING_PROVIDER"] or self.embedding_provider | |
) | |
match os.environ["EMBEDDING_PROVIDER"]: | |
case "ollama": | |
self.embedding_model = os.environ["OLLAMA_EMBEDDING_MODEL"] | |
case "custom": | |
self.embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "custom") | |
case "openai": | |
self.embedding_model = "text-embedding-3-large" | |
case "azure_openai": | |
self.embedding_model = "text-embedding-3-large" | |
case "huggingface": | |
self.embedding_model = "sentence-transformers/all-MiniLM-L6-v2" | |
case _: | |
raise Exception("Embedding provider not found.") | |
_deprecation_warning = ( | |
"LLM_PROVIDER, FAST_LLM_MODEL and SMART_LLM_MODEL are deprecated and " | |
"will be removed soon. Use FAST_LLM and SMART_LLM instead." | |
) | |
if os.getenv("LLM_PROVIDER") is not None: | |
warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2) | |
self.fast_llm_provider = ( | |
os.environ["LLM_PROVIDER"] or self.fast_llm_provider | |
) | |
self.smart_llm_provider = ( | |
os.environ["LLM_PROVIDER"] or self.smart_llm_provider | |
) | |
if os.getenv("FAST_LLM_MODEL") is not None: | |
warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2) | |
self.fast_llm_model = os.environ["FAST_LLM_MODEL"] or self.fast_llm_model | |
if os.getenv("SMART_LLM_MODEL") is not None: | |
warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2) | |
self.smart_llm_model = os.environ["SMART_LLM_MODEL"] or self.smart_llm_model | |
def _set_doc_path(self, config: Dict[str, Any]) -> None: | |
self.doc_path = config['DOC_PATH'] | |
if self.doc_path: | |
try: | |
self.validate_doc_path() | |
except Exception as e: | |
print(f"Warning: Error validating doc_path: {str(e)}. Using default doc_path.") | |
self.doc_path = DEFAULT_CONFIG['DOC_PATH'] | |
def load_config(cls, config_path: str | None) -> Dict[str, Any]: | |
"""Load a configuration by name.""" | |
if config_path is None: | |
return DEFAULT_CONFIG | |
# config_path = os.path.join(cls.CONFIG_DIR, config_path) | |
if not os.path.exists(config_path): | |
if config_path and config_path != "default": | |
print(f"Warning: Configuration not found at '{config_path}'. Using default configuration.") | |
if not config_path.endswith(".json"): | |
print(f"Do you mean '{config_path}.json'?") | |
return DEFAULT_CONFIG | |
with open(config_path, "r") as f: | |
custom_config = json.load(f) | |
# Merge with default config to ensure all keys are present | |
merged_config = DEFAULT_CONFIG.copy() | |
merged_config.update(custom_config) | |
return merged_config | |
def list_available_configs(cls) -> List[str]: | |
"""List all available configuration names.""" | |
configs = ["default"] | |
for file in os.listdir(cls.CONFIG_DIR): | |
if file.endswith(".json"): | |
configs.append(file[:-5]) # Remove .json extension | |
return configs | |
def parse_retrievers(self, retriever_str: str) -> List[str]: | |
"""Parse the retriever string into a list of retrievers and validate them.""" | |
retrievers = [retriever.strip() | |
for retriever in retriever_str.split(",")] | |
valid_retrievers = get_all_retriever_names() or [] | |
invalid_retrievers = [r for r in retrievers if r not in valid_retrievers] | |
if invalid_retrievers: | |
raise ValueError( | |
f"Invalid retriever(s) found: {', '.join(invalid_retrievers)}. " | |
f"Valid options are: {', '.join(valid_retrievers)}." | |
) | |
return retrievers | |
def parse_llm(llm_str: str | None) -> tuple[str | None, str | None]: | |
"""Parse llm string into (llm_provider, llm_model).""" | |
from gpt_researcher.llm_provider.generic.base import _SUPPORTED_PROVIDERS | |
if llm_str is None: | |
return None, None | |
try: | |
llm_provider, llm_model = llm_str.split(":", 1) | |
assert llm_provider in _SUPPORTED_PROVIDERS, ( | |
f"Unsupported {llm_provider}.\nSupported llm providers are: " | |
+ ", ".join(_SUPPORTED_PROVIDERS) | |
) | |
return llm_provider, llm_model | |
except ValueError: | |
raise ValueError( | |
"Set SMART_LLM or FAST_LLM = '<llm_provider>:<llm_model>' " | |
"Eg 'openai:gpt-4o-mini'" | |
) | |
def parse_embedding(embedding_str: str | None) -> tuple[str | None, str | None]: | |
"""Parse embedding string into (embedding_provider, embedding_model).""" | |
from gpt_researcher.memory.embeddings import _SUPPORTED_PROVIDERS | |
if embedding_str is None: | |
return None, None | |
try: | |
embedding_provider, embedding_model = embedding_str.split(":", 1) | |
assert embedding_provider in _SUPPORTED_PROVIDERS, ( | |
f"Unsupported {embedding_provider}.\nSupported embedding providers are: " | |
+ ", ".join(_SUPPORTED_PROVIDERS) | |
) | |
return embedding_provider, embedding_model | |
except ValueError: | |
raise ValueError( | |
"Set EMBEDDING = '<embedding_provider>:<embedding_model>' " | |
"Eg 'openai:text-embedding-3-large'" | |
) | |
def validate_doc_path(self): | |
"""Ensure that the folder exists at the doc path""" | |
os.makedirs(self.doc_path, exist_ok=True) | |
def convert_env_value(key: str, env_value: str, type_hint: Type) -> Any: | |
"""Convert environment variable to the appropriate type based on the type hint.""" | |
origin = get_origin(type_hint) | |
args = get_args(type_hint) | |
if origin is Union: | |
# Handle Union types (e.g., Union[str, None]) | |
for arg in args: | |
if arg is type(None): | |
if env_value.lower() in ("none", "null", ""): | |
return None | |
else: | |
try: | |
return Config.convert_env_value(key, env_value, arg) | |
except ValueError: | |
continue | |
raise ValueError(f"Cannot convert {env_value} to any of {args}") | |
if type_hint is bool: | |
return env_value.lower() in ("true", "1", "yes", "on") | |
elif type_hint is int: | |
return int(env_value) | |
elif type_hint is float: | |
return float(env_value) | |
elif type_hint in (str, Any): | |
return env_value | |
elif origin is list or origin is List: | |
return json.loads(env_value) | |
else: | |
raise ValueError(f"Unsupported type {type_hint} for key {key}") | |