Spaces:
Running
Running
File size: 6,849 Bytes
30fabb4 c2a2581 30fabb4 c2a2581 30fabb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from dataclasses import dataclass
from typing import List, Optional
import re
import textwrap
from cot_reasoning import VisualizationConfig
@dataclass
class BSNode:
"""Data class representing a node in the Beam Search tree"""
id: str
content: str
score: float
parent_id: Optional[str] = None
children: List['BSNode'] = None
is_best_path: bool = False
path_score: Optional[float] = None
def __post_init__(self):
if self.children is None:
self.children = []
@dataclass
class BSResponse:
"""Data class representing a complete Beam Search response"""
question: str
root: BSNode
answer: Optional[str] = None
best_score: Optional[float] = None
result_nodes: List[BSNode] = None
def __post_init__(self):
if self.result_nodes is None:
self.result_nodes = []
def parse_bs_response(response_text: str, question: str) -> BSResponse:
"""Parse Beam Search response text to extract nodes and build the tree"""
# Parse nodes
node_pattern = r'<node id="([^"]+)"(?:\s+parent="([^"]+)")?\s*score="([^"]+)"(?:\s+path_score="([^"]+)")?\s*>\s*(.*?)\s*</node>'
nodes_dict = {}
result_nodes = []
# First pass: create all nodes
for match in re.finditer(node_pattern, response_text, re.DOTALL):
node_id = match.group(1)
parent_id = match.group(2)
score = float(match.group(3))
path_score = float(match.group(4)) if match.group(4) else None
content = match.group(5).strip()
node = BSNode(
id=node_id,
content=content,
score=score,
parent_id=parent_id,
path_score=path_score
)
nodes_dict[node_id] = node
# Collect result nodes
if node_id.startswith('result'):
result_nodes.append(node)
# Second pass: build tree relationships
root = None
for node in nodes_dict.values():
if node.parent_id is None:
root = node
else:
parent = nodes_dict.get(node.parent_id)
if parent:
parent.children.append(node)
# Parse answer if present
answer_pattern = r'<answer>\s*Best path \(path_score: ([^\)]+)\):\s*(.*?)\s*</answer>'
answer_match = re.search(answer_pattern, response_text, re.DOTALL)
answer = None
best_score = None
if answer_match:
best_score = float(answer_match.group(1))
answer = answer_match.group(2).strip()
# Mark the best path based on path_score
current_path_score = best_score
for node in nodes_dict.values():
if node.path_score and abs(node.path_score - current_path_score) < 1e-6:
# Mark all nodes in the path as best
current = node
while current:
current.is_best_path = True
current = nodes_dict.get(current.parent_id)
return BSResponse(
question=question,
root=root,
answer=answer,
best_score=best_score,
result_nodes=result_nodes
)
def create_mermaid_diagram(bs_response: BSResponse, config: VisualizationConfig) -> str:
"""Convert Beam Search response to Mermaid diagram"""
diagram = ['<div class="mermaid">', 'graph TD']
# Add question node
question_content = wrap_text(bs_response.question, config)
diagram.append(f' Q["{question_content}"]')
def add_node_and_children(node: BSNode, parent_id: Optional[str] = None):
# Format content to include scores
score_info = f"Score: {node.score:.2f}"
if node.path_score:
score_info += f"<br>Path Score: {node.path_score:.2f}"
node_content = f"{wrap_text(node.content, config)}<br>{score_info}"
# Determine node style based on type and path
if node.id.startswith('result'):
node_style = 'result'
if node.is_best_path:
node_style = 'best_result'
else:
node_style = 'intermediate'
if node.is_best_path:
node_style = 'best_intermediate'
# Add node
diagram.append(f' {node.id}["{node_content}"]')
diagram.append(f' class {node.id} {node_style};')
# Add connection from parent
if parent_id:
diagram.append(f' {parent_id} --> {node.id}')
# Process children
for child in node.children:
add_node_and_children(child, node.id)
# Build tree structure
if bs_response.root:
diagram.append(f' Q --> {bs_response.root.id}')
add_node_and_children(bs_response.root)
# Add final answer
if bs_response.answer:
answer_content = wrap_text(
f"Final Answer (Path Score: {bs_response.best_score:.2f}):<br>{bs_response.answer}",
config
)
diagram.append(f' Answer["{answer_content}"]')
# Connect all result nodes to the answer
for result_node in bs_response.result_nodes:
diagram.append(f' {result_node.id} --> Answer')
diagram.append(' class Answer final_answer;')
# Add styles
diagram.extend([
' classDef intermediate fill:#f9f9f9,stroke:#333,stroke-width:2px;',
' classDef best_intermediate fill:#f9f9f9,stroke:#333,stroke-width:2px;',
' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
' classDef result fill:#f3f4f6,stroke:#4b5563,stroke-width:2px;',
' classDef best_result fill:#bfdbfe,stroke:#3b82f6,stroke-width:2px;',
' classDef final_answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
' class Q question;',
' linkStyle default stroke:#666,stroke-width:2px;'
])
diagram.append('</div>')
return '\n'.join(diagram)
def wrap_text(text: str, config: VisualizationConfig) -> str:
"""Wrap text to fit within box constraints"""
text = text.replace('\n', ' ').replace('"', "'")
wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line)
if len(wrapped_lines) > config.max_lines:
# Option 1: Simply truncate and add ellipsis to the last line
wrapped_lines = wrapped_lines[:config.max_lines]
wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..."
# Option 2 (alternative): Include part of the next line to show continuity
# original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else ""
# wrapped_lines = wrapped_lines[:config.max_lines-1]
# wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...")
return "<br>".join(wrapped_lines) |