File size: 5,440 Bytes
8f51ef2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
#!/usr/bin/env python3
"""Minimal Streamlit interface for the EchoPilot ReAct agent."""
from __future__ import annotations
import tempfile
import uuid
from pathlib import Path
from typing import Dict, List, Tuple
import streamlit as st
from agents import get_intelligent_agent
from config import Config
from utils.video_utils import convert_video_to_h264
PROJECT_ROOT = Path(__file__).resolve().parent
@st.cache_resource(show_spinner=False)
def load_agent():
IntelligentAgent, _ = get_intelligent_agent()
return IntelligentAgent(device=Config.DEVICE)
def _persist_upload(upload) -> Tuple[Path, Path]:
"""Write uploaded file to a temporary directory and return its path."""
suffix = Path(upload.name or "input.mp4").suffix or ".mp4"
temp_dir = Path(tempfile.mkdtemp(prefix="echopilot_"))
video_path = temp_dir / f"input{suffix}"
with open(video_path, "wb") as handle:
handle.write(upload.getbuffer())
return video_path, temp_dir
def _extract_key_metrics(response) -> List[Tuple[str, str]]:
metrics: List[Tuple[str, str]] = []
results = response.execution_result.results or {}
tool_results: Dict[str, Dict] = results.get("tool_results") or {}
measurement = tool_results.get("echo_measurement_prediction")
if isinstance(measurement, dict) and measurement.get("status") == "success":
entries = measurement.get("measurements") or []
if entries:
data = entries[0].get("measurements", {})
def _format_metric(key: str, label: str, precision: int = 1):
info = data.get(key)
if not isinstance(info, dict):
return
value = info.get("value")
unit = info.get("unit", "")
if value is None:
return
try:
value_str = f"{float(value):.{precision}f}"
except (TypeError, ValueError):
value_str = str(value)
unit_str = f" {unit}".strip()
metrics.append((label, f"{value_str}{unit_str}"))
for key, label in [
("ejection_fraction", "Ejection Fraction"),
("EF", "Ejection Fraction"),
]:
if key in data:
_format_metric(key, label, precision=1)
break
if "pulmonary_artery_pressure_continuous" in data:
_format_metric("pulmonary_artery_pressure_continuous", "Pulmonary Artery Pressure", precision=1)
if "dilated_ivc" in data:
_format_metric("dilated_ivc", "IVC Diameter", precision=2)
return metrics
def main() -> None:
st.set_page_config(page_title="EchoPilot Agent", page_icon="🫀", layout="wide")
st.title("EchoPilot · Echocardiography Co-Pilot")
st.caption("Upload a study, ask a focused question, and EchoPilot will run the appropriate tools to answer.")
upload_col, info_col = st.columns([2, 1])
with upload_col:
uploaded_video = st.file_uploader(
"Echo video file",
type=["mp4", "mov", "m4v", "avi", "wmv"],
help="Standard ultrasound formats are supported.",
)
default_question = "Estimate the ejection fraction and note any major abnormalities."
query = st.text_area("Clinical question", value=default_question, height=120)
with info_col:
st.markdown("### How it works")
st.write(
"- EchoPilot uses a ReAct loop to decide which tools to call.\n"
"- It may segment chambers, compute EchoPrime measurements, or run disease classifiers.\n"
"- Results are summarized below; raw tool logs are hidden for clarity."
)
response = None
display_video: Path | None = None
run_clicked = st.button("Run Analysis", type="primary", use_container_width=True, disabled=not uploaded_video or not query.strip())
if run_clicked:
agent = load_agent()
video_path, temp_dir = _persist_upload(uploaded_video)
temp_display_dir = PROJECT_ROOT / "temp"
temp_display_dir.mkdir(parents=True, exist_ok=True)
display_target = temp_display_dir / f"display_{uuid.uuid4().hex}.mp4"
display_video = Path(convert_video_to_h264(str(video_path), str(display_target)))
with st.spinner("EchoPilot is analyzing the study..."):
response = agent.process_query(query.strip(), str(video_path))
# Clean up the original upload to save disk space
if temp_dir.exists():
for item in temp_dir.iterdir():
item.unlink(missing_ok=True)
temp_dir.rmdir()
if response:
st.success("Analysis complete")
metrics = _extract_key_metrics(response)
container = st.container()
video_col, metrics_col = container.columns([2, 1])
if display_video and display_video.exists():
with video_col:
st.video(str(display_video))
if metrics:
with metrics_col:
st.markdown("#### Key Measurements")
for label, value in metrics:
st.metric(label, value)
st.divider()
st.markdown("#### EchoPilot Response")
st.chat_message("user").write(query.strip())
st.chat_message("assistant").write(response.response_text)
if __name__ == "__main__":
main()
|