#!/usr/bin/env python3 """Minimal Streamlit interface for the EchoPilot ReAct agent.""" from __future__ import annotations import tempfile import uuid from pathlib import Path from typing import Dict, List, Tuple import streamlit as st from agents import get_intelligent_agent from config import Config from utils.video_utils import convert_video_to_h264 PROJECT_ROOT = Path(__file__).resolve().parent @st.cache_resource(show_spinner=False) def load_agent(): IntelligentAgent, _ = get_intelligent_agent() return IntelligentAgent(device=Config.DEVICE) def _persist_upload(upload) -> Tuple[Path, Path]: """Write uploaded file to a temporary directory and return its path.""" suffix = Path(upload.name or "input.mp4").suffix or ".mp4" temp_dir = Path(tempfile.mkdtemp(prefix="echopilot_")) video_path = temp_dir / f"input{suffix}" with open(video_path, "wb") as handle: handle.write(upload.getbuffer()) return video_path, temp_dir def _extract_key_metrics(response) -> List[Tuple[str, str]]: metrics: List[Tuple[str, str]] = [] results = response.execution_result.results or {} tool_results: Dict[str, Dict] = results.get("tool_results") or {} measurement = tool_results.get("echo_measurement_prediction") if isinstance(measurement, dict) and measurement.get("status") == "success": entries = measurement.get("measurements") or [] if entries: data = entries[0].get("measurements", {}) def _format_metric(key: str, label: str, precision: int = 1): info = data.get(key) if not isinstance(info, dict): return value = info.get("value") unit = info.get("unit", "") if value is None: return try: value_str = f"{float(value):.{precision}f}" except (TypeError, ValueError): value_str = str(value) unit_str = f" {unit}".strip() metrics.append((label, f"{value_str}{unit_str}")) for key, label in [ ("ejection_fraction", "Ejection Fraction"), ("EF", "Ejection Fraction"), ]: if key in data: _format_metric(key, label, precision=1) break if "pulmonary_artery_pressure_continuous" in data: _format_metric("pulmonary_artery_pressure_continuous", "Pulmonary Artery Pressure", precision=1) if "dilated_ivc" in data: _format_metric("dilated_ivc", "IVC Diameter", precision=2) return metrics def main() -> None: st.set_page_config(page_title="EchoPilot Agent", page_icon="🫀", layout="wide") st.title("EchoPilot · Echocardiography Co-Pilot") st.caption("Upload a study, ask a focused question, and EchoPilot will run the appropriate tools to answer.") upload_col, info_col = st.columns([2, 1]) with upload_col: uploaded_video = st.file_uploader( "Echo video file", type=["mp4", "mov", "m4v", "avi", "wmv"], help="Standard ultrasound formats are supported.", ) default_question = "Estimate the ejection fraction and note any major abnormalities." query = st.text_area("Clinical question", value=default_question, height=120) with info_col: st.markdown("### How it works") st.write( "- EchoPilot uses a ReAct loop to decide which tools to call.\n" "- It may segment chambers, compute EchoPrime measurements, or run disease classifiers.\n" "- Results are summarized below; raw tool logs are hidden for clarity." ) response = None display_video: Path | None = None run_clicked = st.button("Run Analysis", type="primary", use_container_width=True, disabled=not uploaded_video or not query.strip()) if run_clicked: agent = load_agent() video_path, temp_dir = _persist_upload(uploaded_video) temp_display_dir = PROJECT_ROOT / "temp" temp_display_dir.mkdir(parents=True, exist_ok=True) display_target = temp_display_dir / f"display_{uuid.uuid4().hex}.mp4" display_video = Path(convert_video_to_h264(str(video_path), str(display_target))) with st.spinner("EchoPilot is analyzing the study..."): response = agent.process_query(query.strip(), str(video_path)) # Clean up the original upload to save disk space if temp_dir.exists(): for item in temp_dir.iterdir(): item.unlink(missing_ok=True) temp_dir.rmdir() if response: st.success("Analysis complete") metrics = _extract_key_metrics(response) container = st.container() video_col, metrics_col = container.columns([2, 1]) if display_video and display_video.exists(): with video_col: st.video(str(display_video)) if metrics: with metrics_col: st.markdown("#### Key Measurements") for label, value in metrics: st.metric(label, value) st.divider() st.markdown("#### EchoPilot Response") st.chat_message("user").write(query.strip()) st.chat_message("assistant").write(response.response_text) if __name__ == "__main__": main()