ckadirt commited on
Commit
65ffd92
·
verified ·
1 Parent(s): 6edc6bc

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. cache/models--facebook--musicgen-small/blobs/1bdc99d43eb6c775967df24b65b0a9f847c0907e95664698d93b5a1c35f5090d +3 -0
  3. cache/models--facebook--musicgen-small/blobs/45e996eaadd56e1cdaab46cb5e97d295541d40a3 +10 -0
  4. cache/models--facebook--musicgen-small/blobs/9664ce3a7ca28d1084f10413971b54481198589a +298 -0
  5. cache/models--facebook--musicgen-small/refs/main +1 -0
  6. cache/models--facebook--musicgen-small/snapshots/51027f0bee8489c1750a7b8a4806894ab2e7dc4d/config.json +298 -0
  7. cache/models--facebook--musicgen-small/snapshots/51027f0bee8489c1750a7b8a4806894ab2e7dc4d/generation_config.json +10 -0
  8. cache/models--facebook--musicgen-small/snapshots/51027f0bee8489c1750a7b8a4806894ab2e7dc4d/model.safetensors +3 -0
  9. data/encodec32khz_testing_embeds_sorted.npy +3 -0
  10. data/encodec32khz_training_embeds_sorted.npy +3 -0
  11. data/encodec_test_embeds.npy +3 -0
  12. data/encodec_testing_embeds_sorted.npy +3 -0
  13. data/encodec_training_embeds.npy +3 -0
  14. data/encodec_training_embeds_sorted.npy +3 -0
  15. data/sub-001_Resp_Test.npy +3 -0
  16. data/sub-001_Resp_Test_Mean.npy +3 -0
  17. data/sub-001_Resp_Training.npy +3 -0
  18. src/.ipynb_checkpoints/Copy_of_MusicGen-checkpoint.ipynb +0 -0
  19. src/.ipynb_checkpoints/MLP-model copy-checkpoint.ipynb +582 -0
  20. src/.ipynb_checkpoints/MLP-model-checkpoint.ipynb +449 -0
  21. src/.ipynb_checkpoints/MLPencoder-checkpoint.ipynb +0 -0
  22. src/.ipynb_checkpoints/mlpdummy-checkpoint.py +146 -0
  23. src/.ipynb_checkpoints/musicgen_test copy-checkpoint.ipynb +0 -0
  24. src/Copy_of_MusicGen.ipynb +0 -0
  25. src/MLP-model copy.ipynb +581 -0
  26. src/MLP-model.ipynb +448 -0
  27. src/MLPencoder.ipynb +0 -0
  28. src/b2m-ckpt1 +0 -0
  29. src/b2m-ckpt1.pt +3 -0
  30. src/mlpdummy.py +146 -0
  31. src/musicgen_test copy.ipynb +0 -0
  32. src/musicgen_test.ipynb +0 -0
  33. src/outputs_train0.pt +3 -0
  34. src/outputs_train1.pt +3 -0
  35. src/outputs_train10.pt +3 -0
  36. src/outputs_train11.pt +3 -0
  37. src/outputs_train12.pt +3 -0
  38. src/outputs_train13.pt +3 -0
  39. src/outputs_train14.pt +3 -0
  40. src/outputs_train15.pt +3 -0
  41. src/outputs_train16.pt +3 -0
  42. src/outputs_train17.pt +3 -0
  43. src/outputs_train18.pt +3 -0
  44. src/outputs_train19.pt +3 -0
  45. src/outputs_train2.pt +3 -0
  46. src/outputs_train20.pt +3 -0
  47. src/outputs_train21.pt +3 -0
  48. src/outputs_train22.pt +3 -0
  49. src/outputs_train23.pt +3 -0
  50. src/outputs_train24.pt +3 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ cache/models--facebook--musicgen-small/blobs/1bdc99d43eb6c775967df24b65b0a9f847c0907e95664698d93b5a1c35f5090d filter=lfs diff=lfs merge=lfs -text
37
+ src/playground.ipynb filter=lfs diff=lfs merge=lfs -text
38
+ src/wandb/latest-run/run-jggbeix7.wandb filter=lfs diff=lfs merge=lfs -text
39
+ src/wandb/run-20230831_163543-sqaecwr8/run-sqaecwr8.wandb filter=lfs diff=lfs merge=lfs -text
40
+ src/wandb/run-20230831_165224-f6iksuh9/run-f6iksuh9.wandb filter=lfs diff=lfs merge=lfs -text
41
+ src/wandb/run-20230831_200743-4bm6v8ps/run-4bm6v8ps.wandb filter=lfs diff=lfs merge=lfs -text
42
+ src/wandb/run-20230831_203013-sdb63g4i/run-sdb63g4i.wandb filter=lfs diff=lfs merge=lfs -text
43
+ src/wandb/run-20230831_220330-mf53e4vk/run-mf53e4vk.wandb filter=lfs diff=lfs merge=lfs -text
44
+ src/wandb/run-20230901_003519-kxdc1ebl/run-kxdc1ebl.wandb filter=lfs diff=lfs merge=lfs -text
45
+ src/wandb/run-20230901_011334-pej3kex1/run-pej3kex1.wandb filter=lfs diff=lfs merge=lfs -text
46
+ src/wandb/run-20230908_042527-jggbeix7/run-jggbeix7.wandb filter=lfs diff=lfs merge=lfs -text
cache/models--facebook--musicgen-small/blobs/1bdc99d43eb6c775967df24b65b0a9f847c0907e95664698d93b5a1c35f5090d ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bdc99d43eb6c775967df24b65b0a9f847c0907e95664698d93b5a1c35f5090d
3
+ size 2364427288
cache/models--facebook--musicgen-small/blobs/45e996eaadd56e1cdaab46cb5e97d295541d40a3 ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2048,
4
+ "decoder_start_token_id": 2048,
5
+ "do_sample": true,
6
+ "guidance_scale": 3.0,
7
+ "max_length": 1500,
8
+ "pad_token_id": 2048,
9
+ "transformers_version": "4.31.0.dev0"
10
+ }
cache/models--facebook--musicgen-small/blobs/9664ce3a7ca28d1084f10413971b54481198589a ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "MusicgenForConditionalGeneration"
5
+ ],
6
+ "audio_encoder": {
7
+ "_name_or_path": "facebook/encodec_32khz",
8
+ "add_cross_attention": false,
9
+ "architectures": [
10
+ "EncodecModel"
11
+ ],
12
+ "audio_channels": 1,
13
+ "bad_words_ids": null,
14
+ "begin_suppress_tokens": null,
15
+ "bos_token_id": null,
16
+ "chunk_length_s": null,
17
+ "chunk_size_feed_forward": 0,
18
+ "codebook_dim": 128,
19
+ "codebook_size": 2048,
20
+ "compress": 2,
21
+ "cross_attention_hidden_size": null,
22
+ "decoder_start_token_id": null,
23
+ "dilation_growth_rate": 2,
24
+ "diversity_penalty": 0.0,
25
+ "do_sample": false,
26
+ "early_stopping": false,
27
+ "encoder_no_repeat_ngram_size": 0,
28
+ "eos_token_id": null,
29
+ "exponential_decay_length_penalty": null,
30
+ "finetuning_task": null,
31
+ "forced_bos_token_id": null,
32
+ "forced_eos_token_id": null,
33
+ "hidden_size": 128,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "is_decoder": false,
39
+ "is_encoder_decoder": false,
40
+ "kernel_size": 7,
41
+ "label2id": {
42
+ "LABEL_0": 0,
43
+ "LABEL_1": 1
44
+ },
45
+ "last_kernel_size": 7,
46
+ "length_penalty": 1.0,
47
+ "max_length": 20,
48
+ "min_length": 0,
49
+ "model_type": "encodec",
50
+ "no_repeat_ngram_size": 0,
51
+ "norm_type": "weight_norm",
52
+ "normalize": false,
53
+ "num_beam_groups": 1,
54
+ "num_beams": 1,
55
+ "num_filters": 64,
56
+ "num_lstm_layers": 2,
57
+ "num_residual_layers": 1,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "overlap": null,
63
+ "pad_mode": "reflect",
64
+ "pad_token_id": null,
65
+ "prefix": null,
66
+ "problem_type": null,
67
+ "pruned_heads": {},
68
+ "remove_invalid_values": false,
69
+ "repetition_penalty": 1.0,
70
+ "residual_kernel_size": 3,
71
+ "return_dict": true,
72
+ "return_dict_in_generate": false,
73
+ "sampling_rate": 32000,
74
+ "sep_token_id": null,
75
+ "suppress_tokens": null,
76
+ "target_bandwidths": [
77
+ 2.2
78
+ ],
79
+ "task_specific_params": null,
80
+ "temperature": 1.0,
81
+ "tf_legacy_loss": false,
82
+ "tie_encoder_decoder": false,
83
+ "tie_word_embeddings": true,
84
+ "tokenizer_class": null,
85
+ "top_k": 50,
86
+ "top_p": 1.0,
87
+ "torch_dtype": "float32",
88
+ "torchscript": false,
89
+ "transformers_version": "4.31.0.dev0",
90
+ "trim_right_ratio": 1.0,
91
+ "typical_p": 1.0,
92
+ "upsampling_ratios": [
93
+ 8,
94
+ 5,
95
+ 4,
96
+ 4
97
+ ],
98
+ "use_bfloat16": false,
99
+ "use_causal_conv": false,
100
+ "use_conv_shortcut": false
101
+ },
102
+ "decoder": {
103
+ "_name_or_path": "",
104
+ "activation_dropout": 0.0,
105
+ "activation_function": "gelu",
106
+ "add_cross_attention": false,
107
+ "architectures": null,
108
+ "attention_dropout": 0.0,
109
+ "bad_words_ids": null,
110
+ "begin_suppress_tokens": null,
111
+ "bos_token_id": 2048,
112
+ "chunk_size_feed_forward": 0,
113
+ "classifier_dropout": 0.0,
114
+ "cross_attention_hidden_size": null,
115
+ "decoder_start_token_id": null,
116
+ "diversity_penalty": 0.0,
117
+ "do_sample": false,
118
+ "dropout": 0.1,
119
+ "early_stopping": false,
120
+ "encoder_no_repeat_ngram_size": 0,
121
+ "eos_token_id": null,
122
+ "exponential_decay_length_penalty": null,
123
+ "ffn_dim": 4096,
124
+ "finetuning_task": null,
125
+ "forced_bos_token_id": null,
126
+ "forced_eos_token_id": null,
127
+ "hidden_size": 1024,
128
+ "id2label": {
129
+ "0": "LABEL_0",
130
+ "1": "LABEL_1"
131
+ },
132
+ "initializer_factor": 0.02,
133
+ "is_decoder": false,
134
+ "is_encoder_decoder": false,
135
+ "label2id": {
136
+ "LABEL_0": 0,
137
+ "LABEL_1": 1
138
+ },
139
+ "layerdrop": 0.0,
140
+ "length_penalty": 1.0,
141
+ "max_length": 20,
142
+ "max_position_embeddings": 2048,
143
+ "min_length": 0,
144
+ "model_type": "musicgen_decoder",
145
+ "no_repeat_ngram_size": 0,
146
+ "num_attention_heads": 16,
147
+ "num_beam_groups": 1,
148
+ "num_beams": 1,
149
+ "num_codebooks": 4,
150
+ "num_hidden_layers": 24,
151
+ "num_return_sequences": 1,
152
+ "output_attentions": false,
153
+ "output_hidden_states": false,
154
+ "output_scores": false,
155
+ "pad_token_id": 2048,
156
+ "prefix": null,
157
+ "problem_type": null,
158
+ "pruned_heads": {},
159
+ "remove_invalid_values": false,
160
+ "repetition_penalty": 1.0,
161
+ "return_dict": true,
162
+ "return_dict_in_generate": false,
163
+ "scale_embedding": false,
164
+ "sep_token_id": null,
165
+ "suppress_tokens": null,
166
+ "task_specific_params": null,
167
+ "temperature": 1.0,
168
+ "tf_legacy_loss": false,
169
+ "tie_encoder_decoder": false,
170
+ "tie_word_embeddings": false,
171
+ "tokenizer_class": null,
172
+ "top_k": 50,
173
+ "top_p": 1.0,
174
+ "torch_dtype": null,
175
+ "torchscript": false,
176
+ "transformers_version": "4.31.0.dev0",
177
+ "typical_p": 1.0,
178
+ "use_bfloat16": false,
179
+ "use_cache": true,
180
+ "vocab_size": 2048
181
+ },
182
+ "is_encoder_decoder": true,
183
+ "model_type": "musicgen",
184
+ "text_encoder": {
185
+ "_name_or_path": "t5-base",
186
+ "add_cross_attention": false,
187
+ "architectures": [
188
+ "T5ForConditionalGeneration"
189
+ ],
190
+ "bad_words_ids": null,
191
+ "begin_suppress_tokens": null,
192
+ "bos_token_id": null,
193
+ "chunk_size_feed_forward": 0,
194
+ "cross_attention_hidden_size": null,
195
+ "d_ff": 3072,
196
+ "d_kv": 64,
197
+ "d_model": 768,
198
+ "decoder_start_token_id": 0,
199
+ "dense_act_fn": "relu",
200
+ "diversity_penalty": 0.0,
201
+ "do_sample": false,
202
+ "dropout_rate": 0.1,
203
+ "early_stopping": false,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "eos_token_id": 1,
206
+ "exponential_decay_length_penalty": null,
207
+ "feed_forward_proj": "relu",
208
+ "finetuning_task": null,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "id2label": {
212
+ "0": "LABEL_0",
213
+ "1": "LABEL_1"
214
+ },
215
+ "initializer_factor": 1.0,
216
+ "is_decoder": false,
217
+ "is_encoder_decoder": true,
218
+ "is_gated_act": false,
219
+ "label2id": {
220
+ "LABEL_0": 0,
221
+ "LABEL_1": 1
222
+ },
223
+ "layer_norm_epsilon": 1e-06,
224
+ "length_penalty": 1.0,
225
+ "max_length": 20,
226
+ "min_length": 0,
227
+ "model_type": "t5",
228
+ "n_positions": 512,
229
+ "no_repeat_ngram_size": 0,
230
+ "num_beam_groups": 1,
231
+ "num_beams": 1,
232
+ "num_decoder_layers": 12,
233
+ "num_heads": 12,
234
+ "num_layers": 12,
235
+ "num_return_sequences": 1,
236
+ "output_attentions": false,
237
+ "output_hidden_states": false,
238
+ "output_past": true,
239
+ "output_scores": false,
240
+ "pad_token_id": 0,
241
+ "prefix": null,
242
+ "problem_type": null,
243
+ "pruned_heads": {},
244
+ "relative_attention_max_distance": 128,
245
+ "relative_attention_num_buckets": 32,
246
+ "remove_invalid_values": false,
247
+ "repetition_penalty": 1.0,
248
+ "return_dict": true,
249
+ "return_dict_in_generate": false,
250
+ "sep_token_id": null,
251
+ "suppress_tokens": null,
252
+ "task_specific_params": {
253
+ "summarization": {
254
+ "early_stopping": true,
255
+ "length_penalty": 2.0,
256
+ "max_length": 200,
257
+ "min_length": 30,
258
+ "no_repeat_ngram_size": 3,
259
+ "num_beams": 4,
260
+ "prefix": "summarize: "
261
+ },
262
+ "translation_en_to_de": {
263
+ "early_stopping": true,
264
+ "max_length": 300,
265
+ "num_beams": 4,
266
+ "prefix": "translate English to German: "
267
+ },
268
+ "translation_en_to_fr": {
269
+ "early_stopping": true,
270
+ "max_length": 300,
271
+ "num_beams": 4,
272
+ "prefix": "translate English to French: "
273
+ },
274
+ "translation_en_to_ro": {
275
+ "early_stopping": true,
276
+ "max_length": 300,
277
+ "num_beams": 4,
278
+ "prefix": "translate English to Romanian: "
279
+ }
280
+ },
281
+ "temperature": 1.0,
282
+ "tf_legacy_loss": false,
283
+ "tie_encoder_decoder": false,
284
+ "tie_word_embeddings": true,
285
+ "tokenizer_class": null,
286
+ "top_k": 50,
287
+ "top_p": 1.0,
288
+ "torch_dtype": null,
289
+ "torchscript": false,
290
+ "transformers_version": "4.31.0.dev0",
291
+ "typical_p": 1.0,
292
+ "use_bfloat16": false,
293
+ "use_cache": true,
294
+ "vocab_size": 32128
295
+ },
296
+ "torch_dtype": "float32",
297
+ "transformers_version": null
298
+ }
cache/models--facebook--musicgen-small/refs/main ADDED
@@ -0,0 +1 @@
 
 
1
+ 51027f0bee8489c1750a7b8a4806894ab2e7dc4d
cache/models--facebook--musicgen-small/snapshots/51027f0bee8489c1750a7b8a4806894ab2e7dc4d/config.json ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "MusicgenForConditionalGeneration"
5
+ ],
6
+ "audio_encoder": {
7
+ "_name_or_path": "facebook/encodec_32khz",
8
+ "add_cross_attention": false,
9
+ "architectures": [
10
+ "EncodecModel"
11
+ ],
12
+ "audio_channels": 1,
13
+ "bad_words_ids": null,
14
+ "begin_suppress_tokens": null,
15
+ "bos_token_id": null,
16
+ "chunk_length_s": null,
17
+ "chunk_size_feed_forward": 0,
18
+ "codebook_dim": 128,
19
+ "codebook_size": 2048,
20
+ "compress": 2,
21
+ "cross_attention_hidden_size": null,
22
+ "decoder_start_token_id": null,
23
+ "dilation_growth_rate": 2,
24
+ "diversity_penalty": 0.0,
25
+ "do_sample": false,
26
+ "early_stopping": false,
27
+ "encoder_no_repeat_ngram_size": 0,
28
+ "eos_token_id": null,
29
+ "exponential_decay_length_penalty": null,
30
+ "finetuning_task": null,
31
+ "forced_bos_token_id": null,
32
+ "forced_eos_token_id": null,
33
+ "hidden_size": 128,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "is_decoder": false,
39
+ "is_encoder_decoder": false,
40
+ "kernel_size": 7,
41
+ "label2id": {
42
+ "LABEL_0": 0,
43
+ "LABEL_1": 1
44
+ },
45
+ "last_kernel_size": 7,
46
+ "length_penalty": 1.0,
47
+ "max_length": 20,
48
+ "min_length": 0,
49
+ "model_type": "encodec",
50
+ "no_repeat_ngram_size": 0,
51
+ "norm_type": "weight_norm",
52
+ "normalize": false,
53
+ "num_beam_groups": 1,
54
+ "num_beams": 1,
55
+ "num_filters": 64,
56
+ "num_lstm_layers": 2,
57
+ "num_residual_layers": 1,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "overlap": null,
63
+ "pad_mode": "reflect",
64
+ "pad_token_id": null,
65
+ "prefix": null,
66
+ "problem_type": null,
67
+ "pruned_heads": {},
68
+ "remove_invalid_values": false,
69
+ "repetition_penalty": 1.0,
70
+ "residual_kernel_size": 3,
71
+ "return_dict": true,
72
+ "return_dict_in_generate": false,
73
+ "sampling_rate": 32000,
74
+ "sep_token_id": null,
75
+ "suppress_tokens": null,
76
+ "target_bandwidths": [
77
+ 2.2
78
+ ],
79
+ "task_specific_params": null,
80
+ "temperature": 1.0,
81
+ "tf_legacy_loss": false,
82
+ "tie_encoder_decoder": false,
83
+ "tie_word_embeddings": true,
84
+ "tokenizer_class": null,
85
+ "top_k": 50,
86
+ "top_p": 1.0,
87
+ "torch_dtype": "float32",
88
+ "torchscript": false,
89
+ "transformers_version": "4.31.0.dev0",
90
+ "trim_right_ratio": 1.0,
91
+ "typical_p": 1.0,
92
+ "upsampling_ratios": [
93
+ 8,
94
+ 5,
95
+ 4,
96
+ 4
97
+ ],
98
+ "use_bfloat16": false,
99
+ "use_causal_conv": false,
100
+ "use_conv_shortcut": false
101
+ },
102
+ "decoder": {
103
+ "_name_or_path": "",
104
+ "activation_dropout": 0.0,
105
+ "activation_function": "gelu",
106
+ "add_cross_attention": false,
107
+ "architectures": null,
108
+ "attention_dropout": 0.0,
109
+ "bad_words_ids": null,
110
+ "begin_suppress_tokens": null,
111
+ "bos_token_id": 2048,
112
+ "chunk_size_feed_forward": 0,
113
+ "classifier_dropout": 0.0,
114
+ "cross_attention_hidden_size": null,
115
+ "decoder_start_token_id": null,
116
+ "diversity_penalty": 0.0,
117
+ "do_sample": false,
118
+ "dropout": 0.1,
119
+ "early_stopping": false,
120
+ "encoder_no_repeat_ngram_size": 0,
121
+ "eos_token_id": null,
122
+ "exponential_decay_length_penalty": null,
123
+ "ffn_dim": 4096,
124
+ "finetuning_task": null,
125
+ "forced_bos_token_id": null,
126
+ "forced_eos_token_id": null,
127
+ "hidden_size": 1024,
128
+ "id2label": {
129
+ "0": "LABEL_0",
130
+ "1": "LABEL_1"
131
+ },
132
+ "initializer_factor": 0.02,
133
+ "is_decoder": false,
134
+ "is_encoder_decoder": false,
135
+ "label2id": {
136
+ "LABEL_0": 0,
137
+ "LABEL_1": 1
138
+ },
139
+ "layerdrop": 0.0,
140
+ "length_penalty": 1.0,
141
+ "max_length": 20,
142
+ "max_position_embeddings": 2048,
143
+ "min_length": 0,
144
+ "model_type": "musicgen_decoder",
145
+ "no_repeat_ngram_size": 0,
146
+ "num_attention_heads": 16,
147
+ "num_beam_groups": 1,
148
+ "num_beams": 1,
149
+ "num_codebooks": 4,
150
+ "num_hidden_layers": 24,
151
+ "num_return_sequences": 1,
152
+ "output_attentions": false,
153
+ "output_hidden_states": false,
154
+ "output_scores": false,
155
+ "pad_token_id": 2048,
156
+ "prefix": null,
157
+ "problem_type": null,
158
+ "pruned_heads": {},
159
+ "remove_invalid_values": false,
160
+ "repetition_penalty": 1.0,
161
+ "return_dict": true,
162
+ "return_dict_in_generate": false,
163
+ "scale_embedding": false,
164
+ "sep_token_id": null,
165
+ "suppress_tokens": null,
166
+ "task_specific_params": null,
167
+ "temperature": 1.0,
168
+ "tf_legacy_loss": false,
169
+ "tie_encoder_decoder": false,
170
+ "tie_word_embeddings": false,
171
+ "tokenizer_class": null,
172
+ "top_k": 50,
173
+ "top_p": 1.0,
174
+ "torch_dtype": null,
175
+ "torchscript": false,
176
+ "transformers_version": "4.31.0.dev0",
177
+ "typical_p": 1.0,
178
+ "use_bfloat16": false,
179
+ "use_cache": true,
180
+ "vocab_size": 2048
181
+ },
182
+ "is_encoder_decoder": true,
183
+ "model_type": "musicgen",
184
+ "text_encoder": {
185
+ "_name_or_path": "t5-base",
186
+ "add_cross_attention": false,
187
+ "architectures": [
188
+ "T5ForConditionalGeneration"
189
+ ],
190
+ "bad_words_ids": null,
191
+ "begin_suppress_tokens": null,
192
+ "bos_token_id": null,
193
+ "chunk_size_feed_forward": 0,
194
+ "cross_attention_hidden_size": null,
195
+ "d_ff": 3072,
196
+ "d_kv": 64,
197
+ "d_model": 768,
198
+ "decoder_start_token_id": 0,
199
+ "dense_act_fn": "relu",
200
+ "diversity_penalty": 0.0,
201
+ "do_sample": false,
202
+ "dropout_rate": 0.1,
203
+ "early_stopping": false,
204
+ "encoder_no_repeat_ngram_size": 0,
205
+ "eos_token_id": 1,
206
+ "exponential_decay_length_penalty": null,
207
+ "feed_forward_proj": "relu",
208
+ "finetuning_task": null,
209
+ "forced_bos_token_id": null,
210
+ "forced_eos_token_id": null,
211
+ "id2label": {
212
+ "0": "LABEL_0",
213
+ "1": "LABEL_1"
214
+ },
215
+ "initializer_factor": 1.0,
216
+ "is_decoder": false,
217
+ "is_encoder_decoder": true,
218
+ "is_gated_act": false,
219
+ "label2id": {
220
+ "LABEL_0": 0,
221
+ "LABEL_1": 1
222
+ },
223
+ "layer_norm_epsilon": 1e-06,
224
+ "length_penalty": 1.0,
225
+ "max_length": 20,
226
+ "min_length": 0,
227
+ "model_type": "t5",
228
+ "n_positions": 512,
229
+ "no_repeat_ngram_size": 0,
230
+ "num_beam_groups": 1,
231
+ "num_beams": 1,
232
+ "num_decoder_layers": 12,
233
+ "num_heads": 12,
234
+ "num_layers": 12,
235
+ "num_return_sequences": 1,
236
+ "output_attentions": false,
237
+ "output_hidden_states": false,
238
+ "output_past": true,
239
+ "output_scores": false,
240
+ "pad_token_id": 0,
241
+ "prefix": null,
242
+ "problem_type": null,
243
+ "pruned_heads": {},
244
+ "relative_attention_max_distance": 128,
245
+ "relative_attention_num_buckets": 32,
246
+ "remove_invalid_values": false,
247
+ "repetition_penalty": 1.0,
248
+ "return_dict": true,
249
+ "return_dict_in_generate": false,
250
+ "sep_token_id": null,
251
+ "suppress_tokens": null,
252
+ "task_specific_params": {
253
+ "summarization": {
254
+ "early_stopping": true,
255
+ "length_penalty": 2.0,
256
+ "max_length": 200,
257
+ "min_length": 30,
258
+ "no_repeat_ngram_size": 3,
259
+ "num_beams": 4,
260
+ "prefix": "summarize: "
261
+ },
262
+ "translation_en_to_de": {
263
+ "early_stopping": true,
264
+ "max_length": 300,
265
+ "num_beams": 4,
266
+ "prefix": "translate English to German: "
267
+ },
268
+ "translation_en_to_fr": {
269
+ "early_stopping": true,
270
+ "max_length": 300,
271
+ "num_beams": 4,
272
+ "prefix": "translate English to French: "
273
+ },
274
+ "translation_en_to_ro": {
275
+ "early_stopping": true,
276
+ "max_length": 300,
277
+ "num_beams": 4,
278
+ "prefix": "translate English to Romanian: "
279
+ }
280
+ },
281
+ "temperature": 1.0,
282
+ "tf_legacy_loss": false,
283
+ "tie_encoder_decoder": false,
284
+ "tie_word_embeddings": true,
285
+ "tokenizer_class": null,
286
+ "top_k": 50,
287
+ "top_p": 1.0,
288
+ "torch_dtype": null,
289
+ "torchscript": false,
290
+ "transformers_version": "4.31.0.dev0",
291
+ "typical_p": 1.0,
292
+ "use_bfloat16": false,
293
+ "use_cache": true,
294
+ "vocab_size": 32128
295
+ },
296
+ "torch_dtype": "float32",
297
+ "transformers_version": null
298
+ }
cache/models--facebook--musicgen-small/snapshots/51027f0bee8489c1750a7b8a4806894ab2e7dc4d/generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2048,
4
+ "decoder_start_token_id": 2048,
5
+ "do_sample": true,
6
+ "guidance_scale": 3.0,
7
+ "max_length": 1500,
8
+ "pad_token_id": 2048,
9
+ "transformers_version": "4.31.0.dev0"
10
+ }
cache/models--facebook--musicgen-small/snapshots/51027f0bee8489c1750a7b8a4806894ab2e7dc4d/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bdc99d43eb6c775967df24b65b0a9f847c0907e95664698d93b5a1c35f5090d
3
+ size 2364427288
data/encodec32khz_testing_embeds_sorted.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:365ada4922d3328895b978df9127c44a35a807732157839ae862c23acb174719
3
+ size 2880128
data/encodec32khz_training_embeds_sorted.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b26b9d0e279143f1cad0c7e3b3879afb0c10fc2ba774c1f35521ac04efde67f4
3
+ size 5760128
data/encodec_test_embeds.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ba64ae0cc073640c9104ee24814b480bab5fedbefc3e1fc82f7435243090875
3
+ size 2160128
data/encodec_testing_embeds_sorted.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:691527ea72cbec7b67821d15ce2a179648d25b08b488afec7bcc3580fee2b4d3
3
+ size 2160128
data/encodec_training_embeds.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:733ea631bafb622c61b325435f3beaa6d92d8cbe1756ea2eab40ce1ca9f14170
3
+ size 4320128
data/encodec_training_embeds_sorted.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2da1180e4324c6fa148c531160aa179c8f00dd677de076555bf632ccc5f9e09c
3
+ size 4320128
data/sub-001_Resp_Test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f38bf483013229e1e2d56111fd041743e0948f649aa041ddb0aa0bfbfe0f7e3
3
+ size 583526528
data/sub-001_Resp_Test_Mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abc68b4d0b8974b6d3a720d463beb52ee464860a3d324c1c8e696109248a76c7
3
+ size 145881728
data/sub-001_Resp_Training.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66aaaf44e962c1f3121e7f06b8a07fc9140d10be48cc71e0e3d7a1feaa6b130b
3
+ size 1167052928
src/.ipynb_checkpoints/Copy_of_MusicGen-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/.ipynb_checkpoints/MLP-model copy-checkpoint.ipynb ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/ckadirt/miniconda3/envs/b2m/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv\n",
19
+ "import lightning as L\n",
20
+ "import numpy as np, pandas as pd, matplotlib.pyplot as plt\n",
21
+ "from pytorch_lightning.loggers import WandbLogger"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 5,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "# create the datasets and dataloaders\n",
31
+ "train_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 \n",
32
+ "test_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600\n",
33
+ "\n",
34
+ "train_embeddings_path = '/home/ckadirt/brain2music/encodec_training_embeds_150.npy' # path to training embeddings 480 * 2 * 1125\n",
35
+ "test_embeddings_path = '/home/ckadirt/brain2music/encodec_test_embeds_150.npy' # path to test embeddings 600 * 2 * 1125\n",
36
+ "\n",
37
+ "class VoxelsDataset(data.Dataset):\n",
38
+ " def __init__(self, voxels_path, embeddings_path):\n",
39
+ " # transpose the two dimensions of the voxels data to match the embeddings data\n",
40
+ " self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)\n",
41
+ " self.embeddings = torch.from_numpy(np.load(embeddings_path))\n",
42
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
43
+ " self.len = len(self.voxels) // 10\n",
44
+ " print(\"The len is \", self.len )\n",
45
+ "\n",
46
+ " def __getitem__(self, index):\n",
47
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
48
+ " voxels = self.voxels[index*10:(index+1)*10]\n",
49
+ " embeddings = self.embeddings[index]\n",
50
+ " return voxels, embeddings\n",
51
+ "\n",
52
+ " def __len__(self):\n",
53
+ " return self.len\n",
54
+ " \n",
55
+ "class VoxelsEmbeddinsEncodecDataModule(L.LightningDataModule):\n",
56
+ " def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4):\n",
57
+ " super().__init__()\n",
58
+ " self.train_voxels_path = train_voxels_path\n",
59
+ " self.train_embeddings_path = train_embeddings_path\n",
60
+ " self.test_voxels_path = test_voxels_path\n",
61
+ " self.test_embeddings_path = test_embeddings_path\n",
62
+ " self.batch_size = batch_size\n",
63
+ "\n",
64
+ " def setup(self, stage=None):\n",
65
+ " self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)\n",
66
+ " self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)\n",
67
+ "\n",
68
+ " def train_dataloader(self):\n",
69
+ " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)\n",
70
+ "\n",
71
+ " def val_dataloader(self):\n",
72
+ " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)\n",
73
+ "\n",
74
+ "\n"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 33,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "data_module_example = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 34,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "data_module_example.setup()\n",
93
+ "train_dataloader = data_module_example.train_dataloader()\n",
94
+ "val_dataset = data_module_example.val_dataloader()"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 38,
100
+ "metadata": {},
101
+ "outputs": [
102
+ {
103
+ "data": {
104
+ "text/plain": [
105
+ "(tensor([], size=(0, 60784)),\n",
106
+ " tensor([[ 302., 244., 660., 854., 660., 480., 854., 618., 618., 854.,\n",
107
+ " 790., 750., 659., 59., 891., 891., 536., 167., 343., 536.,\n",
108
+ " 715., 758., 758., 758., 480., 498., 854., 4., 4., 308.,\n",
109
+ " 270., 342., 342., 660., 342., 854., 342., 435., 549., 150.,\n",
110
+ " 631., 485., 844., 366., 266., 35., 847., 667., 862., 109.,\n",
111
+ " 573., 379., 226., 573., 603., 513., 178., 302., 715., 631.,\n",
112
+ " 342., 258., 244., 302., 715., 854., 854., 294., 366., 660.,\n",
113
+ " 361., 302., 729., 962., 790., 711., 660., 243., 294., 802.,\n",
114
+ " 329., 513., 962., 342., 711., 244., 243., 549., 802., 854.,\n",
115
+ " 750., 81., 342., 381., 854., 603., 790., 109., 294., 513.,\n",
116
+ " 419., 485., 504., 660., 361., 790., 790., 167., 802., 246.,\n",
117
+ " 485., 246., 81., 1023., 149., 81., 943., 504., 755., 414.,\n",
118
+ " 246., 972., 715., 1023., 790., 692., 790., 572., 504., 302.,\n",
119
+ " 308., 853., 631., 657., 790., 361., 660., 715., 686., 213.,\n",
120
+ " 226., 187., 586., 361., 485., 790., 729., 951., 962., 485.],\n",
121
+ " [ 963., 645., 645., 326., 138., 1013., 680., 525., 411., 102.,\n",
122
+ " 462., 466., 698., 409., 289., 923., 878., 415., 386., 604.,\n",
123
+ " 975., 162., 603., 284., 233., 75., 244., 1016., 1016., 242.,\n",
124
+ " 67., 194., 122., 492., 856., 997., 997., 221., 243., 814.,\n",
125
+ " 386., 598., 317., 166., 583., 439., 654., 430., 201., 160.,\n",
126
+ " 813., 716., 312., 664., 204., 462., 375., 451., 67., 535.,\n",
127
+ " 854., 209., 548., 812., 657., 827., 408., 411., 422., 352.,\n",
128
+ " 99., 711., 664., 239., 890., 529., 617., 186., 536., 178.,\n",
129
+ " 29., 930., 187., 973., 354., 450., 468., 273., 995., 653.,\n",
130
+ " 935., 335., 973., 812., 348., 664., 575., 184., 299., 782.,\n",
131
+ " 36., 29., 641., 653., 105., 958., 653., 828., 981., 218.,\n",
132
+ " 1021., 381., 356., 35., 416., 675., 45., 839., 690., 331.,\n",
133
+ " 634., 610., 317., 745., 673., 331., 575., 57., 100., 564.,\n",
134
+ " 590., 492., 902., 53., 73., 332., 1005., 395., 679., 781.,\n",
135
+ " 174., 74., 121., 667., 265., 479., 583., 655., 163., 81.]]))"
136
+ ]
137
+ },
138
+ "execution_count": 38,
139
+ "metadata": {},
140
+ "output_type": "execute_result"
141
+ }
142
+ ],
143
+ "source": [
144
+ "val_dataset.dataset[239]"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "class MLP(L.LightningModule):\n",
154
+ " def __init__(self, sizes, residual_conections, dropout):\n",
155
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
156
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
157
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
158
+ " super().__init__()\n",
159
+ " self.sizes = sizes\n",
160
+ " self.residual_conections = residual_conections\n",
161
+ " self.dropout = dropout\n",
162
+ " self.layers = nn.Sequential()\n",
163
+ " for i in range(len(sizes)-1):\n",
164
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
165
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
166
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
167
+ "\n",
168
+ " self.loss = nn.CrossEntropyLoss()\n",
169
+ " self.test_outptus = []\n",
170
+ " self.train_outptus = []\n",
171
+ "\n",
172
+ " def forward(self, x):\n",
173
+ " return self.layers(x)\n",
174
+ " \n",
175
+ " def training_step(self, batch, batch_idx):\n",
176
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
177
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
178
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
179
+ " #take just the first 200 embeddings\n",
180
+ " embeddings = embeddings[:, :200]\n",
181
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
182
+ " voxels = voxels[:, 0:2, :]\n",
183
+ " voxels = voxels.mean(dim=1)\n",
184
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
185
+ " outputs = self(voxels)\n",
186
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
187
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
188
+ " loss = self.loss(outputs, embeddings)\n",
189
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
190
+ " self.log('train_loss', loss, sync_dist=True)\n",
191
+ " self.log('train_accuracy', acuracy, sync_dist=True)\n",
192
+ " discrete_outputs = outputs.argmax(dim=1)\n",
193
+ " self.train_outptus.append(discrete_outputs)\n",
194
+ " return loss\n",
195
+ " \n",
196
+ " def tokens_accuracy(self, outputs, embeddings):\n",
197
+ " # outputs is [batch_size, 1024, 200]\n",
198
+ " # embeddings is [batch_size, 200]\n",
199
+ " # we need to get the index of the maximum value of each token\n",
200
+ " outputs = outputs.argmax(dim=1)\n",
201
+ " # now we need to compare the outputs with the embeddings\n",
202
+ " return (outputs == embeddings).float().mean()\n",
203
+ " \n",
204
+ " def on_train_epoch_end(self):\n",
205
+ " self.train_outptus = torch.cat(self.train_outptus)\n",
206
+ " # save the outputs with the current epoch name\n",
207
+ " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n",
208
+ " self.train_outptus = []\n",
209
+ " \n",
210
+ " def on_validation_epoch_end(self):\n",
211
+ " self.test_outptus = torch.cat(self.test_outptus)\n",
212
+ " # save the outputs with the current epoch name\n",
213
+ " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n",
214
+ " self.test_outptus = []\n",
215
+ "\n",
216
+ " \n",
217
+ " def validation_step(self, batch, batch_idx):\n",
218
+ " voxels, embeddings = batch\n",
219
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
220
+ " embeddings = embeddings[:, :200]\n",
221
+ " voxels = voxels[:, 0:2, :]\n",
222
+ " voxels = voxels.mean(dim=1)\n",
223
+ " voxels = voxels.flatten(start_dim=1)\n",
224
+ " outputs = self(voxels)\n",
225
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
226
+ " loss = self.loss(outputs, embeddings)\n",
227
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
228
+ " self.log('val_loss', loss, sync_dist=True)\n",
229
+ " self.log('val_accuracy', accuracy, sync_dist=True)\n",
230
+ " discrete_outputs = outputs.argmax(dim=1)\n",
231
+ " self.test_outptus.append(discrete_outputs)\n",
232
+ " return loss\n",
233
+ " \n",
234
+ " \n",
235
+ " def configure_optimizers(self):\n",
236
+ " return torch.optim.Adam(self.parameters(), lr=1e-6)\n",
237
+ " \n",
238
+ "\n",
239
+ "# create the model\n",
240
+ "sizes = [60784, 500, 500, 150*1024]\n",
241
+ "residual_conections = [[0], [1], [2], [3]]\n",
242
+ "dropout = [0.3, 0.3, 0.3, 0.3]\n",
243
+ "model = MLP(sizes, residual_conections, dropout)\n",
244
+ "\n",
245
+ "# create the data module\n",
246
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n",
247
+ "\n",
248
+ "wandb.finish()\n",
249
+ "\n",
250
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
251
+ "\n",
252
+ "# define the trainer\n",
253
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
254
+ "\n",
255
+ "# train the model\n",
256
+ "trainer.fit(model, datamodule=data_module)\n"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "class MLP(L.LightningModule):\n",
266
+ " def __init__(self, sizes, residual_conections, dropout):\n",
267
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
268
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
269
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
270
+ " super().__init__()\n",
271
+ " self.sizes = sizes\n",
272
+ " self.residual_conections = residual_conections\n",
273
+ " self.dropout = dropout\n",
274
+ " self.layers = nn.Sequential()\n",
275
+ " for i in range(len(sizes)-1):\n",
276
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
277
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
278
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
279
+ "\n",
280
+ " self.loss = nn.MSELoss()\n",
281
+ " self.test_outptus = []\n",
282
+ " self.train_outptus = []\n",
283
+ "\n",
284
+ " def forward(self, x):\n",
285
+ " return self.layers(x)\n",
286
+ " \n",
287
+ " def training_step(self, batch, batch_idx):\n",
288
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
289
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
290
+ " embeddings = embeddings.flatten(start_dim=1) # the size is [batch_size, 2250]\n",
291
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
292
+ " voxels = voxels.mean(dim=1)\n",
293
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
294
+ " outputs = self(voxels)\n",
295
+ " loss = self.loss(outputs, embeddings)\n",
296
+ " self.log('train_loss', loss)\n",
297
+ " discrete_outputs = outputs.argmax(dim=1)\n",
298
+ " self.train_outptus.append(discrete_outputs)\n",
299
+ " return loss\n",
300
+ " \n",
301
+ " def on_train_epoch_end(self):\n",
302
+ " self.train_outptus = torch.cat(self.train_outptus)\n",
303
+ " # save the outputs with the current epoch name\n",
304
+ " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n",
305
+ " self.train_outptus = []\n",
306
+ " \n",
307
+ " def on_validation_epoch_end(self):\n",
308
+ " self.test_outptus = torch.cat(self.test_outptus)\n",
309
+ " # save the outputs with the current epoch name\n",
310
+ " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n",
311
+ " self.test_outptus = []\n",
312
+ "\n",
313
+ " def validation_step(self, batch, batch_idx):\n",
314
+ " voxels, embeddings = batch\n",
315
+ " embeddings = embeddings.flatten(start_dim=1)\n",
316
+ " voxels = voxels.mean(dim=1)\n",
317
+ " voxels = voxels.flatten(start_dim=1)\n",
318
+ " outputs = self(voxels)\n",
319
+ " loss = self.loss(outputs, embeddings)\n",
320
+ " self.log('val_loss', loss)\n",
321
+ " discrete_outputs = outputs.argmax(dim=1)\n",
322
+ " self.test_outptus.append(discrete_outputs)\n",
323
+ " return loss\n",
324
+ " \n",
325
+ " \n",
326
+ " def configure_optimizers(self):\n",
327
+ " return torch.optim.Adam(self.parameters(), lr=1e-5)\n",
328
+ " \n",
329
+ "\n",
330
+ "# create the model\n",
331
+ "sizes = [60784, 1000, 1000, 150*2*1024]\n",
332
+ "residual_conections = [[0], [1], [2], [3]]\n",
333
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
334
+ "model = MLP(sizes, residual_conections, dropout)\n",
335
+ "\n",
336
+ "# create the data module\n",
337
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=32)\n",
338
+ "\n",
339
+ "\n",
340
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
341
+ "\n",
342
+ "# define the trainer\n",
343
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
344
+ "\n",
345
+ "# train the model\n",
346
+ "trainer.fit(model, datamodule=data_module)\n"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": null,
352
+ "metadata": {},
353
+ "outputs": [],
354
+ "source": [
355
+ "class MLP(L.LightningModule):\n",
356
+ " def __init__(self, sizes, residual_conections, dropout):\n",
357
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
358
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
359
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
360
+ " super().__init__()\n",
361
+ " self.sizes = sizes\n",
362
+ " self.residual_conections = residual_conections\n",
363
+ " self.dropout = dropout\n",
364
+ " self.layers = nn.Sequential()\n",
365
+ " for i in range(len(sizes)-1):\n",
366
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
367
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
368
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
369
+ "\n",
370
+ " self.loss = nn.CrossEntropyLoss()\n",
371
+ "\n",
372
+ " def forward(self, x):\n",
373
+ " return self.layers(x)\n",
374
+ " \n",
375
+ " def training_step(self, batch, batch_idx):\n",
376
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
377
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
378
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
379
+ " #take just the first 200 embeddings\n",
380
+ " embeddings = embeddings[:, :200]\n",
381
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
382
+ " voxels = voxels.mean(dim=1)\n",
383
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
384
+ " outputs = self(voxels)\n",
385
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
386
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
387
+ " loss = self.loss(outputs, embeddings)\n",
388
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
389
+ " self.log('train_loss', loss)\n",
390
+ " self.log('train_accuracy', acuracy)\n",
391
+ " return loss\n",
392
+ " \n",
393
+ " def tokens_accuracy(self, outputs, embeddings):\n",
394
+ " # outputs is [batch_size, 1024, 200]\n",
395
+ " # embeddings is [batch_size, 200]\n",
396
+ " # we need to get the index of the maximum value of each token\n",
397
+ " outputs = outputs.argmax(dim=1)\n",
398
+ " # now we need to compare the outputs with the embeddings\n",
399
+ " return (outputs == embeddings).float().mean()\n",
400
+ "\n",
401
+ " \n",
402
+ " def validation_step(self, batch, batch_idx):\n",
403
+ " voxels, embeddings = batch\n",
404
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
405
+ " embeddings = embeddings[:, :200]\n",
406
+ " voxels = voxels.mean(dim=1)\n",
407
+ " voxels = voxels.flatten(start_dim=1)\n",
408
+ " outputs = self(voxels)\n",
409
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
410
+ " loss = self.loss(outputs, embeddings)\n",
411
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
412
+ " self.log('val_loss', loss)\n",
413
+ " self.log('val_accuracy', accuracy)\n",
414
+ " return loss\n",
415
+ " \n",
416
+ " \n",
417
+ " def configure_optimizers(self):\n",
418
+ " return torch.optim.Adam(self.parameters(), lr=1e-5)\n",
419
+ " \n",
420
+ "\n",
421
+ "# create the model\n",
422
+ "sizes = [60784, 1000, 1000, 200*1024]\n",
423
+ "residual_conections = [[0], [1], [2], [3]]\n",
424
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
425
+ "model = MLP(sizes, residual_conections, dropout)\n",
426
+ "\n",
427
+ "# create the data module\n",
428
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=2)\n",
429
+ "\n",
430
+ "wandb.finish()\n",
431
+ "\n",
432
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
433
+ "\n",
434
+ "# define the trainer\n",
435
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
436
+ "\n",
437
+ "# train the model\n",
438
+ "trainer.fit(model, datamodule=data_module)\n"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "class MLP(L.LightningModule):\n",
448
+ " def __init__(self, sizes, residual_conections, dropout):\n",
449
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
450
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
451
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
452
+ " super().__init__()\n",
453
+ " self.sizes = sizes\n",
454
+ " self.residual_conections = residual_conections\n",
455
+ " self.dropout = dropout\n",
456
+ " self.layers = nn.Sequential()\n",
457
+ " for i in range(len(sizes)-1):\n",
458
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
459
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
460
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
461
+ "\n",
462
+ " self.loss = nn.CrossEntropyLoss()\n",
463
+ "\n",
464
+ " def forward(self, x):\n",
465
+ " return self.layers(x)\n",
466
+ " \n",
467
+ " def training_step(self, batch, batch_idx):\n",
468
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
469
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
470
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
471
+ " #take just the first 200 embeddings\n",
472
+ " embeddings = embeddings[:, :200]\n",
473
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
474
+ " voxels = voxels[:, 1, :]\n",
475
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
476
+ " outputs = self(voxels)\n",
477
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
478
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
479
+ " loss = self.loss(outputs, embeddings)\n",
480
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
481
+ " self.log('train_loss', loss)\n",
482
+ " self.log('train_accuracy', acuracy)\n",
483
+ " return loss\n",
484
+ " \n",
485
+ " def tokens_accuracy(self, outputs, embeddings):\n",
486
+ " # outputs is [batch_size, 1024, 200]\n",
487
+ " # embeddings is [batch_size, 200]\n",
488
+ " # we need to get the index of the maximum value of each token\n",
489
+ " outputs = outputs.argmax(dim=1)\n",
490
+ " # now we need to compare the outputs with the embeddings\n",
491
+ " return (outputs == embeddings).float().mean()\n",
492
+ "\n",
493
+ " \n",
494
+ " def validation_step(self, batch, batch_idx):\n",
495
+ " voxels, embeddings = batch\n",
496
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
497
+ " embeddings = embeddings[:, :200]\n",
498
+ " voxels = voxels[:, 1, :]\n",
499
+ " voxels = voxels.flatten(start_dim=1)\n",
500
+ " outputs = self(voxels)\n",
501
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
502
+ " loss = self.loss(outputs, embeddings)\n",
503
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
504
+ " self.log('val_loss', loss)\n",
505
+ " self.log('val_accuracy', accuracy)\n",
506
+ " return loss\n",
507
+ " \n",
508
+ " \n",
509
+ " def configure_optimizers(self):\n",
510
+ " return torch.optim.Adam(self.parameters(), lr=1e-6)\n",
511
+ " \n",
512
+ "\n",
513
+ "# create the model\n",
514
+ "sizes = [60784, 1000, 1000, 200*1024]\n",
515
+ "residual_conections = [[0], [1], [2], [3]]\n",
516
+ "dropout = [0.2, 0.2, 0.2, 0.2]\n",
517
+ "model = MLP(sizes, residual_conections, dropout)\n",
518
+ "\n",
519
+ "# create the data module\n",
520
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n",
521
+ "\n",
522
+ "wandb.finish()\n",
523
+ "\n",
524
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
525
+ "\n",
526
+ "# define the trainer\n",
527
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
528
+ "\n",
529
+ "# train the model\n",
530
+ "trainer.fit(model, datamodule=data_module)\n"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": null,
536
+ "metadata": {},
537
+ "outputs": [],
538
+ "source": [
539
+ "model3.eval()\n",
540
+ "outputs = torch.Tensor((480,200))\n",
541
+ "with torch.no_grad():\n",
542
+ " test_dataset = VoxelsDataset(test_voxels_path, test_embeddings_path)\n",
543
+ " dataloader = data.DataLoader(test_dataset, batch_size = 2)\n",
544
+ " for i, (voxels, embeddings) in enumerate(dataloader):\n",
545
+ " voxels = voxels[:, 1, :]\n",
546
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
547
+ " bout = model3(voxels)\n",
548
+ " bout = bout.reshape(-1, 1024, 200)\n",
549
+ " # the 1024 dimension is the number of tokens, we need to get the index of the maximum value of each token\n",
550
+ " bout = bout.argmax(dim=1)\n",
551
+ " # now we need to add the outputs to the outputs tensor\n",
552
+ " outputs[i*2:(i+1)*2] = bout\n",
553
+ " \n",
554
+ " \n",
555
+ "# save the predicted outputs on the current directory\n",
556
+ "torch.save(outputs, 'outputs.pt')"
557
+ ]
558
+ }
559
+ ],
560
+ "metadata": {
561
+ "kernelspec": {
562
+ "display_name": "b2m",
563
+ "language": "python",
564
+ "name": "python3"
565
+ },
566
+ "language_info": {
567
+ "codemirror_mode": {
568
+ "name": "ipython",
569
+ "version": 3
570
+ },
571
+ "file_extension": ".py",
572
+ "mimetype": "text/x-python",
573
+ "name": "python",
574
+ "nbconvert_exporter": "python",
575
+ "pygments_lexer": "ipython3",
576
+ "version": "3.11.4"
577
+ },
578
+ "orig_nbformat": 4
579
+ },
580
+ "nbformat": 4,
581
+ "nbformat_minor": 2
582
+ }
src/.ipynb_checkpoints/MLP-model-checkpoint.ipynb ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv\n",
10
+ "import lightning as L\n",
11
+ "import numpy as np, pandas as pd, matplotlib.pyplot as plt\n",
12
+ "from pytorch_lightning.loggers import WandbLogger"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "# create the datasets and dataloaders\n",
22
+ "train_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 \n",
23
+ "test_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600\n",
24
+ "\n",
25
+ "train_embeddings_path = '/home/ckadirt/brain2music/dataset/Gtanz/audios/sub-001/encodec_embeddings_train.pt' # path to training embeddings 480 * 2 * 1125\n",
26
+ "test_embeddings_path = '/home/ckadirt/brain2music/dataset/Gtanz/audios/sub-001/encodec_embeddings_test.pt' # path to test embeddings 600 * 2 * 1125\n",
27
+ "\n",
28
+ "class VoxelsDataset(data.Dataset):\n",
29
+ " def __init__(self, voxels_path, embeddings_path):\n",
30
+ " # transpose the two dimensions of the voxels data to match the embeddings data\n",
31
+ " self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)\n",
32
+ " self.embeddings = torch.load(embeddings_path)\n",
33
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
34
+ " self.len = len(self.voxels) // 10\n",
35
+ "\n",
36
+ " def __getitem__(self, index):\n",
37
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
38
+ " voxels = self.voxels[index*10:(index+1)*10]\n",
39
+ " embeddings = self.embeddings[index]\n",
40
+ " return voxels, embeddings\n",
41
+ "\n",
42
+ " def __len__(self):\n",
43
+ " return self.len\n",
44
+ " \n",
45
+ "class VoxelsEmbeddinsEncodecDataModule(L.LightningDataModule):\n",
46
+ " def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=32):\n",
47
+ " super().__init__()\n",
48
+ " self.train_voxels_path = train_voxels_path\n",
49
+ " self.train_embeddings_path = train_embeddings_path\n",
50
+ " self.test_voxels_path = test_voxels_path\n",
51
+ " self.test_embeddings_path = test_embeddings_path\n",
52
+ " self.batch_size = batch_size\n",
53
+ "\n",
54
+ " def setup(self, stage=None):\n",
55
+ " self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)\n",
56
+ " self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)\n",
57
+ "\n",
58
+ " def train_dataloader(self):\n",
59
+ " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)\n",
60
+ "\n",
61
+ " def test_dataloader(self):\n",
62
+ " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)\n"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "class MLP(L.LightningModule):\n",
72
+ " def __init__(self, sizes, residual_conections, dropout):\n",
73
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
74
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
75
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
76
+ " super().__init__()\n",
77
+ " self.sizes = sizes\n",
78
+ " self.residual_conections = residual_conections\n",
79
+ " self.dropout = dropout\n",
80
+ " self.layers = nn.ModuleList()\n",
81
+ " for i in range(len(sizes)-1):\n",
82
+ " self.layers.append(nn.Linear(sizes[i], sizes[i+1]))\n",
83
+ " self.relu = nn.ReLU()\n",
84
+ " self.loss = nn.MSELoss()\n",
85
+ "\n",
86
+ " def forward(self, x):\n",
87
+ " x_states = [x]\n",
88
+ " for i in range(len(self.layers)):\n",
89
+ " x = self.layers[i](x)\n",
90
+ " for j in self.residual_conections[i]:\n",
91
+ " x = x + x_states[j]\n",
92
+ " x = self.relu(x)\n",
93
+ " x = nn.Dropout(self.dropout[i])(x)\n",
94
+ " x_states.append(x)\n",
95
+ "\n",
96
+ " return x\n",
97
+ " \n",
98
+ " def training_step(self, batch, batch_idx):\n",
99
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
100
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
101
+ " embeddings = embeddings.flatten(start_dim=1) # the size is [batch_size, 2250]\n",
102
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
103
+ " voxels = voxels.mean(dim=1)\n",
104
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
105
+ " outputs = self(voxels)\n",
106
+ " loss = self.loss(outputs, embeddings)\n",
107
+ " self.log('train_loss', loss)\n",
108
+ " return loss\n",
109
+ " \n",
110
+ " def validation_step(self, batch, batch_idx):\n",
111
+ " voxels, embeddings = batch\n",
112
+ " embeddings = embeddings.flatten(start_dim=1)\n",
113
+ " voxels = voxels.mean(dim=1)\n",
114
+ " voxels = voxels.flatten(start_dim=1)\n",
115
+ " outputs = self(voxels)\n",
116
+ " loss = self.loss(outputs, embeddings)\n",
117
+ " self.log('val_loss', loss)\n",
118
+ " return loss\n",
119
+ " \n",
120
+ " \n",
121
+ " def configure_optimizers(self):\n",
122
+ " return torch.optim.Adam(self.parameters(), lr=1e-3)\n",
123
+ " \n",
124
+ "\n",
125
+ "# create the model\n",
126
+ "sizes = [60784, 1000, 1000, 2250]\n",
127
+ "residual_conections = [[0], [1], [2,1], [3]]\n",
128
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
129
+ "model = MLP(sizes, residual_conections, dropout)\n",
130
+ "\n",
131
+ "# create the data module\n",
132
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=32)\n",
133
+ "\n",
134
+ "\n",
135
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
136
+ "\n",
137
+ "# define the trainer\n",
138
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=100, logger=wandb_logger, precision='16-mixed')\n",
139
+ "\n",
140
+ "# train the model\n",
141
+ "trainer.fit(model, data_module)\n"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "class MLP(L.LightningModule):\n",
151
+ " def __init__(self, sizes, residual_conections, dropout):\n",
152
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
153
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
154
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
155
+ " super().__init__()\n",
156
+ " self.sizes = sizes\n",
157
+ " self.residual_conections = residual_conections\n",
158
+ " self.dropout = dropout\n",
159
+ " self.layers = nn.Sequential()\n",
160
+ " for i in range(len(sizes)-1):\n",
161
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
162
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
163
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
164
+ "\n",
165
+ " self.loss = nn.MSELoss()\n",
166
+ "\n",
167
+ " def forward(self, x):\n",
168
+ " return self.layers(x)\n",
169
+ " \n",
170
+ " def training_step(self, batch, batch_idx):\n",
171
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
172
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
173
+ " embeddings = embeddings.flatten(start_dim=1) # the size is [batch_size, 2250]\n",
174
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
175
+ " voxels = voxels.mean(dim=1)\n",
176
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
177
+ " outputs = self(voxels)\n",
178
+ " loss = self.loss(outputs, embeddings)\n",
179
+ " self.log('train_loss', loss)\n",
180
+ " return loss\n",
181
+ " \n",
182
+ " def validation_step(self, batch, batch_idx):\n",
183
+ " voxels, embeddings = batch\n",
184
+ " embeddings = embeddings.flatten(start_dim=1)\n",
185
+ " voxels = voxels.mean(dim=1)\n",
186
+ " voxels = voxels.flatten(start_dim=1)\n",
187
+ " outputs = self(voxels)\n",
188
+ " loss = self.loss(outputs, embeddings)\n",
189
+ " self.log('val_loss', loss)\n",
190
+ " return loss\n",
191
+ " \n",
192
+ " \n",
193
+ " def configure_optimizers(self):\n",
194
+ " return torch.optim.Adam(self.parameters(), lr=1e-5)\n",
195
+ " \n",
196
+ "\n",
197
+ "# create the model\n",
198
+ "sizes = [60784, 1000, 1000, 2250]\n",
199
+ "residual_conections = [[0], [1], [2], [3]]\n",
200
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
201
+ "model = MLP(sizes, residual_conections, dropout)\n",
202
+ "\n",
203
+ "# create the data module\n",
204
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=32)\n",
205
+ "\n",
206
+ "\n",
207
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
208
+ "\n",
209
+ "# define the trainer\n",
210
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
211
+ "\n",
212
+ "# train the model\n",
213
+ "trainer.fit(model, datamodule=data_module)\n"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "class MLP(L.LightningModule):\n",
223
+ " def __init__(self, sizes, residual_conections, dropout):\n",
224
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
225
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
226
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
227
+ " super().__init__()\n",
228
+ " self.sizes = sizes\n",
229
+ " self.residual_conections = residual_conections\n",
230
+ " self.dropout = dropout\n",
231
+ " self.layers = nn.Sequential()\n",
232
+ " for i in range(len(sizes)-1):\n",
233
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
234
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
235
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
236
+ "\n",
237
+ " self.loss = nn.CrossEntropyLoss()\n",
238
+ "\n",
239
+ " def forward(self, x):\n",
240
+ " return self.layers(x)\n",
241
+ " \n",
242
+ " def training_step(self, batch, batch_idx):\n",
243
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
244
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
245
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
246
+ " #take just the first 200 embeddings\n",
247
+ " embeddings = embeddings[:, :200]\n",
248
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
249
+ " voxels = voxels.mean(dim=1)\n",
250
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
251
+ " outputs = self(voxels)\n",
252
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
253
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
254
+ " loss = self.loss(outputs, embeddings)\n",
255
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
256
+ " self.log('train_loss', loss)\n",
257
+ " self.log('train_accuracy', acuracy)\n",
258
+ " return loss\n",
259
+ " \n",
260
+ " def tokens_accuracy(self, outputs, embeddings):\n",
261
+ " # outputs is [batch_size, 1024, 200]\n",
262
+ " # embeddings is [batch_size, 200]\n",
263
+ " # we need to get the index of the maximum value of each token\n",
264
+ " outputs = outputs.argmax(dim=1)\n",
265
+ " # now we need to compare the outputs with the embeddings\n",
266
+ " return (outputs == embeddings).float().mean()\n",
267
+ "\n",
268
+ " \n",
269
+ " def validation_step(self, batch, batch_idx):\n",
270
+ " voxels, embeddings = batch\n",
271
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
272
+ " embeddings = embeddings[:, :200]\n",
273
+ " voxels = voxels.mean(dim=1)\n",
274
+ " voxels = voxels.flatten(start_dim=1)\n",
275
+ " outputs = self(voxels)\n",
276
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
277
+ " loss = self.loss(outputs, embeddings)\n",
278
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
279
+ " self.log('val_loss', loss)\n",
280
+ " self.log('val_accuracy', accuracy)\n",
281
+ " return loss\n",
282
+ " \n",
283
+ " \n",
284
+ " def configure_optimizers(self):\n",
285
+ " return torch.optim.Adam(self.parameters(), lr=1e-5)\n",
286
+ " \n",
287
+ "\n",
288
+ "# create the model\n",
289
+ "sizes = [60784, 1000, 1000, 200*1024]\n",
290
+ "residual_conections = [[0], [1], [2], [3]]\n",
291
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
292
+ "model = MLP(sizes, residual_conections, dropout)\n",
293
+ "\n",
294
+ "# create the data module\n",
295
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=2)\n",
296
+ "\n",
297
+ "wandb.finish()\n",
298
+ "\n",
299
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
300
+ "\n",
301
+ "# define the trainer\n",
302
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
303
+ "\n",
304
+ "# train the model\n",
305
+ "trainer.fit(model, datamodule=data_module)\n"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": null,
311
+ "metadata": {},
312
+ "outputs": [],
313
+ "source": [
314
+ "class MLP(L.LightningModule):\n",
315
+ " def __init__(self, sizes, residual_conections, dropout):\n",
316
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
317
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
318
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
319
+ " super().__init__()\n",
320
+ " self.sizes = sizes\n",
321
+ " self.residual_conections = residual_conections\n",
322
+ " self.dropout = dropout\n",
323
+ " self.layers = nn.Sequential()\n",
324
+ " for i in range(len(sizes)-1):\n",
325
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
326
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
327
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
328
+ "\n",
329
+ " self.loss = nn.CrossEntropyLoss()\n",
330
+ "\n",
331
+ " def forward(self, x):\n",
332
+ " return self.layers(x)\n",
333
+ " \n",
334
+ " def training_step(self, batch, batch_idx):\n",
335
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
336
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
337
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
338
+ " #take just the first 200 embeddings\n",
339
+ " embeddings = embeddings[:, :200]\n",
340
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
341
+ " voxels = voxels[:, 1, :]\n",
342
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
343
+ " outputs = self(voxels)\n",
344
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
345
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
346
+ " loss = self.loss(outputs, embeddings)\n",
347
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
348
+ " self.log('train_loss', loss)\n",
349
+ " self.log('train_accuracy', acuracy)\n",
350
+ " return loss\n",
351
+ " \n",
352
+ " def tokens_accuracy(self, outputs, embeddings):\n",
353
+ " # outputs is [batch_size, 1024, 200]\n",
354
+ " # embeddings is [batch_size, 200]\n",
355
+ " # we need to get the index of the maximum value of each token\n",
356
+ " outputs = outputs.argmax(dim=1)\n",
357
+ " # now we need to compare the outputs with the embeddings\n",
358
+ " return (outputs == embeddings).float().mean()\n",
359
+ "\n",
360
+ " \n",
361
+ " def validation_step(self, batch, batch_idx):\n",
362
+ " voxels, embeddings = batch\n",
363
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
364
+ " embeddings = embeddings[:, :200]\n",
365
+ " voxels = voxels[:, 1, :]\n",
366
+ " voxels = voxels.flatten(start_dim=1)\n",
367
+ " outputs = self(voxels)\n",
368
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
369
+ " loss = self.loss(outputs, embeddings)\n",
370
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
371
+ " self.log('val_loss', loss)\n",
372
+ " self.log('val_accuracy', accuracy)\n",
373
+ " return loss\n",
374
+ " \n",
375
+ " \n",
376
+ " def configure_optimizers(self):\n",
377
+ " return torch.optim.Adam(self.parameters(), lr=1e-6)\n",
378
+ " \n",
379
+ "\n",
380
+ "# create the model\n",
381
+ "sizes = [60784, 1000, 1000, 200*1024]\n",
382
+ "residual_conections = [[0], [1], [2], [3]]\n",
383
+ "dropout = [0.2, 0.2, 0.2, 0.2]\n",
384
+ "model = MLP(sizes, residual_conections, dropout)\n",
385
+ "\n",
386
+ "# create the data module\n",
387
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n",
388
+ "\n",
389
+ "wandb.finish()\n",
390
+ "\n",
391
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
392
+ "\n",
393
+ "# define the trainer\n",
394
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
395
+ "\n",
396
+ "# train the model\n",
397
+ "trainer.fit(model, datamodule=data_module)\n"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "model3.eval()\n",
407
+ "outputs = torch.Tensor((480,200))\n",
408
+ "with torch.no_grad():\n",
409
+ " test_dataset = VoxelsDataset(test_voxels_path, test_embeddings_path)\n",
410
+ " dataloader = data.DataLoader(test_dataset, batch_size = 2)\n",
411
+ " for i, (voxels, embeddings) in enumerate(dataloader):\n",
412
+ " voxels = voxels[:, 1, :]\n",
413
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
414
+ " bout = model3(voxels)\n",
415
+ " bout = bout.reshape(-1, 1024, 200)\n",
416
+ " # the 1024 dimension is the number of tokens, we need to get the index of the maximum value of each token\n",
417
+ " bout = bout.argmax(dim=1)\n",
418
+ " # now we need to add the outputs to the outputs tensor\n",
419
+ " outputs[i*2:(i+1)*2] = bout\n",
420
+ " \n",
421
+ " \n",
422
+ "# save the predicted outputs on the current directory\n",
423
+ "torch.save(outputs, 'outputs.pt')"
424
+ ]
425
+ }
426
+ ],
427
+ "metadata": {
428
+ "kernelspec": {
429
+ "display_name": "b2m",
430
+ "language": "python",
431
+ "name": "python3"
432
+ },
433
+ "language_info": {
434
+ "codemirror_mode": {
435
+ "name": "ipython",
436
+ "version": 3
437
+ },
438
+ "file_extension": ".py",
439
+ "mimetype": "text/x-python",
440
+ "name": "python",
441
+ "nbconvert_exporter": "python",
442
+ "pygments_lexer": "ipython3",
443
+ "version": "3.11.4"
444
+ },
445
+ "orig_nbformat": 4
446
+ },
447
+ "nbformat": 4,
448
+ "nbformat_minor": 2
449
+ }
src/.ipynb_checkpoints/MLPencoder-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/.ipynb_checkpoints/mlpdummy-checkpoint.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv
2
+ import lightning as L
3
+ import numpy as np, pandas as pd, matplotlib.pyplot as plt
4
+ from pytorch_lightning.loggers import WandbLogger
5
+ import wandb
6
+ import pytorch_lightning as pl
7
+
8
+ torch.set_float32_matmul_precision('medium')
9
+
10
+ # create the datasets and dataloaders
11
+ train_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800
12
+ test_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600
13
+
14
+ train_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_training_embeds_sorted.npy' # path to training embeddings 480 * 2 * 1125
15
+ test_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_testing_embeds_sorted.npy' # path to test embeddings 600 * 2 * 1125
16
+
17
+ class VoxelsDataset(data.Dataset):
18
+ def __init__(self, voxels_path, embeddings_path):
19
+ # transpose the two dimensions of the voxels data to match the embeddings data
20
+ self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)
21
+ self.embeddings = torch.from_numpy(np.load(embeddings_path))
22
+ # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus
23
+ self.len = len(self.voxels) // 10
24
+ print("The len is ", self.len )
25
+
26
+ def __getitem__(self, index):
27
+ # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus
28
+ voxels = self.voxels[index*10:(index+1)*10]
29
+ embeddings = self.embeddings[index]
30
+ return voxels, embeddings
31
+
32
+ def __len__(self):
33
+ return self.len
34
+
35
+ class VoxelsEmbeddinsEncodecDataModule(pl.LightningDataModule):
36
+ def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=8):
37
+ super().__init__()
38
+ self.train_voxels_path = train_voxels_path
39
+ self.train_embeddings_path = train_embeddings_path
40
+ self.test_voxels_path = test_voxels_path
41
+ self.test_embeddings_path = test_embeddings_path
42
+ self.batch_size = batch_size
43
+
44
+ def setup(self, stage=None):
45
+ self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)
46
+ self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)
47
+
48
+ def train_dataloader(self):
49
+ return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
50
+
51
+ def val_dataloader(self):
52
+ return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
53
+
54
+
55
+ class MLP(pl.LightningModule):
56
+ def __init__(self, sizes, residual_conections, dropout):
57
+ # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]
58
+ # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]
59
+ # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
60
+ super().__init__()
61
+ self.sizes = sizes
62
+ self.residual_conections = residual_conections
63
+ self.dropout = dropout
64
+ self.layers = nn.Sequential()
65
+ for i in range(len(sizes)-1):
66
+ self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))
67
+ self.layers.add_module('relu'+str(i), nn.ReLU())
68
+ self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))
69
+
70
+ self.loss = nn.CrossEntropyLoss(reduction='mean')
71
+
72
+ def forward(self, x):
73
+ return self.layers(x)
74
+
75
+ def training_step(self, batch, batch_idx):
76
+ voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]
77
+ # flatten the voxels to [batch_size, rest of the dimensions]
78
+ embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250]
79
+ #take just the first 200 embeddings
80
+ # embeddings = embeddings[:, :200]
81
+ # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus
82
+ voxels = voxels.mean(dim=1)
83
+ voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]
84
+ outputs = self(voxels)
85
+ # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]
86
+ outputs = outputs.reshape(-1, 1024, 1125*2)
87
+ # avoid division by zero
88
+ outputs = outputs + 1e-6
89
+ #print(outputs.shape, embeddings.shape)
90
+ #print(outputs[0,0,:10], embeddings[0,:10])
91
+ loss = self.loss(outputs, embeddings)
92
+ #print(loss)
93
+ acuracy = self.tokens_accuracy(outputs, embeddings)
94
+ self.log('train_loss', loss)
95
+ self.log('train_accuracy', acuracy)
96
+ return loss
97
+
98
+ def tokens_accuracy(self, outputs, embeddings):
99
+ # outputs is [batch_size, 1024, 200]
100
+ # embeddings is [batch_size, 200]
101
+ # we need to get the index of the maximum value of each token
102
+ outputs = outputs.argmax(dim=1)
103
+ # now we need to compare the outputs with the embeddings
104
+ return (outputs == embeddings).float().mean()
105
+
106
+
107
+ def validation_step(self, batch, batch_idx):
108
+ voxels, embeddings = batch
109
+ embeddings = embeddings.flatten(start_dim=1).long()
110
+ #embeddings = embeddings[:, :200]
111
+ voxels = voxels.mean(dim=1)
112
+ voxels = voxels.flatten(start_dim=1)
113
+ outputs = self(voxels)
114
+ outputs = outputs.reshape(-1, 1024, 1125*2)
115
+ loss = self.loss(outputs, embeddings)
116
+ accuracy = self.tokens_accuracy(outputs, embeddings)
117
+ self.log('val_loss', loss)
118
+ self.log('val_accuracy', accuracy)
119
+ return loss
120
+
121
+
122
+ def configure_optimizers(self):
123
+ return torch.optim.Adam(self.trainer.model.parameters(), lr=2e-5, weight_decay=3e-3)
124
+
125
+
126
+ # create the model
127
+ sizes = [60784, 1000, 1000, 1125*2*1024]
128
+ residual_conections = [[0], [1], [2], [3]]
129
+ dropout = [0.5, 0.5, 0.5, 0.5]
130
+ model = MLP(sizes, residual_conections, dropout)
131
+
132
+
133
+ # create the data module
134
+ data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)
135
+
136
+ wandb.finish()
137
+ from pytorch_lightning.strategies import DeepSpeedStrategy
138
+
139
+ wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')
140
+
141
+ # define the trainer
142
+ trainer = pl.Trainer(accelerator="gpu", devices = [0,1,2,3,4,5,6,7], max_epochs=1000, logger=wandb_logger, precision='32', strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=8), enable_checkpointing=False, log_every_n_steps=10)
143
+ #trainer = pl.Trainer(accelerator="gpu", devices = [0,1,2,3], max_epochs=1000, logger=wandb_logger, precision='bf16', strategy='fsdp', enable_checkpointing=False, log_every_n_steps=10)
144
+ # train the model
145
+ trainer.fit(model, datamodule=data_module)
146
+
src/.ipynb_checkpoints/musicgen_test copy-checkpoint.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Copy_of_MusicGen.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/MLP-model copy.ipynb ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/ckadirt/miniconda3/envs/b2m/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv\n",
19
+ "import lightning as L\n",
20
+ "import numpy as np, pandas as pd, matplotlib.pyplot as plt\n",
21
+ "from pytorch_lightning.loggers import WandbLogger"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 5,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "# create the datasets and dataloaders\n",
31
+ "train_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 \n",
32
+ "test_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600\n",
33
+ "\n",
34
+ "train_embeddings_path = '/home/ckadirt/brain2music/encodec_training_embeds_150.npy' # path to training embeddings 480 * 2 * 1125\n",
35
+ "test_embeddings_path = '/home/ckadirt/brain2music/encodec_test_embeds_150.npy' # path to test embeddings 600 * 2 * 1125\n",
36
+ "\n",
37
+ "class VoxelsDataset(data.Dataset):\n",
38
+ " def __init__(self, voxels_path, embeddings_path):\n",
39
+ " # transpose the two dimensions of the voxels data to match the embeddings data\n",
40
+ " self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)\n",
41
+ " self.embeddings = torch.from_numpy(np.load(embeddings_path))\n",
42
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
43
+ " self.len = len(self.voxels) // 10\n",
44
+ " print(\"The len is \", self.len )\n",
45
+ "\n",
46
+ " def __getitem__(self, index):\n",
47
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
48
+ " voxels = self.voxels[index*10:(index+1)*10]\n",
49
+ " embeddings = self.embeddings[index]\n",
50
+ " return voxels, embeddings\n",
51
+ "\n",
52
+ " def __len__(self):\n",
53
+ " return self.len\n",
54
+ " \n",
55
+ "class VoxelsEmbeddinsEncodecDataModule(L.LightningDataModule):\n",
56
+ " def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4):\n",
57
+ " super().__init__()\n",
58
+ " self.train_voxels_path = train_voxels_path\n",
59
+ " self.train_embeddings_path = train_embeddings_path\n",
60
+ " self.test_voxels_path = test_voxels_path\n",
61
+ " self.test_embeddings_path = test_embeddings_path\n",
62
+ " self.batch_size = batch_size\n",
63
+ "\n",
64
+ " def setup(self, stage=None):\n",
65
+ " self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)\n",
66
+ " self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)\n",
67
+ "\n",
68
+ " def train_dataloader(self):\n",
69
+ " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)\n",
70
+ "\n",
71
+ " def val_dataloader(self):\n",
72
+ " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)\n",
73
+ "\n",
74
+ "\n"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 33,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "data_module_example = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 34,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "data_module_example.setup()\n",
93
+ "train_dataloader = data_module_example.train_dataloader()\n",
94
+ "val_dataset = data_module_example.val_dataloader()"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 38,
100
+ "metadata": {},
101
+ "outputs": [
102
+ {
103
+ "data": {
104
+ "text/plain": [
105
+ "(tensor([], size=(0, 60784)),\n",
106
+ " tensor([[ 302., 244., 660., 854., 660., 480., 854., 618., 618., 854.,\n",
107
+ " 790., 750., 659., 59., 891., 891., 536., 167., 343., 536.,\n",
108
+ " 715., 758., 758., 758., 480., 498., 854., 4., 4., 308.,\n",
109
+ " 270., 342., 342., 660., 342., 854., 342., 435., 549., 150.,\n",
110
+ " 631., 485., 844., 366., 266., 35., 847., 667., 862., 109.,\n",
111
+ " 573., 379., 226., 573., 603., 513., 178., 302., 715., 631.,\n",
112
+ " 342., 258., 244., 302., 715., 854., 854., 294., 366., 660.,\n",
113
+ " 361., 302., 729., 962., 790., 711., 660., 243., 294., 802.,\n",
114
+ " 329., 513., 962., 342., 711., 244., 243., 549., 802., 854.,\n",
115
+ " 750., 81., 342., 381., 854., 603., 790., 109., 294., 513.,\n",
116
+ " 419., 485., 504., 660., 361., 790., 790., 167., 802., 246.,\n",
117
+ " 485., 246., 81., 1023., 149., 81., 943., 504., 755., 414.,\n",
118
+ " 246., 972., 715., 1023., 790., 692., 790., 572., 504., 302.,\n",
119
+ " 308., 853., 631., 657., 790., 361., 660., 715., 686., 213.,\n",
120
+ " 226., 187., 586., 361., 485., 790., 729., 951., 962., 485.],\n",
121
+ " [ 963., 645., 645., 326., 138., 1013., 680., 525., 411., 102.,\n",
122
+ " 462., 466., 698., 409., 289., 923., 878., 415., 386., 604.,\n",
123
+ " 975., 162., 603., 284., 233., 75., 244., 1016., 1016., 242.,\n",
124
+ " 67., 194., 122., 492., 856., 997., 997., 221., 243., 814.,\n",
125
+ " 386., 598., 317., 166., 583., 439., 654., 430., 201., 160.,\n",
126
+ " 813., 716., 312., 664., 204., 462., 375., 451., 67., 535.,\n",
127
+ " 854., 209., 548., 812., 657., 827., 408., 411., 422., 352.,\n",
128
+ " 99., 711., 664., 239., 890., 529., 617., 186., 536., 178.,\n",
129
+ " 29., 930., 187., 973., 354., 450., 468., 273., 995., 653.,\n",
130
+ " 935., 335., 973., 812., 348., 664., 575., 184., 299., 782.,\n",
131
+ " 36., 29., 641., 653., 105., 958., 653., 828., 981., 218.,\n",
132
+ " 1021., 381., 356., 35., 416., 675., 45., 839., 690., 331.,\n",
133
+ " 634., 610., 317., 745., 673., 331., 575., 57., 100., 564.,\n",
134
+ " 590., 492., 902., 53., 73., 332., 1005., 395., 679., 781.,\n",
135
+ " 174., 74., 121., 667., 265., 479., 583., 655., 163., 81.]]))"
136
+ ]
137
+ },
138
+ "execution_count": 38,
139
+ "metadata": {},
140
+ "output_type": "execute_result"
141
+ }
142
+ ],
143
+ "source": [
144
+ "val_dataset.dataset[239]"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "class MLP(L.LightningModule):\n",
154
+ " def __init__(self, sizes, residual_conections, dropout):\n",
155
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
156
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
157
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
158
+ " super().__init__()\n",
159
+ " self.sizes = sizes\n",
160
+ " self.residual_conections = residual_conections\n",
161
+ " self.dropout = dropout\n",
162
+ " self.layers = nn.Sequential()\n",
163
+ " for i in range(len(sizes)-1):\n",
164
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
165
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
166
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
167
+ "\n",
168
+ " self.loss = nn.CrossEntropyLoss()\n",
169
+ " self.test_outptus = []\n",
170
+ " self.train_outptus = []\n",
171
+ "\n",
172
+ " def forward(self, x):\n",
173
+ " return self.layers(x)\n",
174
+ " \n",
175
+ " def training_step(self, batch, batch_idx):\n",
176
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
177
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
178
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
179
+ " #take just the first 200 embeddings\n",
180
+ " embeddings = embeddings[:, :200]\n",
181
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
182
+ " voxels = voxels[:, 0:2, :]\n",
183
+ " voxels = voxels.mean(dim=1)\n",
184
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
185
+ " outputs = self(voxels)\n",
186
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
187
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
188
+ " loss = self.loss(outputs, embeddings)\n",
189
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
190
+ " self.log('train_loss', loss, sync_dist=True)\n",
191
+ " self.log('train_accuracy', acuracy, sync_dist=True)\n",
192
+ " discrete_outputs = outputs.argmax(dim=1)\n",
193
+ " self.train_outptus.append(discrete_outputs)\n",
194
+ " return loss\n",
195
+ " \n",
196
+ " def tokens_accuracy(self, outputs, embeddings):\n",
197
+ " # outputs is [batch_size, 1024, 200]\n",
198
+ " # embeddings is [batch_size, 200]\n",
199
+ " # we need to get the index of the maximum value of each token\n",
200
+ " outputs = outputs.argmax(dim=1)\n",
201
+ " # now we need to compare the outputs with the embeddings\n",
202
+ " return (outputs == embeddings).float().mean()\n",
203
+ " \n",
204
+ " def on_train_epoch_end(self):\n",
205
+ " self.train_outptus = torch.cat(self.train_outptus)\n",
206
+ " # save the outputs with the current epoch name\n",
207
+ " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n",
208
+ " self.train_outptus = []\n",
209
+ " \n",
210
+ " def on_validation_epoch_end(self):\n",
211
+ " self.test_outptus = torch.cat(self.test_outptus)\n",
212
+ " # save the outputs with the current epoch name\n",
213
+ " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n",
214
+ " self.test_outptus = []\n",
215
+ "\n",
216
+ " \n",
217
+ " def validation_step(self, batch, batch_idx):\n",
218
+ " voxels, embeddings = batch\n",
219
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
220
+ " embeddings = embeddings[:, :200]\n",
221
+ " voxels = voxels[:, 0:2, :]\n",
222
+ " voxels = voxels.mean(dim=1)\n",
223
+ " voxels = voxels.flatten(start_dim=1)\n",
224
+ " outputs = self(voxels)\n",
225
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
226
+ " loss = self.loss(outputs, embeddings)\n",
227
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
228
+ " self.log('val_loss', loss, sync_dist=True)\n",
229
+ " self.log('val_accuracy', accuracy, sync_dist=True)\n",
230
+ " discrete_outputs = outputs.argmax(dim=1)\n",
231
+ " self.test_outptus.append(discrete_outputs)\n",
232
+ " return loss\n",
233
+ " \n",
234
+ " \n",
235
+ " def configure_optimizers(self):\n",
236
+ " return torch.optim.Adam(self.parameters(), lr=1e-6)\n",
237
+ " \n",
238
+ "\n",
239
+ "# create the model\n",
240
+ "sizes = [60784, 500, 500, 150*1024]\n",
241
+ "residual_conections = [[0], [1], [2], [3]]\n",
242
+ "dropout = [0.3, 0.3, 0.3, 0.3]\n",
243
+ "model = MLP(sizes, residual_conections, dropout)\n",
244
+ "\n",
245
+ "# create the data module\n",
246
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n",
247
+ "\n",
248
+ "wandb.finish()\n",
249
+ "\n",
250
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
251
+ "\n",
252
+ "# define the trainer\n",
253
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
254
+ "\n",
255
+ "# train the model\n",
256
+ "trainer.fit(model, datamodule=data_module)\n"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "class MLP(L.LightningModule):\n",
266
+ " def __init__(self, sizes, residual_conections, dropout):\n",
267
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
268
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
269
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
270
+ " super().__init__()\n",
271
+ " self.sizes = sizes\n",
272
+ " self.residual_conections = residual_conections\n",
273
+ " self.dropout = dropout\n",
274
+ " self.layers = nn.Sequential()\n",
275
+ " for i in range(len(sizes)-1):\n",
276
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
277
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
278
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
279
+ "\n",
280
+ " self.loss = nn.MSELoss()\n",
281
+ " self.test_outptus = []\n",
282
+ " self.train_outptus = []\n",
283
+ "\n",
284
+ " def forward(self, x):\n",
285
+ " return self.layers(x)\n",
286
+ " \n",
287
+ " def training_step(self, batch, batch_idx):\n",
288
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
289
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
290
+ " embeddings = embeddings.flatten(start_dim=1) # the size is [batch_size, 2250]\n",
291
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
292
+ " voxels = voxels.mean(dim=1)\n",
293
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
294
+ " outputs = self(voxels)\n",
295
+ " loss = self.loss(outputs, embeddings)\n",
296
+ " self.log('train_loss', loss)\n",
297
+ " discrete_outputs = outputs.argmax(dim=1)\n",
298
+ " self.train_outptus.append(discrete_outputs)\n",
299
+ " return loss\n",
300
+ " \n",
301
+ " def on_train_epoch_end(self):\n",
302
+ " self.train_outptus = torch.cat(self.train_outptus)\n",
303
+ " # save the outputs with the current epoch name\n",
304
+ " torch.save(self.train_outptus, 'outputs_train'+str(self.current_epoch)+'.pt')\n",
305
+ " self.train_outptus = []\n",
306
+ " \n",
307
+ " def on_validation_epoch_end(self):\n",
308
+ " self.test_outptus = torch.cat(self.test_outptus)\n",
309
+ " # save the outputs with the current epoch name\n",
310
+ " torch.save(self.test_outptus, 'outputs_validation'+str(self.current_epoch)+'.pt')\n",
311
+ " self.test_outptus = []\n",
312
+ "\n",
313
+ " def validation_step(self, batch, batch_idx):\n",
314
+ " voxels, embeddings = batch\n",
315
+ " embeddings = embeddings.flatten(start_dim=1)\n",
316
+ " voxels = voxels.mean(dim=1)\n",
317
+ " voxels = voxels.flatten(start_dim=1)\n",
318
+ " outputs = self(voxels)\n",
319
+ " loss = self.loss(outputs, embeddings)\n",
320
+ " self.log('val_loss', loss)\n",
321
+ " discrete_outputs = outputs.argmax(dim=1)\n",
322
+ " self.test_outptus.append(discrete_outputs)\n",
323
+ " return loss\n",
324
+ " \n",
325
+ " \n",
326
+ " def configure_optimizers(self):\n",
327
+ " return torch.optim.Adam(self.parameters(), lr=1e-5)\n",
328
+ " \n",
329
+ "\n",
330
+ "# create the model\n",
331
+ "sizes = [60784, 1000, 1000, 150*2*1024]\n",
332
+ "residual_conections = [[0], [1], [2], [3]]\n",
333
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
334
+ "model = MLP(sizes, residual_conections, dropout)\n",
335
+ "\n",
336
+ "# create the data module\n",
337
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=32)\n",
338
+ "\n",
339
+ "\n",
340
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
341
+ "\n",
342
+ "# define the trainer\n",
343
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
344
+ "\n",
345
+ "# train the model\n",
346
+ "trainer.fit(model, datamodule=data_module)\n"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": null,
352
+ "metadata": {},
353
+ "outputs": [],
354
+ "source": [
355
+ "class MLP(L.LightningModule):\n",
356
+ " def __init__(self, sizes, residual_conections, dropout):\n",
357
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
358
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
359
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
360
+ " super().__init__()\n",
361
+ " self.sizes = sizes\n",
362
+ " self.residual_conections = residual_conections\n",
363
+ " self.dropout = dropout\n",
364
+ " self.layers = nn.Sequential()\n",
365
+ " for i in range(len(sizes)-1):\n",
366
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
367
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
368
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
369
+ "\n",
370
+ " self.loss = nn.CrossEntropyLoss()\n",
371
+ "\n",
372
+ " def forward(self, x):\n",
373
+ " return self.layers(x)\n",
374
+ " \n",
375
+ " def training_step(self, batch, batch_idx):\n",
376
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
377
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
378
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
379
+ " #take just the first 200 embeddings\n",
380
+ " embeddings = embeddings[:, :200]\n",
381
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
382
+ " voxels = voxels.mean(dim=1)\n",
383
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
384
+ " outputs = self(voxels)\n",
385
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
386
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
387
+ " loss = self.loss(outputs, embeddings)\n",
388
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
389
+ " self.log('train_loss', loss)\n",
390
+ " self.log('train_accuracy', acuracy)\n",
391
+ " return loss\n",
392
+ " \n",
393
+ " def tokens_accuracy(self, outputs, embeddings):\n",
394
+ " # outputs is [batch_size, 1024, 200]\n",
395
+ " # embeddings is [batch_size, 200]\n",
396
+ " # we need to get the index of the maximum value of each token\n",
397
+ " outputs = outputs.argmax(dim=1)\n",
398
+ " # now we need to compare the outputs with the embeddings\n",
399
+ " return (outputs == embeddings).float().mean()\n",
400
+ "\n",
401
+ " \n",
402
+ " def validation_step(self, batch, batch_idx):\n",
403
+ " voxels, embeddings = batch\n",
404
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
405
+ " embeddings = embeddings[:, :200]\n",
406
+ " voxels = voxels.mean(dim=1)\n",
407
+ " voxels = voxels.flatten(start_dim=1)\n",
408
+ " outputs = self(voxels)\n",
409
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
410
+ " loss = self.loss(outputs, embeddings)\n",
411
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
412
+ " self.log('val_loss', loss)\n",
413
+ " self.log('val_accuracy', accuracy)\n",
414
+ " return loss\n",
415
+ " \n",
416
+ " \n",
417
+ " def configure_optimizers(self):\n",
418
+ " return torch.optim.Adam(self.parameters(), lr=1e-5)\n",
419
+ " \n",
420
+ "\n",
421
+ "# create the model\n",
422
+ "sizes = [60784, 1000, 1000, 200*1024]\n",
423
+ "residual_conections = [[0], [1], [2], [3]]\n",
424
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
425
+ "model = MLP(sizes, residual_conections, dropout)\n",
426
+ "\n",
427
+ "# create the data module\n",
428
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=2)\n",
429
+ "\n",
430
+ "wandb.finish()\n",
431
+ "\n",
432
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
433
+ "\n",
434
+ "# define the trainer\n",
435
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
436
+ "\n",
437
+ "# train the model\n",
438
+ "trainer.fit(model, datamodule=data_module)\n"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "metadata": {},
445
+ "outputs": [],
446
+ "source": [
447
+ "class MLP(L.LightningModule):\n",
448
+ " def __init__(self, sizes, residual_conections, dropout):\n",
449
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
450
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
451
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
452
+ " super().__init__()\n",
453
+ " self.sizes = sizes\n",
454
+ " self.residual_conections = residual_conections\n",
455
+ " self.dropout = dropout\n",
456
+ " self.layers = nn.Sequential()\n",
457
+ " for i in range(len(sizes)-1):\n",
458
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
459
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
460
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
461
+ "\n",
462
+ " self.loss = nn.CrossEntropyLoss()\n",
463
+ "\n",
464
+ " def forward(self, x):\n",
465
+ " return self.layers(x)\n",
466
+ " \n",
467
+ " def training_step(self, batch, batch_idx):\n",
468
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
469
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
470
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
471
+ " #take just the first 200 embeddings\n",
472
+ " embeddings = embeddings[:, :200]\n",
473
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
474
+ " voxels = voxels[:, 1, :]\n",
475
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
476
+ " outputs = self(voxels)\n",
477
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
478
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
479
+ " loss = self.loss(outputs, embeddings)\n",
480
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
481
+ " self.log('train_loss', loss)\n",
482
+ " self.log('train_accuracy', acuracy)\n",
483
+ " return loss\n",
484
+ " \n",
485
+ " def tokens_accuracy(self, outputs, embeddings):\n",
486
+ " # outputs is [batch_size, 1024, 200]\n",
487
+ " # embeddings is [batch_size, 200]\n",
488
+ " # we need to get the index of the maximum value of each token\n",
489
+ " outputs = outputs.argmax(dim=1)\n",
490
+ " # now we need to compare the outputs with the embeddings\n",
491
+ " return (outputs == embeddings).float().mean()\n",
492
+ "\n",
493
+ " \n",
494
+ " def validation_step(self, batch, batch_idx):\n",
495
+ " voxels, embeddings = batch\n",
496
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
497
+ " embeddings = embeddings[:, :200]\n",
498
+ " voxels = voxels[:, 1, :]\n",
499
+ " voxels = voxels.flatten(start_dim=1)\n",
500
+ " outputs = self(voxels)\n",
501
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
502
+ " loss = self.loss(outputs, embeddings)\n",
503
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
504
+ " self.log('val_loss', loss)\n",
505
+ " self.log('val_accuracy', accuracy)\n",
506
+ " return loss\n",
507
+ " \n",
508
+ " \n",
509
+ " def configure_optimizers(self):\n",
510
+ " return torch.optim.Adam(self.parameters(), lr=1e-6)\n",
511
+ " \n",
512
+ "\n",
513
+ "# create the model\n",
514
+ "sizes = [60784, 1000, 1000, 200*1024]\n",
515
+ "residual_conections = [[0], [1], [2], [3]]\n",
516
+ "dropout = [0.2, 0.2, 0.2, 0.2]\n",
517
+ "model = MLP(sizes, residual_conections, dropout)\n",
518
+ "\n",
519
+ "# create the data module\n",
520
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n",
521
+ "\n",
522
+ "wandb.finish()\n",
523
+ "\n",
524
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
525
+ "\n",
526
+ "# define the trainer\n",
527
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
528
+ "\n",
529
+ "# train the model\n",
530
+ "trainer.fit(model, datamodule=data_module)\n"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": null,
536
+ "metadata": {},
537
+ "outputs": [],
538
+ "source": [
539
+ "model3.eval()\n",
540
+ "outputs = torch.Tensor((480,200))\n",
541
+ "with torch.no_grad():\n",
542
+ " test_dataset = VoxelsDataset(test_voxels_path, test_embeddings_path)\n",
543
+ " dataloader = data.DataLoader(test_dataset, batch_size = 2)\n",
544
+ " for i, (voxels, embeddings) in enumerate(dataloader):\n",
545
+ " voxels = voxels[:, 1, :]\n",
546
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
547
+ " bout = model3(voxels)\n",
548
+ " bout = bout.reshape(-1, 1024, 200)\n",
549
+ " # the 1024 dimension is the number of tokens, we need to get the index of the maximum value of each token\n",
550
+ " bout = bout.argmax(dim=1)\n",
551
+ " # now we need to add the outputs to the outputs tensor\n",
552
+ " outputs[i*2:(i+1)*2] = bout\n",
553
+ " \n",
554
+ " \n",
555
+ "# save the predicted outputs on the current directory\n",
556
+ "torch.save(outputs, 'outputs.pt')"
557
+ ]
558
+ }
559
+ ],
560
+ "metadata": {
561
+ "kernelspec": {
562
+ "display_name": "Python 3 (ipykernel)",
563
+ "language": "python",
564
+ "name": "python3"
565
+ },
566
+ "language_info": {
567
+ "codemirror_mode": {
568
+ "name": "ipython",
569
+ "version": 3
570
+ },
571
+ "file_extension": ".py",
572
+ "mimetype": "text/x-python",
573
+ "name": "python",
574
+ "nbconvert_exporter": "python",
575
+ "pygments_lexer": "ipython3",
576
+ "version": "3.10.8"
577
+ }
578
+ },
579
+ "nbformat": 4,
580
+ "nbformat_minor": 4
581
+ }
src/MLP-model.ipynb ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv\n",
10
+ "import lightning as L\n",
11
+ "import numpy as np, pandas as pd, matplotlib.pyplot as plt\n",
12
+ "from pytorch_lightning.loggers import WandbLogger"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "# create the datasets and dataloaders\n",
22
+ "train_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800 \n",
23
+ "test_voxels_path = '/home/ckadirt/brain2music/dataset/preproc/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600\n",
24
+ "\n",
25
+ "train_embeddings_path = '/home/ckadirt/brain2music/dataset/Gtanz/audios/sub-001/encodec_embeddings_train.pt' # path to training embeddings 480 * 2 * 1125\n",
26
+ "test_embeddings_path = '/home/ckadirt/brain2music/dataset/Gtanz/audios/sub-001/encodec_embeddings_test.pt' # path to test embeddings 600 * 2 * 1125\n",
27
+ "\n",
28
+ "class VoxelsDataset(data.Dataset):\n",
29
+ " def __init__(self, voxels_path, embeddings_path):\n",
30
+ " # transpose the two dimensions of the voxels data to match the embeddings data\n",
31
+ " self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)\n",
32
+ " self.embeddings = torch.load(embeddings_path)\n",
33
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
34
+ " self.len = len(self.voxels) // 10\n",
35
+ "\n",
36
+ " def __getitem__(self, index):\n",
37
+ " # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus\n",
38
+ " voxels = self.voxels[index*10:(index+1)*10]\n",
39
+ " embeddings = self.embeddings[index]\n",
40
+ " return voxels, embeddings\n",
41
+ "\n",
42
+ " def __len__(self):\n",
43
+ " return self.len\n",
44
+ " \n",
45
+ "class VoxelsEmbeddinsEncodecDataModule(L.LightningDataModule):\n",
46
+ " def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=32):\n",
47
+ " super().__init__()\n",
48
+ " self.train_voxels_path = train_voxels_path\n",
49
+ " self.train_embeddings_path = train_embeddings_path\n",
50
+ " self.test_voxels_path = test_voxels_path\n",
51
+ " self.test_embeddings_path = test_embeddings_path\n",
52
+ " self.batch_size = batch_size\n",
53
+ "\n",
54
+ " def setup(self, stage=None):\n",
55
+ " self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)\n",
56
+ " self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)\n",
57
+ "\n",
58
+ " def train_dataloader(self):\n",
59
+ " return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)\n",
60
+ "\n",
61
+ " def test_dataloader(self):\n",
62
+ " return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)\n"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "class MLP(L.LightningModule):\n",
72
+ " def __init__(self, sizes, residual_conections, dropout):\n",
73
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
74
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
75
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
76
+ " super().__init__()\n",
77
+ " self.sizes = sizes\n",
78
+ " self.residual_conections = residual_conections\n",
79
+ " self.dropout = dropout\n",
80
+ " self.layers = nn.ModuleList()\n",
81
+ " for i in range(len(sizes)-1):\n",
82
+ " self.layers.append(nn.Linear(sizes[i], sizes[i+1]))\n",
83
+ " self.relu = nn.ReLU()\n",
84
+ " self.loss = nn.MSELoss()\n",
85
+ "\n",
86
+ " def forward(self, x):\n",
87
+ " x_states = [x]\n",
88
+ " for i in range(len(self.layers)):\n",
89
+ " x = self.layers[i](x)\n",
90
+ " for j in self.residual_conections[i]:\n",
91
+ " x = x + x_states[j]\n",
92
+ " x = self.relu(x)\n",
93
+ " x = nn.Dropout(self.dropout[i])(x)\n",
94
+ " x_states.append(x)\n",
95
+ "\n",
96
+ " return x\n",
97
+ " \n",
98
+ " def training_step(self, batch, batch_idx):\n",
99
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
100
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
101
+ " embeddings = embeddings.flatten(start_dim=1) # the size is [batch_size, 2250]\n",
102
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
103
+ " voxels = voxels.mean(dim=1)\n",
104
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
105
+ " outputs = self(voxels)\n",
106
+ " loss = self.loss(outputs, embeddings)\n",
107
+ " self.log('train_loss', loss)\n",
108
+ " return loss\n",
109
+ " \n",
110
+ " def validation_step(self, batch, batch_idx):\n",
111
+ " voxels, embeddings = batch\n",
112
+ " embeddings = embeddings.flatten(start_dim=1)\n",
113
+ " voxels = voxels.mean(dim=1)\n",
114
+ " voxels = voxels.flatten(start_dim=1)\n",
115
+ " outputs = self(voxels)\n",
116
+ " loss = self.loss(outputs, embeddings)\n",
117
+ " self.log('val_loss', loss)\n",
118
+ " return loss\n",
119
+ " \n",
120
+ " \n",
121
+ " def configure_optimizers(self):\n",
122
+ " return torch.optim.Adam(self.parameters(), lr=1e-3)\n",
123
+ " \n",
124
+ "\n",
125
+ "# create the model\n",
126
+ "sizes = [60784, 1000, 1000, 2250]\n",
127
+ "residual_conections = [[0], [1], [2,1], [3]]\n",
128
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
129
+ "model = MLP(sizes, residual_conections, dropout)\n",
130
+ "\n",
131
+ "# create the data module\n",
132
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=32)\n",
133
+ "\n",
134
+ "\n",
135
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
136
+ "\n",
137
+ "# define the trainer\n",
138
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=100, logger=wandb_logger, precision='16-mixed')\n",
139
+ "\n",
140
+ "# train the model\n",
141
+ "trainer.fit(model, data_module)\n"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "class MLP(L.LightningModule):\n",
151
+ " def __init__(self, sizes, residual_conections, dropout):\n",
152
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
153
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
154
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
155
+ " super().__init__()\n",
156
+ " self.sizes = sizes\n",
157
+ " self.residual_conections = residual_conections\n",
158
+ " self.dropout = dropout\n",
159
+ " self.layers = nn.Sequential()\n",
160
+ " for i in range(len(sizes)-1):\n",
161
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
162
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
163
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
164
+ "\n",
165
+ " self.loss = nn.MSELoss()\n",
166
+ "\n",
167
+ " def forward(self, x):\n",
168
+ " return self.layers(x)\n",
169
+ " \n",
170
+ " def training_step(self, batch, batch_idx):\n",
171
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
172
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
173
+ " embeddings = embeddings.flatten(start_dim=1) # the size is [batch_size, 2250]\n",
174
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
175
+ " voxels = voxels.mean(dim=1)\n",
176
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
177
+ " outputs = self(voxels)\n",
178
+ " loss = self.loss(outputs, embeddings)\n",
179
+ " self.log('train_loss', loss)\n",
180
+ " return loss\n",
181
+ " \n",
182
+ " def validation_step(self, batch, batch_idx):\n",
183
+ " voxels, embeddings = batch\n",
184
+ " embeddings = embeddings.flatten(start_dim=1)\n",
185
+ " voxels = voxels.mean(dim=1)\n",
186
+ " voxels = voxels.flatten(start_dim=1)\n",
187
+ " outputs = self(voxels)\n",
188
+ " loss = self.loss(outputs, embeddings)\n",
189
+ " self.log('val_loss', loss)\n",
190
+ " return loss\n",
191
+ " \n",
192
+ " \n",
193
+ " def configure_optimizers(self):\n",
194
+ " return torch.optim.Adam(self.parameters(), lr=1e-5)\n",
195
+ " \n",
196
+ "\n",
197
+ "# create the model\n",
198
+ "sizes = [60784, 1000, 1000, 2250]\n",
199
+ "residual_conections = [[0], [1], [2], [3]]\n",
200
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
201
+ "model = MLP(sizes, residual_conections, dropout)\n",
202
+ "\n",
203
+ "# create the data module\n",
204
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=32)\n",
205
+ "\n",
206
+ "\n",
207
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
208
+ "\n",
209
+ "# define the trainer\n",
210
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
211
+ "\n",
212
+ "# train the model\n",
213
+ "trainer.fit(model, datamodule=data_module)\n"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "class MLP(L.LightningModule):\n",
223
+ " def __init__(self, sizes, residual_conections, dropout):\n",
224
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
225
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
226
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
227
+ " super().__init__()\n",
228
+ " self.sizes = sizes\n",
229
+ " self.residual_conections = residual_conections\n",
230
+ " self.dropout = dropout\n",
231
+ " self.layers = nn.Sequential()\n",
232
+ " for i in range(len(sizes)-1):\n",
233
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
234
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
235
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
236
+ "\n",
237
+ " self.loss = nn.CrossEntropyLoss()\n",
238
+ "\n",
239
+ " def forward(self, x):\n",
240
+ " return self.layers(x)\n",
241
+ " \n",
242
+ " def training_step(self, batch, batch_idx):\n",
243
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
244
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
245
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
246
+ " #take just the first 200 embeddings\n",
247
+ " embeddings = embeddings[:, :200]\n",
248
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
249
+ " voxels = voxels.mean(dim=1)\n",
250
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
251
+ " outputs = self(voxels)\n",
252
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
253
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
254
+ " loss = self.loss(outputs, embeddings)\n",
255
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
256
+ " self.log('train_loss', loss)\n",
257
+ " self.log('train_accuracy', acuracy)\n",
258
+ " return loss\n",
259
+ " \n",
260
+ " def tokens_accuracy(self, outputs, embeddings):\n",
261
+ " # outputs is [batch_size, 1024, 200]\n",
262
+ " # embeddings is [batch_size, 200]\n",
263
+ " # we need to get the index of the maximum value of each token\n",
264
+ " outputs = outputs.argmax(dim=1)\n",
265
+ " # now we need to compare the outputs with the embeddings\n",
266
+ " return (outputs == embeddings).float().mean()\n",
267
+ "\n",
268
+ " \n",
269
+ " def validation_step(self, batch, batch_idx):\n",
270
+ " voxels, embeddings = batch\n",
271
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
272
+ " embeddings = embeddings[:, :200]\n",
273
+ " voxels = voxels.mean(dim=1)\n",
274
+ " voxels = voxels.flatten(start_dim=1)\n",
275
+ " outputs = self(voxels)\n",
276
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
277
+ " loss = self.loss(outputs, embeddings)\n",
278
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
279
+ " self.log('val_loss', loss)\n",
280
+ " self.log('val_accuracy', accuracy)\n",
281
+ " return loss\n",
282
+ " \n",
283
+ " \n",
284
+ " def configure_optimizers(self):\n",
285
+ " return torch.optim.Adam(self.parameters(), lr=1e-5)\n",
286
+ " \n",
287
+ "\n",
288
+ "# create the model\n",
289
+ "sizes = [60784, 1000, 1000, 200*1024]\n",
290
+ "residual_conections = [[0], [1], [2], [3]]\n",
291
+ "dropout = [0.5, 0.5, 0.5, 0.5]\n",
292
+ "model = MLP(sizes, residual_conections, dropout)\n",
293
+ "\n",
294
+ "# create the data module\n",
295
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=2)\n",
296
+ "\n",
297
+ "wandb.finish()\n",
298
+ "\n",
299
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
300
+ "\n",
301
+ "# define the trainer\n",
302
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
303
+ "\n",
304
+ "# train the model\n",
305
+ "trainer.fit(model, datamodule=data_module)\n"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": null,
311
+ "metadata": {},
312
+ "outputs": [],
313
+ "source": [
314
+ "class MLP(L.LightningModule):\n",
315
+ " def __init__(self, sizes, residual_conections, dropout):\n",
316
+ " # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]\n",
317
+ " # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]\n",
318
+ " # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]\n",
319
+ " super().__init__()\n",
320
+ " self.sizes = sizes\n",
321
+ " self.residual_conections = residual_conections\n",
322
+ " self.dropout = dropout\n",
323
+ " self.layers = nn.Sequential()\n",
324
+ " for i in range(len(sizes)-1):\n",
325
+ " self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))\n",
326
+ " self.layers.add_module('relu'+str(i), nn.ReLU())\n",
327
+ " self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))\n",
328
+ "\n",
329
+ " self.loss = nn.CrossEntropyLoss()\n",
330
+ "\n",
331
+ " def forward(self, x):\n",
332
+ " return self.layers(x)\n",
333
+ " \n",
334
+ " def training_step(self, batch, batch_idx):\n",
335
+ " voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]\n",
336
+ " # flatten the voxels to [batch_size, rest of the dimensions]\n",
337
+ " embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250] \n",
338
+ " #take just the first 200 embeddings\n",
339
+ " embeddings = embeddings[:, :200]\n",
340
+ " # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus\n",
341
+ " voxels = voxels[:, 1, :]\n",
342
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
343
+ " outputs = self(voxels)\n",
344
+ " # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]\n",
345
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
346
+ " loss = self.loss(outputs, embeddings)\n",
347
+ " acuracy = self.tokens_accuracy(outputs, embeddings)\n",
348
+ " self.log('train_loss', loss)\n",
349
+ " self.log('train_accuracy', acuracy)\n",
350
+ " return loss\n",
351
+ " \n",
352
+ " def tokens_accuracy(self, outputs, embeddings):\n",
353
+ " # outputs is [batch_size, 1024, 200]\n",
354
+ " # embeddings is [batch_size, 200]\n",
355
+ " # we need to get the index of the maximum value of each token\n",
356
+ " outputs = outputs.argmax(dim=1)\n",
357
+ " # now we need to compare the outputs with the embeddings\n",
358
+ " return (outputs == embeddings).float().mean()\n",
359
+ "\n",
360
+ " \n",
361
+ " def validation_step(self, batch, batch_idx):\n",
362
+ " voxels, embeddings = batch\n",
363
+ " embeddings = embeddings.flatten(start_dim=1).long()\n",
364
+ " embeddings = embeddings[:, :200]\n",
365
+ " voxels = voxels[:, 1, :]\n",
366
+ " voxels = voxels.flatten(start_dim=1)\n",
367
+ " outputs = self(voxels)\n",
368
+ " outputs = outputs.reshape(-1, 1024, 200)\n",
369
+ " loss = self.loss(outputs, embeddings)\n",
370
+ " accuracy = self.tokens_accuracy(outputs, embeddings)\n",
371
+ " self.log('val_loss', loss)\n",
372
+ " self.log('val_accuracy', accuracy)\n",
373
+ " return loss\n",
374
+ " \n",
375
+ " \n",
376
+ " def configure_optimizers(self):\n",
377
+ " return torch.optim.Adam(self.parameters(), lr=1e-6)\n",
378
+ " \n",
379
+ "\n",
380
+ "# create the model\n",
381
+ "sizes = [60784, 1000, 1000, 200*1024]\n",
382
+ "residual_conections = [[0], [1], [2], [3]]\n",
383
+ "dropout = [0.2, 0.2, 0.2, 0.2]\n",
384
+ "model = MLP(sizes, residual_conections, dropout)\n",
385
+ "\n",
386
+ "# create the data module\n",
387
+ "data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)\n",
388
+ "\n",
389
+ "wandb.finish()\n",
390
+ "\n",
391
+ "wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')\n",
392
+ "\n",
393
+ "# define the trainer\n",
394
+ "trainer = L.Trainer(devices=2, accelerator=\"gpu\", max_epochs=400, logger=wandb_logger, precision='16-mixed', log_every_n_steps=10)\n",
395
+ "\n",
396
+ "# train the model\n",
397
+ "trainer.fit(model, datamodule=data_module)\n"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "model3.eval()\n",
407
+ "outputs = torch.Tensor((480,200))\n",
408
+ "with torch.no_grad():\n",
409
+ " test_dataset = VoxelsDataset(test_voxels_path, test_embeddings_path)\n",
410
+ " dataloader = data.DataLoader(test_dataset, batch_size = 2)\n",
411
+ " for i, (voxels, embeddings) in enumerate(dataloader):\n",
412
+ " voxels = voxels[:, 1, :]\n",
413
+ " voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]\n",
414
+ " bout = model3(voxels)\n",
415
+ " bout = bout.reshape(-1, 1024, 200)\n",
416
+ " # the 1024 dimension is the number of tokens, we need to get the index of the maximum value of each token\n",
417
+ " bout = bout.argmax(dim=1)\n",
418
+ " # now we need to add the outputs to the outputs tensor\n",
419
+ " outputs[i*2:(i+1)*2] = bout\n",
420
+ " \n",
421
+ " \n",
422
+ "# save the predicted outputs on the current directory\n",
423
+ "torch.save(outputs, 'outputs.pt')"
424
+ ]
425
+ }
426
+ ],
427
+ "metadata": {
428
+ "kernelspec": {
429
+ "display_name": "Python 3 (ipykernel)",
430
+ "language": "python",
431
+ "name": "python3"
432
+ },
433
+ "language_info": {
434
+ "codemirror_mode": {
435
+ "name": "ipython",
436
+ "version": 3
437
+ },
438
+ "file_extension": ".py",
439
+ "mimetype": "text/x-python",
440
+ "name": "python",
441
+ "nbconvert_exporter": "python",
442
+ "pygments_lexer": "ipython3",
443
+ "version": "3.10.8"
444
+ }
445
+ },
446
+ "nbformat": 4,
447
+ "nbformat_minor": 4
448
+ }
src/MLPencoder.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/b2m-ckpt1 ADDED
Binary file (184 kB). View file
 
src/b2m-ckpt1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:324b3fec6e3363f45b6e886219effec76e6a8a8b207cb6391eb7c0829b106484
3
+ size 183603
src/mlpdummy.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, torch.nn as nn, torch.utils.data as data, torchvision as tv
2
+ import lightning as L
3
+ import numpy as np, pandas as pd, matplotlib.pyplot as plt
4
+ from pytorch_lightning.loggers import WandbLogger
5
+ import wandb
6
+ import pytorch_lightning as pl
7
+
8
+ torch.set_float32_matmul_precision('medium')
9
+
10
+ # create the datasets and dataloaders
11
+ train_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Training.npy' # path to training voxels 65000 * 4800
12
+ test_voxels_path = '/fsx/proj-fmri/ckadirt/b2m/data/sub-001_Resp_Test_Mean.npy' # path to test voxels 65000 * 600
13
+
14
+ train_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_training_embeds_sorted.npy' # path to training embeddings 480 * 2 * 1125
15
+ test_embeddings_path = '/fsx/proj-fmri/ckadirt/b2m/data/encodec_testing_embeds_sorted.npy' # path to test embeddings 600 * 2 * 1125
16
+
17
+ class VoxelsDataset(data.Dataset):
18
+ def __init__(self, voxels_path, embeddings_path):
19
+ # transpose the two dimensions of the voxels data to match the embeddings data
20
+ self.voxels = torch.from_numpy(np.load(voxels_path)).float().transpose(0, 1)
21
+ self.embeddings = torch.from_numpy(np.load(embeddings_path))
22
+ # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus
23
+ self.len = len(self.voxels) // 10
24
+ print("The len is ", self.len )
25
+
26
+ def __getitem__(self, index):
27
+ # as each stimulus has been exposed for 15 seconds and the fMRI data is sampled every 1.5 seconds, we take 10 samples per stimulus
28
+ voxels = self.voxels[index*10:(index+1)*10]
29
+ embeddings = self.embeddings[index]
30
+ return voxels, embeddings
31
+
32
+ def __len__(self):
33
+ return self.len
34
+
35
+ class VoxelsEmbeddinsEncodecDataModule(pl.LightningDataModule):
36
+ def __init__(self, train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=8):
37
+ super().__init__()
38
+ self.train_voxels_path = train_voxels_path
39
+ self.train_embeddings_path = train_embeddings_path
40
+ self.test_voxels_path = test_voxels_path
41
+ self.test_embeddings_path = test_embeddings_path
42
+ self.batch_size = batch_size
43
+
44
+ def setup(self, stage=None):
45
+ self.train_dataset = VoxelsDataset(self.train_voxels_path, self.train_embeddings_path)
46
+ self.test_dataset = VoxelsDataset(self.test_voxels_path, self.test_embeddings_path)
47
+
48
+ def train_dataloader(self):
49
+ return data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
50
+
51
+ def val_dataloader(self):
52
+ return data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
53
+
54
+
55
+ class MLP(pl.LightningModule):
56
+ def __init__(self, sizes, residual_conections, dropout):
57
+ # sizes is a list of the sizes of the layers ej: [4800, 1000, 1000, 1000, 1000, 1000, 1000, 600]
58
+ # residual_conections is a list with the same length as sizes, each element is a list of the indexes of the layers that will recieve the output of the layer as input, 0 means that the layer will recieve the x inputs ej. [[0], [1], [2,1], [3], [4,3], [5], [6,5], [7]]
59
+ # dropout is a list with the same length as sizes, each element is the dropout probability of the layer ej. [0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
60
+ super().__init__()
61
+ self.sizes = sizes
62
+ self.residual_conections = residual_conections
63
+ self.dropout = dropout
64
+ self.layers = nn.Sequential()
65
+ for i in range(len(sizes)-1):
66
+ self.layers.add_module('linear'+str(i), nn.Linear(sizes[i], sizes[i+1]))
67
+ self.layers.add_module('relu'+str(i), nn.ReLU())
68
+ self.layers.add_module('dropout'+str(i), nn.Dropout(dropout[i]))
69
+
70
+ self.loss = nn.CrossEntropyLoss(reduction='mean')
71
+
72
+ def forward(self, x):
73
+ return self.layers(x)
74
+
75
+ def training_step(self, batch, batch_idx):
76
+ voxels, embeddings = batch # the sizes are [batch_size, 10, 65000] and [batch_size, 2, 1125]
77
+ # flatten the voxels to [batch_size, rest of the dimensions]
78
+ embeddings = embeddings.flatten(start_dim=1).long() # the size is [batch_size, 2250]
79
+ #take just the first 200 embeddings
80
+ # embeddings = embeddings[:, :200]
81
+ # take the mean of the second dimension of the voxels to get the mean of the 10 samples per stimulus
82
+ voxels = voxels.mean(dim=1)
83
+ voxels = voxels.flatten(start_dim=1) # the size is [batch_size, 65000]
84
+ outputs = self(voxels)
85
+ # the outputs are [batch_size, 200*1024], we need to reshape them to [batch_size, 200, 1024]
86
+ outputs = outputs.reshape(-1, 1024, 1125*2)
87
+ # avoid division by zero
88
+ outputs = outputs + 1e-6
89
+ #print(outputs.shape, embeddings.shape)
90
+ #print(outputs[0,0,:10], embeddings[0,:10])
91
+ loss = self.loss(outputs, embeddings)
92
+ #print(loss)
93
+ acuracy = self.tokens_accuracy(outputs, embeddings)
94
+ self.log('train_loss', loss)
95
+ self.log('train_accuracy', acuracy)
96
+ return loss
97
+
98
+ def tokens_accuracy(self, outputs, embeddings):
99
+ # outputs is [batch_size, 1024, 200]
100
+ # embeddings is [batch_size, 200]
101
+ # we need to get the index of the maximum value of each token
102
+ outputs = outputs.argmax(dim=1)
103
+ # now we need to compare the outputs with the embeddings
104
+ return (outputs == embeddings).float().mean()
105
+
106
+
107
+ def validation_step(self, batch, batch_idx):
108
+ voxels, embeddings = batch
109
+ embeddings = embeddings.flatten(start_dim=1).long()
110
+ #embeddings = embeddings[:, :200]
111
+ voxels = voxels.mean(dim=1)
112
+ voxels = voxels.flatten(start_dim=1)
113
+ outputs = self(voxels)
114
+ outputs = outputs.reshape(-1, 1024, 1125*2)
115
+ loss = self.loss(outputs, embeddings)
116
+ accuracy = self.tokens_accuracy(outputs, embeddings)
117
+ self.log('val_loss', loss)
118
+ self.log('val_accuracy', accuracy)
119
+ return loss
120
+
121
+
122
+ def configure_optimizers(self):
123
+ return torch.optim.Adam(self.trainer.model.parameters(), lr=2e-5, weight_decay=3e-3)
124
+
125
+
126
+ # create the model
127
+ sizes = [60784, 1000, 1000, 1125*2*1024]
128
+ residual_conections = [[0], [1], [2], [3]]
129
+ dropout = [0.5, 0.5, 0.5, 0.5]
130
+ model = MLP(sizes, residual_conections, dropout)
131
+
132
+
133
+ # create the data module
134
+ data_module = VoxelsEmbeddinsEncodecDataModule(train_voxels_path, train_embeddings_path, test_voxels_path, test_embeddings_path, batch_size=4)
135
+
136
+ wandb.finish()
137
+ from pytorch_lightning.strategies import DeepSpeedStrategy
138
+
139
+ wandb_logger = WandbLogger(project='brain2music', entity='ckadirt')
140
+
141
+ # define the trainer
142
+ trainer = pl.Trainer(accelerator="gpu", devices = [0,1,2,3,4,5,6,7], max_epochs=1000, logger=wandb_logger, precision='32', strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=8), enable_checkpointing=False, log_every_n_steps=10)
143
+ #trainer = pl.Trainer(accelerator="gpu", devices = [0,1,2,3], max_epochs=1000, logger=wandb_logger, precision='bf16', strategy='fsdp', enable_checkpointing=False, log_every_n_steps=10)
144
+ # train the model
145
+ trainer.fit(model, datamodule=data_module)
146
+
src/musicgen_test copy.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/musicgen_test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/outputs_train0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c329bb538c038cb0997e4d8a28c1194b985fe749d335acdd217d4584d5a5f33
3
+ size 11505408
src/outputs_train1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b943d8ad8171f749eb949a15b2b533febef14b63766a3822216ab6ca0136a7df
3
+ size 11505408
src/outputs_train10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44220fbaaae452574cbbf3ee994930e1f82b71a80e87d413fa7c1e9e982ba628
3
+ size 11505411
src/outputs_train11.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c724ecb56118655468b1f029e21b26bccfda1a2fbbaa22384c807e64ab04eba
3
+ size 11505411
src/outputs_train12.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:471d0f66b3f174442b1001bdab8672f22d5f0a8958de547beb332ee7bf3326f2
3
+ size 11505411
src/outputs_train13.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a25c407a91bdd02cbc45ab829c1bb41c74865611d2e0ccab2129f266c3002f3
3
+ size 11505411
src/outputs_train14.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89b0eeb9b9aae498e1caaed017ee6d6f7d8819e388853fbb3f0f510b87b3f54b
3
+ size 11505411
src/outputs_train15.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21c8ee497ee20e58bbb164ae8bdece4d11fcc49b956d9024dd9d445acdd655cc
3
+ size 11505411
src/outputs_train16.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc9c20f0192317d604a06ea1f0bece0e292c9c3b5d7fbe4f38baa048f2c21ea6
3
+ size 11505411
src/outputs_train17.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcf1d56c5e018539f8df4c27ceae6d4586f7f337597efee39575520c94bcef02
3
+ size 11505411
src/outputs_train18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f43266b7e5fd07645e2775b51133e13c052530715e0a94885bc44b5ad80943f
3
+ size 11505411
src/outputs_train19.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22e50667a061182431fb5b6e0f1d973942d813a399baad1fa2289197a4ba61e8
3
+ size 11505411
src/outputs_train2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31e397b8e34a6b00ce8ec2930269f78323b5a5717586a941d50fc6d7a11aab09
3
+ size 11505408
src/outputs_train20.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06b1e7abd9058661a98adfc9416e59563bc2c81f5e722748c6a857b8e646dc44
3
+ size 11505411
src/outputs_train21.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:700abd1ea9a42a56deea97287da1ed5991cdac64d3ff2b254c122025e02445de
3
+ size 11505411
src/outputs_train22.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9eb089a91056c1c9fe77240d15c89298c9124e0f82a6c38d29c555cedeb025ec
3
+ size 11505411
src/outputs_train23.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47c6f65fc2cc3728caa62269ced13927b8d7de17190e2c9a0068a7d57181a1aa
3
+ size 11505411
src/outputs_train24.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30752d07d2360df5b65bc9ce02cfe388f85cdb4b68b966e4a701055ba4049025
3
+ size 11505411