Echo / streamlit_app.py
moein99's picture
Initial Echo Space
8f51ef2
#!/usr/bin/env python3
"""Minimal Streamlit interface for the EchoPilot ReAct agent."""
from __future__ import annotations
import tempfile
import uuid
from pathlib import Path
from typing import Dict, List, Tuple
import streamlit as st
from agents import get_intelligent_agent
from config import Config
from utils.video_utils import convert_video_to_h264
PROJECT_ROOT = Path(__file__).resolve().parent
@st.cache_resource(show_spinner=False)
def load_agent():
IntelligentAgent, _ = get_intelligent_agent()
return IntelligentAgent(device=Config.DEVICE)
def _persist_upload(upload) -> Tuple[Path, Path]:
"""Write uploaded file to a temporary directory and return its path."""
suffix = Path(upload.name or "input.mp4").suffix or ".mp4"
temp_dir = Path(tempfile.mkdtemp(prefix="echopilot_"))
video_path = temp_dir / f"input{suffix}"
with open(video_path, "wb") as handle:
handle.write(upload.getbuffer())
return video_path, temp_dir
def _extract_key_metrics(response) -> List[Tuple[str, str]]:
metrics: List[Tuple[str, str]] = []
results = response.execution_result.results or {}
tool_results: Dict[str, Dict] = results.get("tool_results") or {}
measurement = tool_results.get("echo_measurement_prediction")
if isinstance(measurement, dict) and measurement.get("status") == "success":
entries = measurement.get("measurements") or []
if entries:
data = entries[0].get("measurements", {})
def _format_metric(key: str, label: str, precision: int = 1):
info = data.get(key)
if not isinstance(info, dict):
return
value = info.get("value")
unit = info.get("unit", "")
if value is None:
return
try:
value_str = f"{float(value):.{precision}f}"
except (TypeError, ValueError):
value_str = str(value)
unit_str = f" {unit}".strip()
metrics.append((label, f"{value_str}{unit_str}"))
for key, label in [
("ejection_fraction", "Ejection Fraction"),
("EF", "Ejection Fraction"),
]:
if key in data:
_format_metric(key, label, precision=1)
break
if "pulmonary_artery_pressure_continuous" in data:
_format_metric("pulmonary_artery_pressure_continuous", "Pulmonary Artery Pressure", precision=1)
if "dilated_ivc" in data:
_format_metric("dilated_ivc", "IVC Diameter", precision=2)
return metrics
def main() -> None:
st.set_page_config(page_title="EchoPilot Agent", page_icon="🫀", layout="wide")
st.title("EchoPilot · Echocardiography Co-Pilot")
st.caption("Upload a study, ask a focused question, and EchoPilot will run the appropriate tools to answer.")
upload_col, info_col = st.columns([2, 1])
with upload_col:
uploaded_video = st.file_uploader(
"Echo video file",
type=["mp4", "mov", "m4v", "avi", "wmv"],
help="Standard ultrasound formats are supported.",
)
default_question = "Estimate the ejection fraction and note any major abnormalities."
query = st.text_area("Clinical question", value=default_question, height=120)
with info_col:
st.markdown("### How it works")
st.write(
"- EchoPilot uses a ReAct loop to decide which tools to call.\n"
"- It may segment chambers, compute EchoPrime measurements, or run disease classifiers.\n"
"- Results are summarized below; raw tool logs are hidden for clarity."
)
response = None
display_video: Path | None = None
run_clicked = st.button("Run Analysis", type="primary", use_container_width=True, disabled=not uploaded_video or not query.strip())
if run_clicked:
agent = load_agent()
video_path, temp_dir = _persist_upload(uploaded_video)
temp_display_dir = PROJECT_ROOT / "temp"
temp_display_dir.mkdir(parents=True, exist_ok=True)
display_target = temp_display_dir / f"display_{uuid.uuid4().hex}.mp4"
display_video = Path(convert_video_to_h264(str(video_path), str(display_target)))
with st.spinner("EchoPilot is analyzing the study..."):
response = agent.process_query(query.strip(), str(video_path))
# Clean up the original upload to save disk space
if temp_dir.exists():
for item in temp_dir.iterdir():
item.unlink(missing_ok=True)
temp_dir.rmdir()
if response:
st.success("Analysis complete")
metrics = _extract_key_metrics(response)
container = st.container()
video_col, metrics_col = container.columns([2, 1])
if display_video and display_video.exists():
with video_col:
st.video(str(display_video))
if metrics:
with metrics_col:
st.markdown("#### Key Measurements")
for label, value in metrics:
st.metric(label, value)
st.divider()
st.markdown("#### EchoPilot Response")
st.chat_message("user").write(query.strip())
st.chat_message("assistant").write(response.response_text)
if __name__ == "__main__":
main()