Rogerjs commited on
Commit
49be262
·
verified ·
1 Parent(s): 7fd8d2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -39
app.py CHANGED
@@ -6,28 +6,27 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import torch
7
  import os
8
 
9
- # Load an open-source LLM model with no additional training
10
  model_name = "tiiuae/falcon-7b-instruct"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_name,
14
  trust_remote_code=True,
15
  torch_dtype=torch.float16,
16
- device_map="auto" # Automatically selects CPU/GPU if available
17
  )
18
 
19
  def compute_band_power(psd, freqs, fmin, fmax):
20
- """Compute mean band power in the given frequency range."""
21
  freq_mask = (freqs >= fmin) & (freqs <= fmax)
22
- # Take the mean across channels and frequencies
23
  band_psd = psd[:, freq_mask].mean()
24
  return float(band_psd)
25
 
26
- def load_eeg_data(file_path):
27
  """
28
- Load EEG data from a file.
29
- If FIF file is detected, use MNE's read_raw_fif.
30
- If CSV file is detected, load via pandas and create a RawArray.
 
 
31
  """
32
  _, file_ext = os.path.splitext(file_path)
33
  file_ext = file_ext.lower()
@@ -35,24 +34,33 @@ def load_eeg_data(file_path):
35
  if file_ext == '.fif':
36
  raw = mne.io.read_raw_fif(file_path, preload=True)
37
  elif file_ext == '.csv':
38
- # Assume first column is 'time', and subsequent columns are channels
39
  df = pd.read_csv(file_path)
40
- if 'time' not in df.columns:
41
- raise ValueError("CSV must contain a 'time' column for timestamps.")
42
-
43
- time = df['time'].values
44
- data = df.drop(columns=['time']).values.T # shape: (n_channels, n_samples)
45
-
46
- # Estimate sampling frequency from time vector (assuming uniform)
47
- # This is a simplistic approach: we take 1 / average time step.
48
- # Make sure time is in seconds
49
- if len(time) < 2:
50
- raise ValueError("Not enough time points in CSV.")
51
- sfreq = 1.0 / np.mean(np.diff(time))
52
-
53
- # Create MNE Info
54
- ch_names = list(df.columns)
55
- ch_names.remove('time')
 
 
 
 
 
 
 
 
 
 
56
  ch_types = ['eeg'] * len(ch_names)
57
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
58
 
@@ -62,31 +70,24 @@ def load_eeg_data(file_path):
62
 
63
  return raw
64
 
65
- def process_eeg(file):
66
- # Load EEG data
67
- raw = load_eeg_data(file.name)
68
 
69
- # Compute PSD (Power Spectral Density) between 1 and 40 Hz
70
  psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)
71
-
72
- # Compute simple band powers
73
  alpha_power = compute_band_power(psd, freqs, 8, 12)
74
  beta_power = compute_band_power(psd, freqs, 13, 30)
75
 
76
- # Create a short summary of the extracted features
77
  data_summary = (
78
  f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
79
  f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
80
  )
81
 
82
- # Prepare the prompt for the language model
83
  prompt = f"""You are a neuroscientist analyzing EEG features.
84
  Data Summary: {data_summary}
85
 
86
  Provide a concise, user-friendly interpretation of these findings in simple terms.
87
  """
88
 
89
- # Generate the summary using the LLM
90
  inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
91
  outputs = model.generate(
92
  inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
@@ -97,12 +98,18 @@ Provide a concise, user-friendly interpretation of these findings in simple term
97
 
98
  iface = gr.Interface(
99
  fn=process_eeg,
100
- inputs=gr.File(label="Upload your EEG data (FIF or CSV)"),
 
 
 
 
101
  outputs="text",
102
- title="NeuroNarrative-Lite: EEG Summary",
103
- description=("Upload EEG data in FIF (MNE native) or CSV format. "
104
- "The system extracts basic EEG features and generates "
105
- "a human-readable summary using an open-source language model.")
 
 
106
  )
107
 
108
  if __name__ == "__main__":
 
6
  import torch
7
  import os
8
 
 
9
  model_name = "tiiuae/falcon-7b-instruct"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_name,
13
  trust_remote_code=True,
14
  torch_dtype=torch.float16,
15
+ device_map="auto"
16
  )
17
 
18
  def compute_band_power(psd, freqs, fmin, fmax):
 
19
  freq_mask = (freqs >= fmin) & (freqs <= fmax)
 
20
  band_psd = psd[:, freq_mask].mean()
21
  return float(band_psd)
22
 
23
+ def load_eeg_data(file_path, default_sfreq=256.0, time_col='time'):
24
  """
25
+ Load EEG data from a file with flexible CSV handling.
26
+ - If FIF: Use read_raw_fif.
27
+ - If CSV:
28
+ * If `time_col` is present, use it as time.
29
+ * Otherwise, assume a default sfreq and treat all columns as channels.
30
  """
31
  _, file_ext = os.path.splitext(file_path)
32
  file_ext = file_ext.lower()
 
34
  if file_ext == '.fif':
35
  raw = mne.io.read_raw_fif(file_path, preload=True)
36
  elif file_ext == '.csv':
 
37
  df = pd.read_csv(file_path)
38
+
39
+ # Remove non-numeric columns except time_col
40
+ for col in df.columns:
41
+ if col != time_col:
42
+ # Drop non-numeric columns if any
43
+ if not pd.api.types.is_numeric_dtype(df[col]):
44
+ df = df.drop(columns=[col])
45
+
46
+ if time_col in df.columns:
47
+ # Use the provided time column
48
+ time = df[time_col].values
49
+ data_df = df.drop(columns=[time_col])
50
+
51
+ if len(time) < 2:
52
+ raise ValueError("Not enough time points to estimate sampling frequency.")
53
+ sfreq = 1.0 / np.mean(np.diff(time))
54
+ else:
55
+ # No explicit time column, assume uniform sampling at default_sfreq
56
+ sfreq = default_sfreq
57
+ data_df = df
58
+
59
+ # Channels are all remaining columns
60
+ ch_names = list(data_df.columns)
61
+ data = data_df.values.T # shape: (n_channels, n_samples)
62
+
63
+ # Create MNE info
64
  ch_types = ['eeg'] * len(ch_names)
65
  info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
66
 
 
70
 
71
  return raw
72
 
73
+ def process_eeg(file, default_sfreq, time_col):
74
+ raw = load_eeg_data(file.name, default_sfreq=float(default_sfreq), time_col=time_col)
 
75
 
 
76
  psd, freqs = mne.time_frequency.psd_welch(raw, fmin=1, fmax=40)
 
 
77
  alpha_power = compute_band_power(psd, freqs, 8, 12)
78
  beta_power = compute_band_power(psd, freqs, 13, 30)
79
 
 
80
  data_summary = (
81
  f"Alpha power: {alpha_power:.3f}, Beta power: {beta_power:.3f}. "
82
  f"The EEG shows stable alpha rhythms and slightly elevated beta activity."
83
  )
84
 
 
85
  prompt = f"""You are a neuroscientist analyzing EEG features.
86
  Data Summary: {data_summary}
87
 
88
  Provide a concise, user-friendly interpretation of these findings in simple terms.
89
  """
90
 
 
91
  inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
92
  outputs = model.generate(
93
  inputs, max_length=200, do_sample=True, top_k=50, top_p=0.95
 
98
 
99
  iface = gr.Interface(
100
  fn=process_eeg,
101
+ inputs=[
102
+ gr.File(label="Upload your EEG data (FIF or CSV)"),
103
+ gr.Textbox(label="Default Sampling Frequency if no time column (Hz)", value="256"),
104
+ gr.Textbox(label="Time column name (if exists)", value="time")
105
+ ],
106
  outputs="text",
107
+ title="NeuroNarrative-Lite: EEG Summary (Flexible CSV Handling)",
108
+ description=(
109
+ "Upload EEG data in FIF or CSV format. "
110
+ "If CSV, either include a 'time' column or specify a default sampling frequency. "
111
+ "Non-numeric columns will be removed (except the chosen time column)."
112
+ )
113
  )
114
 
115
  if __name__ == "__main__":