peterkros commited on
Commit
ef81b49
·
verified ·
1 Parent(s): 01c16dd

Upload vaccine_stockout_predictor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vaccine_stockout_predictor.py +218 -0
vaccine_stockout_predictor.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vaccine Stock-Out Prediction Pipeline
2
+
3
+ import pandas as pd
4
+ import joblib
5
+ import os
6
+ from typing import Dict, Tuple, Optional
7
+
8
+ class VaccineStockoutPredictor:
9
+ """
10
+ A comprehensive vaccine stock-out prediction system for multiple vaccines across different countries.
11
+
12
+ Supports 8 different vaccine types:
13
+ - BCG (Bacille Calmette-Guérin)
14
+ - HepB (Hepatitis B)
15
+ - bOPV (bivalent Oral Polio Vaccine)
16
+ - Penta (Pentavalent)
17
+ - PCV (Pneumococcal Conjugate Vaccine)
18
+ - Rota (Rotavirus)
19
+ - IPV (Inactivated Polio Vaccine)
20
+ - TT/Td/DT (Tetanus Toxoid/Tetanus-Diphtheria)
21
+ """
22
+
23
+ def __init__(self):
24
+ """Initialize the predictor with all trained models."""
25
+ self.models = {}
26
+ self.reference_data = None
27
+ self._load_models()
28
+ self._load_reference_data()
29
+
30
+ def _load_models(self):
31
+ """Load all trained models from the models directory."""
32
+ model_mapping = {
33
+ 'BCG_model.joblib': 'BCG',
34
+ 'HepB_model.joblib': 'HepB',
35
+ 'bOPV_model.joblib': 'bOPV',
36
+ 'Penta_model.joblib': 'Penta',
37
+ 'PCV_model.joblib': 'PCV',
38
+ 'Rota_model.joblib': 'Rota',
39
+ 'IPV_model.joblib': 'IPV',
40
+ 'TT_Td_DT_model.joblib': 'TT/Td/DT'
41
+ }
42
+
43
+ for filename, vaccine_name in model_mapping.items():
44
+ model_path = os.path.join('models', filename)
45
+ if os.path.exists(model_path):
46
+ self.models[vaccine_name] = joblib.load(model_path)
47
+ print(f"Loaded model for {vaccine_name}")
48
+ else:
49
+ print(f"Warning: Model file {filename} not found")
50
+
51
+ def _load_reference_data(self):
52
+ """Load reference data for store information."""
53
+ try:
54
+ self.reference_data = pd.read_csv('reference_data.csv')
55
+ print(f"Loaded reference data with {len(self.reference_data)} stores")
56
+ except Exception as e:
57
+ print(f"Error loading reference data: {e}")
58
+
59
+ def get_available_vaccines(self) -> list:
60
+ """Return list of available vaccine types."""
61
+ return list(self.models.keys())
62
+
63
+ def get_available_countries(self) -> list:
64
+ """Return list of available countries."""
65
+ if self.reference_data is not None:
66
+ return sorted(self.reference_data['CountryName'].unique().tolist())
67
+ return []
68
+
69
+ def get_available_stores(self, country: str, sc_level: str = None) -> list:
70
+ """Return list of available stores for a country."""
71
+ if self.reference_data is not None:
72
+ mask = self.reference_data['CountryName'] == country
73
+ if sc_level:
74
+ mask &= self.reference_data['SCLevel'] == sc_level
75
+ return sorted(self.reference_data[mask]['StoreName'].unique().tolist())
76
+ return []
77
+
78
+ def predict_stockout_risk(self,
79
+ country_name: str,
80
+ sc_level: str,
81
+ store_name: str,
82
+ vaccine_type: str,
83
+ current_stock: int) -> Dict:
84
+ """
85
+ Predict stock-out risk for a specific vaccine at a specific store.
86
+
87
+ Args:
88
+ country_name: Name of the country
89
+ sc_level: Supply chain level (Central, Subnational, LD)
90
+ store_name: Name of the store
91
+ vaccine_type: Type of vaccine (BCG, HepB, bOPV, Penta, PCV, Rota, IPV, TT/Td/DT)
92
+ current_stock: Current stock level
93
+
94
+ Returns:
95
+ Dictionary containing prediction results
96
+ """
97
+
98
+ # Validate inputs
99
+ if vaccine_type not in self.models:
100
+ return {
101
+ 'error': f"Vaccine type '{vaccine_type}' not supported. Available types: {list(self.models.keys())}"
102
+ }
103
+
104
+ if self.reference_data is None:
105
+ return {'error': 'Reference data not loaded'}
106
+
107
+ # Find store information
108
+ store_info = self.reference_data[
109
+ (self.reference_data['CountryName'] == country_name) &
110
+ (self.reference_data['SCLevel'] == sc_level) &
111
+ (self.reference_data['StoreName'] == store_name)
112
+ ]
113
+
114
+ if len(store_info) == 0:
115
+ return {
116
+ 'error': f"Store '{store_name}' not found in {country_name} at {sc_level} level"
117
+ }
118
+
119
+ store_info = store_info.iloc[0]
120
+
121
+ # Prepare input features
122
+ vaccine_min_col = f'{vaccine_type}_Min'
123
+ vaccine_max_col = f'{vaccine_type}_Max'
124
+
125
+ if vaccine_min_col not in store_info or vaccine_max_col not in store_info:
126
+ return {
127
+ 'error': f"Min/Max data not available for {vaccine_type} at this store"
128
+ }
129
+
130
+ # Calculate utilization
131
+ min_stock = store_info[vaccine_min_col]
132
+ max_stock = store_info[vaccine_max_col]
133
+
134
+ if max_stock <= min_stock:
135
+ utilization = 0.5 # Default value if range is invalid
136
+ else:
137
+ utilization = current_stock / (max_stock - min_stock + 1)
138
+
139
+ # Create input data
140
+ input_data = {
141
+ 'CountryName': country_name,
142
+ 'SCLevel': sc_level,
143
+ 'StoreName': store_name,
144
+ 'Population': store_info['Population'],
145
+ 'DistanceToParent': store_info['DistanceToParent'],
146
+ 'Latitude': store_info['Latitude'],
147
+ 'Longitude': store_info['Longitude'],
148
+ 'Average_Utilization': store_info['Average_Utilization'],
149
+ f'{vaccine_type}_Utilization': utilization
150
+ }
151
+
152
+ input_df = pd.DataFrame([input_data])
153
+
154
+ # Make prediction
155
+ model = self.models[vaccine_type]
156
+ prediction = model.predict(input_df)[0]
157
+ probability = model.predict_proba(input_df)[0][1] if hasattr(model, 'predict_proba') else None
158
+
159
+ return {
160
+ 'vaccine_type': vaccine_type,
161
+ 'country': country_name,
162
+ 'sc_level': sc_level,
163
+ 'store': store_name,
164
+ 'current_stock': current_stock,
165
+ 'min_stock': min_stock,
166
+ 'max_stock': max_stock,
167
+ 'utilization': utilization,
168
+ 'stockout_risk': int(prediction),
169
+ 'risk_probability': float(probability) if probability is not None else None,
170
+ 'risk_level': 'High' if prediction == 1 else 'Low',
171
+ 'recommendation': self._get_recommendation(prediction, probability, current_stock, min_stock, max_stock)
172
+ }
173
+
174
+ def _get_recommendation(self, prediction: int, probability: float,
175
+ current_stock: int, min_stock: int, max_stock: int) -> str:
176
+ """Generate recommendation based on prediction results."""
177
+ if prediction == 1: # High risk
178
+ if current_stock <= min_stock:
179
+ return "URGENT: Stock level below minimum. Immediate restocking required."
180
+ else:
181
+ return "High risk of stock-out. Consider restocking soon."
182
+ else: # Low risk
183
+ if current_stock >= max_stock:
184
+ return "Stock level above maximum. Consider redistribution."
185
+ else:
186
+ return "Stock level adequate. Monitor regularly."
187
+
188
+ def batch_predict(self, predictions_list: list) -> list:
189
+ """
190
+ Perform batch predictions for multiple stores/vaccines.
191
+
192
+ Args:
193
+ predictions_list: List of dictionaries with prediction parameters
194
+
195
+ Returns:
196
+ List of prediction results
197
+ """
198
+ results = []
199
+ for pred_params in predictions_list:
200
+ result = self.predict_stockout_risk(**pred_params)
201
+ results.append(result)
202
+ return results
203
+
204
+
205
+ # Example usage
206
+ if __name__ == "__main__":
207
+ predictor = VaccineStockoutPredictor()
208
+
209
+ # Example prediction
210
+ result = predictor.predict_stockout_risk(
211
+ country_name="Afghanistan",
212
+ sc_level="Subnational",
213
+ store_name="Kabul",
214
+ vaccine_type="BCG",
215
+ current_stock=50000
216
+ )
217
+
218
+ print(json.dumps(result, indent=2))