Spaces:
Runtime error
Runtime error
File size: 4,145 Bytes
35b22df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
"""Response schema."""
from dataclasses import dataclass, field
from typing import Any, Dict, Generator, List, Optional, Union
from dataclasses_json import DataClassJsonMixin
from gpt_index.data_structs.data_structs import Node
from gpt_index.utils import truncate_text
@dataclass
class SourceNode(DataClassJsonMixin):
"""Source node.
User-facing class containing the source text and the corresponding document id.
"""
source_text: str
doc_id: Optional[str]
extra_info: Optional[Dict[str, Any]] = None
node_info: Optional[Dict[str, Any]] = None
# distance score between node and query, if applicable
similarity: Optional[float] = None
@classmethod
def from_node(cls, node: Node, similarity: Optional[float] = None) -> "SourceNode":
"""Create a SourceNode from a Node."""
return cls(
source_text=node.get_text(),
doc_id=node.ref_doc_id,
extra_info=node.extra_info,
node_info=node.node_info,
similarity=similarity,
)
@classmethod
def from_nodes(cls, nodes: List[Node]) -> List["SourceNode"]:
"""Create a list of SourceNodes from a list of Nodes."""
return [cls.from_node(node) for node in nodes]
@dataclass
class Response:
"""Response object.
Returned if streaming=False during the `index.query()` call.
Attributes:
response: The response text.
"""
response: Optional[str]
source_nodes: List[SourceNode] = field(default_factory=list)
extra_info: Optional[Dict[str, Any]] = None
def __str__(self) -> str:
"""Convert to string representation."""
return self.response or "None"
def get_formatted_sources(self, length: int = 100) -> str:
"""Get formatted sources text."""
texts = []
for source_node in self.source_nodes:
fmt_text_chunk = truncate_text(source_node.source_text, length)
doc_id = source_node.doc_id or "None"
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
texts.append(source_text)
return "\n\n".join(texts)
@dataclass
class StreamingResponse:
"""StreamingResponse object.
Returned if streaming=True during the `index.query()` call.
Attributes:
response_gen: The response generator.
"""
response_gen: Optional[Generator]
source_nodes: List[SourceNode] = field(default_factory=list)
extra_info: Optional[Dict[str, Any]] = None
response_txt: Optional[str] = None
def __str__(self) -> str:
"""Convert to string representation."""
if self.response_txt is None and self.response_gen is not None:
response_txt = ""
for text in self.response_gen:
response_txt += text
self.response_txt = response_txt
return self.response_txt or "None"
def get_response(self) -> Response:
"""Get a standard response object."""
if self.response_txt is None and self.response_gen is not None:
response_txt = ""
for text in self.response_gen:
response_txt += text
self.response_txt = response_txt
return Response(self.response_txt, self.source_nodes, self.extra_info)
def print_response_stream(self) -> None:
"""Print the response stream."""
if self.response_txt is None and self.response_gen is not None:
response_txt = ""
for text in self.response_gen:
print(text, end="")
self.response_txt = response_txt
else:
print(self.response_txt)
def get_formatted_sources(self, length: int = 100) -> str:
"""Get formatted sources text."""
texts = []
for source_node in self.source_nodes:
fmt_text_chunk = truncate_text(source_node.source_text, length)
doc_id = source_node.doc_id or "None"
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
texts.append(source_text)
return "\n\n".join(texts)
RESPONSE_TYPE = Union[Response, StreamingResponse]
|