import gradio as gr import pandas as pd import pickle # Load the trained models with open("cyber_iot_8class.pkl", "rb") as model_file_8class: model_8class = pickle.load(model_file_8class) with open("cyber_iot_34class.pkl", "rb") as model_file_34class: model_34class = pickle.load(model_file_34class) # Use 15 features, drop other features as part of feature engineering dropped_features = [ "flow_duration", "urg_count", "Duration", "Number", "Weight", "Srate", "Rate", "ack_count", "fin_count", "TCP", "ack_flag_number", "syn_flag_number", "UDP", "ICMP", "fin_flag_number", "rst_flag_number", "psh_flag_number", "HTTP", "HTTPS", "IPv", "LLC", "SSH", "ece_flag_number", "SMTP", "ARP", "DHCP", "DNS", "cwr_flag_number", "Drate", "Telnet", "IRC" ] def make_prediction(file): if file: # Read the uploaded CSV file data = pd.read_csv(file.name) data = data.drop(columns=dropped_features, axis=1) # Make initial prediction using the 8-class model initial_prediction = model_8class.predict(data) # Convert the prediction results to a list of strings result_benign_malicious = [str(pred) for pred in initial_prediction] #result_benign_malicious = str(initial_prediction) # Initialize an empty list to store detailed predictions result_type_of_attack = [] # Check the predictions based on the given conditions if all(pred == "Benign" for pred in result_benign_malicious): # Condition 1: If all predictions are "Benign" result_type_of_attack.append("N/A") gr.Info("All connections are fine/Benign") elif "Benign" in result_benign_malicious and len(set(result_benign_malicious)) > 1: # Condition 2: If "Benign" and other attacks are present # Remove the Benign entries from the dataframe for 2nd level check for i in result_benign_malicious: if i == "Benign": row = result_benign_malicious.index(i) print("row = ", row) data.drop(data.index[row], inplace=True) detailed_prediction = model_34class.predict(data) result_type_of_attack.append(str(detailed_prediction)) gr.Warning("Malicious attack detected!") elif "Benign" not in result_benign_malicious: # Condition 3: If all are attack connections detailed_prediction = model_34class.predict(data) result_type_of_attack.append(str(detailed_prediction)) gr.Warning("Malicious attack detected!") # Convert the results to strings for display in the Gradio interface result_benign_malicious_str = ', '.join(result_benign_malicious) result_type_of_attack_str = ', '.join(result_type_of_attack) return result_benign_malicious_str, result_type_of_attack_str.strip("[]").replace("'", "").split() else: return "", "" # Define the Gradio interface with gr.Blocks() as cyberIoT: # Add title and description using Markdown gr.Markdown("# IoT Cybersecurity Connection Prediction") gr.Markdown("Upload a .csv file to check whether it is a Non-Attack(Benign) or Attack(Non-Benign) connection. If Malicious, the type of attack will be further classified.") with gr.Row(): with gr.Column(): file_input = gr.File(label="Upload PCAP CSV File", file_types=['csv']) #gr.ClearButton(components=file_input) with gr.Column(): output_benign_malicious = gr.Textbox(label="Benign/Malicious(BruteForce,DDoS,DoS,Mirai,Recon,Spoofing,Web)", placeholder="Initial prediction will appear here...") output_type_of_attack = gr.Textbox(label="Type of Attack", placeholder="Further classification will appear here if applicable...") # When file is uploaded, make predictions and display the results file_input.change(make_prediction, inputs=file_input, outputs=[output_benign_malicious, output_type_of_attack]) gr.ClearButton(components=[file_input, output_benign_malicious,output_type_of_attack]) # Launch the Gradio app cyberIoT.launch(share=True)