import gradio as gr
import pandas as pd
import numpy as np
import pickle
from huggingface_hub import hf_hub_download

def load_artifacts():
    model_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="model.pkl")
    target_encoder_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="target_encoder.pkl")
    label_encoders_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="label_encoders.pkl")
    numeric_columns_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="numeric_columns.pkl")
    categorical_columns_path = hf_hub_download(repo_id="alperugurcan/poisonous-mushrooms", filename="categorical_columns.pkl")
    
    with open(model_path, 'rb') as f:
        model = pickle.load(f)
    with open(target_encoder_path, 'rb') as f:
        target_encoder = pickle.load(f)
    with open(label_encoders_path, 'rb') as f:
        label_encoders = pickle.load(f)
    with open(numeric_columns_path, 'rb') as f:
        numeric_columns = pickle.load(f)
    with open(categorical_columns_path, 'rb') as f:
        categorical_columns = pickle.load(f)
        
    return {
        'model': model,
        'target_encoder': target_encoder,
        'feature_encoders': label_encoders,
        'numeric_columns': numeric_columns,
        'categorical_columns': categorical_columns
    }

def predict_mushroom(cap_diameter, gill_spacing, stem_root, veil_color, season):
    artifacts = load_artifacts()
    model = artifacts['model']
    target_encoder = artifacts['target_encoder']
    feature_encoders = artifacts['feature_encoders']
    
    input_data = {
        'cap-diameter': [cap_diameter],
        'cap-shape': ['x'],           # varsayılan değer
        'cap-surface': ['s'],         # varsayılan değer
        'cap-color': ['n'],           # varsayılan değer
        'does-bruise-or-bleed': ['f'], # varsayılan değer
        'gill-attachment': ['f'],      # varsayılan değer
        'gill-spacing': [gill_spacing],
        'gill-color': ['n'],          # varsayılan değer
        'stem-height': [10.0],        # varsayılan değer
        'stem-width': [5.0],          # varsayılan değer
        'stem-root': [stem_root],
        'stem-surface': ['s'],        # varsayılan değer
        'stem-color': ['w'],          # varsayılan değer
        'veil-type': ['p'],           # varsayılan değer
        'veil-color': [veil_color],
        'has-ring': ['t'],            # varsayılan değer
        'ring-type': ['p'],           # varsayılan değer
        'spore-print-color': ['n'],   # varsayılan değer
        'habitat': ['u'],             # varsayılan değer
        'season': [season]
    }
    
    df = pd.DataFrame(input_data)
    
    categorical_features = [col for col in df.columns if col not in ['cap-diameter', 'stem-height', 'stem-width']]
    for col in categorical_features:
        if col in feature_encoders:
            df[col] = feature_encoders[col].transform(df[col].astype(str))
    
    prediction = model.predict(df)[0]
    class_prediction = target_encoder.inverse_transform([prediction])[0]
    
    return "Edible" if class_prediction == 'e' else "Poisonous"

iface = gr.Interface(
    fn=predict_mushroom,
    inputs=[
        gr.Slider(
            minimum=2.0,
            maximum=20.0,
            value=10.0,
            step=0.5,
            label="Cap Diameter (cm)",
            info="Slide to select mushroom cap width"
        ),
        gr.Dropdown(
            choices=['c', 'w'],
            value='c',
            label="Gill Spacing",
            info="c: close, w: wide"
        ),
        gr.Dropdown(
            choices=['b', 'e', 'c', 'r'],
            value='b',
            label="Stem Root",
            info="b: bulbous, e: equal, c: club, r: rooted"
        ),
        gr.Dropdown(
            choices=['w', 'n', 'o', 'y'],
            value='w',
            label="Veil Color",
            info="w: white, n: brown, o: orange, y: yellow"
        ),
        gr.Dropdown(
            choices=['s', 'u', 'a', 'w'],
            value='s',
            label="Season",
            info="s: spring, u: summer, a: autumn, w: winter"
        )
    ],
    outputs=gr.Label(label="Prediction"),
    title="Mushroom Edibility Classifier",
    description="""
    Predict if a mushroom is edible or poisonous using its 5 most important characteristics.
    WARNING: This is a demonstration only. Never eat wild mushrooms based on this prediction!
    """
)

if __name__ == "__main__":
    iface.launch()