File size: 5,440 Bytes
8f51ef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python3
"""Minimal Streamlit interface for the EchoPilot ReAct agent."""

from __future__ import annotations

import tempfile
import uuid
from pathlib import Path
from typing import Dict, List, Tuple

import streamlit as st

from agents import get_intelligent_agent
from config import Config
from utils.video_utils import convert_video_to_h264

PROJECT_ROOT = Path(__file__).resolve().parent


@st.cache_resource(show_spinner=False)
def load_agent():
    IntelligentAgent, _ = get_intelligent_agent()
    return IntelligentAgent(device=Config.DEVICE)


def _persist_upload(upload) -> Tuple[Path, Path]:
    """Write uploaded file to a temporary directory and return its path."""
    suffix = Path(upload.name or "input.mp4").suffix or ".mp4"
    temp_dir = Path(tempfile.mkdtemp(prefix="echopilot_"))
    video_path = temp_dir / f"input{suffix}"
    with open(video_path, "wb") as handle:
        handle.write(upload.getbuffer())
    return video_path, temp_dir


def _extract_key_metrics(response) -> List[Tuple[str, str]]:
    metrics: List[Tuple[str, str]] = []
    results = response.execution_result.results or {}
    tool_results: Dict[str, Dict] = results.get("tool_results") or {}
    measurement = tool_results.get("echo_measurement_prediction")
    if isinstance(measurement, dict) and measurement.get("status") == "success":
        entries = measurement.get("measurements") or []
        if entries:
            data = entries[0].get("measurements", {})

            def _format_metric(key: str, label: str, precision: int = 1):
                info = data.get(key)
                if not isinstance(info, dict):
                    return
                value = info.get("value")
                unit = info.get("unit", "")
                if value is None:
                    return
                try:
                    value_str = f"{float(value):.{precision}f}"
                except (TypeError, ValueError):
                    value_str = str(value)
                unit_str = f" {unit}".strip()
                metrics.append((label, f"{value_str}{unit_str}"))

            for key, label in [
                ("ejection_fraction", "Ejection Fraction"),
                ("EF", "Ejection Fraction"),
            ]:
                if key in data:
                    _format_metric(key, label, precision=1)
                    break

            if "pulmonary_artery_pressure_continuous" in data:
                _format_metric("pulmonary_artery_pressure_continuous", "Pulmonary Artery Pressure", precision=1)
            if "dilated_ivc" in data:
                _format_metric("dilated_ivc", "IVC Diameter", precision=2)
    return metrics


def main() -> None:
    st.set_page_config(page_title="EchoPilot Agent", page_icon="🫀", layout="wide")
    st.title("EchoPilot · Echocardiography Co-Pilot")
    st.caption("Upload a study, ask a focused question, and EchoPilot will run the appropriate tools to answer.")

    upload_col, info_col = st.columns([2, 1])
    with upload_col:
        uploaded_video = st.file_uploader(
            "Echo video file",
            type=["mp4", "mov", "m4v", "avi", "wmv"],
            help="Standard ultrasound formats are supported.",
        )
        default_question = "Estimate the ejection fraction and note any major abnormalities."
        query = st.text_area("Clinical question", value=default_question, height=120)
    with info_col:
        st.markdown("### How it works")
        st.write(
            "- EchoPilot uses a ReAct loop to decide which tools to call.\n"
            "- It may segment chambers, compute EchoPrime measurements, or run disease classifiers.\n"
            "- Results are summarized below; raw tool logs are hidden for clarity."
        )

    response = None
    display_video: Path | None = None

    run_clicked = st.button("Run Analysis", type="primary", use_container_width=True, disabled=not uploaded_video or not query.strip())
    if run_clicked:
        agent = load_agent()
        video_path, temp_dir = _persist_upload(uploaded_video)
        temp_display_dir = PROJECT_ROOT / "temp"
        temp_display_dir.mkdir(parents=True, exist_ok=True)
        display_target = temp_display_dir / f"display_{uuid.uuid4().hex}.mp4"
        display_video = Path(convert_video_to_h264(str(video_path), str(display_target)))

        with st.spinner("EchoPilot is analyzing the study..."):
            response = agent.process_query(query.strip(), str(video_path))

        # Clean up the original upload to save disk space
        if temp_dir.exists():
            for item in temp_dir.iterdir():
                item.unlink(missing_ok=True)
            temp_dir.rmdir()

    if response:
        st.success("Analysis complete")
        metrics = _extract_key_metrics(response)

        container = st.container()
        video_col, metrics_col = container.columns([2, 1])
        if display_video and display_video.exists():
            with video_col:
                st.video(str(display_video))
        if metrics:
            with metrics_col:
                st.markdown("#### Key Measurements")
                for label, value in metrics:
                    st.metric(label, value)

        st.divider()
        st.markdown("#### EchoPilot Response")
        st.chat_message("user").write(query.strip())
        st.chat_message("assistant").write(response.response_text)


if __name__ == "__main__":
    main()