Spaces:
Runtime error
Runtime error
"""Response schema.""" | |
from dataclasses import dataclass, field | |
from typing import Any, Dict, Generator, List, Optional, Union | |
from dataclasses_json import DataClassJsonMixin | |
from gpt_index.data_structs.data_structs import Node | |
from gpt_index.utils import truncate_text | |
class SourceNode(DataClassJsonMixin): | |
"""Source node. | |
User-facing class containing the source text and the corresponding document id. | |
""" | |
source_text: str | |
doc_id: Optional[str] | |
extra_info: Optional[Dict[str, Any]] = None | |
node_info: Optional[Dict[str, Any]] = None | |
# distance score between node and query, if applicable | |
similarity: Optional[float] = None | |
def from_node(cls, node: Node, similarity: Optional[float] = None) -> "SourceNode": | |
"""Create a SourceNode from a Node.""" | |
return cls( | |
source_text=node.get_text(), | |
doc_id=node.ref_doc_id, | |
extra_info=node.extra_info, | |
node_info=node.node_info, | |
similarity=similarity, | |
) | |
def from_nodes(cls, nodes: List[Node]) -> List["SourceNode"]: | |
"""Create a list of SourceNodes from a list of Nodes.""" | |
return [cls.from_node(node) for node in nodes] | |
class Response: | |
"""Response object. | |
Returned if streaming=False during the `index.query()` call. | |
Attributes: | |
response: The response text. | |
""" | |
response: Optional[str] | |
source_nodes: List[SourceNode] = field(default_factory=list) | |
extra_info: Optional[Dict[str, Any]] = None | |
def __str__(self) -> str: | |
"""Convert to string representation.""" | |
return self.response or "None" | |
def get_formatted_sources(self, length: int = 100) -> str: | |
"""Get formatted sources text.""" | |
texts = [] | |
for source_node in self.source_nodes: | |
fmt_text_chunk = truncate_text(source_node.source_text, length) | |
doc_id = source_node.doc_id or "None" | |
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}" | |
texts.append(source_text) | |
return "\n\n".join(texts) | |
class StreamingResponse: | |
"""StreamingResponse object. | |
Returned if streaming=True during the `index.query()` call. | |
Attributes: | |
response_gen: The response generator. | |
""" | |
response_gen: Optional[Generator] | |
source_nodes: List[SourceNode] = field(default_factory=list) | |
extra_info: Optional[Dict[str, Any]] = None | |
response_txt: Optional[str] = None | |
def __str__(self) -> str: | |
"""Convert to string representation.""" | |
if self.response_txt is None and self.response_gen is not None: | |
response_txt = "" | |
for text in self.response_gen: | |
response_txt += text | |
self.response_txt = response_txt | |
return self.response_txt or "None" | |
def get_response(self) -> Response: | |
"""Get a standard response object.""" | |
if self.response_txt is None and self.response_gen is not None: | |
response_txt = "" | |
for text in self.response_gen: | |
response_txt += text | |
self.response_txt = response_txt | |
return Response(self.response_txt, self.source_nodes, self.extra_info) | |
def print_response_stream(self) -> None: | |
"""Print the response stream.""" | |
if self.response_txt is None and self.response_gen is not None: | |
response_txt = "" | |
for text in self.response_gen: | |
print(text, end="") | |
self.response_txt = response_txt | |
else: | |
print(self.response_txt) | |
def get_formatted_sources(self, length: int = 100) -> str: | |
"""Get formatted sources text.""" | |
texts = [] | |
for source_node in self.source_nodes: | |
fmt_text_chunk = truncate_text(source_node.source_text, length) | |
doc_id = source_node.doc_id or "None" | |
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}" | |
texts.append(source_text) | |
return "\n\n".join(texts) | |
RESPONSE_TYPE = Union[Response, StreamingResponse] | |