Ahmedik95316 commited on
Commit
2179021
Β·
1 Parent(s): a12e15f

Update app/streamlit_app.py

Browse files
Files changed (1) hide show
  1. app/streamlit_app.py +136 -136
app/streamlit_app.py CHANGED
@@ -1,136 +1,136 @@
1
- # app/streamlit_app.py
2
-
3
- import streamlit as st
4
- import requests
5
- import json
6
- import pandas as pd
7
- import altair as alt
8
- import time
9
- import subprocess
10
- import sys
11
- from pathlib import Path
12
-
13
- # Add root to sys.path for imports if needed
14
- sys.path.append(str(Path(__file__).resolve().parent.parent))
15
-
16
- # ---- Constants ----
17
- # API_URL = "http://127.0.0.1:8000/predict"
18
- API_URL = "http://localhost:8000/predict""
19
- CUSTOM_DATA_PATH = Path(__file__).parent.parent / "data" / "custom_upload.csv"
20
- METADATA_PATH = Path(__file__).parent.parent / "model" / "metadata.json"
21
- ACTIVITY_LOG_PATH = Path(__file__).parent.parent / "logs" / "activity_log.json"
22
- DRIFT_LOG_PATH = Path(__file__).parent.parent / "logs" / "monitoring_log.json"
23
-
24
- # ---- Streamlit UI ----
25
- st.set_page_config(page_title="Fake News Detector", layout="centered")
26
- st.title("πŸ“° Fake News Detector")
27
- st.markdown("Enter a news article's headline or content to predict if it's **Fake** or **Real**.")
28
-
29
- # ---- Prediction Form ----
30
- with st.form(key="predict_form"):
31
- user_input = st.text_area("News Text", height=150)
32
- submit = st.form_submit_button("🧠 Predict")
33
-
34
- if submit:
35
- if not user_input.strip():
36
- st.warning("Please enter some text.")
37
- else:
38
- try:
39
- response = requests.post(API_URL, json={"text": user_input})
40
- if response.status_code == 200:
41
- result = response.json()
42
- pred = result["prediction"]
43
- prob = result["confidence"]
44
- st.success(f"🧾 Prediction: **{pred}**")
45
- st.info(f"πŸ“ˆ Confidence: {prob * 100:.2f}%")
46
- else:
47
- st.error(f"API Error: {response.status_code}")
48
- except Exception as e:
49
- st.error(f"❌ Failed to connect to FastAPI: {e}")
50
-
51
- # ---- Upload + Train ----
52
- st.header("πŸ“€ Train with Your Own CSV")
53
-
54
- with st.expander("Upload CSV to Retrain Model (columns: `text`, `label`)"):
55
- uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
56
- if uploaded_file:
57
- try:
58
- df_custom = pd.read_csv(uploaded_file)
59
- if "text" not in df_custom.columns or "label" not in df_custom.columns:
60
- st.error("CSV must contain 'text' and 'label' columns.")
61
- else:
62
- st.success("βœ… File looks good. Starting training...")
63
-
64
- # Save CSV
65
- df_custom.to_csv(CUSTOM_DATA_PATH, index=False)
66
-
67
- # Progress bar animation
68
- progress_bar = st.progress(0)
69
- status_text = st.empty()
70
- for percent in range(0, 101, 10):
71
- progress_bar.progress(percent)
72
- status_text.text(f"Training Progress: {percent}%")
73
- time.sleep(0.2)
74
-
75
- # Trigger training subprocess
76
- result = subprocess.run(
77
- [sys.executable, "model/train.py", "--data_path", str(CUSTOM_DATA_PATH), "--output_path", "model/custom_model.pt"],
78
- capture_output=True, text=True
79
- )
80
-
81
- if result.returncode == 0:
82
- acc = float(result.stdout.strip())
83
- new_version = "custom_" + time.strftime("%H%M%S")
84
- metadata = {
85
- "model_version": new_version,
86
- "test_accuracy": round(acc, 4),
87
- "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S")
88
- }
89
- with open(METADATA_PATH, "w") as f:
90
- json.dump(metadata, f, indent=2)
91
- status_text.text("πŸŽ‰ Training complete!")
92
- st.success(f"New model trained with accuracy: {acc:.4f}")
93
- else:
94
- st.error("Training failed.")
95
- st.text(result.stderr)
96
- except Exception as e:
97
- st.error(f"Error reading file: {e}")
98
-
99
- # ---- Sidebar Info ----
100
- st.sidebar.header("πŸ“Š Model Info")
101
- if METADATA_PATH.exists():
102
- with open(METADATA_PATH) as f:
103
- meta = json.load(f)
104
- st.sidebar.markdown(f"**Version**: `{meta['model_version']}`")
105
- st.sidebar.markdown(f"**Accuracy**: `{meta['test_accuracy']}`")
106
- st.sidebar.markdown(f"**Updated**: `{meta['timestamp'].split('T')[0]}`")
107
- else:
108
- st.sidebar.warning("No metadata found.")
109
-
110
- # ---- Activity Log ----
111
- st.sidebar.header("πŸ“œ Activity Log")
112
- if ACTIVITY_LOG_PATH.exists():
113
- with open(ACTIVITY_LOG_PATH) as f:
114
- activity_log = json.load(f)
115
- for entry in reversed(activity_log[-5:]):
116
- st.sidebar.text(f"{entry['timestamp']} - {entry['event']}")
117
- else:
118
- st.sidebar.info("No recent logs found.")
119
-
120
- # ---- Drift Chart ----
121
- st.sidebar.header("πŸ“‰ Drift Monitoring")
122
- if DRIFT_LOG_PATH.exists():
123
- drift_df = pd.read_json(DRIFT_LOG_PATH)
124
- drift_df["timestamp"] = pd.to_datetime(drift_df["timestamp"])
125
- drift_df["status"] = drift_df["drift_detected"].map({True: "Drift", False: "Stable"})
126
-
127
- chart = alt.Chart(drift_df).mark_line(point=True).encode(
128
- x="timestamp:T",
129
- y=alt.Y("test_accuracy:Q", title="Test Accuracy"),
130
- color="status:N",
131
- tooltip=["timestamp", "test_accuracy", "status"]
132
- ).properties(title="Model Performance & Drift", height=250)
133
-
134
- st.sidebar.altair_chart(chart, use_container_width=True)
135
- else:
136
- st.sidebar.info("No drift data available.")
 
1
+ # app/streamlit_app.py
2
+
3
+ import streamlit as st
4
+ import requests
5
+ import json
6
+ import pandas as pd
7
+ import altair as alt
8
+ import time
9
+ import subprocess
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ # Add root to sys.path for imports if needed
14
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
15
+
16
+ # ---- Constants ----
17
+ # API_URL = "http://127.0.0.1:8000/predict"
18
+ API_URL = "http://localhost:8000/predict"
19
+ CUSTOM_DATA_PATH = Path(__file__).parent.parent / "data" / "custom_upload.csv"
20
+ METADATA_PATH = Path(__file__).parent.parent / "model" / "metadata.json"
21
+ ACTIVITY_LOG_PATH = Path(__file__).parent.parent / "logs" / "activity_log.json"
22
+ DRIFT_LOG_PATH = Path(__file__).parent.parent / "logs" / "monitoring_log.json"
23
+
24
+ # ---- Streamlit UI ----
25
+ st.set_page_config(page_title="Fake News Detector", layout="centered")
26
+ st.title("πŸ“° Fake News Detector")
27
+ st.markdown("Enter a news article's headline or content to predict if it's **Fake** or **Real**.")
28
+
29
+ # ---- Prediction Form ----
30
+ with st.form(key="predict_form"):
31
+ user_input = st.text_area("News Text", height=150)
32
+ submit = st.form_submit_button("🧠 Predict")
33
+
34
+ if submit:
35
+ if not user_input.strip():
36
+ st.warning("Please enter some text.")
37
+ else:
38
+ try:
39
+ response = requests.post(API_URL, json={"text": user_input})
40
+ if response.status_code == 200:
41
+ result = response.json()
42
+ pred = result["prediction"]
43
+ prob = result["confidence"]
44
+ st.success(f"🧾 Prediction: **{pred}**")
45
+ st.info(f"πŸ“ˆ Confidence: {prob * 100:.2f}%")
46
+ else:
47
+ st.error(f"API Error: {response.status_code}")
48
+ except Exception as e:
49
+ st.error(f"❌ Failed to connect to FastAPI: {e}")
50
+
51
+ # ---- Upload + Train ----
52
+ st.header("πŸ“€ Train with Your Own CSV")
53
+
54
+ with st.expander("Upload CSV to Retrain Model (columns: `text`, `label`)"):
55
+ uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
56
+ if uploaded_file:
57
+ try:
58
+ df_custom = pd.read_csv(uploaded_file)
59
+ if "text" not in df_custom.columns or "label" not in df_custom.columns:
60
+ st.error("CSV must contain 'text' and 'label' columns.")
61
+ else:
62
+ st.success("βœ… File looks good. Starting training...")
63
+
64
+ # Save CSV
65
+ df_custom.to_csv(CUSTOM_DATA_PATH, index=False)
66
+
67
+ # Progress bar animation
68
+ progress_bar = st.progress(0)
69
+ status_text = st.empty()
70
+ for percent in range(0, 101, 10):
71
+ progress_bar.progress(percent)
72
+ status_text.text(f"Training Progress: {percent}%")
73
+ time.sleep(0.2)
74
+
75
+ # Trigger training subprocess
76
+ result = subprocess.run(
77
+ [sys.executable, "model/train.py", "--data_path", str(CUSTOM_DATA_PATH), "--output_path", "model/custom_model.pt"],
78
+ capture_output=True, text=True
79
+ )
80
+
81
+ if result.returncode == 0:
82
+ acc = float(result.stdout.strip())
83
+ new_version = "custom_" + time.strftime("%H%M%S")
84
+ metadata = {
85
+ "model_version": new_version,
86
+ "test_accuracy": round(acc, 4),
87
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S")
88
+ }
89
+ with open(METADATA_PATH, "w") as f:
90
+ json.dump(metadata, f, indent=2)
91
+ status_text.text("πŸŽ‰ Training complete!")
92
+ st.success(f"New model trained with accuracy: {acc:.4f}")
93
+ else:
94
+ st.error("Training failed.")
95
+ st.text(result.stderr)
96
+ except Exception as e:
97
+ st.error(f"Error reading file: {e}")
98
+
99
+ # ---- Sidebar Info ----
100
+ st.sidebar.header("πŸ“Š Model Info")
101
+ if METADATA_PATH.exists():
102
+ with open(METADATA_PATH) as f:
103
+ meta = json.load(f)
104
+ st.sidebar.markdown(f"**Version**: `{meta['model_version']}`")
105
+ st.sidebar.markdown(f"**Accuracy**: `{meta['test_accuracy']}`")
106
+ st.sidebar.markdown(f"**Updated**: `{meta['timestamp'].split('T')[0]}`")
107
+ else:
108
+ st.sidebar.warning("No metadata found.")
109
+
110
+ # ---- Activity Log ----
111
+ st.sidebar.header("πŸ“œ Activity Log")
112
+ if ACTIVITY_LOG_PATH.exists():
113
+ with open(ACTIVITY_LOG_PATH) as f:
114
+ activity_log = json.load(f)
115
+ for entry in reversed(activity_log[-5:]):
116
+ st.sidebar.text(f"{entry['timestamp']} - {entry['event']}")
117
+ else:
118
+ st.sidebar.info("No recent logs found.")
119
+
120
+ # ---- Drift Chart ----
121
+ st.sidebar.header("πŸ“‰ Drift Monitoring")
122
+ if DRIFT_LOG_PATH.exists():
123
+ drift_df = pd.read_json(DRIFT_LOG_PATH)
124
+ drift_df["timestamp"] = pd.to_datetime(drift_df["timestamp"])
125
+ drift_df["status"] = drift_df["drift_detected"].map({True: "Drift", False: "Stable"})
126
+
127
+ chart = alt.Chart(drift_df).mark_line(point=True).encode(
128
+ x="timestamp:T",
129
+ y=alt.Y("test_accuracy:Q", title="Test Accuracy"),
130
+ color="status:N",
131
+ tooltip=["timestamp", "test_accuracy", "status"]
132
+ ).properties(title="Model Performance & Drift", height=250)
133
+
134
+ st.sidebar.altair_chart(chart, use_container_width=True)
135
+ else:
136
+ st.sidebar.info("No drift data available.")