adarsh commited on
Commit
1fed04b
·
1 Parent(s): a31a9c7

updated training metrics with pd and sns

Browse files
Files changed (1) hide show
  1. pages/page_1.py +60 -2
pages/page_1.py CHANGED
@@ -1,10 +1,68 @@
1
  import streamlit as st
 
 
 
2
 
3
  def main():
4
  st.title("Training Metrics")
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  if __name__ == "__main__":
8
  main()
9
-
10
-
 
1
  import streamlit as st
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+ import pandas as pd
5
 
6
  def main():
7
  st.title("Training Metrics")
8
 
9
 
10
+ st.markdown("#### Sweeps for hyperparameter tuning")
11
+ st.image("assets/img/cogni-bert-12sweeps.png", use_column_width=True )
12
+
13
+ data = [
14
+ {"val_loss": 0.60, "train_loss": 1.24, "val_f1": 0.82},
15
+ {"val_loss": 0.37, "train_loss": 0.44, "val_f1": 0.89},
16
+ {"val_loss": 0.39, "train_loss": 0.23, "val_f1": 0.88},
17
+ {"val_loss": 0.351, "train_loss": 0.13, "val_f1": 0.90},
18
+ {"val_loss": 0.353, "train_loss": 0.071, "val_f1": 0.922},
19
+
20
+
21
+ ]
22
+
23
+ # Convert the list of dictionaries to a Pandas DataFrame
24
+ df = pd.DataFrame(data)
25
+
26
+ df["epoch"] = range(1, len(data)+1 )
27
+
28
+ st.markdown("### Cogni-BERT Best Model Train")
29
+
30
+
31
+ # st.dataframe(data)
32
+
33
+
34
+ col1, col2, col3 = st.columns(3)
35
+
36
+ # Line chart for validation loss with Seaborn
37
+ with col1:
38
+ st.markdown("#### Validation Loss")
39
+ fig, ax = plt.subplots()
40
+ sns.lineplot(x="epoch", y="val_loss", data=df, ax=ax, color='skyblue', marker='o', linestyle='-', linewidth=2, markersize=8)
41
+ plt.xlabel("epoch")
42
+ plt.ylabel("Validation Loss")
43
+ st.pyplot(fig)
44
+
45
+ # Line chart for training loss with Seaborn
46
+ with col2:
47
+ st.markdown("#### Training Loss")
48
+ fig, ax = plt.subplots()
49
+ sns.lineplot(x="epoch", y="train_loss", data=df, ax=ax, color='salmon', marker='s', linestyle='--', linewidth=2, markersize=8)
50
+ plt.xlabel("epoch")
51
+ plt.ylabel("Training Loss")
52
+ st.pyplot(fig)
53
+
54
+ # Line chart for F1 score with Seaborn
55
+ with col3:
56
+ st.markdown("#### Validation F1")
57
+ fig, ax = plt.subplots()
58
+ sns.lineplot(x="epoch", y="val_f1", data=df, ax=ax, color='limegreen', marker='D', linestyle='-.', linewidth=2, markersize=8)
59
+ plt.xlabel("epoch")
60
+ plt.ylabel("F1 Score")
61
+ st.pyplot(fig)
62
+
63
+
64
+ st.markdown('<p style="text-align:center;">Made with ❤️ by <a href="https://www.adarshmaurya.onionreads.com">Adarsh Maurya</a></p>', unsafe_allow_html=True)
65
+
66
+
67
  if __name__ == "__main__":
68
  main()