File size: 1,865 Bytes
35b22df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Weaviate utils."""

from typing import Any, Dict, List, Set, cast

from gpt_index.utils import get_new_int_id

DEFAULT_CLASS_PREFIX_STUB = "Gpt_Index"


def get_default_class_prefix(current_id_set: Set = set()) -> str:
    """Get default class prefix."""
    return DEFAULT_CLASS_PREFIX_STUB + "_" + str(get_new_int_id(current_id_set))


def validate_client(client: Any) -> None:
    """Validate client and import weaviate library."""
    try:
        import weaviate  # noqa: F401
        from weaviate import Client

        client = cast(Client, client)
    except ImportError:
        raise ImportError(
            "Weaviate is not installed. "
            "Please install it with `pip install weaviate-client`."
        )
    cast(Client, client)


def parse_get_response(response: Dict) -> Dict:
    """Parse get response from Weaviate."""
    if "errors" in response:
        raise ValueError("Invalid query, got errors: {}".format(response["errors"]))
    data_response = response["data"]
    if "Get" not in data_response:
        raise ValueError("Invalid query response, must be a Get query.")

    return data_response["Get"]


def get_by_id(
    client: Any, object_id: str, class_name: str, properties: List[str]
) -> Dict:
    """Get response by id from Weaviate."""
    validate_client(client)

    where_filter = {"path": ["id"], "operator": "Equal", "valueString": object_id}
    query_result = (
        client.query.get(class_name, properties)
        .with_where(where_filter)
        .with_additional(["id", "vector"])
        .do()
    )

    parsed_result = parse_get_response(query_result)
    entries = parsed_result[class_name]
    if len(entries) == 0:
        raise ValueError("No entry found for the given id")
    elif len(entries) > 1:
        raise ValueError("More than one entry found for the given id")
    return entries[0]