File size: 9,372 Bytes
bdb955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import argparse
import cv2
import numpy as np
import torch
from pathlib import Path
import time
import traceback

# Assurez-vous que le répertoire tvcalib est dans le PYTHONPATH
# ou exécutez depuis le répertoire tvcalib_image_processor
from tvcalib.infer.module import TvCalibInferModule
# Importer les fonctions de visualisation et les constantes de modulation
from visualizer import (
    create_minimap_view, 
    create_minimap_with_offset_skeletons,
    DYNAMIC_SCALE_MIN_MODULATION, # Importer les constantes
    DYNAMIC_SCALE_MAX_MODULATION
)
# Importer la fonction d'extraction des données joueurs
from pose_estimator import get_player_data

# Constantes
IMAGE_SHAPE = (720, 1280)  # Hauteur, Largeur
SEGMENTATION_MODEL_PATH = Path("models/segmentation/train_59.pt")

def preprocess_image_tvcalib(image_bgr):
    """Prétraite l'image BGR pour TvCalib et retourne le tenseur et l'image RGB redimensionnée."""
    if image_bgr is None:
        raise ValueError("Impossible de charger l'image")

    # 1. Redimensionner en 720p si nécessaire
    h, w = image_bgr.shape[:2]
    if h != IMAGE_SHAPE[0] or w != IMAGE_SHAPE[1]:
        print(f"Redimensionnement de l'image vers {IMAGE_SHAPE[1]}x{IMAGE_SHAPE[0]}")
        image_bgr_resized = cv2.resize(image_bgr, (IMAGE_SHAPE[1], IMAGE_SHAPE[0]), interpolation=cv2.INTER_LINEAR)
    else:
        image_bgr_resized = image_bgr

    # 2. Convertir en RGB (pour TvCalib ET pour la visualisation originale)
    image_rgb_resized = cv2.cvtColor(image_bgr_resized, cv2.COLOR_BGR2RGB)

    # 3. Normalisation spécifique pour le modèle pré-entraîné (pour TvCalib)
    image_tensor = torch.from_numpy(image_rgb_resized).float()
    image_tensor = image_tensor.permute(2, 0, 1)  # HWC -> CHW
    mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    image_tensor = (image_tensor / 255.0 - mean) / std

    # 4. Déplacer sur le bon device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_tensor = image_tensor.to(device)

    # Retourner le tenseur pour TvCalib, l'image BGR et RGB redimensionnée
    return image_tensor, image_bgr_resized, image_rgb_resized

def main():
    parser = argparse.ArgumentParser(description="Exécute la méthode TvCalib sur une seule image.")
    parser.add_argument("image_path", type=str, help="Chemin vers l'image à traiter.")
    parser.add_argument("--output_homography", type=str, default=None, help="Chemin optionnel pour sauvegarder la matrice d'homographie (.npy).")
    parser.add_argument("--optim_steps", type=int, default=500, help="Nombre d'étapes d'optimisation pour la calibration (l'arrêt anticipé est désactivé).")
    parser.add_argument("--target_avg_scale", type=float, default=1, 
                        help="Facteur d'échelle MOYEN CIBLE pour dessiner les squelettes sur la minimap (défaut: 0.35). Le script ajuste l'échelle de base pour tenter d'atteindre cette moyenne.")

    args = parser.parse_args()

    if not Path(args.image_path).exists():
        print(f"Erreur : Fichier image introuvable : {args.image_path}")
        return

    if not SEGMENTATION_MODEL_PATH.exists():
        print(f"Erreur : Modèle de segmentation introuvable : {SEGMENTATION_MODEL_PATH}")
        print("Assurez-vous d'avoir copié train_59.pt dans le dossier models/segmentation/")
        return

    print("Initialisation de TvCalibInferModule...")
    try:
        model = TvCalibInferModule(
            segmentation_checkpoint=SEGMENTATION_MODEL_PATH,
            image_shape=IMAGE_SHAPE,
            optim_steps=args.optim_steps,
            lens_dist=False  # Gardons cela simple pour l'instant
        )
        print(f"✓ Modèle chargé sur {next(model.model_calib.parameters()).device}")
    except Exception as e:
        print(f"Erreur lors de l'initialisation du modèle : {e}")
        return

    print(f"Traitement de l'image : {args.image_path}")
    try:
        # Charger l'image (en BGR par défaut avec OpenCV)
        image_bgr_orig = cv2.imread(args.image_path)
        if image_bgr_orig is None:
            raise FileNotFoundError(f"Impossible de lire le fichier image: {args.image_path}")

        # Prétraiter l'image
        start_preprocess = time.time()
        image_tensor, image_bgr_resized, image_rgb_resized = preprocess_image_tvcalib(image_bgr_orig)
        print(f"Temps de prétraitement TvCalib : {time.time() - start_preprocess:.3f}s")

        # Exécuter la segmentation
        print("Exécution de la segmentation...")
        start_segment = time.time()
        with torch.no_grad():
            keypoints = model._segment(image_tensor)
        print(f"Temps de segmentation : {time.time() - start_segment:.3f}s")

        # Exécuter la calibration (optimisation)
        print("Exécution de la calibration (optimisation)...")
        start_calibrate = time.time()
        homography = model._calibrate(keypoints)
        print(f"Temps de calibration : {time.time() - start_calibrate:.3f}s")

        if homography is not None:
            print("\n--- Homographie Calculée ---")
            if isinstance(homography, torch.Tensor):
                homography_np = homography.detach().cpu().numpy()
            else:
                homography_np = homography
            print(homography_np)

            if args.output_homography:
                try:
                    np.save(args.output_homography, homography_np)
                    print(f"\nHomographie sauvegardée dans : {args.output_homography}")
                except Exception as e:
                    print(f"Erreur lors de la sauvegarde de l'homographie : {e}")

            # --- Extraction des données joueurs --- 
            print("\nExtraction des données joueurs (pose+couleur)...")
            start_pose = time.time()
            player_list = get_player_data(image_bgr_resized) 
            print(f"Temps d'extraction données joueurs : {time.time() - start_pose:.3f}s ({len(player_list)} joueurs trouvés)")
            
            # --- Calcul de l'échelle de base estimée --- 
            print("\nCalcul de l'échelle de base pour atteindre la cible...")
            target_average_scale = args.target_avg_scale
            
            # Calculer la modulation moyenne attendue (hypothèse: joueur moyen au centre Y=0.5)
            # Logique inversée actuelle : MIN + (MAX - MIN) * (1.0 - norm_y)
            avg_modulation_expected = DYNAMIC_SCALE_MIN_MODULATION + \
                                      (DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - 0.5)
            
            estimated_base_scale = target_average_scale # Valeur par défaut si modulation = 0
            if avg_modulation_expected != 0:
                estimated_base_scale = target_average_scale / avg_modulation_expected
            else:
                print("Avertissement : Modulation moyenne attendue nulle, impossible d'estimer l'échelle de base.")
                
            print(f"  Modulation dynamique moyenne attendue (pour Y=0.5) : {avg_modulation_expected:.3f}")
            print(f"  Échelle de base interne estimée pour cible {target_average_scale:.3f} : {estimated_base_scale:.3f}")

            # --- Génération des DEUX minimaps --- 
            print("\nGénération des minimaps (Originale et Squelettes Décalés)...")
            
            # 1. Minimap avec l'image originale (RGB)
            minimap_original = create_minimap_view(image_rgb_resized, homography_np)
            
            # 2. Minimap avec les squelettes
            # Utiliser l'échelle de base ESTIMÉE
            minimap_offset_skeletons, actual_avg_scale = create_minimap_with_offset_skeletons(
                player_list, 
                homography_np, 
                base_skeleton_scale=estimated_base_scale # Utiliser l'estimation
            )

            # Afficher la cible et le résultat réel
            if actual_avg_scale is not None:
                print(f"\nÉchelle moyenne CIBLE demandée (--target_avg_scale) : {target_average_scale:.3f}")
                print(f"Échelle moyenne FINALE RÉELLEMENT appliquée (basée sur joueurs réels) : {actual_avg_scale:.3f}")
            
            # --- Affichage des résultats --- 
            print("\nAffichage des résultats. Appuyez sur une touche pour quitter.")
            
            # Afficher la minimap originale
            if minimap_original is not None:
                cv2.imshow("Minimap avec Projection Originale", minimap_original)
            else:
                 print("N'a pas pu générer la minimap originale.")

            # Afficher la minimap avec les squelettes décalés
            if minimap_offset_skeletons is not None:
                cv2.imshow("Minimap avec Squelettes Decales", minimap_offset_skeletons)
            else:
                 print("N'a pas pu générer la minimap squelettes décalés.")

            cv2.waitKey(0) # Attend qu'une touche soit pressée

        else:
            print("\nAucune homographie n'a pu être calculée.")

    except Exception as e:
        print(f"Erreur lors du traitement de l'image : {e}")
        traceback.print_exc()
    finally:
        print("Fermeture des fenêtres OpenCV.")
        cv2.destroyAllWindows()

if __name__ == "__main__":
    main()