Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import pickle | |
from huggingface_hub import hf_hub_download | |
def load_artifacts(): | |
model_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="model.pkl") | |
target_encoder_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="target_encoder.pkl") | |
label_encoders_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="label_encoders.pkl") | |
numeric_columns_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="numeric_columns.pkl") | |
categorical_columns_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="categorical_columns.pkl") | |
with open(model_path, 'rb') as f: | |
model = pickle.load(f) | |
with open(target_encoder_path, 'rb') as f: | |
target_encoder = pickle.load(f) | |
with open(label_encoders_path, 'rb') as f: | |
label_encoders = pickle.load(f) | |
with open(numeric_columns_path, 'rb') as f: | |
numeric_columns = pickle.load(f) | |
with open(categorical_columns_path, 'rb') as f: | |
categorical_columns = pickle.load(f) | |
return { | |
'model': model, | |
'target_encoder': target_encoder, | |
'feature_encoders': label_encoders, | |
'numeric_columns': numeric_columns, | |
'categorical_columns': categorical_columns | |
} | |
def predict_mushroom(cap_diameter, gill_spacing, stem_root, veil_color, season): | |
artifacts = load_artifacts() | |
model = artifacts['model'] | |
target_encoder = artifacts['target_encoder'] | |
feature_encoders = artifacts['feature_encoders'] | |
input_data = { | |
'cap-diameter': [cap_diameter], | |
'cap-shape': ['x'], # varsayılan değer | |
'cap-surface': ['s'], # varsayılan değer | |
'cap-color': ['n'], # varsayılan değer | |
'does-bruise-or-bleed': ['f'], # varsayılan değer | |
'gill-attachment': ['f'], # varsayılan değer | |
'gill-spacing': [gill_spacing], | |
'gill-color': ['n'], # varsayılan değer | |
'stem-height': [10.0], # varsayılan değer | |
'stem-width': [5.0], # varsayılan değer | |
'stem-root': [stem_root], | |
'stem-surface': ['s'], # varsayılan değer | |
'stem-color': ['w'], # varsayılan değer | |
'veil-type': ['p'], # varsayılan değer | |
'veil-color': [veil_color], | |
'has-ring': ['t'], # varsayılan değer | |
'ring-type': ['p'], # varsayılan değer | |
'spore-print-color': ['n'], # varsayılan değer | |
'habitat': ['u'], # varsayılan değer | |
'season': [season] | |
} | |
df = pd.DataFrame(input_data) | |
categorical_features = [col for col in df.columns if col not in ['cap-diameter', 'stem-height', 'stem-width']] | |
for col in categorical_features: | |
if col in feature_encoders: | |
df[col] = feature_encoders[col].transform(df[col].astype(str)) | |
prediction = model.predict(df)[0] | |
class_prediction = target_encoder.inverse_transform([prediction])[0] | |
return "Edible" if class_prediction == 'e' else "Poisonous" | |
iface = gr.Interface( | |
fn=predict_mushroom, | |
inputs=[ | |
gr.Slider( | |
minimum=2.0, | |
maximum=20.0, | |
value=10.0, | |
step=0.5, | |
label="Cap Diameter (cm)", | |
info="Slide to select mushroom cap width" | |
), | |
gr.Dropdown( | |
choices=['c', 'w'], | |
value='c', | |
label="Gill Spacing", | |
info="c: close, w: wide" | |
), | |
gr.Dropdown( | |
choices=['b', 'e', 'c', 'r'], | |
value='b', | |
label="Stem Root", | |
info="b: bulbous, e: equal, c: club, r: rooted" | |
), | |
gr.Dropdown( | |
choices=['w', 'n', 'o', 'y'], | |
value='w', | |
label="Veil Color", | |
info="w: white, n: brown, o: orange, y: yellow" | |
), | |
gr.Dropdown( | |
choices=['s', 'u', 'a', 'w'], | |
value='s', | |
label="Season", | |
info="s: spring, u: summer, a: autumn, w: winter" | |
) | |
], | |
outputs=gr.Label(label="Prediction"), | |
title="Mushroom Edibility Classifier", | |
description=""" | |
Predict if a mushroom is edible or poisonous using its 5 most important characteristics. | |
WARNING: This is a demonstration only. Never eat wild mushrooms based on this prediction! | |
""" | |
) | |
if __name__ == "__main__": | |
iface.launch() |