Ahmedik95316 commited on
Commit
7e70d4f
Β·
1 Parent(s): 6fcb89a

Update model/retrain.py

Browse files

Added line 46 to load the new data and modified the `train_model(df)` method `train_model()` to remove parameters/arguments to the method

Files changed (1) hide show
  1. model/retrain.py +102 -100
model/retrain.py CHANGED
@@ -1,101 +1,103 @@
1
- import pandas as pd
2
- from pathlib import Path
3
- from sklearn.feature_extraction.text import TfidfVectorizer
4
- from sklearn.linear_model import LogisticRegression
5
- from sklearn.model_selection import train_test_split
6
- from sklearn.metrics import accuracy_score
7
- import joblib
8
- import json
9
- import hashlib
10
- import datetime
11
- import shutil
12
-
13
- # Paths
14
- BASE_DIR = Path(__file__).resolve().parent
15
- DATA_DIR = BASE_DIR.parent / "data"
16
- LOGS_DIR = BASE_DIR.parent / "logs"
17
-
18
- COMBINED = DATA_DIR / "combined_dataset.csv"
19
- SCRAPED = DATA_DIR / "scraped_real.csv"
20
- GENERATED = DATA_DIR / "generated_fake.csv"
21
-
22
- PROD_MODEL = BASE_DIR / "model.pkl"
23
- PROD_VECTORIZER = BASE_DIR / "vectorizer.pkl"
24
-
25
- CANDIDATE_MODEL = BASE_DIR / "model_candidate.pkl"
26
- CANDIDATE_VECTORIZER = BASE_DIR / "vectorizer_candidate.pkl"
27
-
28
- METADATA_PATH = BASE_DIR / "metadata.json"
29
-
30
- def hash_file(path: Path):
31
- return hashlib.md5(path.read_bytes()).hexdigest()
32
-
33
- def load_new_data():
34
- dfs = [pd.read_csv(COMBINED)]
35
- if SCRAPED.exists():
36
- dfs.append(pd.read_csv(SCRAPED))
37
- if GENERATED.exists():
38
- dfs.append(pd.read_csv(GENERATED))
39
- df = pd.concat(dfs, ignore_index=True)
40
- df.dropna(subset=["text"], inplace=True)
41
- df = df[df["text"].str.strip() != ""]
42
- return df
43
-
44
- def train_model(df):
45
- X = df["text"]
46
- y = df["label"]
47
- X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)
48
-
49
- vec = TfidfVectorizer(stop_words="english", max_features=5000)
50
- X_train_vec = vec.fit_transform(X_train)
51
- X_test_vec = vec.transform(X_test)
52
-
53
- model = LogisticRegression(max_iter=1000)
54
- model.fit(X_train_vec, y_train)
55
- acc = accuracy_score(y_test, model.predict(X_test_vec))
56
-
57
- return model, vec, acc, len(X_train), len(X_test)
58
-
59
- def load_metadata():
60
- if METADATA_PATH.exists():
61
- with open(METADATA_PATH) as f:
62
- return json.load(f)
63
- return None
64
-
65
- def bump_version(version: str) -> str:
66
- major, minor = map(int, version.replace("v", "").split("."))
67
- return f"v{major}.{minor+1}"
68
-
69
- def main():
70
- print("πŸ”„ Retraining candidate model...")
71
- df = load_new_data()
72
- model, vec, acc, train_size, test_size = train_model(df)
73
-
74
- print(f"πŸ“Š Candidate Accuracy: {acc:.4f}")
75
- joblib.dump(model, CANDIDATE_MODEL)
76
- joblib.dump(vec, CANDIDATE_VECTORIZER)
77
-
78
- metadata = load_metadata()
79
- prod_acc = metadata["test_accuracy"] if metadata else 0
80
- model_version = bump_version(metadata["model_version"]) if metadata else "v1.0"
81
-
82
- if acc > prod_acc:
83
- print("βœ… Candidate outperforms production. Promoting model...")
84
- shutil.copy(CANDIDATE_MODEL, PROD_MODEL)
85
- shutil.copy(CANDIDATE_VECTORIZER, PROD_VECTORIZER)
86
- metadata = {
87
- "model_version": model_version,
88
- "data_version": hash_file(COMBINED),
89
- "train_size": train_size,
90
- "test_size": test_size,
91
- "test_accuracy": round(acc, 4),
92
- "timestamp": datetime.datetime.now().isoformat()
93
- }
94
- with open(METADATA_PATH, "w") as f:
95
- json.dump(metadata, f, indent=2)
96
- print(f"🟒 Model promoted. Version: {model_version}")
97
- else:
98
- print("🟑 Candidate did not outperform production. Keeping existing model.")
99
-
100
- if __name__ == "__main__":
 
 
101
  main()
 
1
+ import pandas as pd
2
+ from pathlib import Path
3
+ from sklearn.feature_extraction.text import TfidfVectorizer
4
+ from sklearn.linear_model import LogisticRegression
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import accuracy_score
7
+ import joblib
8
+ import json
9
+ import hashlib
10
+ import datetime
11
+ import shutil
12
+
13
+ # Paths
14
+ BASE_DIR = Path(__file__).resolve().parent
15
+ DATA_DIR = BASE_DIR.parent / "data"
16
+ LOGS_DIR = BASE_DIR.parent / "logs"
17
+
18
+ COMBINED = DATA_DIR / "combined_dataset.csv"
19
+ SCRAPED = DATA_DIR / "scraped_real.csv"
20
+ GENERATED = DATA_DIR / "generated_fake.csv"
21
+
22
+ PROD_MODEL = BASE_DIR / "model.pkl"
23
+ PROD_VECTORIZER = BASE_DIR / "vectorizer.pkl"
24
+
25
+ CANDIDATE_MODEL = BASE_DIR / "model_candidate.pkl"
26
+ CANDIDATE_VECTORIZER = BASE_DIR / "vectorizer_candidate.pkl"
27
+
28
+ METADATA_PATH = BASE_DIR / "metadata.json"
29
+
30
+ def hash_file(path: Path):
31
+ return hashlib.md5(path.read_bytes()).hexdigest()
32
+
33
+ def load_new_data():
34
+ dfs = [pd.read_csv(COMBINED)]
35
+ if SCRAPED.exists():
36
+ dfs.append(pd.read_csv(SCRAPED))
37
+ if GENERATED.exists():
38
+ dfs.append(pd.read_csv(GENERATED))
39
+ df = pd.concat(dfs, ignore_index=True)
40
+ df.dropna(subset=["text"], inplace=True)
41
+ df = df[df["text"].str.strip() != ""]
42
+ return df
43
+
44
+ def train_model():
45
+ # Load the new data
46
+ df = load_new_data()
47
+ X = df["text"]
48
+ y = df["label"]
49
+ X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)
50
+
51
+ vec = TfidfVectorizer(stop_words="english", max_features=5000)
52
+ X_train_vec = vec.fit_transform(X_train)
53
+ X_test_vec = vec.transform(X_test)
54
+
55
+ model = LogisticRegression(max_iter=1000)
56
+ model.fit(X_train_vec, y_train)
57
+ acc = accuracy_score(y_test, model.predict(X_test_vec))
58
+
59
+ return model, vec, acc, len(X_train), len(X_test)
60
+
61
+ def load_metadata():
62
+ if METADATA_PATH.exists():
63
+ with open(METADATA_PATH) as f:
64
+ return json.load(f)
65
+ return None
66
+
67
+ def bump_version(version: str) -> str:
68
+ major, minor = map(int, version.replace("v", "").split("."))
69
+ return f"v{major}.{minor+1}"
70
+
71
+ def main():
72
+ print("πŸ”„ Retraining candidate model...")
73
+ df = load_new_data()
74
+ model, vec, acc, train_size, test_size = train_model(df)
75
+
76
+ print(f"πŸ“Š Candidate Accuracy: {acc:.4f}")
77
+ joblib.dump(model, CANDIDATE_MODEL)
78
+ joblib.dump(vec, CANDIDATE_VECTORIZER)
79
+
80
+ metadata = load_metadata()
81
+ prod_acc = metadata["test_accuracy"] if metadata else 0
82
+ model_version = bump_version(metadata["model_version"]) if metadata else "v1.0"
83
+
84
+ if acc > prod_acc:
85
+ print("βœ… Candidate outperforms production. Promoting model...")
86
+ shutil.copy(CANDIDATE_MODEL, PROD_MODEL)
87
+ shutil.copy(CANDIDATE_VECTORIZER, PROD_VECTORIZER)
88
+ metadata = {
89
+ "model_version": model_version,
90
+ "data_version": hash_file(COMBINED),
91
+ "train_size": train_size,
92
+ "test_size": test_size,
93
+ "test_accuracy": round(acc, 4),
94
+ "timestamp": datetime.datetime.now().isoformat()
95
+ }
96
+ with open(METADATA_PATH, "w") as f:
97
+ json.dump(metadata, f, indent=2)
98
+ print(f"🟒 Model promoted. Version: {model_version}")
99
+ else:
100
+ print("🟑 Candidate did not outperform production. Keeping existing model.")
101
+
102
+ if __name__ == "__main__":
103
  main()