import streamlit as st |
import pandas as pd |
import numpy as np |
import plotly.express as px |
st.set_page_config(layout="wide") |
import streamlit as st |
import pandas as pd |
import matplotlib.pyplot as plt |
import streamlit as st |
import pandas as pd |
import matplotlib.pyplot as plt |
import streamlit as st |
import pandas as pd |
import matplotlib.pyplot as plt |
import streamlit as st |
import pandas as pd |
import plotly.express as px |
import plotly.graph_objects as go |
st.markdown( |
""" |
<style> |
*{ |
padding:0; |
margin:0; |
} |
.fixed-col { |
position: fixed; |
top: 4rem; |
right: 0; |
width: 30%; |
padding-left: 0rem; |
background: white; |
z-index: 100; |
} |
body { |
margin: 0; |
padding: 0; |
} |
.maint { |
margin: auto; |
margin-bottom:1.5rem; |
} |
.centered-title { |
text-align: center; |
} |
.scroller { |
margin-top: 2rem; /* Adjust as necessary to avoid overlap */ |
} |
</style> |
""", unsafe_allow_html=True |
) |
margins_css = """ |
<style> |
.main > div { |
padding-left: 3rem; |
padding-right:3rem; |
padding-top:0.4rem; |
} |
</style> |
""" |
st.markdown(margins_css, unsafe_allow_html=True) |
models = ['SSD300', 'SSD512', 'DETR'] |
pruning_methods = ['VIB Pruning','Transfer Pruning'] |
datasets = ['VOC','SPARK'] |
hyperparameters = { |
'SSD300': {'Transfer Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-ITPCC-A', '-', '-', 120), ('SSD300-ITPCC-B', '-', '-', 120), ('SSD300-ITPCC-C', '-', '-', 120)], |
'VIB Pruning': [('SSD300-ResNet50', '-', '-', 120), ('SSD300-VIB-v1', "0.0001", 240, 100), ('SSD300-VIB-v2', "0.0002", 240,100)]}, |
'SSD512': {'Transfer Pruning': [('SSD512-ResNet50', '-', '-', 120), ('SSD512-ITPCC-A', '-', '-', 120), ('SSD512-ITPCC-B', '-', '-', 120), ('SSD512-ITPCC-C', '-', '-', 120)], |
'VIB Pruning': [('SSD512-ResNet50', '-', '-', 120),('SSD512-VIB-v1', "0.0003", 200, 100)]}, |
'DETR': {'SPARK': [("DETR-baseline", "-", "-","-", 20), ("DETR-SPARK-A", "-","-", 30, 40), ("DETR-SPARK-B", "-","-", 30, 40)], |
'VOC': [("DETR-baseline", "-", "-","-", 130), ("DETR-VOC-A", "0.0001","0.00001", 80, 200), ("DETR-VOC-B", "0.00005","0.0001", 80, 200)]}, |
} |
results_data = { |
'SSD300': { |
'VIB Pruning':{'model':['SSD300-ResNet50','SSD300-VIB-v1','SSD300-VIB-v2'],'map': ["77.79", "78.71", "77.41"], 'flops': ["11.1", "5.04", "3.49"],'flopsd':['0.0%','54.55%','68.54%'], 'params': ["49.2", "19.84", "11.18"],'paramsd':['0.0%','59.68%','77.28%'],}, |
'Transfer Pruning':{'model':["SSD300-ResNet50",'SSD300-ITPCC-A','SSD300-ITPCC-B','SSD300-ITPCC-C'],'map': ["77.79", "77.86" , "77.06", "75.08"], 'flops': ["11.1", "6.85", "5.08", "3.38"],'flopsd':['0.0%','38.2%','54.2%',"69.5%"], 'params': ["49.2", "32.5", "25.7", "19.4"],'paramsd':['0.0%','33.94%','47.77%',"60.5%"]}, |
}, |
'SSD512': { |
'VIB Pruning':{'model':["SSD512-ResNet50",'SSD512-VIB-v1'],'map': ["80.9","81.43"], 'flops': ["46.24", "9.73"],'flopsd':['0.0%','78.94%'], 'params': ["58.52","27.2"],'paramsd':['0.0%','53.42%'],}, |
'Transfer Pruning':{'model':["SSD512-ResNet50",'SSD512-ITPCC-A','SSD512-ITPCC-B','SSD512-ITPCC-C'],'map': ["80.9","81.05" , "80.45", "78.82"], 'flops': ["46.2", "31.42", "25.6", "20.1"],'flopsd':['0.0%','31.9%','44.6%',"56.5%"], 'params': ["58.5", "41.8", "35.0", "28.7"],'paramsd':['0.0%','28.5%','40.17%',"50.1%"],}, |
}, |
'DETR': {'SPARK':{'model':["DETR-baseline",'DETR-SPARK-A','DETR-SPARK-B'],'map': ["96.77", "94.5", "95.18"], 'flops': ["85", "56", "58"],'flopsd':['0.0%','34.1%','31.7%'], 'params': ["41.2", "23.3", "26.6"],'paramsd':['0.0%','47.3%','45.4%'],}, |
'VOC':{'model':["DETR-baseline",'DETR-VOC-A','DETR-VOC-B'],'map': ["79.34", "77.2", "78.0"], 'flops': ["85", "55", "60"],'flopsd':['0.0%','35.29%','29.41%'], 'params': ["41.2", "21.71", "22.47"],'paramsd':['0.0%','42.65%','35.5%'],}}, |
} |
st.markdown('<h1 class="centered-title">Variational Information bottleneck pruning for Object detection</h1>', unsafe_allow_html=True) |
col1, col2 = st.columns([5.2, 4.8]) |
with col2: |
st.markdown('<div class="fixed-col">', unsafe_allow_html=True) |
st.subheader('Filters') |
model = st.selectbox('Select model:', models) |
if model in ['SSD300', 'SSD512']: |
pruning = st.selectbox('Select pruning method:', pruning_methods) |
hyperparameter_data = hyperparameters[model][pruning] |
else: |
dataset = st.selectbox('Select dataset:', datasets) |
hyperparameter_data = hyperparameters[model][dataset] |
st.markdown('<div class="scroller">', unsafe_allow_html=True) |
st.subheader('Hyperparameters') |
st.markdown('<br>', unsafe_allow_html=True) |
if model in ['SSD300', 'SSD512']: |
df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl factor', 'Pruning Epochs', 'Finetuning Epochs']) |
else: |
df_hyperparams = pd.DataFrame(hyperparameter_data, columns=['Model', 'kl backbone','kl transformer', 'Pruning Epochs', 'Finetuning Epochs']) |
st.markdown(df_hyperparams.style.hide(axis="index").to_html(), unsafe_allow_html=True) |
st.markdown('</div>', unsafe_allow_html=True) |
st.markdown('</div>', unsafe_allow_html=True) |
with col1: |
st.subheader('Results') |
results = results_data[model] |
if model in ['SSD300', 'SSD512']: |
df_results = pd.DataFrame({ |
'Model': results[pruning]['model'], |
'mAP': results[pruning]['map'], |
'FLOPs': results[pruning]['flops'], |
'down by': results[pruning]['flopsd'], |
'Params': results[pruning]['params'], |
'down by ': results[pruning]['paramsd'], |
}) |
else: |
df_results = pd.DataFrame({ |
'Model': results[dataset]['model'], |
'mAP': results[dataset]['map'], |
'FLOPs': results[dataset]['flops'], |
'down by': results[dataset]['flopsd'], |
'Params': results[dataset]['params'], |
'down by ': results[dataset]['paramsd'], |
}) |
st.markdown(df_results.style.hide(axis="index").to_html(), unsafe_allow_html=True) |
st.markdown("<div style='margin-top: 30px;'></div>", unsafe_allow_html=True) |
st.markdown( |
""" |
<h2 id="evolution-graphs" style="margin-bottom: 0px; padding-bottom: 0px;"> |
Evolution Graphs |
</h2> |
""", |
unsafe_allow_html=True |
) |
epochs = [1, 2, 3, 4] |
fig = go.Figure() |
if model in ['SSD300', 'SSD512']: |
ff=pruning |
else: |
ff=dataset |
fig.add_trace(go.Bar( |
x=results[ff]['model'], |
y=results[ff]['flops'], |
name='FLOPs', |
marker_color='orange' |
)) |
fig.add_trace(go.Bar( |
x=results[ff]['model'], |
y=results[ff]['params'], |
name='Params', |
marker_color='green' |
)) |
fig.update_layout( |
barmode='group', |
title='FLOPs and Params per model', |
xaxis_title='Model', |
yaxis_title='Count', |
legend_title='Metric', |
height=300 |
) |
st.plotly_chart(fig, use_container_width=True) |