Spaces:
Sleeping
Sleeping
add comments
Browse files- src/rubik/moves.py +3 -0
src/rubik/moves.py
CHANGED
|
@@ -135,6 +135,7 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
|
|
| 135 |
inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
|
| 136 |
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
|
| 137 |
|
|
|
|
| 138 |
local_to_total = dict(enumerate(indices.tolist()))
|
| 139 |
total_to_local = {ind: i for i, ind in local_to_total.items()}
|
| 140 |
|
|
@@ -142,6 +143,8 @@ def build_permunation_tensor(size: int, axis: int, slice: int, inverse: int) ->
|
|
| 142 |
total_perm = {
|
| 143 |
i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
|
| 144 |
}
|
|
|
|
|
|
|
| 145 |
perm_indices = torch.tensor(
|
| 146 |
[[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
|
| 147 |
dtype=INT8,
|
|
|
|
| 135 |
inputs = inputs.transpose(0, 1).tolist() # size = (n, 4)
|
| 136 |
outputs = outputs.transpose(0, 1).tolist() # size = (n, 4)
|
| 137 |
|
| 138 |
+
# compute position-based permutation of colors equivalent to rotation converting inputs into outputs
|
| 139 |
local_to_total = dict(enumerate(indices.tolist()))
|
| 140 |
total_to_local = {ind: i for i, ind in local_to_total.items()}
|
| 141 |
|
|
|
|
| 143 |
total_perm = {
|
| 144 |
i: (i if i not in total_to_local else local_to_total[local_perm[total_to_local[i]]]) for i in range(length)
|
| 145 |
}
|
| 146 |
+
|
| 147 |
+
# convert permutation dict into sparse tensor
|
| 148 |
perm_indices = torch.tensor(
|
| 149 |
[[axis] * length, [slice] * length, [inverse] * length, list(total_perm.keys()), list(total_perm.values())],
|
| 150 |
dtype=INT8,
|