alan918727 commited on
Commit
633ef66
·
verified ·
1 Parent(s): dfb48ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -12
app.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  from matminer.featurizers.conversions import StrToComposition
6
  from matminer.featurizers.base import MultipleFeaturizer
7
  from matminer.featurizers import composition as cf
 
8
 
9
  # Define feature calculators
10
  feature_calculators = MultipleFeaturizer([
@@ -49,24 +50,85 @@ def get_features_single(formula):
49
  mlmd = mlmdd_single(formula)
50
  ext_mag = generate_single(formula)
51
  result = pd.concat([ext_mag, mlmd], axis=1)
52
- return result.iloc[:, :10] # Select the first 10 features
53
 
54
  def predict_features(formula):
55
- """
56
- Gradio prediction function to return top 10 features for the input formula.
57
- """
58
  try:
 
59
  df = get_features_single(formula)
60
- return df.head(1).to_dict(orient="records")[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  except Exception as e:
62
  return {"Error": str(e)}
63
 
64
- # Define Gradio interface
65
- iface = gr.Interface(
66
- fn=predict_features,
67
- inputs=gr.Textbox(label="Enter Chemical Formula", placeholder="E.g., BrIPtZrS2"),
68
- outputs=gr.JSON(label="Top 10 Features")
69
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  if __name__ == "__main__":
72
- iface.launch()
 
5
  from matminer.featurizers.conversions import StrToComposition
6
  from matminer.featurizers.base import MultipleFeaturizer
7
  from matminer.featurizers import composition as cf
8
+ import joblib
9
 
10
  # Define feature calculators
11
  feature_calculators = MultipleFeaturizer([
 
50
  mlmd = mlmdd_single(formula)
51
  ext_mag = generate_single(formula)
52
  result = pd.concat([ext_mag, mlmd], axis=1)
53
+ return result
54
 
55
  def predict_features(formula):
 
 
 
56
  try:
57
+ # Generate features for the input formula
58
  df = get_features_single(formula)
59
+ X_user = df.iloc[:, 2:].fillna(0)
60
+
61
+ # Load saved model
62
+ model_path = "saved_model/lgbm_model.pkl"
63
+ loaded_model = joblib.load(model_path)
64
+
65
+ # Load saved LabelEncoder
66
+ label_encoder_path = "saved_model/label_encoder.pkl"
67
+ label_encoder = joblib.load(label_encoder_path)
68
+
69
+ # Load Layer Group Mapping
70
+ mapping_path = "saved_model/layer_group_mapping.pkl"
71
+ layer_group_mapping = joblib.load(mapping_path)
72
+
73
+ # Predict probabilities
74
+ prediction_probs = loaded_model.predict_proba(X_user)[0]
75
+
76
+ # Get top 5 predictions
77
+ top_5_indices = np.argsort(prediction_probs)[-5:][::-1]
78
+ top_5_numbers = label_encoder.inverse_transform(top_5_indices)
79
+ top_5_names = [layer_group_mapping.get(num, "Unknown Layer Group") for num in top_5_numbers]
80
+ top_5_probs = prediction_probs[top_5_indices]
81
+
82
+ # Prepare top 5 results as list of lists
83
+ top_5_results = [
84
+ [num, name, f"{prob:.2%}"]
85
+ for num, name, prob in zip(top_5_numbers, top_5_names, top_5_probs)
86
+ ]
87
+
88
+ # Predict using the loaded model
89
+ prediction = loaded_model.predict(X_user)
90
+
91
+ # Decode Layer Group Number
92
+ decoded_prediction = label_encoder.inverse_transform(prediction)[0]
93
+
94
+ # Map to Layer Group Name
95
+ layer_group_name = layer_group_mapping.get(decoded_prediction, "Unknown Layer Group")
96
+
97
+ return {
98
+ "Formula": formula,
99
+ "Predicted Layer Group Number": decoded_prediction,
100
+ "Predicted Layer Group Name": layer_group_name,
101
+ "Top 5 Predictions": top_5_results
102
+ }
103
  except Exception as e:
104
  return {"Error": str(e)}
105
 
106
+
107
+ # Define a more visually appealing Gradio interface
108
+ def gradio_ui():
109
+ with gr.Blocks() as demo:
110
+ gr.Markdown("""
111
+ # 🔬 2D Material Layer Group Predictor
112
+ Enter a chemical formula of 2D Material below to get the predicted layer group information along with the top 5 probable groups.
113
+ """)
114
+
115
+ with gr.Row():
116
+ formula_input = gr.Textbox(label="Enter Chemical Formula", placeholder="E.g., BrIPtZrS2")
117
+ predict_button = gr.Button("Predict")
118
+
119
+ with gr.Row():
120
+ final_prediction = gr.Textbox(label="Final Prediction (Layer Group Name)", interactive=False)
121
+
122
+ with gr.Row():
123
+ top_5_table = gr.Dataframe(headers=["Layer Group Number", "Layer Group Name", "Probability"], interactive=False)
124
+
125
+ def update_output(formula):
126
+ result = predict_features(formula)
127
+ return result["Predicted Layer Group Name"], result["Top 5 Predictions"]
128
+
129
+ predict_button.click(fn=update_output, inputs=[formula_input], outputs=[final_prediction, top_5_table])
130
+
131
+ return demo
132
 
133
  if __name__ == "__main__":
134
+ gradio_ui().launch()