Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import crystal_toolkit.components as ctc | |
| import dash | |
| import dash_mp_components as dmp | |
| import numpy as np | |
| import periodictable | |
| from crystal_toolkit.settings import SETTINGS | |
| from dash import dcc, html | |
| from dash.dependencies import Input, Output, State | |
| from datasets import load_dataset | |
| from pymatgen.core import Structure | |
| from pymatgen.ext.matproj import MPRester | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| top_k = 500 | |
| # Load only the train split of the dataset | |
| dataset = load_dataset( | |
| "LeMaterial/leDataset", | |
| token=HF_TOKEN, | |
| split="train", | |
| columns=[ | |
| "lattice_vectors", | |
| "species_at_sites", | |
| "cartesian_site_positions", | |
| "energy", | |
| "energy_corrected", | |
| "immutable_id", | |
| "elements", | |
| "functional", | |
| "stress_tensor", | |
| "magnetic_moments", | |
| "forces", | |
| "band_gap_direct", | |
| "band_gap_indirect", | |
| "dos_ef", | |
| "charges", | |
| "functional", | |
| "chemical_formula_reduced", | |
| "chemical_formula_descriptive", | |
| "total_magnetization", | |
| ], | |
| ) | |
| display_columns = [ | |
| "chemical_formula_descriptive", | |
| "functional", | |
| "immutable_id", | |
| "energy", | |
| ] | |
| display_names = { | |
| "chemical_formula_descriptive": "Formula", | |
| "functional": "Functional", | |
| "immutable_id": "Material ID", | |
| "energy": "Energy (eV)", | |
| } | |
| mapping_table_idx_dataset_idx = {} | |
| map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)} | |
| dataset_index = np.zeros((len(dataset), 118)) | |
| train_df = dataset.to_pandas() | |
| pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)") | |
| extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern) | |
| extracted["count"] = extracted["count"].replace("", "1").astype(int) | |
| wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting | |
| index="level_0", # original row index | |
| columns="element", | |
| values="count", | |
| aggfunc="sum", | |
| fill_value=0, | |
| ) | |
| all_elements = [el.symbol for el in periodictable.elements] # full element list | |
| wide_df = wide_df.reindex(columns=all_elements, fill_value=0) | |
| dataset_index = wide_df.values | |
| dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None] | |
| dataset_index = ( | |
| dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] | |
| ) # Normalize vectors | |
| # Initialize the Dash app | |
| app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH) | |
| server = app.server # Expose the server for deployment | |
| # Define the app layout | |
| layout = html.Div( | |
| [ | |
| html.H1( | |
| html.B("Interactive Crystal Viewer"), | |
| style={"textAlign": "center", "margin-top": "20px"}, | |
| ), | |
| html.Div( | |
| [ | |
| html.Div( | |
| id="structure-container", | |
| style={ | |
| "width": "48%", | |
| "display": "inline-block", | |
| "verticalAlign": "top", | |
| }, | |
| ), | |
| html.Div( | |
| id="properties-container", | |
| style={ | |
| "width": "48%", | |
| "display": "inline-block", | |
| "paddingLeft": "4%", | |
| "verticalAlign": "top", | |
| }, | |
| ), | |
| ], | |
| style={"margin-top": "20px"}, | |
| ), | |
| html.Div( | |
| [ | |
| html.Div( | |
| [ | |
| html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"), | |
| dmp.MaterialsInput( | |
| allowedInputTypes=["elements", "formula"], | |
| hidePeriodicTable=False, | |
| periodicTableMode="toggle", | |
| hideWildcardButton=True, | |
| showSubmitButton=True, | |
| submitButtonText="Search", | |
| type="elements", | |
| id="materials-input", | |
| ), | |
| ], | |
| style={ | |
| "width": "100%", | |
| "display": "inline-block", | |
| "verticalAlign": "top", | |
| }, | |
| ), | |
| ], | |
| style={"margin-top": "20px", "margin-bottom": "20px"}, | |
| ), | |
| html.Div( | |
| [ | |
| html.Label("Select Material to Display"), | |
| # dcc.Dropdown( | |
| # id="material-dropdown", | |
| # options=[], # Empty options initially | |
| # value=None, | |
| # ), | |
| dash.dash_table.DataTable( | |
| id="table", | |
| columns=[ | |
| ( | |
| {"name": display_names[col], "id": col} | |
| if col != "energy" | |
| else { | |
| "name": display_names[col], | |
| "id": col, | |
| "type": "numeric", | |
| "format": {"specifier": ".2f"}, | |
| } | |
| ) | |
| for col in display_columns | |
| ], | |
| data=[{}], | |
| style_table={ | |
| "overflowX": "auto", | |
| "height": "220px", | |
| "overflowY": "auto", | |
| }, | |
| style_header={"fontWeight": "bold", "backgroundColor": "lightgrey"}, | |
| style_cell={"textAlign": "center"}, | |
| style_as_list_view=True, | |
| ), | |
| ], | |
| style={"margin-top": "30px"}, | |
| ), | |
| # html.Button("Display Material", id="display-button", n_clicks=0), | |
| ], | |
| style={ | |
| "margin-left": "10px", | |
| "margin-right": "10px", | |
| }, | |
| ) | |
| def search_materials(query): | |
| query_vector = np.zeros(118) | |
| if "," in query: | |
| element_list = [el.strip() for el in query.split(",")] | |
| for el in element_list: | |
| query_vector[map_periodic_table[el]] = 1 | |
| else: | |
| # Formula | |
| import re | |
| matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query) | |
| for el, numb in matches: | |
| numb = int(numb) if numb else 1 | |
| query_vector[map_periodic_table[el]] = numb | |
| similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector)) | |
| indices = np.argsort(similarity)[::-1][:top_k] | |
| options = [dataset[int(i)] for i in indices] | |
| mapping_table_idx_dataset_idx.clear() | |
| for i, idx in enumerate(indices): | |
| mapping_table_idx_dataset_idx[int(i)] = int(idx) | |
| return options | |
| # Callback to update the table based on search | |
| def on_submit_materials_input(n_clicks, query): | |
| if n_clicks is None or not query: | |
| return [] | |
| entries = search_materials(query) | |
| return [{col: entry[col] for col in display_columns} for entry in entries] | |
| # Callback to display the selected material | |
| def display_material(active_cell): | |
| if not active_cell: | |
| return "", "" | |
| idx_active = active_cell["row"] | |
| row = dataset[mapping_table_idx_dataset_idx[idx_active]] | |
| structure = Structure( | |
| [x for y in row["lattice_vectors"] for x in y], | |
| row["species_at_sites"], | |
| row["cartesian_site_positions"], | |
| coords_are_cartesian=True, | |
| ) | |
| # Create the StructureMoleculeComponent | |
| structure_component = ctc.StructureMoleculeComponent(structure) | |
| # Extract key properties | |
| properties = { | |
| "Material ID": row["immutable_id"], | |
| "Formula": row["chemical_formula_descriptive"], | |
| "Energy per atom (eV/atom)": row["energy"] / len(row["species_at_sites"]), | |
| "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], | |
| "Total Magnetization (μB/f.u.)": row["total_magnetization"], | |
| } | |
| # Format properties as an HTML table | |
| properties_html = html.Table( | |
| [ | |
| html.Tbody( | |
| [ | |
| html.Tr([html.Th(key), html.Td(str(value))]) | |
| for key, value in properties.items() | |
| ] | |
| ) | |
| ], | |
| style={ | |
| "border": "1px solid black", | |
| "width": "100%", | |
| "borderCollapse": "collapse", | |
| }, | |
| ) | |
| return structure_component.layout(), properties_html | |
| # Register crystal toolkit with the app | |
| ctc.register_crystal_toolkit(app, layout) | |
| if __name__ == "__main__": | |
| app.run_server(debug=True, port=7860, host="0.0.0.0") | |