Spaces:
Runtime error
Runtime error
Commit
·
dd97a63
1
Parent(s):
97dc735
Update free_lunch_utils.py
Browse files- free_lunch_utils.py +25 -2
free_lunch_utils.py
CHANGED
|
@@ -93,13 +93,36 @@ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
|
| 93 |
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 94 |
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# --------------- FreeU code -----------------------
|
| 97 |
# Only operate on the first two stages
|
| 98 |
if hidden_states.shape[1] == 1280:
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 101 |
if hidden_states.shape[1] == 640:
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 104 |
# ---------------------------------------------------------
|
| 105 |
|
|
|
|
| 93 |
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 94 |
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
|
| 95 |
|
| 96 |
+
# # --------------- FreeU code -----------------------
|
| 97 |
+
# # Only operate on the first two stages
|
| 98 |
+
# if hidden_states.shape[1] == 1280:
|
| 99 |
+
# hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 100 |
+
# res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 101 |
+
# if hidden_states.shape[1] == 640:
|
| 102 |
+
# hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 103 |
+
# res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 104 |
+
# # ---------------------------------------------------------
|
| 105 |
+
|
| 106 |
# --------------- FreeU code -----------------------
|
| 107 |
# Only operate on the first two stages
|
| 108 |
if hidden_states.shape[1] == 1280:
|
| 109 |
+
hidden_mean = hidden_states.mean(1).unsqueeze(1)
|
| 110 |
+
B = hidden_mean.shape[0]
|
| 111 |
+
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 112 |
+
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 113 |
+
|
| 114 |
+
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
| 115 |
+
|
| 116 |
+
hidden_states[:,:640] = hidden_states[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
|
| 117 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 118 |
if hidden_states.shape[1] == 640:
|
| 119 |
+
hidden_mean = hidden_states.mean(1).unsqueeze(1)
|
| 120 |
+
B = hidden_mean.shape[0]
|
| 121 |
+
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 122 |
+
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
| 123 |
+
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
| 124 |
+
|
| 125 |
+
hidden_states[:,:320] = hidden_states[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
|
| 126 |
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 127 |
# ---------------------------------------------------------
|
| 128 |
|