Rogerjs commited on
Commit
56dc0d1
·
verified ·
1 Parent(s): 9f5f6bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -6
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
  import mne
3
  import numpy as np
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
 
6
 
7
  # Load an open-source LLM model with no additional training
8
  model_name = "tiiuae/falcon-7b-instruct"
@@ -18,13 +20,51 @@ def compute_band_power(psd, freqs, fmin, fmax):
18
  """Compute mean band power in the given frequency range."""
19
  freq_mask = (freqs >= fmin) & (freqs <= fmax)
20
  # Take the mean across channels and frequencies
21
- band_psd = psd[:, freq_mask].mean()
22
  return float(band_psd)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def process_eeg(file):
25
- # Load EEG data using MNE
26
- # This expects a .fif file containing raw EEG data
27
- raw = mne.io.read_raw_fif(file.name, preload=True)
28
 
29
  # Compute PSD (Power Spectral Density) between 1 and 40 Hz
30
  psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)
@@ -57,10 +97,12 @@ Provide a concise, user-friendly interpretation of these findings in simple term
57
 
58
  iface = gr.Interface(
59
  fn=process_eeg,
60
- inputs=gr.File(label="Upload your EEG data (FIF format)"),
61
  outputs="text",
62
  title="NeuroNarrative-Lite: EEG Summary",
63
- description="Upload EEG data to receive a text-based summary from an open-source language model. No training required!"
 
 
64
  )
65
 
66
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import mne
3
  import numpy as np
4
+ import pandas as pd
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import torch
7
+ import os
8
 
9
  # Load an open-source LLM model with no additional training
10
  model_name = "tiiuae/falcon-7b-instruct"
 
20
  """Compute mean band power in the given frequency range."""
21
  freq_mask = (freqs >= fmin) & (freqs <= fmax)
22
  # Take the mean across channels and frequencies
23
+ band_psd = psd[:, freq_mask].mean()
24
  return float(band_psd)
25
 
26
+ def load_eeg_data(file_path):
27
+ """
28
+ Load EEG data from a file.
29
+ If FIF file is detected, use MNE's read_raw_fif.
30
+ If CSV file is detected, load via pandas and create a RawArray.
31
+ """
32
+ _, file_ext = os.path.splitext(file_path)
33
+ file_ext = file_ext.lower()
34
+
35
+ if file_ext == '.fif':
36
+ raw = mne.io.read_raw_fif(file_path, preload=True)
37
+ elif file_ext == '.csv':
38
+ # Assume first column is 'time', and subsequent columns are channels
39
+ df = pd.read_csv(file_path)
40
+ if 'time' not in df.columns:
41
+ raise ValueError("CSV must contain a 'time' column for timestamps.")
42
+
43
+ time = df['time'].values
44
+ data = df.drop(columns=['time']).values.T # shape: (n_channels, n_samples)
45
+
46
+ # Estimate sampling frequency from time vector (assuming uniform)
47
+ # This is a simplistic approach: we take 1 / average time step.
48
+ # Make sure time is in seconds
49
+ if len(time) < 2:
50
+ raise ValueError("Not enough time points in CSV.")
51
+ sfreq = 1.0 / np.mean(np.diff(time))
52
+
53
+ # Create MNE Info
54
+ ch_names = list(df.columns)
55
+ ch_names.remove('time')
56
+ ch_types = ['eeg'] * len(ch_names)
57
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
58
+
59
+ raw = mne.io.RawArray(data, info)
60
+ else:
61
+ raise ValueError("Unsupported file format. Please provide a FIF or CSV file.")
62
+
63
+ return raw
64
+
65
  def process_eeg(file):
66
+ # Load EEG data
67
+ raw = load_eeg_data(file.name)
 
68
 
69
  # Compute PSD (Power Spectral Density) between 1 and 40 Hz
70
  psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)
 
97
 
98
  iface = gr.Interface(
99
  fn=process_eeg,
100
+ inputs=gr.File(label="Upload your EEG data (FIF or CSV)"),
101
  outputs="text",
102
  title="NeuroNarrative-Lite: EEG Summary",
103
+ description=("Upload EEG data in FIF (MNE native) or CSV format. "
104
+ "The system extracts basic EEG features and generates "
105
+ "a human-readable summary using an open-source language model.")
106
  )
107
 
108
  if __name__ == "__main__":