Cyber_IoT / app.py
wookimchye's picture
Update app.py
5878037 verified
import gradio as gr
import pandas as pd
import pickle
# Load the trained models
with open("cyber_iot_8class.pkl", "rb") as model_file_8class:
model_8class = pickle.load(model_file_8class)
with open("cyber_iot_34class.pkl", "rb") as model_file_34class:
model_34class = pickle.load(model_file_34class)
# Use 15 features, drop other features as part of feature engineering
dropped_features = [
"flow_duration",
"urg_count",
"Duration",
"Number",
"Weight",
"Srate",
"Rate",
"ack_count",
"fin_count",
"TCP",
"ack_flag_number",
"syn_flag_number",
"UDP",
"ICMP",
"fin_flag_number",
"rst_flag_number",
"psh_flag_number",
"HTTP",
"HTTPS",
"IPv",
"LLC",
"SSH",
"ece_flag_number",
"SMTP",
"ARP",
"DHCP",
"DNS",
"cwr_flag_number",
"Drate",
"Telnet",
"IRC"
]
def make_prediction(file):
if file:
# Read the uploaded CSV file
data = pd.read_csv(file.name)
data = data.drop(columns=dropped_features, axis=1)
# Make initial prediction using the 8-class model
initial_prediction = model_8class.predict(data)
# Convert the prediction results to a list of strings
result_benign_malicious = [str(pred) for pred in initial_prediction]
#result_benign_malicious = str(initial_prediction)
# Initialize an empty list to store detailed predictions
result_type_of_attack = []
# Check the predictions based on the given conditions
if all(pred == "Benign" for pred in result_benign_malicious):
# Condition 1: If all predictions are "Benign"
result_type_of_attack.append("N/A")
gr.Info("All connections are fine/Benign")
elif "Benign" in result_benign_malicious and len(set(result_benign_malicious)) > 1:
# Condition 2: If "Benign" and other attacks are present
# Remove the Benign entries from the dataframe for 2nd level check
for i in result_benign_malicious:
if i == "Benign":
row = result_benign_malicious.index(i)
print("row = ", row)
data.drop(data.index[row], inplace=True)
detailed_prediction = model_34class.predict(data)
result_type_of_attack.append(str(detailed_prediction))
gr.Warning("Malicious attack detected!")
elif "Benign" not in result_benign_malicious:
# Condition 3: If all are attack connections
detailed_prediction = model_34class.predict(data)
result_type_of_attack.append(str(detailed_prediction))
gr.Warning("Malicious attack detected!")
# Convert the results to strings for display in the Gradio interface
result_benign_malicious_str = ', '.join(result_benign_malicious)
result_type_of_attack_str = ', '.join(result_type_of_attack)
return result_benign_malicious_str, result_type_of_attack_str.strip("[]").replace("'", "").split()
else:
return "", ""
# Define the Gradio interface
with gr.Blocks() as cyberIoT:
# Add title and description using Markdown
gr.Markdown("# IoT Cybersecurity Connection Prediction")
gr.Markdown("Upload a .csv file to check whether it is a Non-Attack(Benign) or Attack(Non-Benign) connection. If Malicious, the type of attack will be further classified.")
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload PCAP CSV File", file_types=['csv'])
#gr.ClearButton(components=file_input)
with gr.Column():
output_benign_malicious = gr.Textbox(label="Benign/Malicious(BruteForce,DDoS,DoS,Mirai,Recon,Spoofing,Web)", placeholder="Initial prediction will appear here...")
output_type_of_attack = gr.Textbox(label="Type of Attack", placeholder="Further classification will appear here if applicable...")
# When file is uploaded, make predictions and display the results
file_input.change(make_prediction, inputs=file_input, outputs=[output_benign_malicious, output_type_of_attack])
gr.ClearButton(components=[file_input, output_benign_malicious,output_type_of_attack])
# Launch the Gradio app
cyberIoT.launch(share=True)