miguelmuzo's picture
Upload 426 files
3de0e37 verified
# -*- coding: utf-8 -*-
"""
# File name: get_mean_code.py
# Time : 2022/2/22 17:22
# Author: [email protected]
# Description:
"""
from glob import glob
import numpy as np
import os
layers_list = ['ACE.npy']
style_list = []
for cat_i in range(19):
for layer_j in layers_list:
tmp_list = glob('styles_test/style_codes/*/' + str(cat_i) + '/' + layer_j)
style_list = []
for k in tmp_list:
style_list.append(np.load(k))
if len(style_list) > 0:
style_list = np.array(style_list)
style_list_norm2 = np.linalg.norm(style_list, axis=1, keepdims=True) ** 2
dist_matrix = (style_list_norm2 + style_list_norm2.T -2 * style_list @ style_list.T)
dist_matrix[dist_matrix < 0] = 0
dist_matrix = dist_matrix ** 0.5
median_idx = dist_matrix.sum(axis=1).argmin()
feature = style_list[median_idx]
save_folder = os.path.join('styles_test/mean_style_code/median', str(cat_i))
if not os.path.exists(save_folder):
os.makedirs(save_folder)
save_name = os.path.join(save_folder, layer_j)
np.save(save_name, feature)
print(100)