|
|
|
|
|
"""Minimal Streamlit interface for the EchoPilot ReAct agent.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import tempfile |
|
|
import uuid |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import streamlit as st |
|
|
|
|
|
from agents import get_intelligent_agent |
|
|
from config import Config |
|
|
from utils.video_utils import convert_video_to_h264 |
|
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def load_agent(): |
|
|
IntelligentAgent, _ = get_intelligent_agent() |
|
|
return IntelligentAgent(device=Config.DEVICE) |
|
|
|
|
|
|
|
|
def _persist_upload(upload) -> Tuple[Path, Path]: |
|
|
"""Write uploaded file to a temporary directory and return its path.""" |
|
|
suffix = Path(upload.name or "input.mp4").suffix or ".mp4" |
|
|
temp_dir = Path(tempfile.mkdtemp(prefix="echopilot_")) |
|
|
video_path = temp_dir / f"input{suffix}" |
|
|
with open(video_path, "wb") as handle: |
|
|
handle.write(upload.getbuffer()) |
|
|
return video_path, temp_dir |
|
|
|
|
|
|
|
|
def _extract_key_metrics(response) -> List[Tuple[str, str]]: |
|
|
metrics: List[Tuple[str, str]] = [] |
|
|
results = response.execution_result.results or {} |
|
|
tool_results: Dict[str, Dict] = results.get("tool_results") or {} |
|
|
measurement = tool_results.get("echo_measurement_prediction") |
|
|
if isinstance(measurement, dict) and measurement.get("status") == "success": |
|
|
entries = measurement.get("measurements") or [] |
|
|
if entries: |
|
|
data = entries[0].get("measurements", {}) |
|
|
|
|
|
def _format_metric(key: str, label: str, precision: int = 1): |
|
|
info = data.get(key) |
|
|
if not isinstance(info, dict): |
|
|
return |
|
|
value = info.get("value") |
|
|
unit = info.get("unit", "") |
|
|
if value is None: |
|
|
return |
|
|
try: |
|
|
value_str = f"{float(value):.{precision}f}" |
|
|
except (TypeError, ValueError): |
|
|
value_str = str(value) |
|
|
unit_str = f" {unit}".strip() |
|
|
metrics.append((label, f"{value_str}{unit_str}")) |
|
|
|
|
|
for key, label in [ |
|
|
("ejection_fraction", "Ejection Fraction"), |
|
|
("EF", "Ejection Fraction"), |
|
|
]: |
|
|
if key in data: |
|
|
_format_metric(key, label, precision=1) |
|
|
break |
|
|
|
|
|
if "pulmonary_artery_pressure_continuous" in data: |
|
|
_format_metric("pulmonary_artery_pressure_continuous", "Pulmonary Artery Pressure", precision=1) |
|
|
if "dilated_ivc" in data: |
|
|
_format_metric("dilated_ivc", "IVC Diameter", precision=2) |
|
|
return metrics |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
st.set_page_config(page_title="EchoPilot Agent", page_icon="🫀", layout="wide") |
|
|
st.title("EchoPilot · Echocardiography Co-Pilot") |
|
|
st.caption("Upload a study, ask a focused question, and EchoPilot will run the appropriate tools to answer.") |
|
|
|
|
|
upload_col, info_col = st.columns([2, 1]) |
|
|
with upload_col: |
|
|
uploaded_video = st.file_uploader( |
|
|
"Echo video file", |
|
|
type=["mp4", "mov", "m4v", "avi", "wmv"], |
|
|
help="Standard ultrasound formats are supported.", |
|
|
) |
|
|
default_question = "Estimate the ejection fraction and note any major abnormalities." |
|
|
query = st.text_area("Clinical question", value=default_question, height=120) |
|
|
with info_col: |
|
|
st.markdown("### How it works") |
|
|
st.write( |
|
|
"- EchoPilot uses a ReAct loop to decide which tools to call.\n" |
|
|
"- It may segment chambers, compute EchoPrime measurements, or run disease classifiers.\n" |
|
|
"- Results are summarized below; raw tool logs are hidden for clarity." |
|
|
) |
|
|
|
|
|
response = None |
|
|
display_video: Path | None = None |
|
|
|
|
|
run_clicked = st.button("Run Analysis", type="primary", use_container_width=True, disabled=not uploaded_video or not query.strip()) |
|
|
if run_clicked: |
|
|
agent = load_agent() |
|
|
video_path, temp_dir = _persist_upload(uploaded_video) |
|
|
temp_display_dir = PROJECT_ROOT / "temp" |
|
|
temp_display_dir.mkdir(parents=True, exist_ok=True) |
|
|
display_target = temp_display_dir / f"display_{uuid.uuid4().hex}.mp4" |
|
|
display_video = Path(convert_video_to_h264(str(video_path), str(display_target))) |
|
|
|
|
|
with st.spinner("EchoPilot is analyzing the study..."): |
|
|
response = agent.process_query(query.strip(), str(video_path)) |
|
|
|
|
|
|
|
|
if temp_dir.exists(): |
|
|
for item in temp_dir.iterdir(): |
|
|
item.unlink(missing_ok=True) |
|
|
temp_dir.rmdir() |
|
|
|
|
|
if response: |
|
|
st.success("Analysis complete") |
|
|
metrics = _extract_key_metrics(response) |
|
|
|
|
|
container = st.container() |
|
|
video_col, metrics_col = container.columns([2, 1]) |
|
|
if display_video and display_video.exists(): |
|
|
with video_col: |
|
|
st.video(str(display_video)) |
|
|
if metrics: |
|
|
with metrics_col: |
|
|
st.markdown("#### Key Measurements") |
|
|
for label, value in metrics: |
|
|
st.metric(label, value) |
|
|
|
|
|
st.divider() |
|
|
st.markdown("#### EchoPilot Response") |
|
|
st.chat_message("user").write(query.strip()) |
|
|
st.chat_message("assistant").write(response.response_text) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|