diff --git "a/Experiments.ipynb" "b/Experiments.ipynb" --- "a/Experiments.ipynb" +++ "b/Experiments.ipynb" @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "c2807819", "metadata": {}, "outputs": [], @@ -20,6 +20,27 @@ "import wandb\n" ] }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cdd08230-c057-4a6e-83b9-435b2c0fbaaf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.13.0+cu117'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.__version__" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -30,13 +51,14 @@ "name": "stderr", "output_type": "stream", "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmattricesound\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ - "Tracking run with wandb version 0.13.5" + "Tracking run with wandb version 0.13.6" ], "text/plain": [ "" @@ -48,7 +70,7 @@ { "data": { "text/html": [ - "Run data is saved locally in /Users/matthewrice/Developer/remfx/wandb/run-20221203_231820-3tdw8zp6" + "Run data is saved locally in /home/jovyan/RemFx/wandb/run-20221209_160820-9wzgwfl3" ], "text/plain": [ "" @@ -60,7 +82,7 @@ { "data": { "text/html": [ - "Syncing run ruby-oath-2 to Weights & Biases (docs)
" + "Syncing run fast-snowflake-6 to Weights & Biases (docs)
" ], "text/plain": [ "" @@ -72,10 +94,10 @@ { "data": { "text/html": [ - "" + "" ], "text/plain": [ - "" + "" ] }, "execution_count": 2, @@ -89,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 3, "id": "8d7eacfc", "metadata": {}, "outputs": [], @@ -100,17 +122,17 @@ }, { "cell_type": "code", - "execution_count": 19, - "id": "a70e9942", + "execution_count": 6, + "id": "d8f78b50-b8f5-4008-b986-fb02590a9cd1", "metadata": {}, "outputs": [], "source": [ - "model = AudioDiffusionModel(in_channels=1)" + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 9, "id": "cdc0fb64", "metadata": {}, "outputs": [], @@ -131,12 +153,12 @@ " resampled_x = F.pad(resampled_x, (0, LENGTH - resampled_x.shape[1]))\n", " elif resampled_x.shape[1] > LENGTH:\n", " resampled_x = resampled_x[:, :LENGTH]\n", - " return resampled_x" + " return resampled_x.to(device)" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 10, "id": "148c2a96", "metadata": {}, "outputs": [], @@ -146,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 237, + "execution_count": 11, "id": "670c94a5", "metadata": {}, "outputs": [ @@ -165,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 12, "id": "e1c83600", "metadata": {}, "outputs": [], @@ -175,7 +197,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 13, "id": "4d46f992", "metadata": {}, "outputs": [], @@ -186,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 14, "id": "1103e520", "metadata": {}, "outputs": [ @@ -196,7 +218,7 @@ "torch.Size([1, 131072])" ] }, - "execution_count": 25, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -207,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 15, "id": "6b0f1575", "metadata": {}, "outputs": [], @@ -217,61 +239,1253 @@ }, { "cell_type": "code", - "execution_count": 28, - "id": "a6a2bb97", + "execution_count": 39, + "id": "314fd8af-a813-436e-9ca5-29dc3a5ad460", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "eff19abd-304c-449e-9fb5-4e9ce4d4b19c", + "metadata": {}, + "outputs": [], + "source": [ + "model = AudioDiffusionModel(in_channels=1, \n", + " patch_size=1,\n", + " multipliers=[1, 2, 4, 4, 4, 4, 4],\n", + " factors=[2, 2, 2, 2, 2, 2],\n", + " num_blocks=[2, 2, 2, 2, 2, 2],\n", + " attentions=[0, 0, 0, 0, 0, 0]\n", + " )\n", + "model = model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "75dd6e95-5e31-43f5-a0f8-05c7e13e7a14", "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - " 14%|█████████▏ | 7/50 [7:29:41<59:56:20, 5018.16s/it]wandb: Network error (ConnectionError), entering retry loop.\n", - " 14%|█████████▏ | 7/50 [8:13:48<50:33:21, 4232.58s/it]\n" + "300\n", + "310\n", + "320\n", + "330\n", + "340\n", + "350\n", + "360\n", + "370\n", + "380\n", + "390\n", + "400\n", + "410\n", + "420\n", + "430\n", + "440\n", + "450\n", + "460\n", + "470\n", + "480\n", + "490\n", + "500\n", + "510\n", + "520\n", + "530\n", + "540\n", + "550\n", + "560\n", + "570\n", + "580\n", + "590\n", + "600\n", + "610\n", + "620\n", + "630\n", + "640\n", + "650\n", + "660\n", + "670\n", + "680\n", + "690\n", + "700\n", + "710\n", + "720\n", + "730\n", + "740\n", + "750\n", + "760\n", + "770\n", + "780\n", + "790\n", + "800\n", + "810\n", + "820\n", + "830\n", + "840\n", + "850\n", + "860\n", + "870\n", + "880\n", + "890\n", + "900\n", + "910\n", + "920\n", + "930\n", + "940\n", + "950\n", + "960\n", + "970\n", + "980\n", + "990\n", + "1000\n", + "1010\n", + "1020\n", + "1030\n", + "1040\n", + "1050\n", + "1060\n", + "1070\n", + "1080\n", + "1090\n", + "1100\n", + "1110\n", + "1120\n", + "1130\n", + "1140\n", + "1150\n", + "1160\n", + "1170\n", + "1180\n", + "1190\n", + "1200\n", + "1210\n", + "1220\n", + "1230\n", + "1240\n", + "1250\n", + "1260\n", + "1270\n", + "1280\n", + "1290\n", + "1300\n", + "1310\n", + "1320\n", + "1330\n", + "1340\n", + "1350\n", + "1360\n", + "1370\n", + "1380\n", + "1390\n", + "1400\n", + "1410\n", + "1420\n", + "1430\n", + "1440\n", + "1450\n", + "1460\n", + "1470\n", + "1480\n", + "1490\n", + "1500\n", + "1510\n", + "1520\n", + "1530\n", + "1540\n", + "1550\n", + "1560\n", + "1570\n", + "1580\n", + "1590\n", + "1600\n", + "1610\n", + "1620\n", + "1630\n", + "1640\n", + "1650\n", + "1660\n", + "1670\n", + "1680\n", + "1690\n", + "1700\n", + "1710\n", + "1720\n", + "1730\n", + "1740\n", + "1750\n", + "1760\n", + "1770\n", + "1780\n", + "1790\n", + "1800\n", + "1810\n", + "1820\n", + "1830\n", + "1840\n", + "1850\n", + "1860\n", + "1870\n", + "1880\n", + "1890\n", + "1900\n", + "1910\n", + "1920\n", + "1930\n", + "1940\n", + "1950\n", + "1960\n", + "1970\n", + "1980\n", + "1990\n", + "2000\n", + "2010\n", + "2020\n", + "2030\n", + "2040\n", + "2050\n", + "2060\n", + "2070\n", + "2080\n", + "2090\n", + "2100\n", + "2110\n", + "2120\n", + "2130\n", + "2140\n", + "2150\n", + "2160\n", + "2170\n", + "2180\n", + "2190\n", + "2200\n", + "2210\n", + "2220\n", + "2230\n", + "2240\n", + "2250\n", + "2260\n", + "2270\n", + "2280\n", + "2290\n", + "2300\n", + "2310\n", + "2320\n", + "2330\n", + "2340\n", + "2350\n", + "2360\n", + "2370\n", + "2380\n", + "2390\n", + "2400\n", + "2410\n", + "2420\n", + "2430\n", + "2440\n", + "2450\n", + "2460\n", + "2470\n", + "2480\n", + "2490\n", + "2500\n", + "2510\n", + "2520\n", + "2530\n", + "2540\n", + "2550\n", + "2560\n", + "2570\n", + "2580\n", + "2590\n", + "2600\n", + "2610\n", + "2620\n", + "2630\n", + "2640\n", + "2650\n", + "2660\n", + "2670\n", + "2680\n", + "2690\n", + "2700\n", + "2710\n", + "2720\n", + "2730\n", + "2740\n", + "2750\n", + "2760\n", + "2770\n", + "2780\n", + "2790\n", + "2800\n", + "2810\n", + "2820\n", + "2830\n", + "2840\n", + "2850\n", + "2860\n", + "2870\n", + "2880\n", + "2890\n", + "2900\n", + "2910\n", + "2920\n", + "2930\n", + "2940\n", + "2950\n", + "2960\n", + "2970\n", + "2980\n", + "2990\n", + "3000\n", + "3010\n", + "3020\n", + "3030\n", + "3040\n", + "3050\n", + "3060\n", + "3070\n", + "3080\n", + "3090\n", + "3100\n", + "3110\n", + "3120\n", + "3130\n", + "3140\n", + "3150\n", + "3160\n", + "3170\n", + "3180\n", + "3190\n", + "3200\n", + "3210\n", + "3220\n", + "3230\n", + "3240\n", + "3250\n", + "3260\n", + "3270\n", + "3280\n", + "3290\n", + "3300\n", + "3310\n", + "3320\n", + "3330\n", + "3340\n", + "3350\n", + "3360\n", + "3370\n", + "3380\n", + "3390\n", + "3400\n", + "3410\n", + "3420\n", + "3430\n", + "3440\n", + "3450\n", + "3460\n", + "3470\n", + "3480\n", + "3490\n", + "3500\n", + "3510\n", + "3520\n", + "3530\n", + "3540\n", + "3550\n", + "3560\n", + "3570\n", + "3580\n", + "3590\n", + "3600\n", + "3610\n", + "3620\n", + "3630\n", + "3640\n", + "3650\n", + "3660\n", + "3670\n", + "3680\n", + "3690\n", + "3700\n", + "3710\n", + "3720\n", + "3730\n", + "3740\n", + "3750\n", + "3760\n", + "3770\n", + "3780\n", + "3790\n", + "3800\n", + "3810\n", + "3820\n", + "3830\n", + "3840\n", + "3850\n", + "3860\n", + "3870\n", + "3880\n", + "3890\n", + "3900\n", + "3910\n", + "3920\n", + "3930\n", + "3940\n", + "3950\n", + "3960\n", + "3970\n", + "3980\n", + "3990\n", + "4000\n", + "4010\n", + "4020\n", + "4030\n", + "4040\n", + "4050\n", + "4060\n", + "4070\n", + "4080\n", + "4090\n", + "4100\n", + "4110\n", + "4120\n", + "4130\n", + "4140\n", + "4150\n", + "4160\n", + "4170\n", + "4180\n", + "4190\n", + "4200\n", + "4210\n", + "4220\n", + "4230\n", + "4240\n", + "4250\n", + "4260\n", + "4270\n", + "4280\n", + "4290\n", + "4300\n", + "4310\n", + "4320\n", + "4330\n", + "4340\n", + "4350\n", + "4360\n", + "4370\n", + "4380\n", + "4390\n", + "4400\n", + "4410\n", + "4420\n", + "4430\n", + "4440\n", + "4450\n", + "4460\n", + "4470\n", + "4480\n", + "4490\n", + "4500\n", + "4510\n", + "4520\n", + "4530\n", + "4540\n", + "4550\n", + "4560\n", + "4570\n", + "4580\n", + "4590\n", + "4600\n", + "4610\n", + "4620\n", + "4630\n", + "4640\n", + "4650\n", + "4660\n", + "4670\n", + "4680\n", + "4690\n", + "4700\n", + "4710\n", + "4720\n", + "4730\n", + "4740\n", + "4750\n", + "4760\n", + "4770\n", + "4780\n", + "4790\n", + "4800\n", + "4810\n", + "4820\n", + "4830\n", + "4840\n", + "4850\n", + "4860\n", + "4870\n", + "4880\n", + "4890\n", + "4900\n", + "4910\n", + "4920\n", + "4930\n", + "4940\n", + "4950\n", + "4960\n", + "4970\n", + "4980\n", + "4990\n", + "5000\n", + "5010\n", + "5020\n", + "5030\n", + "5040\n", + "5050\n", + "5060\n", + "5070\n", + "5080\n", + "5090\n", + "5100\n", + "5110\n", + "5120\n", + "5130\n", + "5140\n", + "5150\n", + "5160\n", + "5170\n", + "5180\n", + "5190\n", + "5200\n", + "5210\n", + "5220\n", + "5230\n", + "5240\n", + "5250\n", + "5260\n", + "5270\n", + "5280\n", + "5290\n", + "5300\n", + "5310\n", + "5320\n", + "5330\n", + "5340\n", + "5350\n", + "5360\n", + "5370\n", + "5380\n", + "5390\n", + "5400\n", + "5410\n", + "5420\n", + "5430\n", + "5440\n", + "5450\n", + "5460\n", + "5470\n", + "5480\n", + "5490\n", + "5500\n", + "5510\n", + "5520\n", + "5530\n", + "5540\n", + "5550\n", + "5560\n", + "5570\n", + "5580\n", + "5590\n", + "5600\n", + "5610\n", + "5620\n", + "5630\n", + "5640\n", + "5650\n", + "5660\n", + "5670\n", + "5680\n", + "5690\n", + "5700\n", + "5710\n", + "5720\n", + "5730\n", + "5740\n", + "5750\n", + "5760\n", + "5770\n", + "5780\n", + "5790\n", + "5800\n", + "5810\n", + "5820\n", + "5830\n", + "5840\n", + "5850\n", + "5860\n", + "5870\n", + "5880\n", + "5890\n", + "5900\n", + "5910\n", + "5920\n", + "5930\n", + "5940\n", + "5950\n", + "5960\n", + "5970\n", + "5980\n", + "5990\n", + "6000\n", + "6010\n", + "6020\n", + "6030\n", + "6040\n", + "6050\n", + "6060\n", + "6070\n", + "6080\n", + "6090\n", + "6100\n", + "6110\n", + "6120\n", + "6130\n", + "6140\n", + "6150\n", + "6160\n", + "6170\n", + "6180\n", + "6190\n", + "6200\n", + "6210\n", + "6220\n", + "6230\n", + "6240\n", + "6250\n", + "6260\n", + "6270\n", + "6280\n", + "6290\n", + "6300\n", + "6310\n", + "6320\n", + "6330\n", + "6340\n", + "6350\n", + "6360\n", + "6370\n", + "6380\n", + "6390\n", + "6400\n", + "6410\n", + "6420\n", + "6430\n", + "6440\n", + "6450\n", + "6460\n", + "6470\n", + "6480\n", + "6490\n", + "6500\n", + "6510\n", + "6520\n", + "6530\n", + "6540\n", + "6550\n", + "6560\n", + "6570\n", + "6580\n", + "6590\n", + "6600\n", + "6610\n", + "6620\n", + "6630\n", + "6640\n", + "6650\n", + "6660\n", + "6670\n", + "6680\n", + "6690\n", + "6700\n", + "6710\n", + "6720\n", + "6730\n", + "6740\n", + "6750\n", + "6760\n", + "6770\n", + "6780\n", + "6790\n", + "6800\n", + "6810\n", + "6820\n", + "6830\n", + "6840\n", + "6850\n", + "6860\n", + "6870\n", + "6880\n", + "6890\n", + "6900\n", + "6910\n", + "6920\n", + "6930\n", + "6940\n", + "6950\n", + "6960\n", + "6970\n", + "6980\n", + "6990\n", + "7000\n", + "7010\n", + "7020\n", + "7030\n", + "7040\n", + "7050\n", + "7060\n", + "7070\n", + "7080\n", + "7090\n", + "7100\n", + "7110\n", + "7120\n", + "7130\n", + "7140\n", + "7150\n", + "7160\n", + "7170\n", + "7180\n", + "7190\n", + "7200\n", + "7210\n", + "7220\n", + "7230\n", + "7240\n", + "7250\n", + "7260\n", + "7270\n", + "7280\n", + "7290\n", + "7300\n", + "7310\n", + "7320\n", + "7330\n", + "7340\n", + "7350\n", + "7360\n", + "7370\n", + "7380\n", + "7390\n", + "7400\n", + "7410\n", + "7420\n", + "7430\n", + "7440\n", + "7450\n", + "7460\n", + "7470\n", + "7480\n", + "7490\n", + "7500\n", + "7510\n", + "7520\n", + "7530\n", + "7540\n", + "7550\n", + "7560\n", + "7570\n", + "7580\n", + "7590\n", + "7600\n", + "7610\n", + "7620\n", + "7630\n", + "7640\n", + "7650\n", + "7660\n", + "7670\n", + "7680\n", + "7690\n", + "7700\n", + "7710\n", + "7720\n", + "7730\n", + "7740\n", + "7750\n", + "7760\n", + "7770\n", + "7780\n", + "7790\n", + "7800\n", + "7810\n", + "7820\n", + "7830\n", + "7840\n", + "7850\n", + "7860\n", + "7870\n", + "7880\n", + "7890\n", + "7900\n", + "7910\n", + "7920\n", + "7930\n", + "7940\n", + "7950\n", + "7960\n", + "7970\n", + "7980\n", + "7990\n", + "8000\n", + "8010\n", + "8020\n", + "8030\n", + "8040\n", + "8050\n", + "8060\n", + "8070\n", + "8080\n", + "8090\n", + "8100\n", + "8110\n", + "8120\n", + "8130\n", + "8140\n", + "8150\n", + "8160\n", + "8170\n", + "8180\n", + "8190\n", + "8200\n", + "8210\n", + "8220\n", + "8230\n", + "8240\n", + "8250\n", + "8260\n", + "8270\n", + "8280\n", + "8290\n", + "8300\n", + "8310\n", + "8320\n", + "8330\n", + "8340\n", + "8350\n", + "8360\n", + "8370\n", + "8380\n", + "8390\n", + "8400\n", + "8410\n", + "8420\n", + "8430\n", + "8440\n", + "8450\n", + "8460\n", + "8470\n", + "8480\n", + "8490\n", + "8500\n", + "8510\n", + "8520\n", + "8530\n", + "8540\n", + "8550\n", + "8560\n", + "8570\n", + "8580\n", + "8590\n", + "8600\n", + "8610\n", + "8620\n", + "8630\n", + "8640\n", + "8650\n", + "8660\n", + "8670\n", + "8680\n", + "8690\n", + "8700\n", + "8710\n", + "8720\n", + "8730\n", + "8740\n", + "8750\n", + "8760\n", + "8770\n", + "8780\n", + "8790\n", + "8800\n", + "8810\n", + "8820\n", + "8830\n", + "8840\n", + "8850\n", + "8860\n", + "8870\n", + "8880\n", + "8890\n", + "8900\n", + "8910\n", + "8920\n", + "8930\n", + "8940\n", + "8950\n", + "8960\n", + "8970\n", + "8980\n", + "8990\n" ] - }, + } + ], + "source": [ + "fs = 22050\n", + "t = 2 ** 18 / 22050\n", + "samples = torch.arange(t * fs) / fs\n", + "\n", + "for i in range(300, 8000):\n", + " f = i\n", + " signal1 = torch.sin(2 * torch.pi * f * samples)\n", + " signal2 = torch.sin(2 * torch.pi * (f*2) * samples)\n", + " stacked_signal = torch.stack((signal1, signal2)).unsqueeze(1)\n", + " stacked_signal = stacked_signal.to(device)\n", + " loss = model(stacked_signal)\n", + " loss.backward() \n", + " if i % 10 == 0:\n", + " print(i)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "bda06495-0546-4474-ba5c-bf55e4887329", + "metadata": {}, + "outputs": [], + "source": [ + "# Sample 2 sources given start noise\n", + "noise = torch.randn(2, 1, 2 ** 18)\n", + "noise = noise.to(device)\n", + "sampled = model.sample(\n", + " noise=noise,\n", + " num_steps=10 # Suggested range: 2-50\n", + ") # [2, 1, 2 ** 18]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2d025c1e-3618-4801-9b9b-b4e50e41dcf7", + "metadata": {}, + "outputs": [], + "source": [ + "z = sampled[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "583d4d28-7b1b-463b-8642-4975b36f38f2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 262144])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "z.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "eeec47b7-4b99-4239-9c61-fd36ad881876", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(z.cpu(), rate=22050)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "4d87215c-4f2d-410b-ac33-7cc1d9f73fac", + "metadata": {}, + "outputs": [ { - "ename": "KeyboardInterrupt", - "evalue": "", + "ename": "NameError", + "evalue": "name 'sig' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[28], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(epochs)):\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m data:\n\u001b[0;32m----> 4\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m%\u001b[39m \u001b[38;5;241m5\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/model.py:36\u001b[0m, in \u001b[0;36mModel1d.forward\u001b[0;34m(self, x, **kwargs)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x: Tensor, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m---> 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffusion\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py:674\u001b[0m, in \u001b[0;36mXDiffusion.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 674\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiffusion\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py:161\u001b[0m, in \u001b[0;36mVDiffusion.forward\u001b[0;34m(self, x, noise, **kwargs)\u001b[0m\n\u001b[1;32m 158\u001b[0m x_target \u001b[38;5;241m=\u001b[39m noise \u001b[38;5;241m*\u001b[39m alpha \u001b[38;5;241m-\u001b[39m x \u001b[38;5;241m*\u001b[39m beta\n\u001b[1;32m 160\u001b[0m \u001b[38;5;66;03m# Denoise and return loss\u001b[39;00m\n\u001b[0;32m--> 161\u001b[0m x_denoised \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdenoise_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_noisy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msigmas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mmse_loss(x_denoised, x_target)\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/diffusion.py:143\u001b[0m, in \u001b[0;36mVDiffusion.denoise_fn\u001b[0;34m(self, x_noisy, sigmas, sigma, **kwargs)\u001b[0m\n\u001b[1;32m 141\u001b[0m batch_size, device \u001b[38;5;241m=\u001b[39m x_noisy\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m], x_noisy\u001b[38;5;241m.\u001b[39mdevice\n\u001b[1;32m 142\u001b[0m sigmas \u001b[38;5;241m=\u001b[39m to_batch(x\u001b[38;5;241m=\u001b[39msigma, xs\u001b[38;5;241m=\u001b[39msigmas, batch_size\u001b[38;5;241m=\u001b[39mbatch_size, device\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[0;32m--> 143\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_noisy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msigmas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:1099\u001b[0m, in \u001b[0;36mUNet1d.forward\u001b[0;34m(self, x, time, features, channels_list, embedding)\u001b[0m\n\u001b[1;32m 1097\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, downsample \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdownsamples):\n\u001b[1;32m 1098\u001b[0m channels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_channels(channels_list, layer\u001b[38;5;241m=\u001b[39mi \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m-> 1099\u001b[0m x, skips \u001b[38;5;241m=\u001b[39m \u001b[43mdownsample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1100\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmapping\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchannels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchannels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membedding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membedding\u001b[49m\n\u001b[1;32m 1101\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1102\u001b[0m skips_list \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m [skips]\n\u001b[1;32m 1104\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbottleneck(x, mapping\u001b[38;5;241m=\u001b[39mmapping, embedding\u001b[38;5;241m=\u001b[39membedding)\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:670\u001b[0m, in \u001b[0;36mDownsampleBlock1d.forward\u001b[0;34m(self, x, mapping, channels, embedding)\u001b[0m\n\u001b[1;32m 668\u001b[0m skips \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 669\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m block \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mblocks:\n\u001b[0;32m--> 670\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mblock\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmapping\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 671\u001b[0m skips \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m [x] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_skip \u001b[38;5;28;01melse\u001b[39;00m []\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_transformer:\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:199\u001b[0m, in \u001b[0;36mResnetBlock1d.forward\u001b[0;34m(self, x, mapping)\u001b[0m\n\u001b[1;32m 196\u001b[0m assert_message \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcontext mapping required if context_mapping_features > 0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_mapping \u001b[38;5;241m^\u001b[39m exists(mapping)), assert_message\n\u001b[0;32m--> 199\u001b[0m h \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mblock1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 201\u001b[0m scale_shift \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_mapping:\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/audio_diffusion_pytorch/modules.py:124\u001b[0m, in \u001b[0;36mConvBlock1d.forward\u001b[0;34m(self, x, scale_shift)\u001b[0m\n\u001b[1;32m 122\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m*\u001b[39m (scale \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m+\u001b[39m shift\n\u001b[1;32m 123\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation(x)\n\u001b[0;32m--> 124\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproject\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/module.py:1190\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1186\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1187\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1188\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1189\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/conv.py:313\u001b[0m, in \u001b[0;36mConv1d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 313\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_conv_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Developer/remfx/env/lib/python3.9/site-packages/torch/nn/modules/conv.py:309\u001b[0m, in \u001b[0;36mConv1d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mzeros\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mconv1d(F\u001b[38;5;241m.\u001b[39mpad(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpadding_mode),\n\u001b[1;32m 307\u001b[0m weight, bias, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstride,\n\u001b[1;32m 308\u001b[0m _single(\u001b[38;5;241m0\u001b[39m), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdilation, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgroups)\n\u001b[0;32m--> 309\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv1d\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 310\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdilation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgroups\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn [14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m Audio(\u001b[43msig\u001b[49m[\u001b[38;5;241m0\u001b[39m], rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m22050\u001b[39m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'sig' is not defined" + ] + } + ], + "source": [ + "Audio(sig[0], rate=22050)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "2ccb733d-706a-4535-93b6-73ae2469de8a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(stacked_signal[1].cpu(), rate=22050)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "0377cc63-846b-4acf-8fa9-f1d4a2b07be4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7999" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "i" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "dcf6a106-7967-470a-932e-156b00e46ab2", + "metadata": {}, + "outputs": [], + "source": [ + "f = 4000\n", + "signal1 = torch.sin(2 * torch.pi * f * samples)\n", + "signal2 = torch.sin(2 * torch.pi * (f*2) * samples)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "fac2d679-9e68-4bcc-8119-745435d128ed", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(signal1.cpu(), rate=22050)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "ddf58e57-4660-4e1a-83e3-5909da3b42fe", + "metadata": {}, + "outputs": [], + "source": [ + "fs = 22050\n", + "f = 440\n", + "t = 2 ** 18 / 22050\n", + "samples = torch.arange(t * fs) / fs\n", + "signal = torch.sin(2 * torch.pi * f * samples)\n", + "sig = torch.stack((signal, signal)).unsqueeze(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "faef7cc2-94b0-4b85-919f-0339542570c7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 1, 262144])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sig.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "6cd94fea-3d4c-4a5b-bcba-2220fb3e9414", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "16384.0" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "262144 / 16" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "a62143ce-e47b-49e8-979f-e9241068d744", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([264600])" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "signal.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "e79b1b33-1905-4ae6-9dbe-73b68eec1dc5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(sig[0], rate=22050)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "a6a2bb97", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 500/500 [15:47:21<00:00, 113.68s/it] \n" ] } ], "source": [ - "epochs = 50\n", + "epochs = 500\n", "for i in tqdm(range(epochs)):\n", " for batch in data:\n", " loss = model(batch)\n", " loss.backward()\n", - " if i % 5 == 0:\n", + " if i % 10 == 0:\n", " wandb.log({\"loss\": loss})\n", " with torch.no_grad():\n", - " noise = torch.randn(1, 1, 2**17)\n", + " noise = torch.randn(1, 1, 2**17).to(device)\n", " sampled = model.sample(noise=noise, num_steps=40)\n", " z = sampled.view(-1)\n", - " wandb.log({f\"Audio_{i}\": wandb.Audio(z.numpy(), sample_rate=SAMPLE_RATE)})\n", + " wandb.log({f\"Audio_{i}\": wandb.Audio(z.cpu().numpy(), sample_rate=SAMPLE_RATE)})\n", " \n", " \n", " " @@ -308,7 +1522,7 @@ }, { "cell_type": "code", - "execution_count": 261, + "execution_count": 32, "id": "fc8becc0", "metadata": {}, "outputs": [], @@ -319,32 +1533,27 @@ }, { "cell_type": "code", - "execution_count": 262, - "id": "7731b0e7", + "execution_count": null, + "id": "2c2296ba-7e43-4155-a754-349a7ee5f519", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 262, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Audio(z, rate=22050)" - ] + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "887fc2c1-de1a-4847-86ca-88b7c59f45fb", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55e3555b-3f88-4a33-9fc8-a47bf5f28df7", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -352,18 +1561,6 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.13" } }, "nbformat": 4,