Upload the pre-trained model and pre-training, inference, downstream, and utility scripts
Browse files- .gitignore +2 -0
- downstream.py +146 -0
- inference.py +52 -0
- input_preprocess.py +1020 -0
- lwm_model.py +154 -0
- main.py +120 -0
- models/model.pth +3 -0
- train.py +446 -0
- utils.py +247 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__*
|
| 2 |
+
/images
|
downstream.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Fri Jan 10 11:11:58 2025
|
| 4 |
+
|
| 5 |
+
This script evaluates downstream task performance by comparing models trained
|
| 6 |
+
on raw channel representations versus those trained on LWM embeddings.
|
| 7 |
+
|
| 8 |
+
@author: Sadjad Alikhani
|
| 9 |
+
"""
|
| 10 |
+
#%% IMPORT PACKAGES & MODULES
|
| 11 |
+
from input_preprocess import tokenizer, scenarios_list
|
| 12 |
+
from inference import lwm_inference
|
| 13 |
+
from utils import prepare_loaders
|
| 14 |
+
from train import finetune
|
| 15 |
+
import lwm_model
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import warnings
|
| 21 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 22 |
+
#%% DOWNSTERAM DATA GENERATION
|
| 23 |
+
n_beams = 16
|
| 24 |
+
task = ['Beam Prediction', 'LoS/NLoS Classification'][1]
|
| 25 |
+
task_type = ["classification", "regression"][0]
|
| 26 |
+
visualization_method = ["pca", "umap", "tsne"][2]
|
| 27 |
+
input_types = ["cls_emb", "channel_emb", "raw"]
|
| 28 |
+
train_ratios = [.001, .01, .05, .1, .25, .5, .8]
|
| 29 |
+
fine_tuning_status = [None, ["layers.8", "layers.9", "layers.10", "layers.11"], "full"]
|
| 30 |
+
selected_scenario_names = [scenarios_list()[18]]
|
| 31 |
+
preprocessed_data, labels, raw_chs = tokenizer(
|
| 32 |
+
selected_scenario_names,
|
| 33 |
+
bs_idxs=[3],
|
| 34 |
+
load_data=False,
|
| 35 |
+
task=task,
|
| 36 |
+
n_beams=n_beams)
|
| 37 |
+
#%% LOAD THE MODEL
|
| 38 |
+
gpu_ids = [0]
|
| 39 |
+
device = torch.device("cuda:0")
|
| 40 |
+
model = lwm_model.lwm().to(device)
|
| 41 |
+
|
| 42 |
+
model_name = "lwm_epoch50_train0.0077_val0.0060_masking0.40.pth"
|
| 43 |
+
state_dict = torch.load(f"models/{model_name}", map_location=device)
|
| 44 |
+
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 45 |
+
model.load_state_dict(new_state_dict)
|
| 46 |
+
|
| 47 |
+
model = nn.DataParallel(model, gpu_ids)
|
| 48 |
+
print(f"Model loaded successfully on GPU {device.index}")
|
| 49 |
+
#%% 2D EMBEDDING SPACE VISUALIZATIONN BEFORE FINE-TUNING
|
| 50 |
+
chs = lwm_inference(
|
| 51 |
+
model,
|
| 52 |
+
preprocessed_data,
|
| 53 |
+
input_type="cls_emb",
|
| 54 |
+
device=device,
|
| 55 |
+
batch_size=64,
|
| 56 |
+
visualization=False,
|
| 57 |
+
labels=labels,
|
| 58 |
+
visualization_method=visualization_method)
|
| 59 |
+
#%% FINE-TUNE
|
| 60 |
+
results = np.zeros((len(fine_tuning_status), len(input_types), len(train_ratios)))
|
| 61 |
+
for fine_tuning_stat_idx, fine_tuning_stat in enumerate(fine_tuning_status):
|
| 62 |
+
for input_type_idx, input_type in enumerate(input_types):
|
| 63 |
+
|
| 64 |
+
if input_type == "raw" and fine_tuning_stat is not None:
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
selected_patches_idxs = None
|
| 68 |
+
for train_ratio_idx, train_ratio in enumerate(train_ratios):
|
| 69 |
+
|
| 70 |
+
print(f"\nfine-tuning status: {fine_tuning_stat}")
|
| 71 |
+
print(f"input type: {input_type}")
|
| 72 |
+
print(f"train ratio: {train_ratio}\n")
|
| 73 |
+
|
| 74 |
+
# PREPARE LOADERS
|
| 75 |
+
train_loader, val_loader, samples, target = prepare_loaders(
|
| 76 |
+
preprocessed_data=preprocessed_data,
|
| 77 |
+
labels=labels,
|
| 78 |
+
selected_patches_idxs=selected_patches_idxs,
|
| 79 |
+
input_type=input_type,
|
| 80 |
+
task_type=task_type,
|
| 81 |
+
train_ratio=train_ratio,
|
| 82 |
+
batch_size=128,
|
| 83 |
+
seed=42
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# FINE-TUNE LWM
|
| 87 |
+
fine_tuned_model, best_model_path, train_losses, val_losses, f1_scores, attn_maps_ft = finetune(
|
| 88 |
+
base_model=model,
|
| 89 |
+
train_loader=train_loader,
|
| 90 |
+
val_loader=val_loader,
|
| 91 |
+
task_type=task_type,
|
| 92 |
+
input_type=input_type,
|
| 93 |
+
num_classes=n_beams if task=='Beam Prediction' else 2 if task=='LoS/NLoS Classification' else None,
|
| 94 |
+
output_dim=target.shape[-1] if task_type =='regression' else None,
|
| 95 |
+
use_custom_head=True,
|
| 96 |
+
fine_tune_layers=fine_tuning_stat,
|
| 97 |
+
optimizer_config={"lr": 1e-3},
|
| 98 |
+
epochs=15,
|
| 99 |
+
device=device,
|
| 100 |
+
task=task
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
results[fine_tuning_stat_idx][input_type_idx][train_ratio_idx] = f1_scores[-1]
|
| 104 |
+
|
| 105 |
+
markers = ['o', 's', 'D']
|
| 106 |
+
labels = ['CLS Emb', 'CHS Emb', 'Raw']
|
| 107 |
+
fine_tuning_status_labels = ['No FT', 'Partial FT', 'Full FT']
|
| 108 |
+
line_styles = ['-', '--', ':']
|
| 109 |
+
colors = plt.cm.viridis(np.linspace(0, 0.8, len(labels)))
|
| 110 |
+
plt.figure(figsize=(12, 8), dpi=500)
|
| 111 |
+
for ft_idx, (ft_status_label, line_style) in enumerate(zip(fine_tuning_status_labels, line_styles)):
|
| 112 |
+
for idx, (marker, label, color) in enumerate(zip(markers, labels, colors)):
|
| 113 |
+
# For "Raw Channels," only plot "No Fine-Tuning" case
|
| 114 |
+
if label == "Raw" and ft_status_label != "No FT":
|
| 115 |
+
continue
|
| 116 |
+
# Simplify label for "Raw Channels" without fine-tuning
|
| 117 |
+
plot_label = label if label != "Raw Channels" or ft_status_label != "No Fine-Tuning" else "Raw Channels"
|
| 118 |
+
plt.plot(
|
| 119 |
+
train_ratios,
|
| 120 |
+
results[ft_idx, idx],
|
| 121 |
+
marker=marker,
|
| 122 |
+
linestyle=line_style,
|
| 123 |
+
label=f"{plot_label} ({ft_status_label})" if label != "Raw Channels" else plot_label,
|
| 124 |
+
color=color,
|
| 125 |
+
linewidth=3,
|
| 126 |
+
markersize=9
|
| 127 |
+
)
|
| 128 |
+
plt.xscale('log')
|
| 129 |
+
plt.xlabel("Train Ratio", fontsize=20)
|
| 130 |
+
plt.ylabel("F1-Score", fontsize=20)
|
| 131 |
+
plt.legend(fontsize=17, loc="best")
|
| 132 |
+
plt.grid(True, linestyle="--", alpha=0.7)
|
| 133 |
+
plt.xticks(fontsize=17)
|
| 134 |
+
plt.yticks(fontsize=17)
|
| 135 |
+
plt.tight_layout()
|
| 136 |
+
plt.show()
|
| 137 |
+
#%% 2D EMBEDDING SPACE VISUALIZATIONN AFTER FINE-TUNING
|
| 138 |
+
chs = lwm_inference(
|
| 139 |
+
fine_tuned_model.model,
|
| 140 |
+
preprocessed_data,
|
| 141 |
+
input_type="cls_emb",
|
| 142 |
+
device=device,
|
| 143 |
+
batch_size=64,
|
| 144 |
+
visualization=False,
|
| 145 |
+
labels=labels,
|
| 146 |
+
visualization_method=visualization_method)
|
inference.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Sun Sep 15 18:27:17 2024
|
| 4 |
+
|
| 5 |
+
This scripts performs the LWM inference on raw channel representations.
|
| 6 |
+
|
| 7 |
+
@author: Sadjad Alikhani
|
| 8 |
+
"""
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 11 |
+
from utils import visualize_embeddings
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import warnings
|
| 14 |
+
warnings.filterwarnings('ignore')
|
| 15 |
+
#%%
|
| 16 |
+
def lwm_inference(model, data, input_type="cls_emb", device="cpu", batch_size=64, visualization=False, labels=None, visualization_method="t-sne"):
|
| 17 |
+
|
| 18 |
+
if input_type == "raw":
|
| 19 |
+
output_total = data
|
| 20 |
+
else:
|
| 21 |
+
dataset = TensorDataset(data)
|
| 22 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
| 23 |
+
|
| 24 |
+
embeddings = []
|
| 25 |
+
model.eval()
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
with tqdm(dataloader, desc="Inference", unit="batch") as t:
|
| 28 |
+
for batch in t:
|
| 29 |
+
|
| 30 |
+
input_ids = batch[0].to(device)
|
| 31 |
+
output = model(input_ids)[0]
|
| 32 |
+
|
| 33 |
+
if input_type == "cls_emb":
|
| 34 |
+
batch_embeddings = output[:, 0, :]
|
| 35 |
+
embeddings.append(batch_embeddings)
|
| 36 |
+
elif input_type == "channel_emb":
|
| 37 |
+
batch_embeddings = output[:, 1:, :]
|
| 38 |
+
embeddings.append(batch_embeddings)
|
| 39 |
+
|
| 40 |
+
output_total = torch.cat(embeddings, dim=0).float()
|
| 41 |
+
|
| 42 |
+
if visualization:
|
| 43 |
+
visualize_embeddings(output_total.view(output_total.size(0), -1),
|
| 44 |
+
labels,
|
| 45 |
+
method=visualization_method,
|
| 46 |
+
label="Embedding Space")
|
| 47 |
+
visualize_embeddings(data.view(data.size(0), -1),
|
| 48 |
+
labels,
|
| 49 |
+
method=visualization_method,
|
| 50 |
+
label="Original Space")
|
| 51 |
+
|
| 52 |
+
return output_total
|
input_preprocess.py
ADDED
|
@@ -0,0 +1,1020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Fri Sep 13 16:13:29 2024
|
| 4 |
+
|
| 5 |
+
This script generates preprocessed data from wireless communication scenarios,
|
| 6 |
+
including channel generation, patch generation, masking, and preparing raw
|
| 7 |
+
channels for the Transformer-based LWM model.
|
| 8 |
+
|
| 9 |
+
@author: Sadjad Alikhani
|
| 10 |
+
"""
|
| 11 |
+
import numpy as np
|
| 12 |
+
import os
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import time
|
| 15 |
+
import pickle
|
| 16 |
+
import DeepMIMOv3
|
| 17 |
+
import torch
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
from utils import generate_gaussian_noise, plot_coverage
|
| 20 |
+
#%% Scenarios List
|
| 21 |
+
def scenarios_list():
|
| 22 |
+
scen_list = np.array([
|
| 23 |
+
'city_0_newyork',
|
| 24 |
+
'city_1_losangeles',
|
| 25 |
+
'city_2_chicago',
|
| 26 |
+
'city_3_houston',
|
| 27 |
+
'city_4_phoenix',
|
| 28 |
+
'city_5_philadelphia',
|
| 29 |
+
'city_6_miami',
|
| 30 |
+
'city_7_sandiego',
|
| 31 |
+
'city_8_dallas',
|
| 32 |
+
'city_9_sanfrancisco',
|
| 33 |
+
'city_10_austin',
|
| 34 |
+
'city_11_santaclara',
|
| 35 |
+
'city_12_fortworth',
|
| 36 |
+
'city_13_columbus',
|
| 37 |
+
'city_14_charlotte',
|
| 38 |
+
'city_15_indianapolis',
|
| 39 |
+
'city_16_sanfrancisco',
|
| 40 |
+
'city_17_seattle',
|
| 41 |
+
'city_18_denver',
|
| 42 |
+
'city_19_oklahoma',
|
| 43 |
+
'asu_campus1_v1',
|
| 44 |
+
'asu_campus1_v2',
|
| 45 |
+
'asu_campus1_v3',
|
| 46 |
+
'asu_campus1_v4',
|
| 47 |
+
'asu_campus1_v5',
|
| 48 |
+
'asu_campus1_v6',
|
| 49 |
+
'asu_campus1_v7',
|
| 50 |
+
'asu_campus1_v8',
|
| 51 |
+
'asu_campus1_v9',
|
| 52 |
+
'asu_campus1_v10',
|
| 53 |
+
'asu_campus1_v11',
|
| 54 |
+
'asu_campus1_v12',
|
| 55 |
+
'asu_campus1_v13',
|
| 56 |
+
'asu_campus1_v14',
|
| 57 |
+
'asu_campus1_v15',
|
| 58 |
+
'asu_campus1_v16',
|
| 59 |
+
'asu_campus1_v17',
|
| 60 |
+
'asu_campus1_v18',
|
| 61 |
+
'asu_campus1_v19',
|
| 62 |
+
'asu_campus1_v20',
|
| 63 |
+
'Boston5G_3p5_v1',
|
| 64 |
+
'Boston5G_3p5_v2',
|
| 65 |
+
'Boston5G_3p5_v3',
|
| 66 |
+
'Boston5G_3p5_v4',
|
| 67 |
+
'Boston5G_3p5_v5',
|
| 68 |
+
'Boston5G_3p5_v6',
|
| 69 |
+
'Boston5G_3p5_v7',
|
| 70 |
+
'Boston5G_3p5_v8',
|
| 71 |
+
'Boston5G_3p5_v9',
|
| 72 |
+
'Boston5G_3p5_v10',
|
| 73 |
+
'Boston5G_3p5_v11',
|
| 74 |
+
'Boston5G_3p5_v12',
|
| 75 |
+
'Boston5G_3p5_v13',
|
| 76 |
+
'Boston5G_3p5_v14',
|
| 77 |
+
'Boston5G_3p5_v15',
|
| 78 |
+
'Boston5G_3p5_v16',
|
| 79 |
+
'Boston5G_3p5_v17',
|
| 80 |
+
'Boston5G_3p5_v18',
|
| 81 |
+
'Boston5G_3p5_v19',
|
| 82 |
+
'Boston5G_3p5_v20',
|
| 83 |
+
'O1_3p5_v1',
|
| 84 |
+
'O1_3p5_v2',
|
| 85 |
+
'O1_3p5_v3',
|
| 86 |
+
'O1_3p5_v4',
|
| 87 |
+
'O1_3p5_v5',
|
| 88 |
+
'O1_3p5_v6',
|
| 89 |
+
'O1_3p5_v7',
|
| 90 |
+
'O1_3p5_v8',
|
| 91 |
+
'O1_3p5_v9',
|
| 92 |
+
'O1_3p5_v10',
|
| 93 |
+
'O1_3p5_v11',
|
| 94 |
+
'O1_3p5_v12',
|
| 95 |
+
'O1_3p5_v13',
|
| 96 |
+
'O1_3p5_v14',
|
| 97 |
+
'O1_3p5_v15',
|
| 98 |
+
'O1_3p5_v16',
|
| 99 |
+
'O1_3p5_v17',
|
| 100 |
+
'O1_3p5_v18',
|
| 101 |
+
'O1_3p5_v19',
|
| 102 |
+
'O1_3p5_v20',
|
| 103 |
+
'asu_campus1',
|
| 104 |
+
'O1_3p5',
|
| 105 |
+
'Boston5G_3p5',
|
| 106 |
+
'city_0_newyork_v16x64',
|
| 107 |
+
'city_1_losangeles_v16x64',
|
| 108 |
+
'city_2_chicago_v16x64',
|
| 109 |
+
'city_3_houston_v16x64',
|
| 110 |
+
'city_4_phoenix_v16x64',
|
| 111 |
+
'city_5_philadelphia_v16x64',
|
| 112 |
+
'city_6_miami_v16x64',
|
| 113 |
+
'city_7_sandiego_v16x64',
|
| 114 |
+
'city_8_dallas_v16x64',
|
| 115 |
+
'city_9_sanfrancisco_v16x64'
|
| 116 |
+
])
|
| 117 |
+
return scen_list
|
| 118 |
+
#%% Token Generation
|
| 119 |
+
def patch_gen(N_ROWS=4, N_COLUMNS=4, selected_scenario_names=None,
|
| 120 |
+
manual_data=None, bs_idxs=[1,2,3], load_data=False,
|
| 121 |
+
save_dir="data", task="LoS/NLoS Classification",
|
| 122 |
+
n_beams=64, o1_bs_idx=[4]):
|
| 123 |
+
|
| 124 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 125 |
+
|
| 126 |
+
if manual_data is not None:
|
| 127 |
+
patches = patch_maker(np.expand_dims(np.array(manual_data), axis=1))
|
| 128 |
+
else:
|
| 129 |
+
deepmimo_data = []
|
| 130 |
+
for scenario_name in selected_scenario_names:
|
| 131 |
+
if "O1" in scenario_name: # make an exception for bs idxs of the o1 scenario
|
| 132 |
+
if o1_bs_idx is None:
|
| 133 |
+
bs_idxs = [4, 15]
|
| 134 |
+
else:
|
| 135 |
+
bs_idxs = o1_bs_idx
|
| 136 |
+
for bs_idx in bs_idxs:
|
| 137 |
+
if has_version_suffix(scenario_name) and bs_idx in [2,3]:
|
| 138 |
+
continue
|
| 139 |
+
if not load_data:
|
| 140 |
+
print(f"\nGenerating data for scenario: {scenario_name}, BS #{bs_idx}")
|
| 141 |
+
data, n_ant_bs, n_subcarriers = DeepMIMO_data_gen(scenario_name, bs_idx)
|
| 142 |
+
file_name = f"{save_dir}/{scenario_name}_ant{n_ant_bs}_sub{n_subcarriers}_bs{bs_idx}.npy"
|
| 143 |
+
np.save(file_name, data)
|
| 144 |
+
print(f"Data saved to {file_name}")
|
| 145 |
+
deepmimo_data.append(data)
|
| 146 |
+
else:
|
| 147 |
+
n_ant_bs, n_subcarriers = parametersv2(scenario_name, bs_idx)
|
| 148 |
+
print(f"\nLoading data for scenario: {scenario_name}, BS #{bs_idx}")
|
| 149 |
+
file_name = f"{save_dir}/{scenario_name}_ant{n_ant_bs}_sub{n_subcarriers}_bs{bs_idx}.npy"
|
| 150 |
+
data = np.load(file_name, allow_pickle=True).item()
|
| 151 |
+
print(f"Data loaded from {file_name}")
|
| 152 |
+
deepmimo_data.append(data)
|
| 153 |
+
|
| 154 |
+
cleaned_deepmimo_data = [deepmimo_data_cleaning(deepmimo_data[scenario_idx]) for scenario_idx in range(len(deepmimo_data))] #n_scenarios*n_bs_idxs
|
| 155 |
+
patches = [patch_maker(cleaned_deepmimo_data[scenario_idx], N_ROWS, N_COLUMNS) for scenario_idx in range(len(deepmimo_data))]
|
| 156 |
+
raw_chs = torch.tensor(cleaned_deepmimo_data[0]).squeeze(1)
|
| 157 |
+
raw_chs = raw_chs.view(raw_chs.size(0), -1)
|
| 158 |
+
raw_chs = torch.hstack((raw_chs.real, raw_chs.imag))
|
| 159 |
+
|
| 160 |
+
if task:
|
| 161 |
+
labels = [label_gen(task, deepmimo_data[scenario_idx], selected_scenario_names[scenario_idx], n_beams=n_beams) for scenario_idx in range(len(deepmimo_data))]
|
| 162 |
+
return patches, torch.tensor(labels[0]), raw_chs.view(raw_chs.size(0), -1)
|
| 163 |
+
else:
|
| 164 |
+
return patches, raw_chs.view(raw_chs.size(0), -1)
|
| 165 |
+
#%%
|
| 166 |
+
def tokenizer(selected_scenario_names,
|
| 167 |
+
bs_idxs=[1,2,3],
|
| 168 |
+
load_data=False,
|
| 169 |
+
task="LoS/NLoS Classification",
|
| 170 |
+
n_beams=64,
|
| 171 |
+
MAX_LEN=513,
|
| 172 |
+
masking_percent=.40,
|
| 173 |
+
mask=False,
|
| 174 |
+
seed=42,
|
| 175 |
+
snr=None):
|
| 176 |
+
|
| 177 |
+
patches, labels, raw_chs = patch_gen(
|
| 178 |
+
selected_scenario_names=selected_scenario_names,
|
| 179 |
+
bs_idxs=bs_idxs,
|
| 180 |
+
load_data=load_data,
|
| 181 |
+
task=task,
|
| 182 |
+
n_beams=n_beams
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
patches = [patch for patch_list in patches for patch in patch_list]
|
| 186 |
+
print("Total number of samples:", len(patches))
|
| 187 |
+
|
| 188 |
+
grouped_data = defaultdict(list) # Group samples by sequence length
|
| 189 |
+
grouped_data_2 = []
|
| 190 |
+
|
| 191 |
+
for user_idx in tqdm(range(len(patches)), desc="Processing items"):
|
| 192 |
+
patch_size = patches[user_idx].shape[1]
|
| 193 |
+
n_patches = patches[user_idx].shape[0]
|
| 194 |
+
n_masks_half = int(masking_percent * n_patches)
|
| 195 |
+
|
| 196 |
+
word2id = {
|
| 197 |
+
'[CLS]': 0.2 * np.ones((patch_size)),
|
| 198 |
+
'[MASK]': 0.1 * np.ones((patch_size))
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
sample = make_sample(
|
| 202 |
+
user_idx, patches, word2id, n_patches, n_masks_half, patch_size, MAX_LEN, mask=mask, seed=seed
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if mask:
|
| 206 |
+
seq_length = len(sample[0])
|
| 207 |
+
grouped_data[seq_length].append(sample)
|
| 208 |
+
else:
|
| 209 |
+
grouped_data_2.append(sample)
|
| 210 |
+
|
| 211 |
+
if mask:
|
| 212 |
+
# Normalize keys to 0, 1, 2, ...
|
| 213 |
+
normalized_grouped_data = {i: grouped_data[key] for i, key in enumerate(sorted(grouped_data.keys()))}
|
| 214 |
+
else:
|
| 215 |
+
normalized_grouped_data = torch.stack(grouped_data_2, dim=0)
|
| 216 |
+
# normalized_grouped_data = grouped_data_2
|
| 217 |
+
if snr is not None:
|
| 218 |
+
normalized_grouped_data += generate_gaussian_noise(normalized_grouped_data, snr)
|
| 219 |
+
# normalized_grouped_data = {i: grouped_data[key] for i, key in enumerate(sorted(grouped_data.keys()))}
|
| 220 |
+
|
| 221 |
+
return normalized_grouped_data, labels, raw_chs
|
| 222 |
+
#%% REMOVE ZERO CHANNELS AND SCALE
|
| 223 |
+
def deepmimo_data_cleaning(deepmimo_data):
|
| 224 |
+
idxs = np.where(deepmimo_data['user']['LoS'] != -1)[0]
|
| 225 |
+
cleaned_deepmimo_data = deepmimo_data['user']['channel'][idxs]
|
| 226 |
+
return np.array(cleaned_deepmimo_data) * 1e6
|
| 227 |
+
#%%
|
| 228 |
+
def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, MAX_LEN, mask=True, seed=None):
|
| 229 |
+
|
| 230 |
+
if seed is not None:
|
| 231 |
+
np.random.seed(seed)
|
| 232 |
+
|
| 233 |
+
# Step 1: Retrieve tokens and prepend [CLS]
|
| 234 |
+
tokens = patch[user_idx]
|
| 235 |
+
input_ids = np.vstack((word2id['[CLS]'], tokens))
|
| 236 |
+
|
| 237 |
+
# Step 2: Mask real and imaginary patches
|
| 238 |
+
tokens_size = int(n_patches) # int(n_patches / 2)
|
| 239 |
+
masked_pos = np.random.choice(range(1, tokens_size), size=n_masks, replace=False)
|
| 240 |
+
|
| 241 |
+
masked_tokens = []
|
| 242 |
+
for pos in masked_pos:
|
| 243 |
+
original_masked_tokens = input_ids[pos].copy()
|
| 244 |
+
masked_tokens.append(original_masked_tokens)
|
| 245 |
+
if mask:
|
| 246 |
+
rnd_num = np.random.rand()
|
| 247 |
+
if rnd_num < 0.1:
|
| 248 |
+
input_ids[pos] = np.random.rand(patch_size) # Replace with random values
|
| 249 |
+
elif rnd_num < 0.9:
|
| 250 |
+
input_ids[pos] = word2id['[MASK]'] # Replace with [MASK]
|
| 251 |
+
|
| 252 |
+
if not mask:
|
| 253 |
+
return torch.tensor(input_ids)
|
| 254 |
+
else:
|
| 255 |
+
return [input_ids, masked_tokens, masked_pos]
|
| 256 |
+
#%% Patch GENERATION
|
| 257 |
+
def patch_maker(original_ch, patch_rows, patch_cols):
|
| 258 |
+
# Step 1: Remove the singleton channel dimension
|
| 259 |
+
n_samples, _, n_rows, n_cols = original_ch.shape # Unpack shape
|
| 260 |
+
original_ch = original_ch[:, 0] # Remove the singleton dimension
|
| 261 |
+
|
| 262 |
+
# Step 2: Split into real and imaginary parts and interleave them
|
| 263 |
+
flat_real = original_ch.real
|
| 264 |
+
flat_imag = original_ch.imag
|
| 265 |
+
|
| 266 |
+
# Interleave real and imaginary parts along the last axis
|
| 267 |
+
interleaved = np.empty((n_samples, n_rows, n_cols * 2), dtype=np.float32)
|
| 268 |
+
interleaved[:, :, 0::2] = flat_real
|
| 269 |
+
interleaved[:, :, 1::2] = flat_imag
|
| 270 |
+
|
| 271 |
+
# Step 3: Compute the number of patches along rows and columns
|
| 272 |
+
n_patches_rows = int(np.ceil(n_rows / patch_rows))
|
| 273 |
+
n_patches_cols = int(np.ceil(n_cols / patch_cols))
|
| 274 |
+
|
| 275 |
+
# Step 4: Pad the matrix if necessary to make it divisible by patch size
|
| 276 |
+
padded_rows = n_patches_rows * patch_rows - n_rows
|
| 277 |
+
padded_cols = n_patches_cols * patch_cols - n_cols
|
| 278 |
+
if padded_rows > 0 or padded_cols > 0:
|
| 279 |
+
interleaved = np.pad(
|
| 280 |
+
interleaved,
|
| 281 |
+
((0, 0), (0, padded_rows), (0, padded_cols * 2)), # Double padding for interleaved axis
|
| 282 |
+
mode='constant',
|
| 283 |
+
constant_values=0,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Step 5: Create patches by dividing into blocks
|
| 287 |
+
n_samples, padded_rows, padded_cols = interleaved.shape
|
| 288 |
+
padded_cols //= 2 # Adjust for interleaving (real and imaginary parts count as one)
|
| 289 |
+
patches = []
|
| 290 |
+
|
| 291 |
+
for i in range(0, padded_rows, patch_rows):
|
| 292 |
+
for j in range(0, padded_cols, patch_cols):
|
| 293 |
+
patch = interleaved[:, i:i + patch_rows, j * 2:(j + patch_cols) * 2]
|
| 294 |
+
patches.append(patch.reshape(n_samples, -1)) # Flatten each patch
|
| 295 |
+
|
| 296 |
+
# Step 6: Stack patches to form the final array
|
| 297 |
+
patches = np.stack(patches, axis=1) # Shape: (num_samples, n_patches, patch_rows * patch_cols * 2)
|
| 298 |
+
|
| 299 |
+
return patches
|
| 300 |
+
#%% Data Generation for Scenario Areas
|
| 301 |
+
def DeepMIMO_data_gen(scenario, bs_idx):
|
| 302 |
+
import DeepMIMOv3
|
| 303 |
+
parameters, row_column_users = get_parameters(scenario, bs_idx)
|
| 304 |
+
deepMIMO_dataset = DeepMIMOv3.generate_data(parameters)
|
| 305 |
+
|
| 306 |
+
if "O1" in scenario:
|
| 307 |
+
hops = [2, 2]
|
| 308 |
+
else:
|
| 309 |
+
hops = [1, 1]
|
| 310 |
+
|
| 311 |
+
uniform_idxs = uniform_sampling(deepMIMO_dataset, hops, len(parameters['user_rows']),
|
| 312 |
+
users_per_row=row_column_users[scenario]['n_per_row'])
|
| 313 |
+
data = select_by_idx(deepMIMO_dataset, uniform_idxs)[0]
|
| 314 |
+
|
| 315 |
+
n_ant_bs = parameters['bs_antenna']['shape'][0]
|
| 316 |
+
n_subcarriers = parameters['OFDM']['subcarriers']
|
| 317 |
+
|
| 318 |
+
return data, n_ant_bs, n_subcarriers
|
| 319 |
+
#%%
|
| 320 |
+
def parametersv2(scenario, bs_idx):
|
| 321 |
+
parameters, _ = get_parameters(scenario, bs_idx)
|
| 322 |
+
n_ant_bs = parameters['bs_antenna']['shape'][0]
|
| 323 |
+
n_subcarriers = parameters['OFDM']['subcarriers']
|
| 324 |
+
return n_ant_bs, n_subcarriers
|
| 325 |
+
#%%%
|
| 326 |
+
def get_parameters(scenario, bs_idx=1):
|
| 327 |
+
|
| 328 |
+
n_ant_ue = 1
|
| 329 |
+
scs = 30e3
|
| 330 |
+
|
| 331 |
+
row_column_users = scenario_prop()
|
| 332 |
+
|
| 333 |
+
parameters = DeepMIMOv3.default_params()
|
| 334 |
+
parameters['dataset_folder'] = './scenarios'
|
| 335 |
+
parameters['scenario'] = scenario.split("_v")[0]
|
| 336 |
+
|
| 337 |
+
n_ant_bs = row_column_users[scenario]['n_ant_bs']
|
| 338 |
+
n_subcarriers = row_column_users[scenario]['n_subcarriers']
|
| 339 |
+
parameters['active_BS'] = np.array([bs_idx])
|
| 340 |
+
|
| 341 |
+
if isinstance(row_column_users[scenario]['n_rows'], int):
|
| 342 |
+
parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
|
| 343 |
+
else:
|
| 344 |
+
parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
|
| 345 |
+
row_column_users[scenario]['n_rows'][1])
|
| 346 |
+
|
| 347 |
+
parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
|
| 348 |
+
parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
|
| 349 |
+
parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
|
| 350 |
+
parameters['enable_BS2BS'] = False
|
| 351 |
+
parameters['OFDM']['subcarriers'] = n_subcarriers
|
| 352 |
+
parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
|
| 353 |
+
|
| 354 |
+
parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
|
| 355 |
+
parameters['num_paths'] = 20
|
| 356 |
+
|
| 357 |
+
return parameters, row_column_users
|
| 358 |
+
#%% Sampling and Data Selection
|
| 359 |
+
def uniform_sampling(dataset, sampling_div, n_rows, users_per_row):
|
| 360 |
+
cols = np.arange(users_per_row, step=sampling_div[0])
|
| 361 |
+
rows = np.arange(n_rows, step=sampling_div[1])
|
| 362 |
+
uniform_idxs = np.array([j + i * users_per_row for i in rows for j in cols])
|
| 363 |
+
return uniform_idxs
|
| 364 |
+
|
| 365 |
+
def select_by_idx(dataset, idxs):
|
| 366 |
+
dataset_t = [] # Trimmed dataset
|
| 367 |
+
for bs_idx in range(len(dataset)):
|
| 368 |
+
dataset_t.append({})
|
| 369 |
+
for key in dataset[bs_idx].keys():
|
| 370 |
+
dataset_t[bs_idx]['location'] = dataset[bs_idx]['location']
|
| 371 |
+
dataset_t[bs_idx]['user'] = {k: dataset[bs_idx]['user'][k][idxs] for k in dataset[bs_idx]['user']}
|
| 372 |
+
return dataset_t
|
| 373 |
+
#%%
|
| 374 |
+
def inverse_patch_maker(patches, original_shape, patch_rows, patch_cols):
|
| 375 |
+
"""
|
| 376 |
+
Reconstructs the original channel matrix from patches.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
patches (numpy array): Patches of shape (num_samples, n_patches, patch_rows * patch_cols * 2).
|
| 380 |
+
original_shape (tuple): Original shape of the channel matrix (num_samples, 1, n_rows, n_cols).
|
| 381 |
+
patch_rows (int): Number of rows in each patch.
|
| 382 |
+
patch_cols (int): Number of columns in each patch.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
numpy array: Reconstructed complex-valued channel matrix of shape (num_samples, 1, n_rows, n_cols).
|
| 386 |
+
"""
|
| 387 |
+
n_samples, n_patches, patch_size = patches.shape
|
| 388 |
+
_, _, n_rows, n_cols = original_shape
|
| 389 |
+
|
| 390 |
+
# Ensure patch dimensions match
|
| 391 |
+
assert patch_rows * patch_cols * 2 == patch_size, "Patch size mismatch with provided dimensions."
|
| 392 |
+
|
| 393 |
+
# Compute the number of patches along rows and columns
|
| 394 |
+
n_patches_rows = int(np.ceil(n_rows / patch_rows))
|
| 395 |
+
n_patches_cols = int(np.ceil(n_cols / patch_cols))
|
| 396 |
+
|
| 397 |
+
# Reassemble interleaved array from patches
|
| 398 |
+
interleaved = np.zeros((n_samples, n_patches_rows * patch_rows, n_patches_cols * patch_cols * 2), dtype=np.float32)
|
| 399 |
+
patch_idx = 0
|
| 400 |
+
|
| 401 |
+
for i in range(n_patches_rows):
|
| 402 |
+
for j in range(n_patches_cols):
|
| 403 |
+
patch = patches[:, patch_idx, :].reshape(n_samples, patch_rows, patch_cols * 2)
|
| 404 |
+
interleaved[:, i * patch_rows:(i + 1) * patch_rows, j * patch_cols * 2:(j + 1) * patch_cols * 2] = patch
|
| 405 |
+
patch_idx += 1
|
| 406 |
+
|
| 407 |
+
# Remove padding if necessary
|
| 408 |
+
interleaved = interleaved[:, :n_rows, :n_cols * 2]
|
| 409 |
+
|
| 410 |
+
# Separate real and imaginary parts
|
| 411 |
+
flat_real = interleaved[:, :, 0::2]
|
| 412 |
+
flat_imag = interleaved[:, :, 1::2]
|
| 413 |
+
|
| 414 |
+
# Reconstruct the complex-valued original channel
|
| 415 |
+
reconstructed = flat_real + 1j * flat_imag
|
| 416 |
+
|
| 417 |
+
# Add the singleton channel dimension back
|
| 418 |
+
reconstructed = reconstructed[:, np.newaxis, :, :] # Shape: (num_samples, 1, n_rows, n_cols)
|
| 419 |
+
|
| 420 |
+
return reconstructed
|
| 421 |
+
#%%
|
| 422 |
+
def label_gen(task, data, scenario, n_beams=64):
|
| 423 |
+
|
| 424 |
+
idxs = np.where(data['user']['LoS'] != -1)[0]
|
| 425 |
+
|
| 426 |
+
if task == 'LoS/NLoS Classification':
|
| 427 |
+
label = data['user']['LoS'][idxs]
|
| 428 |
+
|
| 429 |
+
losChs = np.where(data['user']['LoS'] == -1, np.nan, data['user']['LoS'])
|
| 430 |
+
plot_coverage(data['user']['location'], losChs, cbar_title='LoS status')
|
| 431 |
+
|
| 432 |
+
elif task == 'Beam Prediction':
|
| 433 |
+
parameters, row_column_users = get_parameters(scenario, bs_idx=1)
|
| 434 |
+
n_users = len(data['user']['channel'])
|
| 435 |
+
n_subbands = 1
|
| 436 |
+
fov = 180
|
| 437 |
+
|
| 438 |
+
# Setup Beamformers
|
| 439 |
+
beam_angles = np.around(np.arange(-fov/2, fov/2+.1, fov/(n_beams-1)), 2)
|
| 440 |
+
|
| 441 |
+
F1 = np.array([steering_vec(parameters['bs_antenna']['shape'],
|
| 442 |
+
phi=azi*np.pi/180,
|
| 443 |
+
kd=2*np.pi*parameters['bs_antenna']['spacing']).squeeze()
|
| 444 |
+
for azi in beam_angles])
|
| 445 |
+
|
| 446 |
+
full_dbm = np.zeros((n_beams, n_subbands, n_users), dtype=float)
|
| 447 |
+
for ue_idx in tqdm(range(n_users), desc='Computing the channel for each user'):
|
| 448 |
+
if data['user']['LoS'][ue_idx] == -1:
|
| 449 |
+
full_dbm[:,:,ue_idx] = np.nan
|
| 450 |
+
else:
|
| 451 |
+
chs = F1 @ data['user']['channel'][ue_idx]
|
| 452 |
+
full_linear = np.abs(np.mean(chs.squeeze().reshape((n_beams, n_subbands, -1)), axis=-1))
|
| 453 |
+
full_dbm[:,:,ue_idx] = np.around(20*np.log10(full_linear) + 30, 1)
|
| 454 |
+
|
| 455 |
+
best_beams = np.argmax(np.mean(full_dbm,axis=1), axis=0)
|
| 456 |
+
best_beams = best_beams.astype(float)
|
| 457 |
+
best_beams[np.isnan(full_dbm[0,0,:])] = np.nan
|
| 458 |
+
# max_bf_pwr = np.max(np.mean(full_dbm,axis=1), axis=0)
|
| 459 |
+
|
| 460 |
+
plot_coverage(data['user']['location'], best_beams, tx_pos=data['location'],
|
| 461 |
+
tx_ori=parameters['bs_antenna']['rotation']*np.pi/180,
|
| 462 |
+
cbar_title='Best beam index')
|
| 463 |
+
|
| 464 |
+
label = best_beams[idxs]
|
| 465 |
+
|
| 466 |
+
return label.astype(int)
|
| 467 |
+
#%%
|
| 468 |
+
def steering_vec(array, phi=0, theta=0, kd=np.pi):
|
| 469 |
+
idxs = DeepMIMOv3.ant_indices(array)
|
| 470 |
+
resp = DeepMIMOv3.array_response(idxs, phi, theta+np.pi/2, kd)
|
| 471 |
+
return resp / np.linalg.norm(resp)
|
| 472 |
+
#%%
|
| 473 |
+
import re
|
| 474 |
+
def has_version_suffix(s):
|
| 475 |
+
pattern = r"_v([1-9]|1[0-9]|20)$"
|
| 476 |
+
return bool(re.search(pattern, s))
|
| 477 |
+
#%%
|
| 478 |
+
def scenario_prop():
|
| 479 |
+
row_column_users = {
|
| 480 |
+
'city_0_newyork': {
|
| 481 |
+
'n_rows': 109,
|
| 482 |
+
'n_per_row': 291,
|
| 483 |
+
'n_ant_bs': 8,
|
| 484 |
+
'n_subcarriers': 32
|
| 485 |
+
},
|
| 486 |
+
'city_1_losangeles': {
|
| 487 |
+
'n_rows': 142,
|
| 488 |
+
'n_per_row': 201,
|
| 489 |
+
'n_ant_bs': 8,
|
| 490 |
+
'n_subcarriers': 64
|
| 491 |
+
},
|
| 492 |
+
'city_2_chicago': {
|
| 493 |
+
'n_rows': 139,
|
| 494 |
+
'n_per_row': 200,
|
| 495 |
+
'n_ant_bs': 8,
|
| 496 |
+
'n_subcarriers': 128
|
| 497 |
+
},
|
| 498 |
+
'city_3_houston': {
|
| 499 |
+
'n_rows': 154,
|
| 500 |
+
'n_per_row': 202,
|
| 501 |
+
'n_ant_bs': 8,
|
| 502 |
+
'n_subcarriers': 256
|
| 503 |
+
},
|
| 504 |
+
'city_4_phoenix': {
|
| 505 |
+
'n_rows': 198,
|
| 506 |
+
'n_per_row': 214,
|
| 507 |
+
'n_ant_bs': 8,
|
| 508 |
+
'n_subcarriers': 512
|
| 509 |
+
},
|
| 510 |
+
'city_5_philadelphia': {
|
| 511 |
+
'n_rows': 239,
|
| 512 |
+
'n_per_row': 164,
|
| 513 |
+
'n_ant_bs': 8,
|
| 514 |
+
'n_subcarriers': 1024
|
| 515 |
+
},
|
| 516 |
+
'city_6_miami': {
|
| 517 |
+
'n_rows': 199,
|
| 518 |
+
'n_per_row': 216 ,
|
| 519 |
+
'n_ant_bs': 16,
|
| 520 |
+
'n_subcarriers': 32
|
| 521 |
+
},
|
| 522 |
+
'city_7_sandiego': {
|
| 523 |
+
'n_rows': 207,
|
| 524 |
+
'n_per_row': 176,
|
| 525 |
+
'n_ant_bs': 16,
|
| 526 |
+
'n_subcarriers': 64
|
| 527 |
+
},
|
| 528 |
+
'city_8_dallas': {
|
| 529 |
+
'n_rows': 207,
|
| 530 |
+
'n_per_row': 190,
|
| 531 |
+
'n_ant_bs': 16,
|
| 532 |
+
'n_subcarriers': 128
|
| 533 |
+
},
|
| 534 |
+
'city_9_sanfrancisco': {
|
| 535 |
+
'n_rows': 196,
|
| 536 |
+
'n_per_row': 206,
|
| 537 |
+
'n_ant_bs': 16,
|
| 538 |
+
'n_subcarriers': 256
|
| 539 |
+
},
|
| 540 |
+
'city_10_austin': {
|
| 541 |
+
'n_rows': 255,
|
| 542 |
+
'n_per_row': 137,
|
| 543 |
+
'n_ant_bs': 16,
|
| 544 |
+
'n_subcarriers': 512
|
| 545 |
+
},
|
| 546 |
+
'city_11_santaclara': {
|
| 547 |
+
'n_rows': 117,
|
| 548 |
+
'n_per_row': 285,
|
| 549 |
+
'n_ant_bs': 32,
|
| 550 |
+
'n_subcarriers': 32
|
| 551 |
+
},
|
| 552 |
+
'city_12_fortworth': {
|
| 553 |
+
'n_rows': 214,
|
| 554 |
+
'n_per_row': 179,
|
| 555 |
+
'n_ant_bs': 32,
|
| 556 |
+
'n_subcarriers': 64
|
| 557 |
+
},
|
| 558 |
+
'city_13_columbus': {
|
| 559 |
+
'n_rows': 178,
|
| 560 |
+
'n_per_row': 240,
|
| 561 |
+
'n_ant_bs': 32,
|
| 562 |
+
'n_subcarriers': 128
|
| 563 |
+
},
|
| 564 |
+
'city_14_charlotte': {
|
| 565 |
+
'n_rows': 216,
|
| 566 |
+
'n_per_row': 177,
|
| 567 |
+
'n_ant_bs': 32,
|
| 568 |
+
'n_subcarriers': 256
|
| 569 |
+
},
|
| 570 |
+
'city_15_indianapolis': {
|
| 571 |
+
'n_rows': 200,
|
| 572 |
+
'n_per_row': 196,
|
| 573 |
+
'n_ant_bs': 64,
|
| 574 |
+
'n_subcarriers': 32
|
| 575 |
+
},
|
| 576 |
+
'city_16_sanfrancisco': {
|
| 577 |
+
'n_rows': 201,
|
| 578 |
+
'n_per_row': 208,
|
| 579 |
+
'n_ant_bs': 64,
|
| 580 |
+
'n_subcarriers': 64
|
| 581 |
+
},
|
| 582 |
+
'city_17_seattle': {
|
| 583 |
+
'n_rows': 185,
|
| 584 |
+
'n_per_row': 205,
|
| 585 |
+
'n_ant_bs': 64,
|
| 586 |
+
'n_subcarriers': 128
|
| 587 |
+
},
|
| 588 |
+
'city_18_denver': {
|
| 589 |
+
'n_rows': 212,
|
| 590 |
+
'n_per_row': 204,
|
| 591 |
+
'n_ant_bs': 128,
|
| 592 |
+
'n_subcarriers': 32
|
| 593 |
+
},
|
| 594 |
+
'city_19_oklahoma': {
|
| 595 |
+
'n_rows': 204,
|
| 596 |
+
'n_per_row': 188,
|
| 597 |
+
'n_ant_bs': 128,
|
| 598 |
+
'n_subcarriers': 64
|
| 599 |
+
},
|
| 600 |
+
'asu_campus1_v1': {
|
| 601 |
+
'n_rows': [0, 1*int(321/20)],
|
| 602 |
+
'n_per_row': 411,
|
| 603 |
+
'n_ant_bs': 8,
|
| 604 |
+
'n_subcarriers': 32
|
| 605 |
+
},
|
| 606 |
+
'asu_campus1_v2': {
|
| 607 |
+
'n_rows': [1*int(321/20), 2*int(321/20)],
|
| 608 |
+
'n_per_row': 411,
|
| 609 |
+
'n_ant_bs': 8,
|
| 610 |
+
'n_subcarriers': 64
|
| 611 |
+
},
|
| 612 |
+
'asu_campus1_v3': {
|
| 613 |
+
'n_rows': [2*int(321/20), 3*int(321/20)],
|
| 614 |
+
'n_per_row': 411,
|
| 615 |
+
'n_ant_bs': 8,
|
| 616 |
+
'n_subcarriers': 128
|
| 617 |
+
},
|
| 618 |
+
'asu_campus1_v4': {
|
| 619 |
+
'n_rows': [3*int(321/20), 4*int(321/20)],
|
| 620 |
+
'n_per_row': 411,
|
| 621 |
+
'n_ant_bs': 8,
|
| 622 |
+
'n_subcarriers': 256
|
| 623 |
+
},
|
| 624 |
+
'asu_campus1_v5': {
|
| 625 |
+
'n_rows': [4*int(321/20), 5*int(321/20)],
|
| 626 |
+
'n_per_row': 411,
|
| 627 |
+
'n_ant_bs': 8,
|
| 628 |
+
'n_subcarriers': 512
|
| 629 |
+
},
|
| 630 |
+
'asu_campus1_v6': {
|
| 631 |
+
'n_rows': [5*int(321/20), 6*int(321/20)],
|
| 632 |
+
'n_per_row': 411,
|
| 633 |
+
'n_ant_bs': 8,
|
| 634 |
+
'n_subcarriers': 1024
|
| 635 |
+
},
|
| 636 |
+
'asu_campus1_v7': {
|
| 637 |
+
'n_rows': [6*int(321/20), 7*int(321/20)],
|
| 638 |
+
'n_per_row': 411,
|
| 639 |
+
'n_ant_bs': 16,
|
| 640 |
+
'n_subcarriers': 32
|
| 641 |
+
},
|
| 642 |
+
'asu_campus1_v8': {
|
| 643 |
+
'n_rows': [7*int(321/20), 8*int(321/20)],
|
| 644 |
+
'n_per_row': 411,
|
| 645 |
+
'n_ant_bs':16,
|
| 646 |
+
'n_subcarriers': 64
|
| 647 |
+
},
|
| 648 |
+
'asu_campus1_v9': {
|
| 649 |
+
'n_rows': [8*int(321/20), 9*int(321/20)],
|
| 650 |
+
'n_per_row': 411,
|
| 651 |
+
'n_ant_bs': 16,
|
| 652 |
+
'n_subcarriers': 128
|
| 653 |
+
},
|
| 654 |
+
'asu_campus1_v10': {
|
| 655 |
+
'n_rows': [9*int(321/20), 10*int(321/20)],
|
| 656 |
+
'n_per_row': 411,
|
| 657 |
+
'n_ant_bs': 16,
|
| 658 |
+
'n_subcarriers': 256
|
| 659 |
+
},
|
| 660 |
+
'asu_campus1_v11': {
|
| 661 |
+
'n_rows': [10*int(321/20), 11*int(321/20)],
|
| 662 |
+
'n_per_row': 411,
|
| 663 |
+
'n_ant_bs': 16,
|
| 664 |
+
'n_subcarriers': 512
|
| 665 |
+
},
|
| 666 |
+
'asu_campus1_v12': {
|
| 667 |
+
'n_rows': [11*int(321/20), 12*int(321/20)],
|
| 668 |
+
'n_per_row': 411,
|
| 669 |
+
'n_ant_bs': 32,
|
| 670 |
+
'n_subcarriers': 32
|
| 671 |
+
},
|
| 672 |
+
'asu_campus1_v13': {
|
| 673 |
+
'n_rows': [12*int(321/20), 13*int(321/20)],
|
| 674 |
+
'n_per_row': 411,
|
| 675 |
+
'n_ant_bs': 32,
|
| 676 |
+
'n_subcarriers': 64
|
| 677 |
+
},
|
| 678 |
+
'asu_campus1_v14': {
|
| 679 |
+
'n_rows': [13*int(321/20), 14*int(321/20)],
|
| 680 |
+
'n_per_row': 411,
|
| 681 |
+
'n_ant_bs': 32,
|
| 682 |
+
'n_subcarriers': 128
|
| 683 |
+
},
|
| 684 |
+
'asu_campus1_v15': {
|
| 685 |
+
'n_rows': [14*int(321/20), 15*int(321/20)],
|
| 686 |
+
'n_per_row': 411,
|
| 687 |
+
'n_ant_bs': 32,
|
| 688 |
+
'n_subcarriers': 256
|
| 689 |
+
},
|
| 690 |
+
'asu_campus1_v16': {
|
| 691 |
+
'n_rows': [15*int(321/20), 16*int(321/20)],
|
| 692 |
+
'n_per_row': 411,
|
| 693 |
+
'n_ant_bs': 64,
|
| 694 |
+
'n_subcarriers': 32
|
| 695 |
+
},
|
| 696 |
+
'asu_campus1_v17': {
|
| 697 |
+
'n_rows': [16*int(321/20), 17*int(321/20)],
|
| 698 |
+
'n_per_row': 411,
|
| 699 |
+
'n_ant_bs': 64,
|
| 700 |
+
'n_subcarriers': 64
|
| 701 |
+
},
|
| 702 |
+
'asu_campus1_v18': {
|
| 703 |
+
'n_rows': [17*int(321/20), 18*int(321/20)],
|
| 704 |
+
'n_per_row': 411,
|
| 705 |
+
'n_ant_bs': 64,
|
| 706 |
+
'n_subcarriers': 128
|
| 707 |
+
},
|
| 708 |
+
'asu_campus1_v19': {
|
| 709 |
+
'n_rows': [18*int(321/20), 19*int(321/20)],
|
| 710 |
+
'n_per_row': 411,
|
| 711 |
+
'n_ant_bs': 128,
|
| 712 |
+
'n_subcarriers': 32
|
| 713 |
+
},
|
| 714 |
+
'asu_campus1_v20': {
|
| 715 |
+
'n_rows': [19*int(321/20), 20*int(321/20)],
|
| 716 |
+
'n_per_row': 411,
|
| 717 |
+
'n_ant_bs': 128,
|
| 718 |
+
'n_subcarriers': 64
|
| 719 |
+
},
|
| 720 |
+
'Boston5G_3p5_v1': {
|
| 721 |
+
'n_rows': [812, 812 + 1*int((1622-812)/20)],
|
| 722 |
+
'n_per_row': 595,
|
| 723 |
+
'n_ant_bs': 8,
|
| 724 |
+
'n_subcarriers': 32
|
| 725 |
+
},
|
| 726 |
+
'Boston5G_3p5_v2': {
|
| 727 |
+
'n_rows': [812 + 1*int((1622-812)/20), 812 + 2*int((1622-812)/20)],
|
| 728 |
+
'n_per_row': 595,
|
| 729 |
+
'n_ant_bs': 8,
|
| 730 |
+
'n_subcarriers': 64
|
| 731 |
+
},
|
| 732 |
+
'Boston5G_3p5_v3': {
|
| 733 |
+
'n_rows': [812 + 2*int((1622-812)/20), 812 + 3*int((1622-812)/20)],
|
| 734 |
+
'n_per_row': 595,
|
| 735 |
+
'n_ant_bs': 8,
|
| 736 |
+
'n_subcarriers': 128
|
| 737 |
+
},
|
| 738 |
+
'Boston5G_3p5_v4': {
|
| 739 |
+
'n_rows': [812 + 3*int((1622-812)/20), 812 + 4*int((1622-812)/20)],
|
| 740 |
+
'n_per_row': 595,
|
| 741 |
+
'n_ant_bs': 8,
|
| 742 |
+
'n_subcarriers': 256
|
| 743 |
+
},
|
| 744 |
+
'Boston5G_3p5_v5': {
|
| 745 |
+
'n_rows': [812 + 4*int((1622-812)/20), 812 + 5*int((1622-812)/20)],
|
| 746 |
+
'n_per_row': 595,
|
| 747 |
+
'n_ant_bs': 8,
|
| 748 |
+
'n_subcarriers': 512
|
| 749 |
+
},
|
| 750 |
+
'Boston5G_3p5_v6': {
|
| 751 |
+
'n_rows': [812 + 5*int((1622-812)/20), 812 + 6*int((1622-812)/20)],
|
| 752 |
+
'n_per_row': 595,
|
| 753 |
+
'n_ant_bs': 8,
|
| 754 |
+
'n_subcarriers': 1024
|
| 755 |
+
},
|
| 756 |
+
'Boston5G_3p5_v7': {
|
| 757 |
+
'n_rows': [812 + 6*int((1622-812)/20), 812 + 7*int((1622-812)/20)],
|
| 758 |
+
'n_per_row': 595,
|
| 759 |
+
'n_ant_bs': 16,
|
| 760 |
+
'n_subcarriers': 32
|
| 761 |
+
},
|
| 762 |
+
'Boston5G_3p5_v8': {
|
| 763 |
+
'n_rows': [812 + 7*int((1622-812)/20), 812 + 8*int((1622-812)/20)],
|
| 764 |
+
'n_per_row': 595,
|
| 765 |
+
'n_ant_bs':16,
|
| 766 |
+
'n_subcarriers': 64
|
| 767 |
+
},
|
| 768 |
+
'Boston5G_3p5_v9': {
|
| 769 |
+
'n_rows': [812 + 8*int((1622-812)/20), 812 + 9*int((1622-812)/20)],
|
| 770 |
+
'n_per_row': 595,
|
| 771 |
+
'n_ant_bs': 16,
|
| 772 |
+
'n_subcarriers': 128
|
| 773 |
+
},
|
| 774 |
+
'Boston5G_3p5_v10': {
|
| 775 |
+
'n_rows': [812 + 9*int((1622-812)/20), 812 + 10*int((1622-812)/20)],
|
| 776 |
+
'n_per_row': 595,
|
| 777 |
+
'n_ant_bs': 16,
|
| 778 |
+
'n_subcarriers': 256
|
| 779 |
+
},
|
| 780 |
+
'Boston5G_3p5_v11': {
|
| 781 |
+
'n_rows': [812 + 10*int((1622-812)/20), 812 + 11*int((1622-812)/20)],
|
| 782 |
+
'n_per_row': 595,
|
| 783 |
+
'n_ant_bs': 16,
|
| 784 |
+
'n_subcarriers': 512
|
| 785 |
+
},
|
| 786 |
+
'Boston5G_3p5_v12': {
|
| 787 |
+
'n_rows': [812 + 11*int((1622-812)/20), 812 + 12*int((1622-812)/20)],
|
| 788 |
+
'n_per_row': 595,
|
| 789 |
+
'n_ant_bs': 32,
|
| 790 |
+
'n_subcarriers': 32
|
| 791 |
+
},
|
| 792 |
+
'Boston5G_3p5_v13': {
|
| 793 |
+
'n_rows': [812 + 12*int((1622-812)/20), 812 + 13*int((1622-812)/20)],
|
| 794 |
+
'n_per_row': 595,
|
| 795 |
+
'n_ant_bs': 32,
|
| 796 |
+
'n_subcarriers': 64
|
| 797 |
+
},
|
| 798 |
+
'Boston5G_3p5_v14': {
|
| 799 |
+
'n_rows': [812 + 13*int((1622-812)/20), 812 + 14*int((1622-812)/20)],
|
| 800 |
+
'n_per_row': 595,
|
| 801 |
+
'n_ant_bs': 32,
|
| 802 |
+
'n_subcarriers': 128
|
| 803 |
+
},
|
| 804 |
+
'Boston5G_3p5_v15': {
|
| 805 |
+
'n_rows': [812 + 14*int((1622-812)/20), 812 + 15*int((1622-812)/20)],
|
| 806 |
+
'n_per_row': 595,
|
| 807 |
+
'n_ant_bs': 32,
|
| 808 |
+
'n_subcarriers': 256
|
| 809 |
+
},
|
| 810 |
+
'Boston5G_3p5_v16': {
|
| 811 |
+
'n_rows': [812 + 15*int((1622-812)/20), 812 + 16*int((1622-812)/20)],
|
| 812 |
+
'n_per_row': 595,
|
| 813 |
+
'n_ant_bs': 64,
|
| 814 |
+
'n_subcarriers': 32
|
| 815 |
+
},
|
| 816 |
+
'Boston5G_3p5_v17': {
|
| 817 |
+
'n_rows': [812 + 16*int((1622-812)/20), 812 + 17*int((1622-812)/20)],
|
| 818 |
+
'n_per_row': 595,
|
| 819 |
+
'n_ant_bs': 64,
|
| 820 |
+
'n_subcarriers': 64
|
| 821 |
+
},
|
| 822 |
+
'Boston5G_3p5_v18': {
|
| 823 |
+
'n_rows': [812 + 17*int((1622-812)/20), 812 + 18*int((1622-812)/20)],
|
| 824 |
+
'n_per_row': 595,
|
| 825 |
+
'n_ant_bs': 64,
|
| 826 |
+
'n_subcarriers': 128
|
| 827 |
+
},
|
| 828 |
+
'Boston5G_3p5_v19': {
|
| 829 |
+
'n_rows': [812 + 18*int((1622-812)/20), 812 + 19*int((1622-812)/20)],
|
| 830 |
+
'n_per_row': 595,
|
| 831 |
+
'n_ant_bs': 128,
|
| 832 |
+
'n_subcarriers': 32
|
| 833 |
+
},
|
| 834 |
+
'Boston5G_3p5_v20': {
|
| 835 |
+
'n_rows': [812 + 19*int((1622-812)/20), 812 + 20*int((1622-812)/20)],
|
| 836 |
+
'n_per_row': 595,
|
| 837 |
+
'n_ant_bs': 128,
|
| 838 |
+
'n_subcarriers': 64
|
| 839 |
+
},
|
| 840 |
+
'O1_3p5_v1': {
|
| 841 |
+
'n_rows': [0*int(3852/12), 1*int(3852/12)],
|
| 842 |
+
'n_per_row': 181,
|
| 843 |
+
'n_ant_bs': 8,
|
| 844 |
+
'n_subcarriers': 32
|
| 845 |
+
},
|
| 846 |
+
'O1_3p5_v2': {
|
| 847 |
+
'n_rows': [1*int(3852/12), 2*int(3852/12)],
|
| 848 |
+
'n_per_row': 181,
|
| 849 |
+
'n_ant_bs': 8,
|
| 850 |
+
'n_subcarriers': 64
|
| 851 |
+
},
|
| 852 |
+
'O1_3p5_v3': {
|
| 853 |
+
'n_rows': [2*int(3852/12), 3*int(3852/12)],
|
| 854 |
+
'n_per_row': 181,
|
| 855 |
+
'n_ant_bs': 8,
|
| 856 |
+
'n_subcarriers': 128
|
| 857 |
+
},
|
| 858 |
+
'O1_3p5_v4': {
|
| 859 |
+
'n_rows': [3*int(3852/12), 4*int(3852/12)],
|
| 860 |
+
'n_per_row': 181,
|
| 861 |
+
'n_ant_bs': 8,
|
| 862 |
+
'n_subcarriers': 256
|
| 863 |
+
},
|
| 864 |
+
'O1_3p5_v5': {
|
| 865 |
+
'n_rows': [4*int(3852/12), 5*int(3852/12)],
|
| 866 |
+
'n_per_row': 181,
|
| 867 |
+
'n_ant_bs': 8,
|
| 868 |
+
'n_subcarriers': 512
|
| 869 |
+
},
|
| 870 |
+
'O1_3p5_v6': {
|
| 871 |
+
'n_rows': [5*int(3852/12), 6*int(3852/12)],
|
| 872 |
+
'n_per_row': 181,
|
| 873 |
+
'n_ant_bs': 8,
|
| 874 |
+
'n_subcarriers': 1024
|
| 875 |
+
},
|
| 876 |
+
'O1_3p5_v7': {
|
| 877 |
+
'n_rows': [6*int(3852/12), 7*int(3852/12)],
|
| 878 |
+
'n_per_row': 181,
|
| 879 |
+
'n_ant_bs': 16,
|
| 880 |
+
'n_subcarriers': 32
|
| 881 |
+
},
|
| 882 |
+
'O1_3p5_v8': {
|
| 883 |
+
'n_rows': [7*int(3852/12), 8*int(3852/12)],
|
| 884 |
+
'n_per_row': 181,
|
| 885 |
+
'n_ant_bs': 16,
|
| 886 |
+
'n_subcarriers': 64
|
| 887 |
+
},
|
| 888 |
+
'O1_3p5_v9': {
|
| 889 |
+
'n_rows': [8*int(3852/12), 9*int(3852/12)],
|
| 890 |
+
'n_per_row': 181,
|
| 891 |
+
'n_ant_bs': 16,
|
| 892 |
+
'n_subcarriers': 128
|
| 893 |
+
},
|
| 894 |
+
'O1_3p5_v10': {
|
| 895 |
+
'n_rows': [9*int(3852/12), 10*int(3852/12)],
|
| 896 |
+
'n_per_row': 181,
|
| 897 |
+
'n_ant_bs': 16,
|
| 898 |
+
'n_subcarriers': 256
|
| 899 |
+
},
|
| 900 |
+
'O1_3p5_v11': {
|
| 901 |
+
'n_rows': [10*int(3852/12), 11*int(3852/12)],
|
| 902 |
+
'n_per_row': 181,
|
| 903 |
+
'n_ant_bs': 16,
|
| 904 |
+
'n_subcarriers': 512
|
| 905 |
+
},
|
| 906 |
+
'O1_3p5_v12': {
|
| 907 |
+
'n_rows': [11*int(3852/12), 12*int(3852/12)],
|
| 908 |
+
'n_per_row': 181,
|
| 909 |
+
'n_ant_bs': 32,
|
| 910 |
+
'n_subcarriers': 32
|
| 911 |
+
},
|
| 912 |
+
'O1_3p5_v13': {
|
| 913 |
+
'n_rows': [12*int(3852/12)+0*int(1351/10), 12*int(3852/12)+1*int(1351/10)],
|
| 914 |
+
'n_per_row': 361,
|
| 915 |
+
'n_ant_bs': 32,
|
| 916 |
+
'n_subcarriers': 64
|
| 917 |
+
},
|
| 918 |
+
'O1_3p5_v14': {
|
| 919 |
+
'n_rows': [12*int(3852/12)+1*int(1351/10), 12*int(3852/12)+2*int(1351/10)],
|
| 920 |
+
'n_per_row': 181,
|
| 921 |
+
'n_ant_bs': 32,
|
| 922 |
+
'n_subcarriers': 128
|
| 923 |
+
},
|
| 924 |
+
'O1_3p5_v15': {
|
| 925 |
+
'n_rows': [12*int(3852/12)+2*int(1351/10), 12*int(3852/12)+3*int(1351/10)],
|
| 926 |
+
'n_per_row': 181,
|
| 927 |
+
'n_ant_bs': 32,
|
| 928 |
+
'n_subcarriers': 256
|
| 929 |
+
},
|
| 930 |
+
'O1_3p5_v16': {
|
| 931 |
+
'n_rows': [12*int(3852/12)+3*int(1351/10), 12*int(3852/12)+4*int(1351/10)],
|
| 932 |
+
'n_per_row': 181,
|
| 933 |
+
'n_ant_bs': 64,
|
| 934 |
+
'n_subcarriers': 32
|
| 935 |
+
},
|
| 936 |
+
'O1_3p5_v17': {
|
| 937 |
+
'n_rows': [12*int(3852/12)+4*int(1351/10), 12*int(3852/12)+5*int(1351/10)],
|
| 938 |
+
'n_per_row': 181,
|
| 939 |
+
'n_ant_bs': 64,
|
| 940 |
+
'n_subcarriers': 64
|
| 941 |
+
},
|
| 942 |
+
'O1_3p5_v18': {
|
| 943 |
+
'n_rows': [12*int(3852/12)+5*int(1351/10), 12*int(3852/12)+6*int(1351/10)],
|
| 944 |
+
'n_per_row': 181,
|
| 945 |
+
'n_ant_bs': 64,
|
| 946 |
+
'n_subcarriers': 128
|
| 947 |
+
},
|
| 948 |
+
'O1_3p5_v19': {
|
| 949 |
+
'n_rows': [12*int(3852/12)+6*int(1351/10), 12*int(3852/12)+7*int(1351/10)],
|
| 950 |
+
'n_per_row': 181,
|
| 951 |
+
'n_ant_bs': 128,
|
| 952 |
+
'n_subcarriers': 32
|
| 953 |
+
},
|
| 954 |
+
'O1_3p5_v20': {
|
| 955 |
+
'n_rows': [12*int(3852/12)+7*int(1351/10), 12*int(3852/12)+8*int(1351/10)],
|
| 956 |
+
'n_per_row': 181,
|
| 957 |
+
'n_ant_bs': 128,
|
| 958 |
+
'n_subcarriers': 64
|
| 959 |
+
},
|
| 960 |
+
'city_0_newyork_v16x64': {
|
| 961 |
+
'n_rows': 109,
|
| 962 |
+
'n_per_row': 291,
|
| 963 |
+
'n_ant_bs': 16,
|
| 964 |
+
'n_subcarriers': 64
|
| 965 |
+
},
|
| 966 |
+
'city_1_losangeles_v16x64': {
|
| 967 |
+
'n_rows': 142,
|
| 968 |
+
'n_per_row': 201,
|
| 969 |
+
'n_ant_bs': 16,
|
| 970 |
+
'n_subcarriers': 64
|
| 971 |
+
},
|
| 972 |
+
'city_2_chicago_v16x64': {
|
| 973 |
+
'n_rows': 139,
|
| 974 |
+
'n_per_row': 200,
|
| 975 |
+
'n_ant_bs': 16,
|
| 976 |
+
'n_subcarriers': 64
|
| 977 |
+
},
|
| 978 |
+
'city_3_houston_v16x64': {
|
| 979 |
+
'n_rows': 154,
|
| 980 |
+
'n_per_row': 202,
|
| 981 |
+
'n_ant_bs': 16,
|
| 982 |
+
'n_subcarriers': 64
|
| 983 |
+
},
|
| 984 |
+
'city_4_phoenix_v16x64': {
|
| 985 |
+
'n_rows': 198,
|
| 986 |
+
'n_per_row': 214,
|
| 987 |
+
'n_ant_bs': 16,
|
| 988 |
+
'n_subcarriers': 64
|
| 989 |
+
},
|
| 990 |
+
'city_5_philadelphia_v16x64': {
|
| 991 |
+
'n_rows': 239,
|
| 992 |
+
'n_per_row': 164,
|
| 993 |
+
'n_ant_bs': 16,
|
| 994 |
+
'n_subcarriers': 64
|
| 995 |
+
},
|
| 996 |
+
'city_6_miami_v16x64': {
|
| 997 |
+
'n_rows': 199,
|
| 998 |
+
'n_per_row': 216,
|
| 999 |
+
'n_ant_bs': 16,
|
| 1000 |
+
'n_subcarriers': 64
|
| 1001 |
+
},
|
| 1002 |
+
'city_7_sandiego_v16x64': {
|
| 1003 |
+
'n_rows': 207,
|
| 1004 |
+
'n_per_row': 176,
|
| 1005 |
+
'n_ant_bs': 16,
|
| 1006 |
+
'n_subcarriers': 64
|
| 1007 |
+
},
|
| 1008 |
+
'city_8_dallas_v16x64': {
|
| 1009 |
+
'n_rows': 207,
|
| 1010 |
+
'n_per_row': 190,
|
| 1011 |
+
'n_ant_bs': 16,
|
| 1012 |
+
'n_subcarriers': 64
|
| 1013 |
+
},
|
| 1014 |
+
'city_9_sanfrancisco_v16x64': {
|
| 1015 |
+
'n_rows': 196,
|
| 1016 |
+
'n_per_row': 206,
|
| 1017 |
+
'n_ant_bs': 16,
|
| 1018 |
+
'n_subcarriers': 64
|
| 1019 |
+
}}
|
| 1020 |
+
return row_column_users
|
lwm_model.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Fri Sep 13 19:23:54 2024
|
| 4 |
+
|
| 5 |
+
This script defines the LWM model architecture.
|
| 6 |
+
|
| 7 |
+
@author: Sadjad Alikhani
|
| 8 |
+
"""
|
| 9 |
+
#%%
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import numpy as np
|
| 14 |
+
#%%
|
| 15 |
+
class LayerNormalization(nn.Module):
|
| 16 |
+
def __init__(self, d_model: int, eps: float = 1e-6) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.eps = eps
|
| 19 |
+
self.alpha = nn.Parameter(torch.ones(d_model))
|
| 20 |
+
self.bias = nn.Parameter(torch.zeros(d_model))
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
mean = x.mean(dim=-1, keepdim=True)
|
| 24 |
+
std = x.std(dim=-1, keepdim=True)
|
| 25 |
+
return self.alpha * (x - mean) / (std + self.eps) + self.bias
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Embedding(nn.Module):
|
| 29 |
+
def __init__(self, element_length, d_model, max_len=513):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.element_length = element_length
|
| 32 |
+
self.d_model = d_model
|
| 33 |
+
self.proj = nn.Linear(element_length, d_model)
|
| 34 |
+
self.pos_embed = nn.Embedding(max_len, d_model)
|
| 35 |
+
self.norm = LayerNormalization(d_model)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
seq_len = x.size(1)
|
| 39 |
+
pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
|
| 40 |
+
pos_encodings = self.pos_embed(pos)
|
| 41 |
+
tok_emb = self.proj(x.float())
|
| 42 |
+
embedding = tok_emb + pos_encodings
|
| 43 |
+
return self.norm(embedding)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ScaledDotProductAttention(nn.Module):
|
| 47 |
+
def __init__(self, d_k):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.d_k = d_k
|
| 50 |
+
|
| 51 |
+
def forward(self, Q, K, V):
|
| 52 |
+
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
|
| 53 |
+
attn = F.softmax(scores, dim=-1)
|
| 54 |
+
context = torch.matmul(attn, V)
|
| 55 |
+
return context, attn
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class MultiHeadAttention(nn.Module):
|
| 59 |
+
def __init__(self, d_model, n_heads, dropout):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.d_k = d_model // n_heads
|
| 62 |
+
self.d_v = d_model // n_heads
|
| 63 |
+
self.n_heads = n_heads
|
| 64 |
+
self.W_Q = nn.Linear(d_model, self.d_k * n_heads)
|
| 65 |
+
self.W_K = nn.Linear(d_model, self.d_k * n_heads)
|
| 66 |
+
self.W_V = nn.Linear(d_model, self.d_v * n_heads)
|
| 67 |
+
self.linear = nn.Linear(n_heads * self.d_v, d_model)
|
| 68 |
+
self.dropout = nn.Dropout(dropout)
|
| 69 |
+
self.scaled_dot_attn = ScaledDotProductAttention(self.d_k)
|
| 70 |
+
|
| 71 |
+
def forward(self, Q, K, V):
|
| 72 |
+
residual, batch_size = Q, Q.size(0)
|
| 73 |
+
q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
|
| 74 |
+
k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
|
| 75 |
+
v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
|
| 76 |
+
|
| 77 |
+
context, attn = self.scaled_dot_attn(q_s, k_s, v_s)
|
| 78 |
+
output = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
|
| 79 |
+
output = self.linear(output)
|
| 80 |
+
return residual + self.dropout(output), attn
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class PoswiseFeedForwardNet(nn.Module):
|
| 84 |
+
def __init__(self, d_model, d_ff, dropout):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.fc1 = nn.Linear(d_model, d_ff)
|
| 87 |
+
self.fc2 = nn.Linear(d_ff, d_model)
|
| 88 |
+
self.dropout = nn.Dropout(dropout)
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
return self.fc2(self.dropout(F.relu(self.fc1(x))))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class EncoderLayer(nn.Module):
|
| 95 |
+
def __init__(self, d_model, n_heads, d_ff, dropout):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.enc_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 98 |
+
self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff, dropout)
|
| 99 |
+
self.norm1 = LayerNormalization(d_model)
|
| 100 |
+
self.norm2 = LayerNormalization(d_model)
|
| 101 |
+
|
| 102 |
+
def forward(self, enc_inputs):
|
| 103 |
+
# Self-Attention with Add & Norm
|
| 104 |
+
attn_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs)
|
| 105 |
+
attn_outputs = self.norm1(enc_inputs + attn_outputs) # Add & Norm
|
| 106 |
+
|
| 107 |
+
# Feed-Forward with Add & Norm
|
| 108 |
+
ff_outputs = self.pos_ffn(attn_outputs)
|
| 109 |
+
enc_outputs = self.norm2(attn_outputs + ff_outputs) # Add & Norm
|
| 110 |
+
|
| 111 |
+
return enc_outputs, attn
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class lwm(nn.Module):
|
| 115 |
+
def __init__(self, element_length=32, d_model=128, n_layers=12, max_len=513, n_heads=8, dropout=0.1):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.embedding = Embedding(element_length, d_model, max_len)
|
| 118 |
+
self.layers = nn.ModuleList(
|
| 119 |
+
[EncoderLayer(d_model, n_heads, d_model*4, dropout) for _ in range(n_layers)]
|
| 120 |
+
)
|
| 121 |
+
self.linear = nn.Linear(d_model, d_model)
|
| 122 |
+
self.norm = LayerNormalization(d_model)
|
| 123 |
+
|
| 124 |
+
embed_weight = self.embedding.proj.weight
|
| 125 |
+
_, n_dim = embed_weight.size()
|
| 126 |
+
self.decoder = nn.Linear(d_model, n_dim, bias=False)
|
| 127 |
+
self.decoder_bias = nn.Parameter(torch.zeros(n_dim))
|
| 128 |
+
|
| 129 |
+
@classmethod
|
| 130 |
+
def from_pretrained(cls, ckpt_name='model_weights.pth', device='cuda'):
|
| 131 |
+
model = cls().to(device)
|
| 132 |
+
model.load_state_dict(torch.load(ckpt_name, map_location=device))
|
| 133 |
+
print(f"Model loaded successfully from {ckpt_name}")
|
| 134 |
+
return model
|
| 135 |
+
|
| 136 |
+
def forward(self, input_ids, masked_pos=None):
|
| 137 |
+
# Step 1: Embedding
|
| 138 |
+
output = self.embedding(input_ids)
|
| 139 |
+
attention_maps = []
|
| 140 |
+
|
| 141 |
+
# Step 2: Pass through Encoder Layers
|
| 142 |
+
for layer in self.layers:
|
| 143 |
+
output, attn = layer(output)
|
| 144 |
+
attention_maps.append(attn)
|
| 145 |
+
|
| 146 |
+
# If masked_pos is provided, perform masked token prediction
|
| 147 |
+
if masked_pos is not None:
|
| 148 |
+
masked_pos = masked_pos.long()[:, :, None].expand(-1, -1, output.size(-1))
|
| 149 |
+
h_masked = torch.gather(output, 1, masked_pos)
|
| 150 |
+
h_masked = self.norm(F.relu(self.linear(h_masked)))
|
| 151 |
+
logits_lm = self.decoder(h_masked) + self.decoder_bias
|
| 152 |
+
return logits_lm, output, attention_maps
|
| 153 |
+
else:
|
| 154 |
+
return output, attention_maps
|
main.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Sat Dec 21 13:24:21 2024
|
| 4 |
+
|
| 5 |
+
This script pre-trains the LWM model
|
| 6 |
+
|
| 7 |
+
@author: salikha4
|
| 8 |
+
"""
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.utils.data import random_split
|
| 12 |
+
from input_preprocess import tokenizer, scenarios_list
|
| 13 |
+
from utils import create_dataloader, count_parameters
|
| 14 |
+
import numpy as np
|
| 15 |
+
import lwm_model
|
| 16 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 17 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 18 |
+
from torch.optim import AdamW
|
| 19 |
+
from train import train_lwm
|
| 20 |
+
import warnings
|
| 21 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 22 |
+
#%% SETTINGS
|
| 23 |
+
EPOCHS = 50
|
| 24 |
+
BATCH_SIZE = 128
|
| 25 |
+
VAL_BATCH_SIZE = 64
|
| 26 |
+
WARMUP_EPOCHS = 5
|
| 27 |
+
BASE_LR = 5e-4
|
| 28 |
+
N_ROWS = 4
|
| 29 |
+
N_COLUMNS = 4
|
| 30 |
+
ELEMENT_LENGTH = N_ROWS*N_COLUMNS*2
|
| 31 |
+
D_MODEL = 128
|
| 32 |
+
MAX_LEN = 513
|
| 33 |
+
N_LAYERS = 12
|
| 34 |
+
WEIGHT_DECAY = 0.05
|
| 35 |
+
BETA1 = 0.9
|
| 36 |
+
BETA2 = 0.999
|
| 37 |
+
MASK_PERCENT = 0.40
|
| 38 |
+
N_HEADS = 8
|
| 39 |
+
DROPOUT = 0.1
|
| 40 |
+
#%% GENERATE DATASET
|
| 41 |
+
bs_idxs = [1, 2, 3]
|
| 42 |
+
selected_scenario_names = scenarios_list()[:80]
|
| 43 |
+
preprocessed_data = tokenizer(
|
| 44 |
+
selected_scenario_names,
|
| 45 |
+
MAX_LEN,
|
| 46 |
+
masking_percent=MASK_PERCENT,
|
| 47 |
+
mask=True,
|
| 48 |
+
seed=42
|
| 49 |
+
)
|
| 50 |
+
#%% SPLIT DATASET
|
| 51 |
+
SEED = 42
|
| 52 |
+
torch.manual_seed(SEED)
|
| 53 |
+
np.random.seed(SEED)
|
| 54 |
+
train_ratio = 0.8
|
| 55 |
+
val_ratio = 0.2
|
| 56 |
+
train_data = {}
|
| 57 |
+
val_data = {}
|
| 58 |
+
test_data = {}
|
| 59 |
+
for key, samples in preprocessed_data.items():
|
| 60 |
+
print(f"key: {key}")
|
| 61 |
+
total_samples = len(samples)
|
| 62 |
+
train_size = int(train_ratio * total_samples)
|
| 63 |
+
val_size = int(val_ratio * total_samples)
|
| 64 |
+
test_size = total_samples - val_size - train_size
|
| 65 |
+
|
| 66 |
+
train_data[key], val_data[key], test_data[key] = random_split(
|
| 67 |
+
samples, [train_size, val_size, test_size]
|
| 68 |
+
)
|
| 69 |
+
train_loaders = create_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True)
|
| 70 |
+
val_loaders = create_dataloader(val_data, batch_size=VAL_BATCH_SIZE, shuffle=False)
|
| 71 |
+
#%% INITIALIZE MODEL
|
| 72 |
+
load_model = True
|
| 73 |
+
gpu_ids = [0]
|
| 74 |
+
device = torch.device("cuda:0")
|
| 75 |
+
model = lwm_model.lwm().to(device)
|
| 76 |
+
|
| 77 |
+
if load_model:
|
| 78 |
+
model_name = "lwm_epoch50_train0.0077_val0.0060_masking0.40.pth"
|
| 79 |
+
state_dict = torch.load(f"models/{model_name}", map_location=device)
|
| 80 |
+
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 81 |
+
model.load_state_dict(new_state_dict)
|
| 82 |
+
|
| 83 |
+
model = nn.DataParallel(model, gpu_ids)
|
| 84 |
+
print(f"Model loaded successfully on GPU {device.index}")
|
| 85 |
+
|
| 86 |
+
n_parameters = count_parameters(model)
|
| 87 |
+
print(f"Number of trainable parameters: {n_parameters:,}")
|
| 88 |
+
#%% OPTIMIZER AND SCHEDULER
|
| 89 |
+
BASE_LR = 5e-5
|
| 90 |
+
MIN_LR = 1e-8
|
| 91 |
+
TOTAL_STEPS = sum(len(loader) for loader in train_loaders.values()) * EPOCHS
|
| 92 |
+
WARMUP_STEPS = sum(len(loader) for loader in train_loaders.values()) * WARMUP_EPOCHS
|
| 93 |
+
|
| 94 |
+
optimizer = AdamW(
|
| 95 |
+
model.parameters(),
|
| 96 |
+
lr=BASE_LR,
|
| 97 |
+
betas=(BETA1, BETA2),
|
| 98 |
+
weight_decay=WEIGHT_DECAY
|
| 99 |
+
)
|
| 100 |
+
def lr_lambda(current_step):
|
| 101 |
+
if current_step < WARMUP_STEPS:
|
| 102 |
+
# Linear warmup
|
| 103 |
+
return current_step / WARMUP_STEPS
|
| 104 |
+
else:
|
| 105 |
+
# Scaled cosine decay
|
| 106 |
+
scaled_progress = (current_step - WARMUP_STEPS) / (TOTAL_STEPS - WARMUP_STEPS)
|
| 107 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * scaled_progress))
|
| 108 |
+
return cosine_decay * (BASE_LR - MIN_LR) / BASE_LR + MIN_LR / BASE_LR
|
| 109 |
+
|
| 110 |
+
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
|
| 111 |
+
#%% PRE-TRAIN THE MODEL
|
| 112 |
+
pretrained_model = train_lwm(
|
| 113 |
+
model,
|
| 114 |
+
train_loaders,
|
| 115 |
+
val_loaders,
|
| 116 |
+
optimizer,
|
| 117 |
+
scheduler,
|
| 118 |
+
EPOCHS,
|
| 119 |
+
device=device
|
| 120 |
+
)
|
models/model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:485611f1a0f819f9c673827b8e613887b39672e97072bd7a412866b49d8dd40f
|
| 3 |
+
size 9960738
|
train.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Fri Dec 20 09:32:12 2024
|
| 4 |
+
|
| 5 |
+
This script contains the LWM pre-training and task-specific fine-tuning functions.
|
| 6 |
+
|
| 7 |
+
@author: Sadjad Alikhani
|
| 8 |
+
"""
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import os
|
| 14 |
+
import csv
|
| 15 |
+
from utils import count_parameters
|
| 16 |
+
import time
|
| 17 |
+
#%% LOSS FUNCTION
|
| 18 |
+
def nmse_loss(y_pred, y_true):
|
| 19 |
+
y_pred_flat = y_pred.view(y_pred.size(0), -1)
|
| 20 |
+
y_true_flat = y_true.view(y_true.size(0), -1)
|
| 21 |
+
mse = torch.sum((y_true_flat - y_pred_flat)**2, dim=-1)
|
| 22 |
+
normalization = torch.sum(y_true_flat**2, dim=-1)
|
| 23 |
+
return mse / normalization
|
| 24 |
+
#%%
|
| 25 |
+
def train_lwm(model, train_loaders, val_loaders, optimizer, scheduler, epochs, device, save_dir="models", log_file="training_log.csv"):
|
| 26 |
+
|
| 27 |
+
if not os.path.exists(save_dir):
|
| 28 |
+
os.makedirs(save_dir)
|
| 29 |
+
|
| 30 |
+
# Initialize CSV log
|
| 31 |
+
if not os.path.exists(log_file):
|
| 32 |
+
with open(log_file, mode='w', newline='') as file:
|
| 33 |
+
writer = csv.writer(file)
|
| 34 |
+
writer.writerow(["Epoch", "Train NMSE", "Validation NMSE", "Learning Rate", "Best Model"])
|
| 35 |
+
|
| 36 |
+
train_nmse_losses = []
|
| 37 |
+
val_nmse_losses = []
|
| 38 |
+
best_val_nmse = float('inf')
|
| 39 |
+
|
| 40 |
+
for epoch in range(epochs):
|
| 41 |
+
model.train()
|
| 42 |
+
train_nmse = 0.0
|
| 43 |
+
train_samples = 0
|
| 44 |
+
|
| 45 |
+
# Training loop across all buckets
|
| 46 |
+
print(f"\nEpoch {epoch + 1}/{epochs} [Training]")
|
| 47 |
+
for length, train_loader in train_loaders.items():
|
| 48 |
+
print(f"Processing sequences of length {length}")
|
| 49 |
+
with tqdm(train_loader, desc=f"Length {length} [Training]", unit="batch") as t:
|
| 50 |
+
for batch in t:
|
| 51 |
+
# train_batches += 1
|
| 52 |
+
optimizer.zero_grad()
|
| 53 |
+
|
| 54 |
+
# Move data to device
|
| 55 |
+
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
|
| 56 |
+
|
| 57 |
+
# Forward pass
|
| 58 |
+
logits_lm, _, _ = model(input_ids, masked_pos)
|
| 59 |
+
|
| 60 |
+
# Compute NMSE
|
| 61 |
+
loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
|
| 62 |
+
loss.backward()
|
| 63 |
+
optimizer.step()
|
| 64 |
+
scheduler.step()
|
| 65 |
+
|
| 66 |
+
train_nmse += loss.item()
|
| 67 |
+
train_samples += input_ids.shape[0]
|
| 68 |
+
|
| 69 |
+
# Update progress bar
|
| 70 |
+
t.set_postfix({"nmse": train_nmse/train_samples, "lr": scheduler.get_last_lr()[0]})
|
| 71 |
+
|
| 72 |
+
# Average NMSE across training batches
|
| 73 |
+
train_nmse /= max(train_samples, 1)
|
| 74 |
+
train_nmse_losses.append(train_nmse)
|
| 75 |
+
|
| 76 |
+
if epoch % 2 == 0:
|
| 77 |
+
# Validation loop across all buckets
|
| 78 |
+
model.eval()
|
| 79 |
+
val_nmse = 0.0
|
| 80 |
+
val_samples = 0
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
print(f"\nEpoch {epoch + 1}/{epochs} [Validation]")
|
| 83 |
+
for length, val_loader in val_loaders.items():
|
| 84 |
+
print(f"Processing sequences of length {length}")
|
| 85 |
+
with tqdm(val_loader, desc=f"Length {length} [Validation]", unit="batch") as t:
|
| 86 |
+
for batch in t:
|
| 87 |
+
# val_batches += 1
|
| 88 |
+
|
| 89 |
+
# Move data to device
|
| 90 |
+
input_ids, masked_tokens, masked_pos = [b.to(device) for b in batch]
|
| 91 |
+
|
| 92 |
+
# Forward pass
|
| 93 |
+
logits_lm, _, _ = model(input_ids, masked_pos)
|
| 94 |
+
|
| 95 |
+
# Compute NMSE
|
| 96 |
+
loss = torch.sum(nmse_loss(masked_tokens, logits_lm))
|
| 97 |
+
val_nmse += loss.item()
|
| 98 |
+
val_samples += input_ids.shape[0]
|
| 99 |
+
|
| 100 |
+
# Update progress bar
|
| 101 |
+
t.set_postfix({"nmse": val_nmse/val_samples})
|
| 102 |
+
|
| 103 |
+
# Average NMSE across validation batches
|
| 104 |
+
val_nmse /= max(val_samples, 1)
|
| 105 |
+
val_nmse_losses.append(val_nmse)
|
| 106 |
+
|
| 107 |
+
# Save model if validation NMSE improves
|
| 108 |
+
is_best_model = False
|
| 109 |
+
if val_nmse < best_val_nmse:
|
| 110 |
+
best_val_nmse = val_nmse
|
| 111 |
+
model_path = os.path.join(save_dir, f"lwm_epoch{epoch+1}_train{train_nmse:.4f}_val{val_nmse:.4f}.pth")
|
| 112 |
+
torch.save(model.state_dict(), model_path)
|
| 113 |
+
print(f"Model saved: {model_path}")
|
| 114 |
+
is_best_model = True
|
| 115 |
+
|
| 116 |
+
# Log the results
|
| 117 |
+
print(f" Train NMSE: {train_nmse:.4f}")
|
| 118 |
+
print(f" Validation NMSE: {val_nmse:.4f}")
|
| 119 |
+
print(f" Learning Rate: {scheduler.get_last_lr()[0]:.6e}")
|
| 120 |
+
|
| 121 |
+
# Append to CSV log
|
| 122 |
+
with open(log_file, mode='a', newline='') as file:
|
| 123 |
+
writer = csv.writer(file)
|
| 124 |
+
writer.writerow([epoch + 1, train_nmse, val_nmse, scheduler.get_last_lr()[0], is_best_model])
|
| 125 |
+
|
| 126 |
+
# Plot losses after each epoch
|
| 127 |
+
plt.figure(figsize=(10, 6))
|
| 128 |
+
plt.plot(range(1, len(train_nmse_losses) + 1), train_nmse_losses, label="Train NMSE")
|
| 129 |
+
plt.plot(range(1, len(val_nmse_losses) + 1), val_nmse_losses, label="Validation NMSE")
|
| 130 |
+
plt.xlabel("Epochs")
|
| 131 |
+
plt.ylabel("NMSE")
|
| 132 |
+
plt.title("Training and Validation NMSE Loss")
|
| 133 |
+
plt.legend()
|
| 134 |
+
plt.grid(True)
|
| 135 |
+
plt.show()
|
| 136 |
+
|
| 137 |
+
print("Training and validation complete.")
|
| 138 |
+
return model
|
| 139 |
+
#%% FINE-TUNE
|
| 140 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 141 |
+
|
| 142 |
+
# Define the ClassificationHead
|
| 143 |
+
class ClassificationHead(nn.Module):
|
| 144 |
+
def __init__(self, input_dim, num_classes):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.fc = nn.Linear(input_dim, num_classes)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
return self.fc(x)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# Define the RegressionHead
|
| 153 |
+
class RegressionHead(nn.Module):
|
| 154 |
+
def __init__(self, input_dim):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.fc = nn.Linear(input_dim, 1)
|
| 157 |
+
|
| 158 |
+
def forward(self, x):
|
| 159 |
+
return self.fc(x)
|
| 160 |
+
|
| 161 |
+
class CustomClassificationHead(nn.Module):
|
| 162 |
+
def __init__(self, input_dim, num_classes):
|
| 163 |
+
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.classifier = nn.Sequential(
|
| 166 |
+
nn.Linear(input_dim, 512),
|
| 167 |
+
nn.BatchNorm1d(512),
|
| 168 |
+
nn.ReLU(),
|
| 169 |
+
nn.Dropout(0.1),
|
| 170 |
+
nn.Linear(512, 256),
|
| 171 |
+
nn.BatchNorm1d(256),
|
| 172 |
+
nn.ReLU(),
|
| 173 |
+
nn.Dropout(0.1),
|
| 174 |
+
nn.Linear(256, 128),
|
| 175 |
+
nn.BatchNorm1d(128),
|
| 176 |
+
nn.ReLU(),
|
| 177 |
+
# nn.Dropout(0.1),
|
| 178 |
+
nn.Linear(128, num_classes)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
return self.classifier(x)
|
| 183 |
+
|
| 184 |
+
class CustomRegressionHead(nn.Module):
|
| 185 |
+
def __init__(self, input_dim, output_dim):
|
| 186 |
+
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.regressor = nn.Sequential(
|
| 189 |
+
nn.Linear(input_dim, 512),
|
| 190 |
+
nn.BatchNorm1d(512),
|
| 191 |
+
nn.ReLU(),
|
| 192 |
+
nn.Dropout(0.1),
|
| 193 |
+
nn.Linear(512, 256),
|
| 194 |
+
nn.BatchNorm1d(256),
|
| 195 |
+
nn.ReLU(),
|
| 196 |
+
nn.Dropout(0.1),
|
| 197 |
+
nn.Linear(256, output_dim)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def forward(self, x):
|
| 201 |
+
return self.regressor(x)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def custom_heads(input_dim, num_classes=None, output_dim=None, task_type="classification"):
|
| 205 |
+
"""
|
| 206 |
+
Creates a custom head for classification or regression tasks.
|
| 207 |
+
Users should modify the class implementations for further customization.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
input_dim (int): Input dimension of the head.
|
| 211 |
+
num_classes (int): Number of classes for classification tasks. Ignored for regression.
|
| 212 |
+
task_type (str): "classification" or "regression".
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
nn.Module: Custom head for the specified task.
|
| 216 |
+
"""
|
| 217 |
+
if task_type == "classification":
|
| 218 |
+
if num_classes is None:
|
| 219 |
+
raise ValueError("num_classes must be specified for classification tasks.")
|
| 220 |
+
return CustomClassificationHead(input_dim=input_dim, num_classes=num_classes)
|
| 221 |
+
elif task_type == "regression":
|
| 222 |
+
return CustomRegressionHead(input_dim=input_dim, output_dim=output_dim)
|
| 223 |
+
else:
|
| 224 |
+
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
|
| 225 |
+
#%%
|
| 226 |
+
# Fine-tuning wrapper for the base model
|
| 227 |
+
class FineTuningWrapper(nn.Module):
|
| 228 |
+
def __init__(self, model, task_head, fine_tune_layers="full"):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.model = model
|
| 231 |
+
self.task_head = task_head
|
| 232 |
+
|
| 233 |
+
# Freeze all layers initially
|
| 234 |
+
for param in self.model.parameters():
|
| 235 |
+
param.requires_grad = False
|
| 236 |
+
|
| 237 |
+
# Handle fine-tuning layers
|
| 238 |
+
if fine_tune_layers is not None:
|
| 239 |
+
if fine_tune_layers == "full":
|
| 240 |
+
# Unfreeze all layers if "all" is specified
|
| 241 |
+
for param in self.model.parameters():
|
| 242 |
+
param.requires_grad = True
|
| 243 |
+
else:
|
| 244 |
+
# Get a list of all available layer names in the model
|
| 245 |
+
available_layers = [name for name, _ in self.model.named_parameters()]
|
| 246 |
+
|
| 247 |
+
# Validate that specified layers exist in the model
|
| 248 |
+
for layer in fine_tune_layers:
|
| 249 |
+
if not any(layer in lname for lname in available_layers):
|
| 250 |
+
raise ValueError(
|
| 251 |
+
f"Layer '{layer}' not found in the model. "
|
| 252 |
+
f"Available layers: {available_layers}"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Unfreeze only the specified layers
|
| 256 |
+
for name, param in self.model.named_parameters():
|
| 257 |
+
if any(layer in name for layer in fine_tune_layers):
|
| 258 |
+
param.requires_grad = True
|
| 259 |
+
|
| 260 |
+
def forward(self, x, input_type="cls_emb"):
|
| 261 |
+
if input_type == "raw":
|
| 262 |
+
task_input = x.view(x.size(0), -1)
|
| 263 |
+
else:
|
| 264 |
+
embeddings, attn_maps = self.model(x) # Get embeddings from the base model
|
| 265 |
+
if input_type == "cls_emb":
|
| 266 |
+
task_input = embeddings[:, 0, :] # CLS token
|
| 267 |
+
elif input_type == "chs_emb":
|
| 268 |
+
chs_emb = embeddings[:, 1:, :]
|
| 269 |
+
task_input = chs_emb.view(chs_emb.size(0), -1) # embeddings.mean(dim=1) # Mean pooling over channel embeddings
|
| 270 |
+
|
| 271 |
+
return self.task_head(task_input), 0 if input_type=="raw" else attn_maps
|
| 272 |
+
#%%
|
| 273 |
+
# Fine-tuning function
|
| 274 |
+
from sklearn.metrics import f1_score
|
| 275 |
+
def finetune(
|
| 276 |
+
base_model,
|
| 277 |
+
train_loader,
|
| 278 |
+
val_loader=None,
|
| 279 |
+
task_type="classification",
|
| 280 |
+
input_type="cls_emb",
|
| 281 |
+
num_classes=None,
|
| 282 |
+
output_dim=None,
|
| 283 |
+
use_custom_head=False,
|
| 284 |
+
fine_tune_layers=None,
|
| 285 |
+
optimizer_config=None,
|
| 286 |
+
criterion=None,
|
| 287 |
+
epochs=10,
|
| 288 |
+
device="cuda",
|
| 289 |
+
task="Beam Prediction"
|
| 290 |
+
):
|
| 291 |
+
"""
|
| 292 |
+
Configures and fine-tunes the base model with user-defined settings, saving results and models.
|
| 293 |
+
"""
|
| 294 |
+
# Create results folder
|
| 295 |
+
time_now = f"{time.time():.0f}"
|
| 296 |
+
results_folder = f"results/{task}/{time_now}"
|
| 297 |
+
os.makedirs(results_folder, exist_ok=True)
|
| 298 |
+
log_file = os.path.join(results_folder, "training_log.csv")
|
| 299 |
+
|
| 300 |
+
# Initialize the CSV log
|
| 301 |
+
with open(log_file, mode='w', newline='') as file:
|
| 302 |
+
writer = csv.writer(file)
|
| 303 |
+
writer.writerow(["Task", "Input", "Epoch", "Train Loss", "Validation Loss", "F1-Score (Classification)", "Learning Rate", "Time"])
|
| 304 |
+
|
| 305 |
+
for batch in val_loader:
|
| 306 |
+
input_data, targets = batch[0].to(device), batch[1].to(device)
|
| 307 |
+
break
|
| 308 |
+
|
| 309 |
+
if input_type == "cls_emb":
|
| 310 |
+
n_patches = 1
|
| 311 |
+
patch_size = 128
|
| 312 |
+
elif input_type == "channel_emb":
|
| 313 |
+
n_patches = input_data.shape[1]-1
|
| 314 |
+
patch_size = 128
|
| 315 |
+
elif input_type == "raw":
|
| 316 |
+
n_patches = input_data.shape[1]
|
| 317 |
+
patch_size = 32
|
| 318 |
+
# patch_size = 1
|
| 319 |
+
|
| 320 |
+
if use_custom_head:
|
| 321 |
+
custom_head = custom_heads(input_dim=n_patches*patch_size,
|
| 322 |
+
num_classes=num_classes,
|
| 323 |
+
output_dim=output_dim,
|
| 324 |
+
task_type=task_type)
|
| 325 |
+
|
| 326 |
+
# Handle DataParallel models
|
| 327 |
+
if isinstance(base_model, nn.DataParallel):
|
| 328 |
+
base_model = base_model.module
|
| 329 |
+
|
| 330 |
+
# Set up the task-specific head
|
| 331 |
+
if use_custom_head:
|
| 332 |
+
task_head = custom_head
|
| 333 |
+
elif task_type == "classification":
|
| 334 |
+
if num_classes is None:
|
| 335 |
+
raise ValueError("num_classes must be specified for classification tasks.")
|
| 336 |
+
task_head = ClassificationHead(input_dim=n_patches*patch_size, num_classes=num_classes) # input_dim=base_model.embedding.d_model
|
| 337 |
+
elif task_type == "regression":
|
| 338 |
+
task_head = RegressionHead(input_dim=n_patches*patch_size) # input_dim=base_model.embedding.d_model
|
| 339 |
+
else:
|
| 340 |
+
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
|
| 341 |
+
|
| 342 |
+
# Wrap the model with the fine-tuning head
|
| 343 |
+
wrapper = FineTuningWrapper(base_model, task_head, fine_tune_layers=fine_tune_layers)
|
| 344 |
+
wrapper = wrapper.to(device)
|
| 345 |
+
|
| 346 |
+
print(f'Number of head parameters: {count_parameters(wrapper)}')
|
| 347 |
+
|
| 348 |
+
# Set default optimizer config if not provided
|
| 349 |
+
if optimizer_config is None:
|
| 350 |
+
optimizer_config = {"lr": 1e-4}
|
| 351 |
+
# Set up the optimizer
|
| 352 |
+
optimizer = torch.optim.Adam(wrapper.parameters(), **optimizer_config)
|
| 353 |
+
# Set up the scheduler for learning rate decay
|
| 354 |
+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2) # Example: Reduce LR by 10x every 10 epochs
|
| 355 |
+
|
| 356 |
+
# Set up the loss criterion
|
| 357 |
+
if criterion is None:
|
| 358 |
+
criterion = nn.CrossEntropyLoss() if task_type == "classification" else nn.MSELoss()
|
| 359 |
+
|
| 360 |
+
scaler = GradScaler()
|
| 361 |
+
train_losses, val_losses, f1_scores = [], [], []
|
| 362 |
+
best_val_loss = float("inf")
|
| 363 |
+
best_model_path = None
|
| 364 |
+
|
| 365 |
+
for epoch in range(epochs):
|
| 366 |
+
# Training loop
|
| 367 |
+
wrapper.train()
|
| 368 |
+
epoch_loss = 0.0
|
| 369 |
+
|
| 370 |
+
with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") as progress_bar:
|
| 371 |
+
for batch in progress_bar:
|
| 372 |
+
input_data, targets = batch[0].to(device), batch[1].to(device)
|
| 373 |
+
optimizer.zero_grad()
|
| 374 |
+
|
| 375 |
+
with autocast():
|
| 376 |
+
outputs, attn_maps = wrapper(input_data, input_type=input_type)
|
| 377 |
+
loss = criterion(outputs, targets)
|
| 378 |
+
|
| 379 |
+
scaler.scale(loss).backward()
|
| 380 |
+
scaler.step(optimizer)
|
| 381 |
+
scaler.update()
|
| 382 |
+
|
| 383 |
+
epoch_loss += loss.item()
|
| 384 |
+
progress_bar.set_postfix({"Loss": loss.item()})
|
| 385 |
+
|
| 386 |
+
avg_train_loss = epoch_loss / len(train_loader)
|
| 387 |
+
train_losses.append(avg_train_loss)
|
| 388 |
+
|
| 389 |
+
# Validation loop
|
| 390 |
+
if val_loader:
|
| 391 |
+
wrapper.eval()
|
| 392 |
+
val_loss = 0.0
|
| 393 |
+
all_preds, all_targets = [], []
|
| 394 |
+
|
| 395 |
+
with torch.no_grad():
|
| 396 |
+
for batch in val_loader:
|
| 397 |
+
input_data, targets = batch[0].to(device), batch[1].to(device)
|
| 398 |
+
with autocast():
|
| 399 |
+
outputs, _ = wrapper(input_data, input_type=input_type)
|
| 400 |
+
loss = criterion(outputs, targets)
|
| 401 |
+
|
| 402 |
+
val_loss += loss.item()
|
| 403 |
+
|
| 404 |
+
if task_type == "classification":
|
| 405 |
+
preds = torch.argmax(outputs, dim=1).cpu().numpy()
|
| 406 |
+
all_preds.extend(preds)
|
| 407 |
+
all_targets.extend(targets.cpu().numpy())
|
| 408 |
+
|
| 409 |
+
avg_val_loss = val_loss / len(val_loader)
|
| 410 |
+
val_losses.append(avg_val_loss)
|
| 411 |
+
|
| 412 |
+
time_now = f"{time.time():.0f}"
|
| 413 |
+
# Save the best model
|
| 414 |
+
if avg_val_loss < best_val_loss:
|
| 415 |
+
best_val_loss = avg_val_loss
|
| 416 |
+
best_model_path = os.path.join(results_folder, f"{input_type}_epoch{epoch+1}_valLoss{avg_val_loss:.4f}_{time_now}.pth")
|
| 417 |
+
torch.save(wrapper.state_dict(), best_model_path)
|
| 418 |
+
print(f"Model saved at {best_model_path} with validation loss: {best_val_loss:.4f}")
|
| 419 |
+
|
| 420 |
+
# Compute F1-score for classification tasks
|
| 421 |
+
f1 = None
|
| 422 |
+
if task_type == "classification":
|
| 423 |
+
f1 = f1_score(all_targets, all_preds, average="macro")
|
| 424 |
+
print(f"Epoch {epoch + 1}, Validation F1-Score: {f1:.4f}")
|
| 425 |
+
f1_scores.append(f1)
|
| 426 |
+
|
| 427 |
+
scheduler.step()
|
| 428 |
+
|
| 429 |
+
# Log results
|
| 430 |
+
with open(log_file, mode='a', newline='') as file:
|
| 431 |
+
writer = csv.writer(file)
|
| 432 |
+
writer.writerow([task, input_type, epoch + 1, avg_train_loss, avg_val_loss, f1 if f1 is not None else "-", scheduler.get_last_lr()[0], f"{time_now}"])
|
| 433 |
+
|
| 434 |
+
# Plot training and validation losses
|
| 435 |
+
plt.figure(figsize=(10, 6))
|
| 436 |
+
plt.plot(range(1, epochs + 1), train_losses, label="Training Loss")
|
| 437 |
+
plt.plot(range(1, epochs + 1), val_losses, label="Validation Loss", linestyle="--")
|
| 438 |
+
plt.xlabel("Epochs")
|
| 439 |
+
plt.ylabel("Loss")
|
| 440 |
+
plt.title("Training and Validation Loss")
|
| 441 |
+
plt.legend()
|
| 442 |
+
plt.grid(True)
|
| 443 |
+
# plt.savefig(os.path.join(results_folder, "loss_curve.png"))
|
| 444 |
+
plt.show()
|
| 445 |
+
|
| 446 |
+
return wrapper, best_model_path, train_losses, val_losses, f1_scores if task_type == "classification" else 0, attn_maps
|
utils.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
#%%
|
| 5 |
+
def create_dataloader(grouped_data, batch_size, shuffle):
|
| 6 |
+
|
| 7 |
+
dataloaders = {}
|
| 8 |
+
|
| 9 |
+
for seq_length, group in grouped_data.items():
|
| 10 |
+
|
| 11 |
+
print(f"dataloader in progress ...\nkey: {seq_length}")
|
| 12 |
+
|
| 13 |
+
## Uncomment the following line if you run out of memory during pre-training
|
| 14 |
+
# batch_size = batch_size // 8 if seq_length >= 5 else batch_size
|
| 15 |
+
|
| 16 |
+
# Unpack samples for the current group
|
| 17 |
+
input_ids, masked_tokens, masked_pos = zip(*group)
|
| 18 |
+
|
| 19 |
+
# Convert to tensors
|
| 20 |
+
input_ids_tensor = torch.tensor(input_ids, dtype=torch.float32)
|
| 21 |
+
masked_tokens_tensor = torch.tensor(masked_tokens, dtype=torch.float32)
|
| 22 |
+
masked_pos_tensor = torch.tensor(masked_pos, dtype=torch.long)
|
| 23 |
+
|
| 24 |
+
# Create TensorDataset and DataLoader
|
| 25 |
+
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
|
| 26 |
+
dataloaders[seq_length] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True)
|
| 27 |
+
|
| 28 |
+
return dataloaders
|
| 29 |
+
#%%
|
| 30 |
+
def count_parameters(model):
|
| 31 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 32 |
+
#%%
|
| 33 |
+
import matplotlib.pyplot as plt
|
| 34 |
+
from sklearn.decomposition import PCA
|
| 35 |
+
from sklearn.manifold import TSNE
|
| 36 |
+
import umap
|
| 37 |
+
|
| 38 |
+
def visualize_embeddings(embeddings, labels, method="pca", label=None):
|
| 39 |
+
"""
|
| 40 |
+
Visualize embeddings using PCA, UMAP, or t-SNE with color-coded labels.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
embeddings (torch.Tensor or np.ndarray): Embeddings to visualize, shape (n_samples, n_features).
|
| 44 |
+
labels (torch.Tensor or np.ndarray): Class labels corresponding to embeddings, shape (n_samples,).
|
| 45 |
+
method (str): Dimensionality reduction method ('pca', 'umap', or 'tsne').
|
| 46 |
+
title (str): Title of the plot.
|
| 47 |
+
"""
|
| 48 |
+
# Convert to numpy if input is a torch.Tensor
|
| 49 |
+
if isinstance(embeddings, torch.Tensor):
|
| 50 |
+
embeddings = embeddings.cpu().numpy()
|
| 51 |
+
if isinstance(labels, torch.Tensor):
|
| 52 |
+
labels = labels.cpu().numpy()
|
| 53 |
+
|
| 54 |
+
# Apply the selected dimensionality reduction method
|
| 55 |
+
if method.lower() == "pca":
|
| 56 |
+
reducer = PCA(n_components=2)
|
| 57 |
+
elif method.lower() == "umap":
|
| 58 |
+
reducer = umap.UMAP(n_components=2, n_neighbors=16, random_state=42)
|
| 59 |
+
elif method.lower() == "tsne":
|
| 60 |
+
reducer = TSNE(n_components=2, random_state=42, init="random")
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError("Invalid method. Choose from 'pca', 'umap', or 'tsne'.")
|
| 63 |
+
|
| 64 |
+
reduced_embeddings = reducer.fit_transform(embeddings)
|
| 65 |
+
|
| 66 |
+
# Create a scatter plot with color-coding based on labels
|
| 67 |
+
plt.figure(figsize=(10, 8))
|
| 68 |
+
num_classes = len(np.unique(labels))
|
| 69 |
+
colors = plt.cm.get_cmap("tab10", num_classes)
|
| 70 |
+
|
| 71 |
+
for class_idx in range(num_classes):
|
| 72 |
+
class_points = reduced_embeddings[labels == class_idx]
|
| 73 |
+
plt.scatter(
|
| 74 |
+
class_points[:, 0], class_points[:, 1],
|
| 75 |
+
label=f"Class {class_idx}",
|
| 76 |
+
alpha=0.6
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Customize the plot
|
| 80 |
+
plt.title(f"{label} ({method.upper()})")
|
| 81 |
+
plt.xlabel("Component 1")
|
| 82 |
+
plt.ylabel("Component 2")
|
| 83 |
+
plt.legend()
|
| 84 |
+
plt.show()
|
| 85 |
+
#%%
|
| 86 |
+
def generate_gaussian_noise(data, snr_db):
|
| 87 |
+
"""
|
| 88 |
+
Generate Gaussian noise given an SNR and apply it to the data.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
data (torch.Tensor): Input data tensor of shape (n_samples, seq_len, feature_dim).
|
| 92 |
+
snr_db (float): Signal-to-Noise Ratio in decibels (dB).
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
torch.Tensor: Data with Gaussian noise applied.
|
| 96 |
+
"""
|
| 97 |
+
# Separate the input data to exclude the first channel
|
| 98 |
+
a = data[:, 1:, :] # Shape: (n_samples, seq_len-1, feature_dim)
|
| 99 |
+
flat_data = a.view(a.size(0), -1) # Flatten data to calculate power
|
| 100 |
+
signal_power = torch.mean(flat_data**2, dim=1, keepdim=True) # Shape: (n_samples, 1)
|
| 101 |
+
snr_linear = 10 ** (snr_db / 10)
|
| 102 |
+
noise_power = signal_power / snr_linear
|
| 103 |
+
noise = torch.randn_like(flat_data) * torch.sqrt(noise_power)
|
| 104 |
+
noise = noise.view_as(a)
|
| 105 |
+
noise = torch.cat((torch.zeros_like(data[:, :1, :]), noise), dim=1) # Add zero noise for the first channel
|
| 106 |
+
|
| 107 |
+
return noise
|
| 108 |
+
#%%
|
| 109 |
+
def plot_coverage(rxs, cov_map, dpi=200, figsize=(6,4), cbar_title=None, title=False,
|
| 110 |
+
scat_sz=.5, tx_pos=None, tx_ori=None, legend=False, lims=None,
|
| 111 |
+
proj_3D=False, equal_aspect=False, tight=True, cmap='tab20'):
|
| 112 |
+
|
| 113 |
+
plt_params = {'cmap': cmap}
|
| 114 |
+
if lims:
|
| 115 |
+
plt_params['vmin'], plt_params['vmax'] = lims[0], lims[1]
|
| 116 |
+
|
| 117 |
+
n = 3 if proj_3D else 2 # n coordinates to consider 2 = xy | 3 = xyz
|
| 118 |
+
|
| 119 |
+
xyz = {'x': rxs[:,0], 'y': rxs[:,1]}
|
| 120 |
+
if proj_3D:
|
| 121 |
+
xyz['zs'] = rxs[:,2]
|
| 122 |
+
|
| 123 |
+
fig, ax = plt.subplots(dpi=dpi, figsize=figsize,
|
| 124 |
+
subplot_kw={'projection': '3d'} if proj_3D else {})
|
| 125 |
+
|
| 126 |
+
im = plt.scatter(**xyz, c=cov_map, s=scat_sz, marker='s', **plt_params)
|
| 127 |
+
|
| 128 |
+
cbar = plt.colorbar(im, label='' if not cbar_title else cbar_title)
|
| 129 |
+
|
| 130 |
+
plt.xlabel('x (m)')
|
| 131 |
+
plt.ylabel('y (m)')
|
| 132 |
+
|
| 133 |
+
# TX position
|
| 134 |
+
if tx_pos is not None:
|
| 135 |
+
ax.scatter(*tx_pos[:n], marker='P', c='r', label='TX')
|
| 136 |
+
|
| 137 |
+
# TX orientation
|
| 138 |
+
if tx_ori is not None and tx_pos is not None: # ori = [azi, el]
|
| 139 |
+
# positive azimuths point left (like positive angles in a unit circle)
|
| 140 |
+
# positive elevations point up
|
| 141 |
+
r = 30 # ref size of pointing direction
|
| 142 |
+
tx_lookat = np.copy(tx_pos)
|
| 143 |
+
tx_lookat[:2] += r * np.array([np.cos(tx_ori[2]), np.sin(tx_ori[2])]) # azimuth
|
| 144 |
+
tx_lookat[2] += r * np.sin(tx_ori[1]) # elevation
|
| 145 |
+
|
| 146 |
+
line_components = [[tx_pos[i], tx_lookat[i]] for i in range(n)]
|
| 147 |
+
line = {key:val for key,val in zip(['xs', 'ys', 'zs'], line_components)}
|
| 148 |
+
if n == 2:
|
| 149 |
+
ax.plot(line_components[0], line_components[1], c='k', alpha=.5, zorder=3)
|
| 150 |
+
else:
|
| 151 |
+
ax.plot(**line, c='k', alpha=.5, zorder=3)
|
| 152 |
+
|
| 153 |
+
if title:
|
| 154 |
+
ax.set_title(title)
|
| 155 |
+
|
| 156 |
+
if legend:
|
| 157 |
+
plt.legend(loc='upper center', ncols=10, framealpha=.5)
|
| 158 |
+
|
| 159 |
+
if tight:
|
| 160 |
+
s = 1
|
| 161 |
+
mins, maxs = np.min(rxs, axis=0)-s, np.max(rxs, axis=0)+s
|
| 162 |
+
if not proj_3D:
|
| 163 |
+
plt.xlim([mins[0], maxs[0]])
|
| 164 |
+
plt.ylim([mins[1], maxs[1]])
|
| 165 |
+
else:
|
| 166 |
+
ax.axes.set_xlim3d([mins[0], maxs[0]])
|
| 167 |
+
ax.axes.set_ylim3d([mins[1], maxs[1]])
|
| 168 |
+
if tx_pos is None:
|
| 169 |
+
ax.axes.set_zlim3d([mins[2], maxs[2]])
|
| 170 |
+
else:
|
| 171 |
+
ax.axes.set_zlim3d([np.min([mins[2], tx_pos[2]]),
|
| 172 |
+
np.max([mins[2], tx_pos[2]])])
|
| 173 |
+
|
| 174 |
+
if equal_aspect and not proj_3D: # disrups the plot
|
| 175 |
+
plt.axis('scaled')
|
| 176 |
+
|
| 177 |
+
return fig, ax, cbar
|
| 178 |
+
#%%
|
| 179 |
+
def prepare_loaders(
|
| 180 |
+
preprocessed_data,
|
| 181 |
+
labels=None,
|
| 182 |
+
selected_patches_idxs=None,
|
| 183 |
+
input_type="raw",
|
| 184 |
+
task_type="classification",
|
| 185 |
+
feature_selection=False,
|
| 186 |
+
train_ratio=0.8,
|
| 187 |
+
batch_size=64,
|
| 188 |
+
seed=42 # Default seed for reproducibility
|
| 189 |
+
):
|
| 190 |
+
"""
|
| 191 |
+
Prepares datasets and data loaders for training and validation.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
preprocessed_data (torch.Tensor): The input data, either raw or preprocessed.
|
| 195 |
+
labels (torch.Tensor, optional): The labels for classification tasks.
|
| 196 |
+
selected_patches_idxs (torch.Tensor, optional): Indices of selected patches for feature selection.
|
| 197 |
+
input_type (str): "raw" or "processed" to specify input data type.
|
| 198 |
+
task_type (str): "classification" or "regression".
|
| 199 |
+
feature_selection (bool): Whether to perform feature selection based on selected_patches_idxs.
|
| 200 |
+
train_ratio (float): Proportion of data to use for training (remaining for validation).
|
| 201 |
+
batch_size (int): Batch size for data loaders.
|
| 202 |
+
seed (int): Random seed for reproducibility.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
tuple: (train_loader, val_loader)
|
| 206 |
+
"""
|
| 207 |
+
# Set random seed for reproducibility
|
| 208 |
+
torch.manual_seed(seed)
|
| 209 |
+
|
| 210 |
+
# Prepare samples
|
| 211 |
+
if input_type == "raw":
|
| 212 |
+
if feature_selection and selected_patches_idxs is not None:
|
| 213 |
+
batch_indices = torch.arange(preprocessed_data.size(0)).unsqueeze(1) # Shape: [batch_size, 1]
|
| 214 |
+
samples = torch.tensor(preprocessed_data[batch_indices, selected_patches_idxs], dtype=torch.float32)
|
| 215 |
+
else:
|
| 216 |
+
samples = torch.tensor(preprocessed_data[:, 1:], dtype=torch.float32) # raw_chs
|
| 217 |
+
else:
|
| 218 |
+
samples = torch.tensor(preprocessed_data, dtype=torch.float32)
|
| 219 |
+
|
| 220 |
+
# Prepare dataset
|
| 221 |
+
if task_type == "classification":
|
| 222 |
+
if labels is None:
|
| 223 |
+
raise ValueError("Labels are required for classification tasks.")
|
| 224 |
+
labels = torch.tensor(labels, dtype=torch.long)
|
| 225 |
+
dataset = TensorDataset(samples, labels)
|
| 226 |
+
target = 0 # REVISE if needed
|
| 227 |
+
elif task_type == "regression":
|
| 228 |
+
target = samples[:, 1:, :].view(samples.size(0), -1) # Reshape for regression targets
|
| 229 |
+
dataset = TensorDataset(samples, target)
|
| 230 |
+
else:
|
| 231 |
+
raise ValueError("Invalid task_type. Choose 'classification' or 'regression'.")
|
| 232 |
+
|
| 233 |
+
# Set random seed for reproducibility
|
| 234 |
+
generator = torch.Generator().manual_seed(seed)
|
| 235 |
+
|
| 236 |
+
# Split dataset into training and validation
|
| 237 |
+
n_samples = len(dataset)
|
| 238 |
+
train_size = int(train_ratio * n_samples)
|
| 239 |
+
val_size = n_samples - train_size
|
| 240 |
+
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=generator)
|
| 241 |
+
|
| 242 |
+
# Create DataLoaders
|
| 243 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, generator=generator)
|
| 244 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 245 |
+
|
| 246 |
+
print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")
|
| 247 |
+
return train_loader, val_loader, samples, target
|