|
|
|
|
|
import json |
|
import os |
|
import re |
|
import unittest |
|
from pathlib import Path |
|
from typing import Any, cast |
|
from unittest import mock |
|
|
|
import pytest |
|
import yaml |
|
from pydantic import ValidationError |
|
|
|
import graphrag.config.defaults as defs |
|
from graphrag.config import ( |
|
ApiKeyMissingError, |
|
AzureApiBaseMissingError, |
|
AzureDeploymentNameMissingError, |
|
CacheConfig, |
|
CacheConfigInput, |
|
CacheType, |
|
ChunkingConfig, |
|
ChunkingConfigInput, |
|
ClaimExtractionConfig, |
|
ClaimExtractionConfigInput, |
|
ClusterGraphConfig, |
|
ClusterGraphConfigInput, |
|
CommunityReportsConfig, |
|
CommunityReportsConfigInput, |
|
EmbedGraphConfig, |
|
EmbedGraphConfigInput, |
|
EntityExtractionConfig, |
|
EntityExtractionConfigInput, |
|
GlobalSearchConfig, |
|
GraphRagConfig, |
|
GraphRagConfigInput, |
|
InputConfig, |
|
InputConfigInput, |
|
InputFileType, |
|
InputType, |
|
LLMParameters, |
|
LLMParametersInput, |
|
LocalSearchConfig, |
|
ParallelizationParameters, |
|
ReportingConfig, |
|
ReportingConfigInput, |
|
ReportingType, |
|
SnapshotsConfig, |
|
SnapshotsConfigInput, |
|
StorageConfig, |
|
StorageConfigInput, |
|
StorageType, |
|
SummarizeDescriptionsConfig, |
|
SummarizeDescriptionsConfigInput, |
|
TextEmbeddingConfig, |
|
TextEmbeddingConfigInput, |
|
UmapConfig, |
|
UmapConfigInput, |
|
create_graphrag_config, |
|
) |
|
from graphrag.index import ( |
|
PipelineConfig, |
|
PipelineCSVInputConfig, |
|
PipelineFileCacheConfig, |
|
PipelineFileReportingConfig, |
|
PipelineFileStorageConfig, |
|
PipelineInputConfig, |
|
PipelineTextInputConfig, |
|
PipelineWorkflowReference, |
|
create_pipeline_config, |
|
) |
|
|
|
current_dir = os.path.dirname(__file__) |
|
|
|
ALL_ENV_VARS = { |
|
"GRAPHRAG_API_BASE": "http://some/base", |
|
"GRAPHRAG_API_KEY": "test", |
|
"GRAPHRAG_API_ORGANIZATION": "test_org", |
|
"GRAPHRAG_API_PROXY": "http://some/proxy", |
|
"GRAPHRAG_API_VERSION": "v1234", |
|
"GRAPHRAG_ASYNC_MODE": "asyncio", |
|
"GRAPHRAG_CACHE_STORAGE_ACCOUNT_BLOB_URL": "cache_account_blob_url", |
|
"GRAPHRAG_CACHE_BASE_DIR": "/some/cache/dir", |
|
"GRAPHRAG_CACHE_CONNECTION_STRING": "test_cs1", |
|
"GRAPHRAG_CACHE_CONTAINER_NAME": "test_cn1", |
|
"GRAPHRAG_CACHE_TYPE": "blob", |
|
"GRAPHRAG_CHUNK_BY_COLUMNS": "a,b", |
|
"GRAPHRAG_CHUNK_OVERLAP": "12", |
|
"GRAPHRAG_CHUNK_SIZE": "500", |
|
"GRAPHRAG_CLAIM_EXTRACTION_ENABLED": "True", |
|
"GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123", |
|
"GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "5000", |
|
"GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-a.txt", |
|
"GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH": "23456", |
|
"GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE": "tests/unit/config/prompt-b.txt", |
|
"GRAPHRAG_EMBEDDING_BATCH_MAX_TOKENS": "17", |
|
"GRAPHRAG_EMBEDDING_BATCH_SIZE": "1000000", |
|
"GRAPHRAG_EMBEDDING_CONCURRENT_REQUESTS": "12", |
|
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "model-deployment-name", |
|
"GRAPHRAG_EMBEDDING_MAX_RETRIES": "3", |
|
"GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT": "0.1123", |
|
"GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2", |
|
"GRAPHRAG_EMBEDDING_REQUESTS_PER_MINUTE": "500", |
|
"GRAPHRAG_EMBEDDING_SKIP": "a1,b1,c1", |
|
"GRAPHRAG_EMBEDDING_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", |
|
"GRAPHRAG_EMBEDDING_TARGET": "all", |
|
"GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345", |
|
"GRAPHRAG_EMBEDDING_THREAD_STAGGER": "0.456", |
|
"GRAPHRAG_EMBEDDING_TOKENS_PER_MINUTE": "7000", |
|
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", |
|
"GRAPHRAG_ENCODING_MODEL": "test123", |
|
"GRAPHRAG_INPUT_STORAGE_ACCOUNT_BLOB_URL": "input_account_blob_url", |
|
"GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES": "cat,dog,elephant", |
|
"GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112", |
|
"GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-c.txt", |
|
"GRAPHRAG_INPUT_BASE_DIR": "/some/input/dir", |
|
"GRAPHRAG_INPUT_CONNECTION_STRING": "input_cs", |
|
"GRAPHRAG_INPUT_CONTAINER_NAME": "input_cn", |
|
"GRAPHRAG_INPUT_DOCUMENT_ATTRIBUTE_COLUMNS": "test1,test2", |
|
"GRAPHRAG_INPUT_ENCODING": "utf-16", |
|
"GRAPHRAG_INPUT_FILE_PATTERN": ".*\\test\\.txt$", |
|
"GRAPHRAG_INPUT_SOURCE_COLUMN": "test_source", |
|
"GRAPHRAG_INPUT_TYPE": "blob", |
|
"GRAPHRAG_INPUT_TEXT_COLUMN": "test_text", |
|
"GRAPHRAG_INPUT_TIMESTAMP_COLUMN": "test_timestamp", |
|
"GRAPHRAG_INPUT_TIMESTAMP_FORMAT": "test_format", |
|
"GRAPHRAG_INPUT_TITLE_COLUMN": "test_title", |
|
"GRAPHRAG_INPUT_FILE_TYPE": "text", |
|
"GRAPHRAG_LLM_CONCURRENT_REQUESTS": "12", |
|
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x", |
|
"GRAPHRAG_LLM_MAX_RETRIES": "312", |
|
"GRAPHRAG_LLM_MAX_RETRY_WAIT": "0.1122", |
|
"GRAPHRAG_LLM_MAX_TOKENS": "15000", |
|
"GRAPHRAG_LLM_MODEL_SUPPORTS_JSON": "true", |
|
"GRAPHRAG_LLM_MODEL": "test-llm", |
|
"GRAPHRAG_LLM_N": "1", |
|
"GRAPHRAG_LLM_REQUEST_TIMEOUT": "12.7", |
|
"GRAPHRAG_LLM_REQUESTS_PER_MINUTE": "900", |
|
"GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", |
|
"GRAPHRAG_LLM_THREAD_COUNT": "987", |
|
"GRAPHRAG_LLM_THREAD_STAGGER": "0.123", |
|
"GRAPHRAG_LLM_TOKENS_PER_MINUTE": "8000", |
|
"GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
|
"GRAPHRAG_MAX_CLUSTER_SIZE": "123", |
|
"GRAPHRAG_NODE2VEC_ENABLED": "true", |
|
"GRAPHRAG_NODE2VEC_ITERATIONS": "878787", |
|
"GRAPHRAG_NODE2VEC_NUM_WALKS": "5000000", |
|
"GRAPHRAG_NODE2VEC_RANDOM_SEED": "010101", |
|
"GRAPHRAG_NODE2VEC_WALK_LENGTH": "555111", |
|
"GRAPHRAG_NODE2VEC_WINDOW_SIZE": "12345", |
|
"GRAPHRAG_REPORTING_STORAGE_ACCOUNT_BLOB_URL": "reporting_account_blob_url", |
|
"GRAPHRAG_REPORTING_BASE_DIR": "/some/reporting/dir", |
|
"GRAPHRAG_REPORTING_CONNECTION_STRING": "test_cs2", |
|
"GRAPHRAG_REPORTING_CONTAINER_NAME": "test_cn2", |
|
"GRAPHRAG_REPORTING_TYPE": "blob", |
|
"GRAPHRAG_SKIP_WORKFLOWS": "a,b,c", |
|
"GRAPHRAG_SNAPSHOT_GRAPHML": "true", |
|
"GRAPHRAG_SNAPSHOT_RAW_ENTITIES": "true", |
|
"GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES": "true", |
|
"GRAPHRAG_STORAGE_STORAGE_ACCOUNT_BLOB_URL": "storage_account_blob_url", |
|
"GRAPHRAG_STORAGE_BASE_DIR": "/some/storage/dir", |
|
"GRAPHRAG_STORAGE_CONNECTION_STRING": "test_cs", |
|
"GRAPHRAG_STORAGE_CONTAINER_NAME": "test_cn", |
|
"GRAPHRAG_STORAGE_TYPE": "blob", |
|
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345", |
|
"GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE": "tests/unit/config/prompt-d.txt", |
|
"GRAPHRAG_LLM_TEMPERATURE": "0.0", |
|
"GRAPHRAG_LLM_TOP_P": "1.0", |
|
"GRAPHRAG_UMAP_ENABLED": "true", |
|
"GRAPHRAG_LOCAL_SEARCH_TEXT_UNIT_PROP": "0.713", |
|
"GRAPHRAG_LOCAL_SEARCH_COMMUNITY_PROP": "0.1234", |
|
"GRAPHRAG_LOCAL_SEARCH_LLM_TEMPERATURE": "0.1", |
|
"GRAPHRAG_LOCAL_SEARCH_LLM_TOP_P": "0.9", |
|
"GRAPHRAG_LOCAL_SEARCH_LLM_N": "2", |
|
"GRAPHRAG_LOCAL_SEARCH_LLM_MAX_TOKENS": "12", |
|
"GRAPHRAG_LOCAL_SEARCH_TOP_K_RELATIONSHIPS": "15", |
|
"GRAPHRAG_LOCAL_SEARCH_TOP_K_ENTITIES": "14", |
|
"GRAPHRAG_LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS": "2", |
|
"GRAPHRAG_LOCAL_SEARCH_MAX_TOKENS": "142435", |
|
"GRAPHRAG_GLOBAL_SEARCH_LLM_TEMPERATURE": "0.1", |
|
"GRAPHRAG_GLOBAL_SEARCH_LLM_TOP_P": "0.9", |
|
"GRAPHRAG_GLOBAL_SEARCH_LLM_N": "2", |
|
"GRAPHRAG_GLOBAL_SEARCH_MAX_TOKENS": "5123", |
|
"GRAPHRAG_GLOBAL_SEARCH_DATA_MAX_TOKENS": "123", |
|
"GRAPHRAG_GLOBAL_SEARCH_MAP_MAX_TOKENS": "4123", |
|
"GRAPHRAG_GLOBAL_SEARCH_CONCURRENCY": "7", |
|
"GRAPHRAG_GLOBAL_SEARCH_REDUCE_MAX_TOKENS": "15432", |
|
} |
|
|
|
|
|
class TestDefaultConfig(unittest.TestCase): |
|
def test_clear_warnings(self): |
|
"""Just clearing unused import warnings""" |
|
assert CacheConfig is not None |
|
assert ChunkingConfig is not None |
|
assert ClaimExtractionConfig is not None |
|
assert ClusterGraphConfig is not None |
|
assert CommunityReportsConfig is not None |
|
assert EmbedGraphConfig is not None |
|
assert EntityExtractionConfig is not None |
|
assert GlobalSearchConfig is not None |
|
assert GraphRagConfig is not None |
|
assert InputConfig is not None |
|
assert LLMParameters is not None |
|
assert LocalSearchConfig is not None |
|
assert ParallelizationParameters is not None |
|
assert ReportingConfig is not None |
|
assert SnapshotsConfig is not None |
|
assert StorageConfig is not None |
|
assert SummarizeDescriptionsConfig is not None |
|
assert TextEmbeddingConfig is not None |
|
assert UmapConfig is not None |
|
assert PipelineConfig is not None |
|
assert PipelineFileReportingConfig is not None |
|
assert PipelineFileStorageConfig is not None |
|
assert PipelineInputConfig is not None |
|
assert PipelineFileCacheConfig is not None |
|
assert PipelineWorkflowReference is not None |
|
|
|
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True) |
|
def test_string_repr(self): |
|
|
|
config = create_graphrag_config() |
|
string_repr = str(config) |
|
assert string_repr is not None |
|
assert json.loads(string_repr) is not None |
|
|
|
|
|
repr_str = config.__repr__() |
|
|
|
repr_str = repr_str.replace("async_mode=<AsyncType.Threaded: 'threaded'>,", "") |
|
assert eval(repr_str) is not None |
|
|
|
|
|
pipeline_config = create_pipeline_config(config) |
|
string_repr = str(pipeline_config) |
|
assert string_repr is not None |
|
assert json.loads(string_repr) is not None |
|
|
|
|
|
repr_str = pipeline_config.__repr__() |
|
|
|
repr_str = repr_str.replace( |
|
"'async_mode': <AsyncType.Threaded: 'threaded'>,", "" |
|
) |
|
assert eval(repr_str) is not None |
|
|
|
@mock.patch.dict(os.environ, {}, clear=True) |
|
def test_default_config_with_no_env_vars_throws(self): |
|
with pytest.raises(ApiKeyMissingError): |
|
|
|
create_pipeline_config(create_graphrag_config()) |
|
|
|
@mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
|
def test_default_config_with_api_key_passes(self): |
|
|
|
config = create_pipeline_config(create_graphrag_config()) |
|
assert config is not None |
|
|
|
@mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True) |
|
def test_default_config_with_oai_key_passes_envvar(self): |
|
|
|
config = create_pipeline_config(create_graphrag_config()) |
|
assert config is not None |
|
|
|
def test_default_config_with_oai_key_passes_obj(self): |
|
|
|
config = create_pipeline_config( |
|
create_graphrag_config({"llm": {"api_key": "test"}}) |
|
) |
|
assert config is not None |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{"GRAPHRAG_API_KEY": "test", "GRAPHRAG_LLM_TYPE": "azure_openai_chat"}, |
|
clear=True, |
|
) |
|
def test_throws_if_azure_is_used_without_api_base_envvar(self): |
|
with pytest.raises(AzureApiBaseMissingError): |
|
create_graphrag_config() |
|
|
|
@mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
|
def test_throws_if_azure_is_used_without_api_base_obj(self): |
|
with pytest.raises(AzureApiBaseMissingError): |
|
create_graphrag_config( |
|
GraphRagConfigInput(llm=LLMParametersInput(type="azure_openai_chat")) |
|
) |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{ |
|
"GRAPHRAG_API_KEY": "test", |
|
"GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
|
"GRAPHRAG_API_BASE": "http://some/base", |
|
}, |
|
clear=True, |
|
) |
|
def test_throws_if_azure_is_used_without_llm_deployment_name_envvar(self): |
|
with pytest.raises(AzureDeploymentNameMissingError): |
|
create_graphrag_config() |
|
|
|
@mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
|
def test_throws_if_azure_is_used_without_llm_deployment_name_obj(self): |
|
with pytest.raises(AzureDeploymentNameMissingError): |
|
create_graphrag_config( |
|
GraphRagConfigInput( |
|
llm=LLMParametersInput( |
|
type="azure_openai_chat", api_base="http://some/base" |
|
) |
|
) |
|
) |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{ |
|
"GRAPHRAG_API_KEY": "test", |
|
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", |
|
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "x", |
|
}, |
|
clear=True, |
|
) |
|
def test_throws_if_azure_is_used_without_embedding_api_base_envvar(self): |
|
with pytest.raises(AzureApiBaseMissingError): |
|
create_graphrag_config() |
|
|
|
@mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
|
def test_throws_if_azure_is_used_without_embedding_api_base_obj(self): |
|
with pytest.raises(AzureApiBaseMissingError): |
|
create_graphrag_config( |
|
GraphRagConfigInput( |
|
embeddings=TextEmbeddingConfigInput( |
|
llm=LLMParametersInput( |
|
type="azure_openai_embedding", |
|
deployment_name="x", |
|
) |
|
), |
|
) |
|
) |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{ |
|
"GRAPHRAG_API_KEY": "test", |
|
"GRAPHRAG_API_BASE": "http://some/base", |
|
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "x", |
|
"GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
|
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", |
|
}, |
|
clear=True, |
|
) |
|
def test_throws_if_azure_is_used_without_embedding_deployment_name_envvar(self): |
|
with pytest.raises(AzureDeploymentNameMissingError): |
|
create_graphrag_config() |
|
|
|
@mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
|
def test_throws_if_azure_is_used_without_embedding_deployment_name_obj(self): |
|
with pytest.raises(AzureDeploymentNameMissingError): |
|
create_graphrag_config( |
|
GraphRagConfigInput( |
|
llm=LLMParametersInput( |
|
type="azure_openai_chat", |
|
api_base="http://some/base", |
|
deployment_name="model-deployment-name-x", |
|
), |
|
embeddings=TextEmbeddingConfigInput( |
|
llm=LLMParametersInput( |
|
type="azure_openai_embedding", |
|
) |
|
), |
|
) |
|
) |
|
|
|
@mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) |
|
def test_minimim_azure_config_object(self): |
|
config = create_graphrag_config( |
|
GraphRagConfigInput( |
|
llm=LLMParametersInput( |
|
type="azure_openai_chat", |
|
api_base="http://some/base", |
|
deployment_name="model-deployment-name-x", |
|
), |
|
embeddings=TextEmbeddingConfigInput( |
|
llm=LLMParametersInput( |
|
type="azure_openai_embedding", |
|
deployment_name="model-deployment-name", |
|
) |
|
), |
|
) |
|
) |
|
assert config is not None |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{ |
|
"GRAPHRAG_API_KEY": "test", |
|
"GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
|
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "x", |
|
}, |
|
clear=True, |
|
) |
|
def test_throws_if_azure_is_used_without_api_base(self): |
|
with pytest.raises(AzureApiBaseMissingError): |
|
create_graphrag_config() |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{ |
|
"GRAPHRAG_API_KEY": "test", |
|
"GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
|
"GRAPHRAG_LLM_API_BASE": "http://some/base", |
|
}, |
|
clear=True, |
|
) |
|
def test_throws_if_azure_is_used_without_llm_deployment_name(self): |
|
with pytest.raises(AzureDeploymentNameMissingError): |
|
create_graphrag_config() |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{ |
|
"GRAPHRAG_API_KEY": "test", |
|
"GRAPHRAG_LLM_TYPE": "azure_openai_chat", |
|
"GRAPHRAG_API_BASE": "http://some/base", |
|
"GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x", |
|
"GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", |
|
}, |
|
clear=True, |
|
) |
|
def test_throws_if_azure_is_used_without_embedding_deployment_name(self): |
|
with pytest.raises(AzureDeploymentNameMissingError): |
|
create_graphrag_config() |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{"GRAPHRAG_API_KEY": "test", "GRAPHRAG_INPUT_FILE_TYPE": "csv"}, |
|
clear=True, |
|
) |
|
def test_csv_input_returns_correct_config(self): |
|
config = create_pipeline_config(create_graphrag_config(root_dir="/some/root")) |
|
assert config.root_dir == "/some/root" |
|
|
|
assert isinstance(config.input, PipelineCSVInputConfig) |
|
assert (config.input.file_pattern or "") == ".*\\.csv$" |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{"GRAPHRAG_API_KEY": "test", "GRAPHRAG_INPUT_FILE_TYPE": "text"}, |
|
clear=True, |
|
) |
|
def test_text_input_returns_correct_config(self): |
|
config = create_pipeline_config(create_graphrag_config(root_dir=".")) |
|
assert isinstance(config.input, PipelineTextInputConfig) |
|
assert config.input is not None |
|
assert (config.input.file_pattern or "") == ".*\\.txt$" |
|
|
|
def test_all_env_vars_is_accurate(self): |
|
env_var_docs_path = Path("docsite/posts/config/env_vars.md") |
|
query_docs_path = Path("docsite/posts/query/3-cli.md") |
|
|
|
env_var_docs = env_var_docs_path.read_text(encoding="utf-8") |
|
query_docs = query_docs_path.read_text(encoding="utf-8") |
|
|
|
def find_envvar_names(text) -> set[str]: |
|
pattern = r"`(GRAPHRAG_[^`]+)`" |
|
found = re.findall(pattern, text) |
|
found = {f for f in found if not f.endswith("_")} |
|
return {*found} |
|
|
|
graphrag_strings = find_envvar_names(env_var_docs) | find_envvar_names( |
|
query_docs |
|
) |
|
|
|
missing = {s for s in graphrag_strings if s not in ALL_ENV_VARS} - { |
|
|
|
"GRAPHRAG_LLM_API_KEY", |
|
"GRAPHRAG_LLM_API_BASE", |
|
"GRAPHRAG_LLM_API_VERSION", |
|
"GRAPHRAG_LLM_API_ORGANIZATION", |
|
"GRAPHRAG_LLM_API_PROXY", |
|
"GRAPHRAG_EMBEDDING_API_KEY", |
|
"GRAPHRAG_EMBEDDING_API_BASE", |
|
"GRAPHRAG_EMBEDDING_API_VERSION", |
|
"GRAPHRAG_EMBEDDING_API_ORGANIZATION", |
|
"GRAPHRAG_EMBEDDING_API_PROXY", |
|
} |
|
if missing: |
|
msg = f"{len(missing)} missing env vars: {missing}" |
|
print(msg) |
|
raise ValueError(msg) |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{"GRAPHRAG_API_KEY": "test"}, |
|
clear=True, |
|
) |
|
def test_malformed_input_dict_throws(self): |
|
with pytest.raises(ValidationError): |
|
create_graphrag_config(cast(Any, {"llm": 12})) |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
ALL_ENV_VARS, |
|
clear=True, |
|
) |
|
def test_create_parameters_from_env_vars(self) -> None: |
|
parameters = create_graphrag_config() |
|
assert parameters.async_mode == "asyncio" |
|
assert parameters.cache.storage_account_blob_url == "cache_account_blob_url" |
|
assert parameters.cache.base_dir == "/some/cache/dir" |
|
assert parameters.cache.connection_string == "test_cs1" |
|
assert parameters.cache.container_name == "test_cn1" |
|
assert parameters.cache.type == CacheType.blob |
|
assert parameters.chunks.group_by_columns == ["a", "b"] |
|
assert parameters.chunks.overlap == 12 |
|
assert parameters.chunks.size == 500 |
|
assert parameters.claim_extraction.enabled |
|
assert parameters.claim_extraction.description == "test 123" |
|
assert parameters.claim_extraction.max_gleanings == 5000 |
|
assert parameters.claim_extraction.prompt == "tests/unit/config/prompt-a.txt" |
|
assert parameters.cluster_graph.max_cluster_size == 123 |
|
assert parameters.community_reports.max_length == 23456 |
|
assert parameters.community_reports.prompt == "tests/unit/config/prompt-b.txt" |
|
assert parameters.embed_graph.enabled |
|
assert parameters.embed_graph.iterations == 878787 |
|
assert parameters.embed_graph.num_walks == 5_000_000 |
|
assert parameters.embed_graph.random_seed == 10101 |
|
assert parameters.embed_graph.walk_length == 555111 |
|
assert parameters.embed_graph.window_size == 12345 |
|
assert parameters.embeddings.batch_max_tokens == 17 |
|
assert parameters.embeddings.batch_size == 1_000_000 |
|
assert parameters.embeddings.llm.concurrent_requests == 12 |
|
assert parameters.embeddings.llm.deployment_name == "model-deployment-name" |
|
assert parameters.embeddings.llm.max_retries == 3 |
|
assert parameters.embeddings.llm.max_retry_wait == 0.1123 |
|
assert parameters.embeddings.llm.model == "text-embedding-2" |
|
assert parameters.embeddings.llm.requests_per_minute == 500 |
|
assert parameters.embeddings.llm.sleep_on_rate_limit_recommendation is False |
|
assert parameters.embeddings.llm.tokens_per_minute == 7000 |
|
assert parameters.embeddings.llm.type == "azure_openai_embedding" |
|
assert parameters.embeddings.parallelization.num_threads == 2345 |
|
assert parameters.embeddings.parallelization.stagger == 0.456 |
|
assert parameters.embeddings.skip == ["a1", "b1", "c1"] |
|
assert parameters.embeddings.target == "all" |
|
assert parameters.encoding_model == "test123" |
|
assert parameters.entity_extraction.entity_types == ["cat", "dog", "elephant"] |
|
assert parameters.entity_extraction.llm.api_base == "http://some/base" |
|
assert parameters.entity_extraction.max_gleanings == 112 |
|
assert parameters.entity_extraction.prompt == "tests/unit/config/prompt-c.txt" |
|
assert parameters.input.storage_account_blob_url == "input_account_blob_url" |
|
assert parameters.input.base_dir == "/some/input/dir" |
|
assert parameters.input.connection_string == "input_cs" |
|
assert parameters.input.container_name == "input_cn" |
|
assert parameters.input.document_attribute_columns == ["test1", "test2"] |
|
assert parameters.input.encoding == "utf-16" |
|
assert parameters.input.file_pattern == ".*\\test\\.txt$" |
|
assert parameters.input.file_type == InputFileType.text |
|
assert parameters.input.source_column == "test_source" |
|
assert parameters.input.text_column == "test_text" |
|
assert parameters.input.timestamp_column == "test_timestamp" |
|
assert parameters.input.timestamp_format == "test_format" |
|
assert parameters.input.title_column == "test_title" |
|
assert parameters.input.type == InputType.blob |
|
assert parameters.llm.api_base == "http://some/base" |
|
assert parameters.llm.api_key == "test" |
|
assert parameters.llm.api_version == "v1234" |
|
assert parameters.llm.concurrent_requests == 12 |
|
assert parameters.llm.deployment_name == "model-deployment-name-x" |
|
assert parameters.llm.max_retries == 312 |
|
assert parameters.llm.max_retry_wait == 0.1122 |
|
assert parameters.llm.max_tokens == 15000 |
|
assert parameters.llm.model == "test-llm" |
|
assert parameters.llm.model_supports_json |
|
assert parameters.llm.n == 1 |
|
assert parameters.llm.organization == "test_org" |
|
assert parameters.llm.proxy == "http://some/proxy" |
|
assert parameters.llm.request_timeout == 12.7 |
|
assert parameters.llm.requests_per_minute == 900 |
|
assert parameters.llm.sleep_on_rate_limit_recommendation is False |
|
assert parameters.llm.temperature == 0.0 |
|
assert parameters.llm.top_p == 1.0 |
|
assert parameters.llm.tokens_per_minute == 8000 |
|
assert parameters.llm.type == "azure_openai_chat" |
|
assert parameters.parallelization.num_threads == 987 |
|
assert parameters.parallelization.stagger == 0.123 |
|
assert ( |
|
parameters.reporting.storage_account_blob_url |
|
== "reporting_account_blob_url" |
|
) |
|
assert parameters.reporting.base_dir == "/some/reporting/dir" |
|
assert parameters.reporting.connection_string == "test_cs2" |
|
assert parameters.reporting.container_name == "test_cn2" |
|
assert parameters.reporting.type == ReportingType.blob |
|
assert parameters.skip_workflows == ["a", "b", "c"] |
|
assert parameters.snapshots.graphml |
|
assert parameters.snapshots.raw_entities |
|
assert parameters.snapshots.top_level_nodes |
|
assert parameters.storage.storage_account_blob_url == "storage_account_blob_url" |
|
assert parameters.storage.base_dir == "/some/storage/dir" |
|
assert parameters.storage.connection_string == "test_cs" |
|
assert parameters.storage.container_name == "test_cn" |
|
assert parameters.storage.type == StorageType.blob |
|
assert parameters.summarize_descriptions.max_length == 12345 |
|
assert ( |
|
parameters.summarize_descriptions.prompt == "tests/unit/config/prompt-d.txt" |
|
) |
|
assert parameters.umap.enabled |
|
assert parameters.local_search.text_unit_prop == 0.713 |
|
assert parameters.local_search.community_prop == 0.1234 |
|
assert parameters.local_search.llm_max_tokens == 12 |
|
assert parameters.local_search.top_k_relationships == 15 |
|
assert parameters.local_search.conversation_history_max_turns == 2 |
|
assert parameters.local_search.top_k_entities == 14 |
|
assert parameters.local_search.temperature == 0.1 |
|
assert parameters.local_search.top_p == 0.9 |
|
assert parameters.local_search.n == 2 |
|
assert parameters.local_search.max_tokens == 142435 |
|
|
|
assert parameters.global_search.temperature == 0.1 |
|
assert parameters.global_search.top_p == 0.9 |
|
assert parameters.global_search.n == 2 |
|
assert parameters.global_search.max_tokens == 5123 |
|
assert parameters.global_search.data_max_tokens == 123 |
|
assert parameters.global_search.map_max_tokens == 4123 |
|
assert parameters.global_search.concurrency == 7 |
|
assert parameters.global_search.reduce_max_tokens == 15432 |
|
|
|
@mock.patch.dict(os.environ, {"API_KEY_X": "test"}, clear=True) |
|
def test_create_parameters(self) -> None: |
|
parameters = create_graphrag_config( |
|
GraphRagConfigInput( |
|
llm=LLMParametersInput(api_key="${API_KEY_X}", model="test-llm"), |
|
storage=StorageConfigInput( |
|
type=StorageType.blob, |
|
connection_string="test_cs", |
|
container_name="test_cn", |
|
base_dir="/some/storage/dir", |
|
storage_account_blob_url="storage_account_blob_url", |
|
), |
|
cache=CacheConfigInput( |
|
type=CacheType.blob, |
|
connection_string="test_cs1", |
|
container_name="test_cn1", |
|
base_dir="/some/cache/dir", |
|
storage_account_blob_url="cache_account_blob_url", |
|
), |
|
reporting=ReportingConfigInput( |
|
type=ReportingType.blob, |
|
connection_string="test_cs2", |
|
container_name="test_cn2", |
|
base_dir="/some/reporting/dir", |
|
storage_account_blob_url="reporting_account_blob_url", |
|
), |
|
input=InputConfigInput( |
|
file_type=InputFileType.text, |
|
file_encoding="utf-16", |
|
document_attribute_columns=["test1", "test2"], |
|
base_dir="/some/input/dir", |
|
connection_string="input_cs", |
|
container_name="input_cn", |
|
file_pattern=".*\\test\\.txt$", |
|
source_column="test_source", |
|
text_column="test_text", |
|
timestamp_column="test_timestamp", |
|
timestamp_format="test_format", |
|
title_column="test_title", |
|
type="blob", |
|
storage_account_blob_url="input_account_blob_url", |
|
), |
|
embed_graph=EmbedGraphConfigInput( |
|
enabled=True, |
|
num_walks=5_000_000, |
|
iterations=878787, |
|
random_seed=10101, |
|
walk_length=555111, |
|
), |
|
embeddings=TextEmbeddingConfigInput( |
|
batch_size=1_000_000, |
|
batch_max_tokens=8000, |
|
skip=["a1", "b1", "c1"], |
|
llm=LLMParametersInput(model="text-embedding-2"), |
|
), |
|
chunks=ChunkingConfigInput( |
|
size=500, overlap=12, group_by_columns=["a", "b"] |
|
), |
|
snapshots=SnapshotsConfigInput( |
|
graphml=True, |
|
raw_entities=True, |
|
top_level_nodes=True, |
|
), |
|
entity_extraction=EntityExtractionConfigInput( |
|
max_gleanings=112, |
|
entity_types=["cat", "dog", "elephant"], |
|
prompt="entity_extraction_prompt_file.txt", |
|
), |
|
summarize_descriptions=SummarizeDescriptionsConfigInput( |
|
max_length=12345, prompt="summarize_prompt_file.txt" |
|
), |
|
community_reports=CommunityReportsConfigInput( |
|
max_length=23456, |
|
prompt="community_report_prompt_file.txt", |
|
max_input_length=12345, |
|
), |
|
claim_extraction=ClaimExtractionConfigInput( |
|
description="test 123", |
|
max_gleanings=5000, |
|
prompt="claim_extraction_prompt_file.txt", |
|
), |
|
cluster_graph=ClusterGraphConfigInput( |
|
max_cluster_size=123, |
|
), |
|
umap=UmapConfigInput(enabled=True), |
|
encoding_model="test123", |
|
skip_workflows=["a", "b", "c"], |
|
), |
|
".", |
|
) |
|
|
|
assert parameters.cache.base_dir == "/some/cache/dir" |
|
assert parameters.cache.connection_string == "test_cs1" |
|
assert parameters.cache.container_name == "test_cn1" |
|
assert parameters.cache.type == CacheType.blob |
|
assert parameters.cache.storage_account_blob_url == "cache_account_blob_url" |
|
assert parameters.chunks.group_by_columns == ["a", "b"] |
|
assert parameters.chunks.overlap == 12 |
|
assert parameters.chunks.size == 500 |
|
assert parameters.claim_extraction.description == "test 123" |
|
assert parameters.claim_extraction.max_gleanings == 5000 |
|
assert parameters.claim_extraction.prompt == "claim_extraction_prompt_file.txt" |
|
assert parameters.cluster_graph.max_cluster_size == 123 |
|
assert parameters.community_reports.max_input_length == 12345 |
|
assert parameters.community_reports.max_length == 23456 |
|
assert parameters.community_reports.prompt == "community_report_prompt_file.txt" |
|
assert parameters.embed_graph.enabled |
|
assert parameters.embed_graph.iterations == 878787 |
|
assert parameters.embed_graph.num_walks == 5_000_000 |
|
assert parameters.embed_graph.random_seed == 10101 |
|
assert parameters.embed_graph.walk_length == 555111 |
|
assert parameters.embeddings.batch_max_tokens == 8000 |
|
assert parameters.embeddings.batch_size == 1_000_000 |
|
assert parameters.embeddings.llm.model == "text-embedding-2" |
|
assert parameters.embeddings.skip == ["a1", "b1", "c1"] |
|
assert parameters.encoding_model == "test123" |
|
assert parameters.entity_extraction.entity_types == ["cat", "dog", "elephant"] |
|
assert parameters.entity_extraction.max_gleanings == 112 |
|
assert ( |
|
parameters.entity_extraction.prompt == "entity_extraction_prompt_file.txt" |
|
) |
|
assert parameters.input.base_dir == "/some/input/dir" |
|
assert parameters.input.connection_string == "input_cs" |
|
assert parameters.input.container_name == "input_cn" |
|
assert parameters.input.document_attribute_columns == ["test1", "test2"] |
|
assert parameters.input.encoding == "utf-16" |
|
assert parameters.input.file_pattern == ".*\\test\\.txt$" |
|
assert parameters.input.source_column == "test_source" |
|
assert parameters.input.type == "blob" |
|
assert parameters.input.text_column == "test_text" |
|
assert parameters.input.timestamp_column == "test_timestamp" |
|
assert parameters.input.timestamp_format == "test_format" |
|
assert parameters.input.title_column == "test_title" |
|
assert parameters.input.file_type == InputFileType.text |
|
assert parameters.input.storage_account_blob_url == "input_account_blob_url" |
|
assert parameters.llm.api_key == "test" |
|
assert parameters.llm.model == "test-llm" |
|
assert parameters.reporting.base_dir == "/some/reporting/dir" |
|
assert parameters.reporting.connection_string == "test_cs2" |
|
assert parameters.reporting.container_name == "test_cn2" |
|
assert parameters.reporting.type == ReportingType.blob |
|
assert ( |
|
parameters.reporting.storage_account_blob_url |
|
== "reporting_account_blob_url" |
|
) |
|
assert parameters.skip_workflows == ["a", "b", "c"] |
|
assert parameters.snapshots.graphml |
|
assert parameters.snapshots.raw_entities |
|
assert parameters.snapshots.top_level_nodes |
|
assert parameters.storage.base_dir == "/some/storage/dir" |
|
assert parameters.storage.connection_string == "test_cs" |
|
assert parameters.storage.container_name == "test_cn" |
|
assert parameters.storage.type == StorageType.blob |
|
assert parameters.storage.storage_account_blob_url == "storage_account_blob_url" |
|
assert parameters.summarize_descriptions.max_length == 12345 |
|
assert parameters.summarize_descriptions.prompt == "summarize_prompt_file.txt" |
|
assert parameters.umap.enabled |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{"GRAPHRAG_API_KEY": "test"}, |
|
clear=True, |
|
) |
|
def test_default_values(self) -> None: |
|
parameters = create_graphrag_config() |
|
assert parameters.async_mode == defs.ASYNC_MODE |
|
assert parameters.cache.base_dir == defs.CACHE_BASE_DIR |
|
assert parameters.cache.type == defs.CACHE_TYPE |
|
assert parameters.cache.base_dir == defs.CACHE_BASE_DIR |
|
assert parameters.chunks.group_by_columns == defs.CHUNK_GROUP_BY_COLUMNS |
|
assert parameters.chunks.overlap == defs.CHUNK_OVERLAP |
|
assert parameters.chunks.size == defs.CHUNK_SIZE |
|
assert parameters.claim_extraction.description == defs.CLAIM_DESCRIPTION |
|
assert parameters.claim_extraction.max_gleanings == defs.CLAIM_MAX_GLEANINGS |
|
assert ( |
|
parameters.community_reports.max_input_length |
|
== defs.COMMUNITY_REPORT_MAX_INPUT_LENGTH |
|
) |
|
assert ( |
|
parameters.community_reports.max_length == defs.COMMUNITY_REPORT_MAX_LENGTH |
|
) |
|
assert parameters.embeddings.batch_max_tokens == defs.EMBEDDING_BATCH_MAX_TOKENS |
|
assert parameters.embeddings.batch_size == defs.EMBEDDING_BATCH_SIZE |
|
assert parameters.embeddings.llm.model == defs.EMBEDDING_MODEL |
|
assert parameters.embeddings.target == defs.EMBEDDING_TARGET |
|
assert parameters.embeddings.llm.type == defs.EMBEDDING_TYPE |
|
assert ( |
|
parameters.embeddings.llm.requests_per_minute |
|
== defs.LLM_REQUESTS_PER_MINUTE |
|
) |
|
assert parameters.embeddings.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE |
|
assert ( |
|
parameters.embeddings.llm.sleep_on_rate_limit_recommendation |
|
== defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION |
|
) |
|
assert ( |
|
parameters.entity_extraction.entity_types |
|
== defs.ENTITY_EXTRACTION_ENTITY_TYPES |
|
) |
|
assert ( |
|
parameters.entity_extraction.max_gleanings |
|
== defs.ENTITY_EXTRACTION_MAX_GLEANINGS |
|
) |
|
assert parameters.encoding_model == defs.ENCODING_MODEL |
|
assert parameters.input.base_dir == defs.INPUT_BASE_DIR |
|
assert parameters.input.file_pattern == defs.INPUT_CSV_PATTERN |
|
assert parameters.input.encoding == defs.INPUT_FILE_ENCODING |
|
assert parameters.input.type == defs.INPUT_TYPE |
|
assert parameters.input.base_dir == defs.INPUT_BASE_DIR |
|
assert parameters.input.text_column == defs.INPUT_TEXT_COLUMN |
|
assert parameters.input.file_type == defs.INPUT_FILE_TYPE |
|
assert parameters.llm.concurrent_requests == defs.LLM_CONCURRENT_REQUESTS |
|
assert parameters.llm.max_retries == defs.LLM_MAX_RETRIES |
|
assert parameters.llm.max_retry_wait == defs.LLM_MAX_RETRY_WAIT |
|
assert parameters.llm.max_tokens == defs.LLM_MAX_TOKENS |
|
assert parameters.llm.model == defs.LLM_MODEL |
|
assert parameters.llm.request_timeout == defs.LLM_REQUEST_TIMEOUT |
|
assert parameters.llm.requests_per_minute == defs.LLM_REQUESTS_PER_MINUTE |
|
assert parameters.llm.tokens_per_minute == defs.LLM_TOKENS_PER_MINUTE |
|
assert ( |
|
parameters.llm.sleep_on_rate_limit_recommendation |
|
== defs.LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION |
|
) |
|
assert parameters.llm.type == defs.LLM_TYPE |
|
assert parameters.cluster_graph.max_cluster_size == defs.MAX_CLUSTER_SIZE |
|
assert parameters.embed_graph.enabled == defs.NODE2VEC_ENABLED |
|
assert parameters.embed_graph.iterations == defs.NODE2VEC_ITERATIONS |
|
assert parameters.embed_graph.num_walks == defs.NODE2VEC_NUM_WALKS |
|
assert parameters.embed_graph.random_seed == defs.NODE2VEC_RANDOM_SEED |
|
assert parameters.embed_graph.walk_length == defs.NODE2VEC_WALK_LENGTH |
|
assert parameters.embed_graph.window_size == defs.NODE2VEC_WINDOW_SIZE |
|
assert ( |
|
parameters.parallelization.num_threads == defs.PARALLELIZATION_NUM_THREADS |
|
) |
|
assert parameters.parallelization.stagger == defs.PARALLELIZATION_STAGGER |
|
assert parameters.reporting.type == defs.REPORTING_TYPE |
|
assert parameters.reporting.base_dir == defs.REPORTING_BASE_DIR |
|
assert parameters.snapshots.graphml == defs.SNAPSHOTS_GRAPHML |
|
assert parameters.snapshots.raw_entities == defs.SNAPSHOTS_RAW_ENTITIES |
|
assert parameters.snapshots.top_level_nodes == defs.SNAPSHOTS_TOP_LEVEL_NODES |
|
assert parameters.storage.base_dir == defs.STORAGE_BASE_DIR |
|
assert parameters.storage.type == defs.STORAGE_TYPE |
|
assert parameters.umap.enabled == defs.UMAP_ENABLED |
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{"GRAPHRAG_API_KEY": "test"}, |
|
clear=True, |
|
) |
|
def test_prompt_file_reading(self): |
|
config = create_graphrag_config({ |
|
"entity_extraction": {"prompt": "tests/unit/config/prompt-a.txt"}, |
|
"claim_extraction": {"prompt": "tests/unit/config/prompt-b.txt"}, |
|
"community_reports": {"prompt": "tests/unit/config/prompt-c.txt"}, |
|
"summarize_descriptions": {"prompt": "tests/unit/config/prompt-d.txt"}, |
|
}) |
|
strategy = config.entity_extraction.resolved_strategy(".", "abc123") |
|
assert strategy["extraction_prompt"] == "Hello, World! A" |
|
assert strategy["encoding_name"] == "abc123" |
|
|
|
strategy = config.claim_extraction.resolved_strategy(".") |
|
assert strategy["extraction_prompt"] == "Hello, World! B" |
|
|
|
strategy = config.community_reports.resolved_strategy(".") |
|
assert strategy["extraction_prompt"] == "Hello, World! C" |
|
|
|
strategy = config.summarize_descriptions.resolved_strategy(".") |
|
assert strategy["summarize_prompt"] == "Hello, World! D" |
|
|
|
|
|
@mock.patch.dict( |
|
os.environ, |
|
{ |
|
"PIPELINE_LLM_API_KEY": "test", |
|
"PIPELINE_LLM_API_BASE": "http://test", |
|
"PIPELINE_LLM_API_VERSION": "v1", |
|
"PIPELINE_LLM_MODEL": "test-llm", |
|
"PIPELINE_LLM_DEPLOYMENT_NAME": "test", |
|
}, |
|
clear=True, |
|
) |
|
def test_yaml_load_e2e(): |
|
config_dict = yaml.safe_load( |
|
""" |
|
input: |
|
file_type: text |
|
|
|
llm: |
|
type: azure_openai_chat |
|
api_key: ${PIPELINE_LLM_API_KEY} |
|
api_base: ${PIPELINE_LLM_API_BASE} |
|
api_version: ${PIPELINE_LLM_API_VERSION} |
|
model: ${PIPELINE_LLM_MODEL} |
|
deployment_name: ${PIPELINE_LLM_DEPLOYMENT_NAME} |
|
model_supports_json: True |
|
tokens_per_minute: 80000 |
|
requests_per_minute: 900 |
|
thread_count: 50 |
|
concurrent_requests: 25 |
|
""" |
|
) |
|
|
|
model = config_dict |
|
parameters = create_graphrag_config(model, ".") |
|
|
|
assert parameters.llm.api_key == "test" |
|
assert parameters.llm.model == "test-llm" |
|
assert parameters.llm.api_base == "http://test" |
|
assert parameters.llm.api_version == "v1" |
|
assert parameters.llm.deployment_name == "test" |
|
|
|
|
|
pipeline_config = create_pipeline_config(parameters, True) |
|
|
|
config_str = pipeline_config.model_dump_json() |
|
assert "${PIPELINE_LLM_API_KEY}" not in config_str |
|
assert "${PIPELINE_LLM_API_BASE}" not in config_str |
|
assert "${PIPELINE_LLM_API_VERSION}" not in config_str |
|
assert "${PIPELINE_LLM_MODEL}" not in config_str |
|
assert "${PIPELINE_LLM_DEPLOYMENT_NAME}" not in config_str |
|
|