Spaces:
Running
Running
Use sparse matrices
Browse files- app.py +5 -1
- data_utils.py +24 -21
app.py
CHANGED
|
@@ -18,7 +18,6 @@ from components import (
|
|
| 18 |
get_upload_div,
|
| 19 |
)
|
| 20 |
from data_utils import (
|
| 21 |
-
build_embeddings_index,
|
| 22 |
build_formula_index,
|
| 23 |
get_crystal_plot,
|
| 24 |
get_dataset,
|
|
@@ -29,6 +28,11 @@ from data_utils import (
|
|
| 29 |
EMPTY_DATA = False
|
| 30 |
CACHE_PATH = None
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
dataset = get_dataset()
|
| 33 |
|
| 34 |
display_columns_query = [
|
|
|
|
| 18 |
get_upload_div,
|
| 19 |
)
|
| 20 |
from data_utils import (
|
|
|
|
| 21 |
build_formula_index,
|
| 22 |
get_crystal_plot,
|
| 23 |
get_dataset,
|
|
|
|
| 28 |
EMPTY_DATA = False
|
| 29 |
CACHE_PATH = None
|
| 30 |
|
| 31 |
+
if CACHE_PATH is not None:
|
| 32 |
+
import os
|
| 33 |
+
|
| 34 |
+
os.makedirs(CACHE_PATH, exist_ok=True)
|
| 35 |
+
|
| 36 |
dataset = get_dataset()
|
| 37 |
|
| 38 |
display_columns_query = [
|
data_utils.py
CHANGED
|
@@ -72,6 +72,7 @@ mapping_table_idx_dataset_idx = {}
|
|
| 72 |
|
| 73 |
|
| 74 |
def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
|
|
|
|
| 75 |
if empty_data:
|
| 76 |
return np.zeros((1, 1)), {}
|
| 77 |
|
|
@@ -80,40 +81,42 @@ def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=F
|
|
| 80 |
use_dataset = dataset.select(index_range)
|
| 81 |
|
| 82 |
# Preprocessing step to create an index for the dataset
|
| 83 |
-
|
| 84 |
-
train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
|
| 85 |
|
| 86 |
-
|
|
|
|
|
|
|
| 87 |
else:
|
| 88 |
train_df = use_dataset.select_columns(
|
| 89 |
-
["
|
| 90 |
).to_pandas()
|
| 91 |
|
| 92 |
-
|
| 93 |
-
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
|
| 94 |
-
extracted["count"] = extracted["count"].replace("", "1").astype(int)
|
| 95 |
-
|
| 96 |
-
wide_df = (
|
| 97 |
-
extracted.reset_index().pivot_table( # Move index to columns for pivoting
|
| 98 |
-
index="level_0", # original row index
|
| 99 |
-
columns="element",
|
| 100 |
-
values="count",
|
| 101 |
-
aggfunc="sum",
|
| 102 |
-
fill_value=0,
|
| 103 |
-
)
|
| 104 |
-
)
|
| 105 |
|
| 106 |
-
all_elements =
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
|
|
|
|
|
|
|
| 110 |
|
| 111 |
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
|
| 112 |
dataset_index = (
|
| 113 |
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
|
| 114 |
) # Normalize vectors
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
immutable_id_to_idx = train_df["immutable_id"].to_dict()
|
|
|
|
| 117 |
immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
|
| 118 |
|
| 119 |
return dataset_index, immutable_id_to_idx
|
|
@@ -162,7 +165,7 @@ def search_materials(
|
|
| 162 |
numb = int(numb) if numb else 1
|
| 163 |
query_vector[map_periodic_table[el]] = numb
|
| 164 |
|
| 165 |
-
similarity =
|
| 166 |
indices = np.argsort(similarity)[::-1][:top_k]
|
| 167 |
|
| 168 |
options = [dataset[int(i)] for i in indices]
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False):
|
| 75 |
+
print("Building formula index")
|
| 76 |
if empty_data:
|
| 77 |
return np.zeros((1, 1)), {}
|
| 78 |
|
|
|
|
| 81 |
use_dataset = dataset.select(index_range)
|
| 82 |
|
| 83 |
# Preprocessing step to create an index for the dataset
|
| 84 |
+
from scipy.sparse import load_npz
|
|
|
|
| 85 |
|
| 86 |
+
if cache_path is not None and os.path.exists(f"{cache_path}/train_df.pkl"):
|
| 87 |
+
train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb"))
|
| 88 |
+
dataset_index = load_npz(f"{cache_path}/dataset_index.npz")
|
| 89 |
else:
|
| 90 |
train_df = use_dataset.select_columns(
|
| 91 |
+
["species_at_sites", "immutable_id", "functional"]
|
| 92 |
).to_pandas()
|
| 93 |
|
| 94 |
+
import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
all_elements = {
|
| 97 |
+
str(el.symbol): i for i, el in enumerate(periodictable.elements)
|
| 98 |
+
} # full element list
|
| 99 |
+
dataset_index = np.zeros((len(train_df), len(all_elements)))
|
| 100 |
|
| 101 |
+
for idx, species in tqdm.tqdm(enumerate(train_df["species_at_sites"].values)):
|
| 102 |
+
for el in species:
|
| 103 |
+
dataset_index[idx, all_elements[el]] += 1
|
| 104 |
|
| 105 |
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
|
| 106 |
dataset_index = (
|
| 107 |
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
|
| 108 |
) # Normalize vectors
|
| 109 |
|
| 110 |
+
from scipy.sparse import csr_matrix, save_npz
|
| 111 |
+
|
| 112 |
+
dataset_index = csr_matrix(dataset_index)
|
| 113 |
+
|
| 114 |
+
if cache_path is not None:
|
| 115 |
+
pickle.dump(train_df, open(f"{cache_path}/train_df.pkl", "wb"))
|
| 116 |
+
save_npz(f"{cache_path}/dataset_index.npz", dataset_index)
|
| 117 |
+
|
| 118 |
immutable_id_to_idx = train_df["immutable_id"].to_dict()
|
| 119 |
+
del train_df
|
| 120 |
immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()}
|
| 121 |
|
| 122 |
return dataset_index, immutable_id_to_idx
|
|
|
|
| 165 |
numb = int(numb) if numb else 1
|
| 166 |
query_vector[map_periodic_table[el]] = numb
|
| 167 |
|
| 168 |
+
similarity = dataset_index.dot(query_vector) / (np.linalg.norm(query_vector))
|
| 169 |
indices = np.argsort(similarity)[::-1][:top_k]
|
| 170 |
|
| 171 |
options = [dataset[int(i)] for i in indices]
|