Spaces:
Runtime error
Runtime error
"""Weaviate utils.""" | |
from typing import Any, Dict, List, Set, cast | |
from gpt_index.utils import get_new_int_id | |
DEFAULT_CLASS_PREFIX_STUB = "Gpt_Index" | |
def get_default_class_prefix(current_id_set: Set = set()) -> str: | |
"""Get default class prefix.""" | |
return DEFAULT_CLASS_PREFIX_STUB + "_" + str(get_new_int_id(current_id_set)) | |
def validate_client(client: Any) -> None: | |
"""Validate client and import weaviate library.""" | |
try: | |
import weaviate # noqa: F401 | |
from weaviate import Client | |
client = cast(Client, client) | |
except ImportError: | |
raise ImportError( | |
"Weaviate is not installed. " | |
"Please install it with `pip install weaviate-client`." | |
) | |
cast(Client, client) | |
def parse_get_response(response: Dict) -> Dict: | |
"""Parse get response from Weaviate.""" | |
if "errors" in response: | |
raise ValueError("Invalid query, got errors: {}".format(response["errors"])) | |
data_response = response["data"] | |
if "Get" not in data_response: | |
raise ValueError("Invalid query response, must be a Get query.") | |
return data_response["Get"] | |
def get_by_id( | |
client: Any, object_id: str, class_name: str, properties: List[str] | |
) -> Dict: | |
"""Get response by id from Weaviate.""" | |
validate_client(client) | |
where_filter = {"path": ["id"], "operator": "Equal", "valueString": object_id} | |
query_result = ( | |
client.query.get(class_name, properties) | |
.with_where(where_filter) | |
.with_additional(["id", "vector"]) | |
.do() | |
) | |
parsed_result = parse_get_response(query_result) | |
entries = parsed_result[class_name] | |
if len(entries) == 0: | |
raise ValueError("No entry found for the given id") | |
elif len(entries) > 1: | |
raise ValueError("More than one entry found for the given id") | |
return entries[0] | |