AbeerTrial's picture
Duplicate from AbeerTrial/SOAPAssist
35b22df
raw
history blame
4.15 kB
"""Response schema."""
from dataclasses import dataclass, field
from typing import Any, Dict, Generator, List, Optional, Union
from dataclasses_json import DataClassJsonMixin
from gpt_index.data_structs.data_structs import Node
from gpt_index.utils import truncate_text
@dataclass
class SourceNode(DataClassJsonMixin):
"""Source node.
User-facing class containing the source text and the corresponding document id.
"""
source_text: str
doc_id: Optional[str]
extra_info: Optional[Dict[str, Any]] = None
node_info: Optional[Dict[str, Any]] = None
# distance score between node and query, if applicable
similarity: Optional[float] = None
@classmethod
def from_node(cls, node: Node, similarity: Optional[float] = None) -> "SourceNode":
"""Create a SourceNode from a Node."""
return cls(
source_text=node.get_text(),
doc_id=node.ref_doc_id,
extra_info=node.extra_info,
node_info=node.node_info,
similarity=similarity,
)
@classmethod
def from_nodes(cls, nodes: List[Node]) -> List["SourceNode"]:
"""Create a list of SourceNodes from a list of Nodes."""
return [cls.from_node(node) for node in nodes]
@dataclass
class Response:
"""Response object.
Returned if streaming=False during the `index.query()` call.
Attributes:
response: The response text.
"""
response: Optional[str]
source_nodes: List[SourceNode] = field(default_factory=list)
extra_info: Optional[Dict[str, Any]] = None
def __str__(self) -> str:
"""Convert to string representation."""
return self.response or "None"
def get_formatted_sources(self, length: int = 100) -> str:
"""Get formatted sources text."""
texts = []
for source_node in self.source_nodes:
fmt_text_chunk = truncate_text(source_node.source_text, length)
doc_id = source_node.doc_id or "None"
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
texts.append(source_text)
return "\n\n".join(texts)
@dataclass
class StreamingResponse:
"""StreamingResponse object.
Returned if streaming=True during the `index.query()` call.
Attributes:
response_gen: The response generator.
"""
response_gen: Optional[Generator]
source_nodes: List[SourceNode] = field(default_factory=list)
extra_info: Optional[Dict[str, Any]] = None
response_txt: Optional[str] = None
def __str__(self) -> str:
"""Convert to string representation."""
if self.response_txt is None and self.response_gen is not None:
response_txt = ""
for text in self.response_gen:
response_txt += text
self.response_txt = response_txt
return self.response_txt or "None"
def get_response(self) -> Response:
"""Get a standard response object."""
if self.response_txt is None and self.response_gen is not None:
response_txt = ""
for text in self.response_gen:
response_txt += text
self.response_txt = response_txt
return Response(self.response_txt, self.source_nodes, self.extra_info)
def print_response_stream(self) -> None:
"""Print the response stream."""
if self.response_txt is None and self.response_gen is not None:
response_txt = ""
for text in self.response_gen:
print(text, end="")
self.response_txt = response_txt
else:
print(self.response_txt)
def get_formatted_sources(self, length: int = 100) -> str:
"""Get formatted sources text."""
texts = []
for source_node in self.source_nodes:
fmt_text_chunk = truncate_text(source_node.source_text, length)
doc_id = source_node.doc_id or "None"
source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}"
texts.append(source_text)
return "\n\n".join(texts)
RESPONSE_TYPE = Union[Response, StreamingResponse]