|
import torch |
|
import torch.nn as nn |
|
from functools import partial |
|
|
|
from pathlib import Path |
|
from typing import Any, Dict, Tuple |
|
|
|
from common.infer.base import * |
|
|
|
|
|
|
|
|
|
import torchvision.transforms as T |
|
|
|
from ..sn_segmentation.src.custom_extremities import ( |
|
generate_class_synthesis, get_line_extremities |
|
) |
|
from ..models.segmentation import InferenceSegmentationModel |
|
from ..data.dataset import InferenceDatasetCalibration |
|
from ..data.utils import custom_list_collate |
|
from ..cam_modules import CameraParameterWLensDistDictZScore, SNProjectiveCamera |
|
from ..utils.linalg import distance_line_pointcloud_3d, distance_point_pointcloud |
|
from ..utils.objects_3d import SoccerPitchLineCircleSegments, SoccerPitchSNCircleCentralSplit |
|
from ..cam_distr.tv_main_center import get_cam_distr, get_dist_distr |
|
from ..utils.io import detach_dict, tensor2list |
|
|
|
from common.data.utils import yards |
|
|
|
from kornia.geometry.conversions import convert_points_to_homogeneous |
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
class TvCalibInferModule(InferModule): |
|
def __init__( |
|
self, |
|
segmentation_checkpoint: Path, |
|
image_shape=(720,1280), |
|
optim_steps=2000, |
|
lens_dist: bool=False, |
|
playfield_size=(105, 68), |
|
make_images: bool=False |
|
|
|
): |
|
self.image_shape = image_shape |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.make_images = make_images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.fn_generate_class_synthesis = partial( |
|
generate_class_synthesis, |
|
radius=4 |
|
) |
|
self.fn_get_line_extremities = partial( |
|
get_line_extremities, |
|
maxdist=30, |
|
width=455, |
|
height=256, |
|
num_points_lines=4, |
|
num_points_circles=8 |
|
) |
|
|
|
|
|
self.model_seg = InferenceSegmentationModel( |
|
segmentation_checkpoint, |
|
self.device |
|
) |
|
|
|
self.object3d = SoccerPitchLineCircleSegments( |
|
device=self.device, |
|
base_field=SoccerPitchSNCircleCentralSplit() |
|
) |
|
self.object3dcpu = SoccerPitchLineCircleSegments( |
|
device="cpu", |
|
base_field=SoccerPitchSNCircleCentralSplit() |
|
) |
|
|
|
|
|
batch_size_calib = 1 |
|
self.model_calib = TVCalibModule( |
|
self.object3d, |
|
get_cam_distr(1.96, batch_size_calib, 1), |
|
get_dist_distr(batch_size_calib, 1) if lens_dist else None, |
|
(image_shape[0], image_shape[1]), |
|
optim_steps, |
|
self.device, |
|
log_per_step=False, |
|
tqdm_kwqargs=None, |
|
) |
|
self.resize = T.Compose([ |
|
T.Resize(size=(256,455)) |
|
]) |
|
self.offset = np.array([ |
|
[1, 0, playfield_size[0]/2.0 ], |
|
[0, 1, playfield_size[1]/2.0 ], |
|
[0, 0, 1] |
|
]) |
|
|
|
|
|
|
|
def setup(self, datamodule: InferDataModule): |
|
pass |
|
|
|
|
|
def predict(self, x: Any) -> Dict: |
|
|
|
""" |
|
1. Run segmentation & Pick keypoints |
|
2. Calibrate based on selected points |
|
""" |
|
|
|
|
|
image = x["image"] |
|
keypoints = self._segment(x["image"]) |
|
|
|
|
|
homo = self._calibrate(keypoints) |
|
|
|
|
|
image_720p = self.previewer.to_image(image.clone().detach().cpu()) |
|
|
|
|
|
if (homo is not None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
result = { |
|
"homography": homo |
|
} |
|
|
|
if (self.make_images): |
|
|
|
pass |
|
|
|
return result |
|
|
|
|
|
def _segment(self, image): |
|
|
|
|
|
image = self.resize(image) |
|
with torch.no_grad(): |
|
sem_lines = self.model_seg.inference( |
|
image.unsqueeze(0).to(self.device) |
|
) |
|
|
|
sem_lines = sem_lines.detach().cpu().numpy().astype(np.uint8) |
|
|
|
|
|
skeletons_batch = self.fn_generate_class_synthesis(sem_lines[0]) |
|
keypoints_raw_batch = self.fn_get_line_extremities(skeletons_batch) |
|
|
|
|
|
return keypoints_raw_batch |
|
|
|
|
|
def _calibrate(self, keypoints): |
|
|
|
|
|
ds = InferenceDatasetCalibration( |
|
[keypoints], |
|
self.image_shape[1], self.image_shape[0], |
|
self.object3d |
|
) |
|
|
|
|
|
_batch_size = 1 |
|
x_dict = custom_list_collate([ds[0]]) |
|
try: |
|
|
|
per_sample_loss, cam, _ = self.model_calib.self_optim_batch(x_dict) |
|
output_dict = tensor2list( |
|
detach_dict({**cam.get_parameters(_batch_size), **per_sample_loss}) |
|
) |
|
|
|
homo = output_dict["homography"][0] |
|
if (len(homo) > 0): |
|
homo = np.array(homo[0]) |
|
|
|
to_yards = np.array([ |
|
[ yards(1), 0, 0 ], |
|
[ 0, yards(1), 0 ], |
|
[ 0, 0, 1] |
|
]) |
|
|
|
|
|
homo = to_yards @ self.offset @ homo |
|
|
|
else: |
|
homo = None |
|
except Exception as e: |
|
print(f"Erreur lors de la calibration: {str(e)}") |
|
homo = None |
|
|
|
return homo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TVCalibModule(torch.nn.Module): |
|
def __init__( |
|
self, |
|
model3d, |
|
cam_distr, |
|
dist_distr, |
|
image_dim: Tuple[int, int], |
|
optim_steps: int, |
|
device="cpu", |
|
tqdm_kwqargs=None, |
|
log_per_step=False, |
|
*args, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.image_height, self.image_width = image_dim |
|
self.principal_point = (self.image_width / 2, self.image_height / 2) |
|
self.model3d = model3d |
|
self.cam_param_dict = CameraParameterWLensDistDictZScore( |
|
cam_distr, dist_distr, device=device |
|
) |
|
|
|
self.lens_distortion_active = False if dist_distr is None else True |
|
self.optim_steps = optim_steps |
|
self._device = device |
|
|
|
|
|
self.previous_params = None |
|
|
|
self.optim = torch.optim.AdamW( |
|
self.cam_param_dict.param_dict.parameters(), lr=0.1, weight_decay=0.01 |
|
) |
|
self.Scheduler = partial( |
|
torch.optim.lr_scheduler.OneCycleLR, |
|
max_lr=0.05, |
|
total_steps=self.optim_steps, |
|
pct_start=0.5, |
|
) |
|
|
|
if self.lens_distortion_active: |
|
self.optim_lens_distortion = torch.optim.AdamW( |
|
self.cam_param_dict.param_dict_dist.parameters(), lr=1e-3, weight_decay=0.01 |
|
) |
|
self.Scheduler_lens_distortion = partial( |
|
torch.optim.lr_scheduler.OneCycleLR, |
|
max_lr=1e-3, |
|
total_steps=self.optim_steps, |
|
pct_start=0.33, |
|
optimizer=self.optim_lens_distortion, |
|
) |
|
|
|
self.tqdm_kwqargs = tqdm_kwqargs |
|
if tqdm_kwqargs is None: |
|
self.tqdm_kwqargs = {} |
|
|
|
self.hparams = {"optim": str(self.optim), "scheduler": str(self.Scheduler)} |
|
self.log_per_step = log_per_step |
|
|
|
def forward(self, x): |
|
|
|
|
|
phi_hat, psi_hat = self.cam_param_dict() |
|
|
|
cam = SNProjectiveCamera( |
|
phi_hat, |
|
psi_hat, |
|
self.principal_point, |
|
self.image_width, |
|
self.image_height, |
|
device=self._device, |
|
nan_check=False, |
|
) |
|
|
|
|
|
points_px_lines_true = x["lines__ndc_projected_selection_shuffled"].to(self._device) |
|
batch_size, T_l, _, S_l, N_l = points_px_lines_true.shape |
|
|
|
|
|
points_px_circles_true = x["circles__ndc_projected_selection_shuffled"].to(self._device) |
|
_, T_c, _, S_c, N_c = points_px_circles_true.shape |
|
assert T_c == T_l |
|
|
|
|
|
|
|
points3d_lines_keypoints = self.model3d.line_segments |
|
points3d_lines_keypoints = points3d_lines_keypoints.reshape(3, S_l * 2).transpose(0, 1) |
|
points_px_lines_keypoints = convert_points_to_homogeneous( |
|
cam.project_point2ndc(points3d_lines_keypoints, lens_distortion=False) |
|
) |
|
|
|
if batch_size < cam.batch_dim: |
|
points_px_lines_keypoints = points_px_lines_keypoints[:batch_size] |
|
|
|
points_px_lines_keypoints = points_px_lines_keypoints.view(batch_size, T_l, S_l, 2, 3) |
|
|
|
lp1 = points_px_lines_keypoints[..., 0, :].unsqueeze(-2) |
|
lp2 = points_px_lines_keypoints[..., 1, :].unsqueeze(-2) |
|
|
|
pc = ( |
|
points_px_lines_true.view(batch_size, T_l, 3, S_l * N_l) |
|
.transpose(2, 3) |
|
.view(batch_size, T_l, S_l, N_l, 3) |
|
) |
|
|
|
if self.lens_distortion_active: |
|
|
|
pc = pc.view(batch_size, T_l, S_l * N_l, 3) |
|
pc = pc.detach().clone() |
|
pc[..., :2] = cam.undistort_points( |
|
pc[..., :2], cam.intrinsics_ndc, num_iters=1 |
|
) |
|
pc = pc.view(batch_size, T_l, S_l, N_l, 3) |
|
|
|
distances_px_lines_raw = distance_line_pointcloud_3d( |
|
e1=lp2 - lp1, r1=lp1, pc=pc, reduce=None |
|
) |
|
distances_px_lines_raw = distances_px_lines_raw.unsqueeze(-3) |
|
|
|
|
|
|
|
|
|
points3d_circles_pc = self.model3d.circle_segments |
|
_, S_c, N_c_star = points3d_circles_pc.shape |
|
points3d_circles_pc = points3d_circles_pc.reshape(3, S_c * N_c_star).transpose(0, 1) |
|
points_px_circles_pc = cam.project_point2ndc(points3d_circles_pc, lens_distortion=False) |
|
|
|
if batch_size < cam.batch_dim: |
|
points_px_circles_pc = points_px_circles_pc[:batch_size] |
|
|
|
if self.lens_distortion_active: |
|
|
|
points_px_circles_true = points_px_circles_true.view( |
|
batch_size, T_c, 3, S_c * N_c |
|
).transpose(2, 3) |
|
points_px_circles_true = points_px_circles_true.detach().clone() |
|
points_px_circles_true[..., :2] = cam.undistort_points( |
|
points_px_circles_true[..., :2], cam.intrinsics_ndc, num_iters=1 |
|
) |
|
points_px_circles_true = points_px_circles_true.transpose(2, 3).view( |
|
batch_size, T_c, 3, S_c, N_c |
|
) |
|
|
|
distances_px_circles_raw = distance_point_pointcloud( |
|
points_px_circles_true, points_px_circles_pc.view(batch_size, T_c, S_c, N_c_star, 2) |
|
) |
|
|
|
distances_dict = { |
|
"loss_ndc_lines": distances_px_lines_raw, |
|
"loss_ndc_circles": distances_px_circles_raw, |
|
} |
|
return distances_dict, cam |
|
|
|
def self_optim_batch(self, x, *args, **kwargs): |
|
|
|
scheduler = self.Scheduler(self.optim) |
|
if self.lens_distortion_active: |
|
scheduler_lens_distortion = self.Scheduler_lens_distortion() |
|
|
|
|
|
if self.previous_params is not None: |
|
print("Utilisation des paramètres précédents pour l'initialisation") |
|
update_dict = {} |
|
for k, v in self.previous_params.items(): |
|
update_dict[k] = v.detach().clone() |
|
self.cam_param_dict.initialize(update_dict) |
|
else: |
|
print("Première frame : initialisation à zéro") |
|
self.cam_param_dict.initialize(None) |
|
|
|
self.optim.zero_grad() |
|
if self.lens_distortion_active: |
|
self.optim_lens_distortion.zero_grad() |
|
|
|
keypoint_masks = { |
|
"loss_ndc_lines": x["lines__is_keypoint_mask"].to(self._device), |
|
"loss_ndc_circles": x["circles__is_keypoint_mask"].to(self._device), |
|
} |
|
num_actual_points = { |
|
"loss_ndc_circles": keypoint_masks["loss_ndc_circles"].sum(dim=(-1, -2)), |
|
"loss_ndc_lines": keypoint_masks["loss_ndc_lines"].sum(dim=(-1, -2)), |
|
} |
|
|
|
per_sample_loss = {} |
|
per_sample_loss["mask_lines"] = keypoint_masks["loss_ndc_lines"] |
|
per_sample_loss["mask_circles"] = keypoint_masks["loss_ndc_circles"] |
|
|
|
per_step_info = {"loss": [], "lr": []} |
|
|
|
|
|
loss_target = 0.001 |
|
loss_patience = 10 |
|
loss_tolerance = 1e-4 |
|
loss_history = [] |
|
best_loss = float('inf') |
|
steps_without_improvement = 0 |
|
|
|
|
|
with tqdm(range(self.optim_steps), **self.tqdm_kwqargs) as pbar: |
|
for step in pbar: |
|
self.optim.zero_grad() |
|
if self.lens_distortion_active: |
|
self.optim_lens_distortion.zero_grad() |
|
|
|
|
|
distances_dict, cam = self(x) |
|
|
|
|
|
losses = {} |
|
for key_dist, distances in distances_dict.items(): |
|
distances[~keypoint_masks[key_dist]] = 0.0 |
|
per_sample_loss[f"{key_dist}_distances_raw"] = distances |
|
distances_reduced = distances.sum(dim=(-1, -2)) |
|
distances_reduced = distances_reduced / num_actual_points[key_dist] |
|
distances_reduced[num_actual_points[key_dist] == 0] = 0.0 |
|
distances_reduced = distances_reduced.squeeze(-1) |
|
per_sample_loss[key_dist] = distances_reduced |
|
loss = distances_reduced.mean(dim=-1) |
|
loss = loss.sum() |
|
losses[key_dist] = loss |
|
|
|
loss_total_dist = losses["loss_ndc_lines"] + losses["loss_ndc_circles"] |
|
loss_total = loss_total_dist |
|
current_loss = loss_total.item() |
|
|
|
|
|
loss_history.append(current_loss) |
|
|
|
|
|
if current_loss < best_loss: |
|
best_loss = current_loss |
|
steps_without_improvement = 0 |
|
else: |
|
steps_without_improvement += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.log_per_step: |
|
per_step_info["lr"].append(scheduler.get_last_lr()) |
|
per_step_info["loss"].append(distances_reduced) |
|
if step % 50 == 0: |
|
pbar.set_postfix( |
|
loss=f"{loss_total_dist.detach().cpu().tolist():.5f}", |
|
loss_lines=f'{losses["loss_ndc_lines"].detach().cpu().tolist():.3f}', |
|
loss_circles=f'{losses["loss_ndc_circles"].detach().cpu().tolist():.3f}', |
|
) |
|
|
|
loss_total.backward() |
|
self.optim.step() |
|
scheduler.step() |
|
if self.lens_distortion_active: |
|
self.optim_lens_distortion.step() |
|
scheduler_lens_distortion.step() |
|
|
|
|
|
self.previous_params = {} |
|
for k, v in self.cam_param_dict.param_dict.items(): |
|
self.previous_params[k] = v.detach().clone() |
|
|
|
per_sample_loss["loss_ndc_total"] = torch.sum( |
|
torch.stack([per_sample_loss[key_dist] for key_dist in distances_dict.keys()], dim=0), |
|
dim=0, |
|
) |
|
|
|
if self.log_per_step: |
|
per_step_info["loss"] = torch.stack( |
|
per_step_info["loss"], dim=-1 |
|
) |
|
per_step_info["lr"] = torch.tensor(per_step_info["lr"]) |
|
return per_sample_loss, cam, per_step_info |
|
|