Rogerjs's picture
Update app.py
10e0076 verified
raw
history blame
6.42 kB
import os
import mne
import numpy as np
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load LLM
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
def compute_band_power(psd, freqs, fmin, fmax):
# psd shape: (n_channels, n_freqs)
# freqs shape: (n_freqs,)
freq_mask = (freqs >= fmin) & (freqs <= fmax)
band_psd = psd[:, freq_mask].mean()
return float(band_psd)
def inspect_file(file):
"""
Inspect the uploaded file to determine available columns.
If FIF: Just inform that it's an MNE file and no time column is needed.
If CSV: Return a list of columns (both numeric and non-numeric).
"""
if file is None:
return "No file uploaded.", [], "No preview available."
file_path = file.name
_, file_ext = os.path.splitext(file_path)
file_ext = file_ext.lower()
if file_ext == ".fif":
return (
"FIF file detected. No need for time column selection. The file's sampling frequency will be used.",
[],
"FIF file doesn't require column inspection."
)
elif file_ext == ".csv":
try:
df = pd.read_csv(file_path, nrows=5)
except Exception as e:
return f"Error reading CSV: {e}", [], "Could not read CSV preview."
cols = list(df.columns)
preview = df.head().to_markdown()
return (
"CSV file detected. Select a time column if available, or choose (No time column) and specify a default frequency.",
cols,
preview
)
else:
return "Unsupported file format.", [], "No preview available."
def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
"""
Load EEG data with flexibility.
If FIF: Use MNE's read_raw_fif directly.
If CSV:
- If time_col is given and present in the file, use it.
- Otherwise, assume default_sfreq.
"""
_, file_ext = os.path.splitext(file_path)
file_ext = file_ext.lower()
if file_ext == '.fif':
raw = mne.io.read_raw_fif(file_path, preload=True)
elif file_ext == '.csv':
df = pd.read_csv(file_path)
if time_col and time_col in df.columns:
# Use the selected time column to compute sfreq
time = df[time_col].values
data_df = df.drop(columns=[time_col])
for col in data_df.columns:
if not pd.api.types.is_numeric_dtype(data_df[col]):
data_df = data_df.drop(columns=[col])
if len(time) < 2:
sfreq = default_sfreq
else:
dt = np.mean(np.diff(time))
if dt <= 0:
sfreq = default_sfreq
else:
sfreq = 1.0 / dt
else:
# No time column used, assume default_sfreq
for col in df.columns:
if not pd.api.types.is_numeric_dtype(df[col]):
df = df.drop(columns=[col])
data_df = df
sfreq = default_sfreq
if sfreq <= 0:
sfreq = 256.0
ch_names = list(data_df.columns)
data = data_df.values.T # (n_channels, n_samples)
ch_types = ['eeg'] * len(ch_names)
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
raw = mne.io.RawArray(data, info)
else:
raise ValueError("Unsupported file format. Provide a FIF or CSV file.")
return raw
def analyze_eeg(file, default_sfreq, time_col):
if time_col == "(No time column)":
time_col = None
fs = float(default_sfreq)
if fs <= 0:
fs = 256.0
raw = load_eeg_data(file.name, default_sfreq=fs, time_col=time_col)
# Use raw.compute_psd instead of psd_welch
psd_obj = raw.compute_psd(fmin=1, fmax=40, method='welch')
psd, freqs = psd_obj.get_data(return_freqs=True)
alpha_power = compute_band_power(psd, freqs, 8, 12)
beta_power = compute_band_power(psd, freqs, 13, 30)
data_summary = (
f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
)
prompt = f"""You are a neuroscientist analyzing EEG features.
Data Summary: {data_summary}
Provide a concise, user-friendly interpretation of these findings in simple terms.
"""
inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
def preview_file(file):
msg, cols, preview = inspect_file(file)
# Always include (No time column) as the first choice
cols = ["(No time column)"] + cols
default_value = "(No time column)"
return msg, gr.update(choices=cols, value=default_value), preview
with gr.Blocks() as demo:
gr.Markdown("# NeuroNarrative-Lite: EEG Summary with Flexible Preprocessing")
gr.Markdown(
"Upload an EEG file (FIF or CSV). If it's CSV, click 'Inspect File' to preview columns. "
"Select a time column if available or '(No time column)' if not. "
"If no time column is chosen, provide a default sampling frequency. "
"Then click 'Run Analysis'."
)
file_input = gr.File(label="Upload your EEG data (FIF or CSV)")
preview_button = gr.Button("Inspect File")
msg_output = gr.Markdown()
cols_dropdown = gr.Dropdown(label="Select Time Column (optional)", allow_custom_value=True, interactive=True)
preview_output = gr.Markdown()
preview_button.click(
preview_file,
inputs=[file_input],
outputs=[msg_output, cols_dropdown, preview_output]
)
default_sfreq_input = gr.Textbox(label="Default Sampling Frequency (Hz) if no time column", value="100")
analyze_button = gr.Button("Run Analysis")
result_output = gr.Textbox(label="Analysis Summary")
analyze_button.click(
analyze_eeg,
inputs=[file_input, default_sfreq_input, cols_dropdown],
outputs=[result_output]
)
if __name__ == "__main__":
demo.launch()