Rogerjs commited on
Commit
10e0076
·
verified ·
1 Parent(s): b7e02dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -21
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import pandas as pd
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
- from mne.time_frequency import psd_welch
8
  import torch
9
 
10
  # Load LLM
@@ -18,6 +17,8 @@ model = AutoModelForCausalLM.from_pretrained(
18
  )
19
 
20
  def compute_band_power(psd, freqs, fmin, fmax):
 
 
21
  freq_mask = (freqs >= fmin) & (freqs <= fmax)
22
  band_psd = psd[:, freq_mask].mean()
23
  return float(band_psd)
@@ -35,19 +36,16 @@ def inspect_file(file):
35
  file_ext = file_ext.lower()
36
 
37
  if file_ext == ".fif":
38
- # FIF files: MNE compatible, no columns needed
39
  return (
40
  "FIF file detected. No need for time column selection. The file's sampling frequency will be used.",
41
  [],
42
  "FIF file doesn't require column inspection."
43
  )
44
  elif file_ext == ".csv":
45
- # Read a small portion of the CSV to determine columns
46
  try:
47
  df = pd.read_csv(file_path, nrows=5)
48
  except Exception as e:
49
  return f"Error reading CSV: {e}", [], "Could not read CSV preview."
50
-
51
  cols = list(df.columns)
52
  preview = df.head().to_markdown()
53
  return (
@@ -62,7 +60,7 @@ def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
62
  """
63
  Load EEG data with flexibility.
64
  If FIF: Use MNE's read_raw_fif directly.
65
- If CSV:
66
  - If time_col is given and present in the file, use it.
67
  - Otherwise, assume default_sfreq.
68
  """
@@ -79,19 +77,14 @@ def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
79
  # Use the selected time column to compute sfreq
80
  time = df[time_col].values
81
  data_df = df.drop(columns=[time_col])
82
-
83
- # Drop non-numeric columns
84
  for col in data_df.columns:
85
  if not pd.api.types.is_numeric_dtype(data_df[col]):
86
  data_df = data_df.drop(columns=[col])
87
 
88
  if len(time) < 2:
89
- # Not enough time points to compute sfreq, fallback
90
  sfreq = default_sfreq
91
  else:
92
- # Compute sfreq from time
93
  dt = np.mean(np.diff(time))
94
- # Ensure dt is positive
95
  if dt <= 0:
96
  sfreq = default_sfreq
97
  else:
@@ -104,13 +97,11 @@ def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
104
  data_df = df
105
  sfreq = default_sfreq
106
 
107
- # Ensure sfreq is positive
108
  if sfreq <= 0:
109
- sfreq = 256.0 # fallback if something odd happens
110
 
111
  ch_names = list(data_df.columns)
112
- data = data_df.values.T # shape: (n_channels, n_samples)
113
-
114
  ch_types = ['eeg'] * len(ch_names)
115
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
116
  raw = mne.io.RawArray(data, info)
@@ -130,8 +121,9 @@ def analyze_eeg(file, default_sfreq, time_col):
130
 
131
  raw = load_eeg_data(file.name, default_sfreq=fs, time_col=time_col)
132
 
133
- # Use the directly imported psd_welch function
134
- psd, freqs = psd_welch(raw, fmin=1, fmax=40)
 
135
  alpha_power = compute_band_power(psd, freqs, 8, 12)
136
  beta_power = compute_band_power(psd, freqs, 13, 30)
137
 
@@ -154,11 +146,8 @@ Provide a concise, user-friendly interpretation of these findings in simple term
154
  def preview_file(file):
155
  msg, cols, preview = inspect_file(file)
156
  # Always include (No time column) as the first choice
157
- # If no columns were found, we still have (No time column) as an option
158
  cols = ["(No time column)"] + cols
159
  default_value = "(No time column)"
160
-
161
- # Return an update dict for the dropdown
162
  return msg, gr.update(choices=cols, value=default_value), preview
163
 
164
  with gr.Blocks() as demo:
@@ -173,11 +162,14 @@ with gr.Blocks() as demo:
173
  file_input = gr.File(label="Upload your EEG data (FIF or CSV)")
174
  preview_button = gr.Button("Inspect File")
175
  msg_output = gr.Markdown()
176
- # Allow custom values in case something goes off
177
  cols_dropdown = gr.Dropdown(label="Select Time Column (optional)", allow_custom_value=True, interactive=True)
178
  preview_output = gr.Markdown()
179
 
180
- preview_button.click(preview_file, inputs=[file_input], outputs=[msg_output, cols_dropdown, preview_output])
 
 
 
 
181
 
182
  default_sfreq_input = gr.Textbox(label="Default Sampling Frequency (Hz) if no time column", value="100")
183
  analyze_button = gr.Button("Run Analysis")
 
4
  import pandas as pd
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
7
  import torch
8
 
9
  # Load LLM
 
17
  )
18
 
19
  def compute_band_power(psd, freqs, fmin, fmax):
20
+ # psd shape: (n_channels, n_freqs)
21
+ # freqs shape: (n_freqs,)
22
  freq_mask = (freqs >= fmin) & (freqs <= fmax)
23
  band_psd = psd[:, freq_mask].mean()
24
  return float(band_psd)
 
36
  file_ext = file_ext.lower()
37
 
38
  if file_ext == ".fif":
 
39
  return (
40
  "FIF file detected. No need for time column selection. The file's sampling frequency will be used.",
41
  [],
42
  "FIF file doesn't require column inspection."
43
  )
44
  elif file_ext == ".csv":
 
45
  try:
46
  df = pd.read_csv(file_path, nrows=5)
47
  except Exception as e:
48
  return f"Error reading CSV: {e}", [], "Could not read CSV preview."
 
49
  cols = list(df.columns)
50
  preview = df.head().to_markdown()
51
  return (
 
60
  """
61
  Load EEG data with flexibility.
62
  If FIF: Use MNE's read_raw_fif directly.
63
+ If CSV:
64
  - If time_col is given and present in the file, use it.
65
  - Otherwise, assume default_sfreq.
66
  """
 
77
  # Use the selected time column to compute sfreq
78
  time = df[time_col].values
79
  data_df = df.drop(columns=[time_col])
 
 
80
  for col in data_df.columns:
81
  if not pd.api.types.is_numeric_dtype(data_df[col]):
82
  data_df = data_df.drop(columns=[col])
83
 
84
  if len(time) < 2:
 
85
  sfreq = default_sfreq
86
  else:
 
87
  dt = np.mean(np.diff(time))
 
88
  if dt <= 0:
89
  sfreq = default_sfreq
90
  else:
 
97
  data_df = df
98
  sfreq = default_sfreq
99
 
 
100
  if sfreq <= 0:
101
+ sfreq = 256.0
102
 
103
  ch_names = list(data_df.columns)
104
+ data = data_df.values.T # (n_channels, n_samples)
 
105
  ch_types = ['eeg'] * len(ch_names)
106
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
107
  raw = mne.io.RawArray(data, info)
 
121
 
122
  raw = load_eeg_data(file.name, default_sfreq=fs, time_col=time_col)
123
 
124
+ # Use raw.compute_psd instead of psd_welch
125
+ psd_obj = raw.compute_psd(fmin=1, fmax=40, method='welch')
126
+ psd, freqs = psd_obj.get_data(return_freqs=True)
127
  alpha_power = compute_band_power(psd, freqs, 8, 12)
128
  beta_power = compute_band_power(psd, freqs, 13, 30)
129
 
 
146
  def preview_file(file):
147
  msg, cols, preview = inspect_file(file)
148
  # Always include (No time column) as the first choice
 
149
  cols = ["(No time column)"] + cols
150
  default_value = "(No time column)"
 
 
151
  return msg, gr.update(choices=cols, value=default_value), preview
152
 
153
  with gr.Blocks() as demo:
 
162
  file_input = gr.File(label="Upload your EEG data (FIF or CSV)")
163
  preview_button = gr.Button("Inspect File")
164
  msg_output = gr.Markdown()
 
165
  cols_dropdown = gr.Dropdown(label="Select Time Column (optional)", allow_custom_value=True, interactive=True)
166
  preview_output = gr.Markdown()
167
 
168
+ preview_button.click(
169
+ preview_file,
170
+ inputs=[file_input],
171
+ outputs=[msg_output, cols_dropdown, preview_output]
172
+ )
173
 
174
  default_sfreq_input = gr.Textbox(label="Default Sampling Frequency (Hz) if no time column", value="100")
175
  analyze_button = gr.Button("Run Analysis")