Shreyas094's picture
Upload 528 files
372531f verified
import json
import os
import warnings
from typing import Dict, Any, List, Union, Type, get_origin, get_args
from .variables.default import DEFAULT_CONFIG
from .variables.base import BaseConfig
from ..retrievers.utils import get_all_retriever_names
class Config:
"""Config class for GPT Researcher."""
CONFIG_DIR = os.path.join(os.path.dirname(__file__), "variables")
def __init__(self, config_path: str | None = None):
"""Initialize the config class."""
self.config_path = config_path
self.llm_kwargs: Dict[str, Any] = {}
self.embedding_kwargs: Dict[str, Any] = {}
config_to_use = self.load_config(config_path)
self._set_attributes(config_to_use)
self._set_embedding_attributes()
self._set_llm_attributes()
self._handle_deprecated_attributes()
self._set_doc_path(config_to_use)
def _set_attributes(self, config: Dict[str, Any]) -> None:
for key, value in config.items():
env_value = os.getenv(key)
if env_value is not None:
value = self.convert_env_value(key, env_value, BaseConfig.__annotations__[key])
setattr(self, key.lower(), value)
# Handle RETRIEVER with default value
retriever_env = os.environ.get("RETRIEVER", config.get("RETRIEVER", "tavily"))
try:
self.retrievers = self.parse_retrievers(retriever_env)
except ValueError as e:
print(f"Warning: {str(e)}. Defaulting to 'tavily' retriever.")
self.retrievers = ["tavily"]
def _set_embedding_attributes(self) -> None:
self.embedding_provider, self.embedding_model = self.parse_embedding(
self.embedding
)
def _set_llm_attributes(self) -> None:
self.fast_llm_provider, self.fast_llm_model = self.parse_llm(self.fast_llm)
self.smart_llm_provider, self.smart_llm_model = self.parse_llm(self.smart_llm)
self.strategic_llm_provider, self.strategic_llm_model = self.parse_llm(self.strategic_llm)
def _handle_deprecated_attributes(self) -> None:
if os.getenv("EMBEDDING_PROVIDER") is not None:
warnings.warn(
"EMBEDDING_PROVIDER is deprecated and will be removed soon. Use EMBEDDING instead.",
FutureWarning,
stacklevel=2,
)
self.embedding_provider = (
os.environ["EMBEDDING_PROVIDER"] or self.embedding_provider
)
match os.environ["EMBEDDING_PROVIDER"]:
case "ollama":
self.embedding_model = os.environ["OLLAMA_EMBEDDING_MODEL"]
case "custom":
self.embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "custom")
case "openai":
self.embedding_model = "text-embedding-3-large"
case "azure_openai":
self.embedding_model = "text-embedding-3-large"
case "huggingface":
self.embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
case _:
raise Exception("Embedding provider not found.")
_deprecation_warning = (
"LLM_PROVIDER, FAST_LLM_MODEL and SMART_LLM_MODEL are deprecated and "
"will be removed soon. Use FAST_LLM and SMART_LLM instead."
)
if os.getenv("LLM_PROVIDER") is not None:
warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2)
self.fast_llm_provider = (
os.environ["LLM_PROVIDER"] or self.fast_llm_provider
)
self.smart_llm_provider = (
os.environ["LLM_PROVIDER"] or self.smart_llm_provider
)
if os.getenv("FAST_LLM_MODEL") is not None:
warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2)
self.fast_llm_model = os.environ["FAST_LLM_MODEL"] or self.fast_llm_model
if os.getenv("SMART_LLM_MODEL") is not None:
warnings.warn(_deprecation_warning, FutureWarning, stacklevel=2)
self.smart_llm_model = os.environ["SMART_LLM_MODEL"] or self.smart_llm_model
def _set_doc_path(self, config: Dict[str, Any]) -> None:
self.doc_path = config['DOC_PATH']
if self.doc_path:
try:
self.validate_doc_path()
except Exception as e:
print(f"Warning: Error validating doc_path: {str(e)}. Using default doc_path.")
self.doc_path = DEFAULT_CONFIG['DOC_PATH']
@classmethod
def load_config(cls, config_path: str | None) -> Dict[str, Any]:
"""Load a configuration by name."""
if config_path is None:
return DEFAULT_CONFIG
# config_path = os.path.join(cls.CONFIG_DIR, config_path)
if not os.path.exists(config_path):
if config_path and config_path != "default":
print(f"Warning: Configuration not found at '{config_path}'. Using default configuration.")
if not config_path.endswith(".json"):
print(f"Do you mean '{config_path}.json'?")
return DEFAULT_CONFIG
with open(config_path, "r") as f:
custom_config = json.load(f)
# Merge with default config to ensure all keys are present
merged_config = DEFAULT_CONFIG.copy()
merged_config.update(custom_config)
return merged_config
@classmethod
def list_available_configs(cls) -> List[str]:
"""List all available configuration names."""
configs = ["default"]
for file in os.listdir(cls.CONFIG_DIR):
if file.endswith(".json"):
configs.append(file[:-5]) # Remove .json extension
return configs
def parse_retrievers(self, retriever_str: str) -> List[str]:
"""Parse the retriever string into a list of retrievers and validate them."""
retrievers = [retriever.strip()
for retriever in retriever_str.split(",")]
valid_retrievers = get_all_retriever_names() or []
invalid_retrievers = [r for r in retrievers if r not in valid_retrievers]
if invalid_retrievers:
raise ValueError(
f"Invalid retriever(s) found: {', '.join(invalid_retrievers)}. "
f"Valid options are: {', '.join(valid_retrievers)}."
)
return retrievers
@staticmethod
def parse_llm(llm_str: str | None) -> tuple[str | None, str | None]:
"""Parse llm string into (llm_provider, llm_model)."""
from gpt_researcher.llm_provider.generic.base import _SUPPORTED_PROVIDERS
if llm_str is None:
return None, None
try:
llm_provider, llm_model = llm_str.split(":", 1)
assert llm_provider in _SUPPORTED_PROVIDERS, (
f"Unsupported {llm_provider}.\nSupported llm providers are: "
+ ", ".join(_SUPPORTED_PROVIDERS)
)
return llm_provider, llm_model
except ValueError:
raise ValueError(
"Set SMART_LLM or FAST_LLM = '<llm_provider>:<llm_model>' "
"Eg 'openai:gpt-4o-mini'"
)
@staticmethod
def parse_embedding(embedding_str: str | None) -> tuple[str | None, str | None]:
"""Parse embedding string into (embedding_provider, embedding_model)."""
from gpt_researcher.memory.embeddings import _SUPPORTED_PROVIDERS
if embedding_str is None:
return None, None
try:
embedding_provider, embedding_model = embedding_str.split(":", 1)
assert embedding_provider in _SUPPORTED_PROVIDERS, (
f"Unsupported {embedding_provider}.\nSupported embedding providers are: "
+ ", ".join(_SUPPORTED_PROVIDERS)
)
return embedding_provider, embedding_model
except ValueError:
raise ValueError(
"Set EMBEDDING = '<embedding_provider>:<embedding_model>' "
"Eg 'openai:text-embedding-3-large'"
)
def validate_doc_path(self):
"""Ensure that the folder exists at the doc path"""
os.makedirs(self.doc_path, exist_ok=True)
@staticmethod
def convert_env_value(key: str, env_value: str, type_hint: Type) -> Any:
"""Convert environment variable to the appropriate type based on the type hint."""
origin = get_origin(type_hint)
args = get_args(type_hint)
if origin is Union:
# Handle Union types (e.g., Union[str, None])
for arg in args:
if arg is type(None):
if env_value.lower() in ("none", "null", ""):
return None
else:
try:
return Config.convert_env_value(key, env_value, arg)
except ValueError:
continue
raise ValueError(f"Cannot convert {env_value} to any of {args}")
if type_hint is bool:
return env_value.lower() in ("true", "1", "yes", "on")
elif type_hint is int:
return int(env_value)
elif type_hint is float:
return float(env_value)
elif type_hint in (str, Any):
return env_value
elif origin is list or origin is List:
return json.loads(env_value)
else:
raise ValueError(f"Unsupported type {type_hint} for key {key}")