File size: 6,420 Bytes
1ccf8ee
c48497c
4fb4636
56dc0d1
1ccf8ee
c48497c
 
 
1ccf8ee
c48497c
 
4fb4636
 
 
 
49be262
4fb4636
 
 
10e0076
 
4fb4636
56dc0d1
4fb4636
c48497c
1ccf8ee
 
 
 
 
 
 
 
 
 
 
b7e02dd
1ccf8ee
 
b7e02dd
1ccf8ee
b7e02dd
1ccf8ee
 
 
 
 
 
 
 
 
b7e02dd
1ccf8ee
 
 
 
 
 
49be262
56dc0d1
1ccf8ee
 
10e0076
1ccf8ee
 
56dc0d1
 
 
 
 
 
1ccf8ee
56dc0d1
 
49be262
1ccf8ee
b7e02dd
49be262
 
1ccf8ee
 
 
 
49be262
1ccf8ee
 
b7e02dd
 
 
 
 
49be262
1ccf8ee
 
 
 
49be262
1ccf8ee
49be262
b7e02dd
10e0076
b7e02dd
49be262
10e0076
56dc0d1
 
 
1ccf8ee
56dc0d1
1ccf8ee
 
56dc0d1
 
1ccf8ee
97814aa
 
4fb4636
b7e02dd
 
 
 
 
 
10e0076
 
 
c48497c
 
4fb4636
 
 
 
 
 
 
 
 
 
c48497c
4fb4636
 
 
c48497c
 
 
1ccf8ee
 
b7e02dd
 
 
 
1ccf8ee
 
 
 
b7e02dd
 
 
 
49be262
1ccf8ee
 
 
 
b7e02dd
1ccf8ee
 
10e0076
 
 
 
 
1ccf8ee
b7e02dd
1ccf8ee
 
 
b7e02dd
 
 
 
 
c48497c
 
1ccf8ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import mne
import numpy as np
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load LLM
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto"
)

def compute_band_power(psd, freqs, fmin, fmax):
    # psd shape: (n_channels, n_freqs)
    # freqs shape: (n_freqs,)
    freq_mask = (freqs >= fmin) & (freqs <= fmax)
    band_psd = psd[:, freq_mask].mean()
    return float(band_psd)

def inspect_file(file):
    """
    Inspect the uploaded file to determine available columns.
    If FIF: Just inform that it's an MNE file and no time column is needed.
    If CSV: Return a list of columns (both numeric and non-numeric).
    """
    if file is None:
        return "No file uploaded.", [], "No preview available."
    file_path = file.name
    _, file_ext = os.path.splitext(file_path)
    file_ext = file_ext.lower()

    if file_ext == ".fif":
        return (
            "FIF file detected. No need for time column selection. The file's sampling frequency will be used.",
            [],
            "FIF file doesn't require column inspection."
        )
    elif file_ext == ".csv":
        try:
            df = pd.read_csv(file_path, nrows=5)
        except Exception as e:
            return f"Error reading CSV: {e}", [], "Could not read CSV preview."
        cols = list(df.columns)
        preview = df.head().to_markdown()
        return (
            "CSV file detected. Select a time column if available, or choose (No time column) and specify a default frequency.",
            cols,
            preview
        )
    else:
        return "Unsupported file format.", [], "No preview available."

def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
    """
    Load EEG data with flexibility.
    If FIF: Use MNE's read_raw_fif directly.
    If CSV:
      - If time_col is given and present in the file, use it.
      - Otherwise, assume default_sfreq.
    """
    _, file_ext = os.path.splitext(file_path)
    file_ext = file_ext.lower()

    if file_ext == '.fif':
        raw = mne.io.read_raw_fif(file_path, preload=True)

    elif file_ext == '.csv':
        df = pd.read_csv(file_path)

        if time_col and time_col in df.columns:
            # Use the selected time column to compute sfreq
            time = df[time_col].values
            data_df = df.drop(columns=[time_col])
            for col in data_df.columns:
                if not pd.api.types.is_numeric_dtype(data_df[col]):
                    data_df = data_df.drop(columns=[col])

            if len(time) < 2:
                sfreq = default_sfreq
            else:
                dt = np.mean(np.diff(time))
                if dt <= 0:
                    sfreq = default_sfreq
                else:
                    sfreq = 1.0 / dt
        else:
            # No time column used, assume default_sfreq
            for col in df.columns:
                if not pd.api.types.is_numeric_dtype(df[col]):
                    df = df.drop(columns=[col])
            data_df = df
            sfreq = default_sfreq

        if sfreq <= 0:
            sfreq = 256.0

        ch_names = list(data_df.columns)
        data = data_df.values.T  # (n_channels, n_samples)
        ch_types = ['eeg'] * len(ch_names)
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
        raw = mne.io.RawArray(data, info)

    else:
        raise ValueError("Unsupported file format. Provide a FIF or CSV file.")

    return raw

def analyze_eeg(file, default_sfreq, time_col):
    if time_col == "(No time column)":
        time_col = None

    fs = float(default_sfreq)
    if fs <= 0:
        fs = 256.0

    raw = load_eeg_data(file.name, default_sfreq=fs, time_col=time_col)

    # Use raw.compute_psd instead of psd_welch
    psd_obj = raw.compute_psd(fmin=1, fmax=40, method='welch')
    psd, freqs = psd_obj.get_data(return_freqs=True)
    alpha_power = compute_band_power(psd, freqs, 8, 12)
    beta_power = compute_band_power(psd, freqs, 13, 30)

    data_summary = (
        f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
        f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
    )

    prompt = f"""You are a neuroscientist analyzing EEG features.
Data Summary: {data_summary}
Provide a concise, user-friendly interpretation of these findings in simple terms.
"""
    inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary

def preview_file(file):
    msg, cols, preview = inspect_file(file)
    # Always include (No time column) as the first choice
    cols = ["(No time column)"] + cols
    default_value = "(No time column)"
    return msg, gr.update(choices=cols, value=default_value), preview

with gr.Blocks() as demo:
    gr.Markdown("# NeuroNarrative-Lite: EEG Summary with Flexible Preprocessing")
    gr.Markdown(
        "Upload an EEG file (FIF or CSV). If it's CSV, click 'Inspect File' to preview columns. "
        "Select a time column if available or '(No time column)' if not. "
        "If no time column is chosen, provide a default sampling frequency. "
        "Then click 'Run Analysis'."
    )

    file_input = gr.File(label="Upload your EEG data (FIF or CSV)")
    preview_button = gr.Button("Inspect File")
    msg_output = gr.Markdown()
    cols_dropdown = gr.Dropdown(label="Select Time Column (optional)", allow_custom_value=True, interactive=True)
    preview_output = gr.Markdown()

    preview_button.click(
        preview_file, 
        inputs=[file_input], 
        outputs=[msg_output, cols_dropdown, preview_output]
    )

    default_sfreq_input = gr.Textbox(label="Default Sampling Frequency (Hz) if no time column", value="100")
    analyze_button = gr.Button("Run Analysis")
    result_output = gr.Textbox(label="Analysis Summary")

    analyze_button.click(
        analyze_eeg, 
        inputs=[file_input, default_sfreq_input, cols_dropdown],
        outputs=[result_output]
    )

if __name__ == "__main__":
    demo.launch()