Spaces:
Runtime error
Runtime error
Commit
·
9014040
1
Parent(s):
d0f5598
feat
Browse files
app.py
CHANGED
|
@@ -112,21 +112,43 @@ model.eval()
|
|
| 112 |
def infer_solar_image_heatmap(img):
|
| 113 |
# Pré-processamento da imagem
|
| 114 |
img_gray = img.convert("L").resize((224, 224))
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
with torch.no_grad():
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
# Pegar o embedding da saída
|
| 123 |
-
emb = outputs.squeeze().numpy()
|
| 124 |
-
heatmap = emb - emb.min()
|
| 125 |
-
heatmap /= heatmap.max() + 1e-8
|
| 126 |
|
| 127 |
-
#
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
plt.tight_layout()
|
| 131 |
return plt.gcf()
|
| 132 |
|
|
|
|
| 112 |
def infer_solar_image_heatmap(img):
|
| 113 |
# Pré-processamento da imagem
|
| 114 |
img_gray = img.convert("L").resize((224, 224))
|
| 115 |
+
img_np = np.array(img_gray)
|
| 116 |
+
ts_tensor = (
|
| 117 |
+
torch.tensor(img_np, dtype=torch.float32)
|
| 118 |
+
.unsqueeze(0)
|
| 119 |
+
.unsqueeze(0)
|
| 120 |
+
.unsqueeze(2)
|
| 121 |
+
/ 255.0
|
| 122 |
+
) # [B=1,C=1,T=1,H=224,W=224]
|
| 123 |
+
batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1, 1))}
|
| 124 |
+
|
| 125 |
+
# Inferência (retorna tokens [1, L, D] com finetune=True)
|
| 126 |
with torch.no_grad():
|
| 127 |
+
tokens = model(batch).squeeze(0).cpu() # [L, D]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
# Remover o componente estático de posição para evitar mapa "igual" entre imagens
|
| 130 |
+
try:
|
| 131 |
+
pos = model.embedding.pos_embed.squeeze(0).to(tokens.dtype).cpu() # [L, D]
|
| 132 |
+
if pos.shape == tokens.shape:
|
| 133 |
+
tokens = tokens - pos
|
| 134 |
+
except Exception:
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
# Agregar energia por patch (L2) e remontar 14x14
|
| 138 |
+
L, D = tokens.shape
|
| 139 |
+
side = int(L ** 0.5) # 14 para 224/16
|
| 140 |
+
heat_vec = torch.sqrt((tokens**2).mean(dim=1)) # [L]
|
| 141 |
+
heat = heat_vec.reshape(side, side).numpy()
|
| 142 |
+
|
| 143 |
+
# Normalizar e upsample p/ 224x224 (nearest para simplicidade)
|
| 144 |
+
heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8)
|
| 145 |
+
heat224 = np.kron(heat, np.ones((224 // side, 224 // side)))
|
| 146 |
+
|
| 147 |
+
# Overlay sobre a imagem original
|
| 148 |
+
plt.figure(figsize=(5, 5))
|
| 149 |
+
plt.imshow(img_np, cmap="gray")
|
| 150 |
+
plt.imshow(heat224, cmap="inferno", alpha=0.5, vmin=0.0, vmax=1.0)
|
| 151 |
+
plt.axis("off")
|
| 152 |
plt.tight_layout()
|
| 153 |
return plt.gcf()
|
| 154 |
|