File size: 6,849 Bytes
30fabb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2a2581
30fabb4
 
c2a2581
 
 
 
 
 
30fabb4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from dataclasses import dataclass
from typing import List, Optional
import re
import textwrap
from cot_reasoning import VisualizationConfig

@dataclass
class BSNode:
    """Data class representing a node in the Beam Search tree"""
    id: str
    content: str
    score: float
    parent_id: Optional[str] = None
    children: List['BSNode'] = None
    is_best_path: bool = False
    path_score: Optional[float] = None

    def __post_init__(self):
        if self.children is None:
            self.children = []

@dataclass
class BSResponse:
    """Data class representing a complete Beam Search response"""
    question: str
    root: BSNode
    answer: Optional[str] = None
    best_score: Optional[float] = None
    result_nodes: List[BSNode] = None

    def __post_init__(self):
        if self.result_nodes is None:
            self.result_nodes = []

def parse_bs_response(response_text: str, question: str) -> BSResponse:
    """Parse Beam Search response text to extract nodes and build the tree"""
    # Parse nodes
    node_pattern = r'<node id="([^"]+)"(?:\s+parent="([^"]+)")?\s*score="([^"]+)"(?:\s+path_score="([^"]+)")?\s*>\s*(.*?)\s*</node>'
    nodes_dict = {}
    result_nodes = []
    
    # First pass: create all nodes
    for match in re.finditer(node_pattern, response_text, re.DOTALL):
        node_id = match.group(1)
        parent_id = match.group(2)
        score = float(match.group(3))
        path_score = float(match.group(4)) if match.group(4) else None
        content = match.group(5).strip()
        
        node = BSNode(
            id=node_id, 
            content=content, 
            score=score, 
            parent_id=parent_id,
            path_score=path_score
        )
        nodes_dict[node_id] = node
        
        # Collect result nodes
        if node_id.startswith('result'):
            result_nodes.append(node)

    # Second pass: build tree relationships
    root = None
    for node in nodes_dict.values():
        if node.parent_id is None:
            root = node
        else:
            parent = nodes_dict.get(node.parent_id)
            if parent:
                parent.children.append(node)

    # Parse answer if present
    answer_pattern = r'<answer>\s*Best path \(path_score: ([^\)]+)\):\s*(.*?)\s*</answer>'
    answer_match = re.search(answer_pattern, response_text, re.DOTALL)
    answer = None
    best_score = None
    
    if answer_match:
        best_score = float(answer_match.group(1))
        answer = answer_match.group(2).strip()
        
        # Mark the best path based on path_score
        current_path_score = best_score
        for node in nodes_dict.values():
            if node.path_score and abs(node.path_score - current_path_score) < 1e-6:
                # Mark all nodes in the path as best
                current = node
                while current:
                    current.is_best_path = True
                    current = nodes_dict.get(current.parent_id)

    return BSResponse(
        question=question, 
        root=root, 
        answer=answer, 
        best_score=best_score,
        result_nodes=result_nodes
    )

def create_mermaid_diagram(bs_response: BSResponse, config: VisualizationConfig) -> str:
    """Convert Beam Search response to Mermaid diagram"""
    diagram = ['<div class="mermaid">', 'graph TD']
    
    # Add question node
    question_content = wrap_text(bs_response.question, config)
    diagram.append(f'    Q["{question_content}"]')
    
    def add_node_and_children(node: BSNode, parent_id: Optional[str] = None):
        # Format content to include scores
        score_info = f"Score: {node.score:.2f}"
        if node.path_score:
            score_info += f"<br>Path Score: {node.path_score:.2f}"
        node_content = f"{wrap_text(node.content, config)}<br>{score_info}"
        
        # Determine node style based on type and path
        if node.id.startswith('result'):
            node_style = 'result'
            if node.is_best_path:
                node_style = 'best_result'
        else:
            node_style = 'intermediate'
            if node.is_best_path:
                node_style = 'best_intermediate'
        
        # Add node
        diagram.append(f'    {node.id}["{node_content}"]')
        diagram.append(f'    class {node.id} {node_style};')
        
        # Add connection from parent
        if parent_id:
            diagram.append(f'    {parent_id} --> {node.id}')
        
        # Process children
        for child in node.children:
            add_node_and_children(child, node.id)
    
    # Build tree structure
    if bs_response.root:
        diagram.append(f'    Q --> {bs_response.root.id}')
        add_node_and_children(bs_response.root)
    
    # Add final answer
    if bs_response.answer:
        answer_content = wrap_text(
            f"Final Answer (Path Score: {bs_response.best_score:.2f}):<br>{bs_response.answer}",
            config
        )
        diagram.append(f'    Answer["{answer_content}"]')
        
        # Connect all result nodes to the answer
        for result_node in bs_response.result_nodes:
            diagram.append(f'    {result_node.id} --> Answer')
        
        diagram.append('    class Answer final_answer;')
    
    # Add styles
    diagram.extend([
        '    classDef intermediate fill:#f9f9f9,stroke:#333,stroke-width:2px;',
        '    classDef best_intermediate fill:#f9f9f9,stroke:#333,stroke-width:2px;',
        '    classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
        '    classDef result fill:#f3f4f6,stroke:#4b5563,stroke-width:2px;',
        '    classDef best_result fill:#bfdbfe,stroke:#3b82f6,stroke-width:2px;',
        '    classDef final_answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
        '    class Q question;',
        '    linkStyle default stroke:#666,stroke-width:2px;'
    ])
    
    diagram.append('</div>')
    return '\n'.join(diagram)

def wrap_text(text: str, config: VisualizationConfig) -> str:
    """Wrap text to fit within box constraints"""
    text = text.replace('\n', ' ').replace('"', "'")
    wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line)
    
    if len(wrapped_lines) > config.max_lines:
        # Option 1: Simply truncate and add ellipsis to the last line
        wrapped_lines = wrapped_lines[:config.max_lines]
        wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..."
        
        # Option 2 (alternative): Include part of the next line to show continuity
        # original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else ""
        # wrapped_lines = wrapped_lines[:config.max_lines-1]
        # wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...")
    
    return "<br>".join(wrapped_lines)