|
""" |
|
Enhanced TIPM Core with Real Data Integration |
|
=========================================== |
|
|
|
This module extends the TIPM core model to work with real-world datasets |
|
from authoritative sources like UN Comtrade, World Bank, OECD, etc. |
|
""" |
|
|
|
import pandas as pd |
|
import numpy as np |
|
from datetime import datetime, timedelta |
|
from typing import Dict, List, Optional, Tuple, Any |
|
import logging |
|
from dataclasses import dataclass |
|
|
|
from .core import TIPMModel, TariffShock, TIPMPrediction |
|
from .config.settings import TIPMConfig |
|
from .data_connectors import DataIntegrationManager |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class RealDataConfig: |
|
"""Configuration for real data integration""" |
|
|
|
data_cache_dir: str = "real_data_cache" |
|
max_cache_age_hours: int = 24 |
|
fallback_to_synthetic: bool = True |
|
min_data_points: int = 10 |
|
|
|
|
|
primary_trade_source: str = "comtrade" |
|
primary_economic_source: str = "worldbank" |
|
primary_sentiment_source: str = "gdelt" |
|
|
|
|
|
min_trade_coverage: float = 0.7 |
|
min_temporal_coverage: float = 0.8 |
|
max_missing_data_ratio: float = 0.3 |
|
|
|
|
|
class RealDataTIPMModel(TIPMModel): |
|
""" |
|
Enhanced TIPM model that integrates real-world datasets |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config: Optional[TIPMConfig] = None, |
|
real_data_config: Optional[RealDataConfig] = None, |
|
): |
|
super().__init__(config) |
|
|
|
self.real_data_config = real_data_config or RealDataConfig() |
|
self.data_manager = DataIntegrationManager(self.real_data_config.data_cache_dir) |
|
|
|
|
|
self.real_datasets = {} |
|
self.data_quality_metrics = {} |
|
self.last_data_update = None |
|
|
|
logger.info("Initialized RealDataTIPMModel with real data connectors") |
|
|
|
def load_real_data( |
|
self, |
|
countries: List[str], |
|
hs_codes: List[str], |
|
years: List[int], |
|
force_refresh: bool = False, |
|
) -> Dict[str, pd.DataFrame]: |
|
""" |
|
Load real datasets for the specified parameters |
|
|
|
Args: |
|
countries: List of country codes (ISO3 or UN codes) |
|
hs_codes: List of HS product codes |
|
years: List of years to analyze |
|
force_refresh: Force refresh of cached data |
|
""" |
|
|
|
logger.info( |
|
f"Loading real data for {len(countries)} countries, {len(hs_codes)} products, {len(years)} years" |
|
) |
|
|
|
|
|
if ( |
|
force_refresh |
|
or self.last_data_update is None |
|
or datetime.now() - self.last_data_update |
|
> timedelta(hours=self.real_data_config.max_cache_age_hours) |
|
): |
|
|
|
logger.info("Fetching fresh data from external sources...") |
|
|
|
|
|
raw_datasets = self.data_manager.fetch_comprehensive_dataset( |
|
countries=countries, |
|
hs_codes=hs_codes, |
|
years=years, |
|
tariff_query=f"tariff {' '.join(hs_codes)} trade", |
|
) |
|
|
|
|
|
self.real_datasets = self._process_raw_datasets( |
|
raw_datasets, countries, hs_codes, years |
|
) |
|
self.data_quality_metrics = self._assess_data_quality(self.real_datasets) |
|
self.last_data_update = datetime.now() |
|
|
|
logger.info("Real data loading completed") |
|
|
|
return self.real_datasets |
|
|
|
def _process_raw_datasets( |
|
self, |
|
raw_datasets: Dict[str, pd.DataFrame], |
|
countries: List[str], |
|
hs_codes: List[str], |
|
years: List[int], |
|
) -> Dict[str, pd.DataFrame]: |
|
"""Process raw datasets into TIPM-compatible format""" |
|
|
|
processed = {} |
|
|
|
|
|
if "trade_flows" in raw_datasets and not raw_datasets["trade_flows"].empty: |
|
processed["trade_flows"] = self._process_trade_flows( |
|
raw_datasets["trade_flows"], countries, hs_codes, years |
|
) |
|
|
|
|
|
if "tariff_rates" in raw_datasets and not raw_datasets["tariff_rates"].empty: |
|
processed["tariff_shocks"] = self._process_tariff_data( |
|
raw_datasets["tariff_rates"], countries, hs_codes, years |
|
) |
|
|
|
|
|
if ( |
|
"economic_indicators" in raw_datasets |
|
and not raw_datasets["economic_indicators"].empty |
|
): |
|
processed.update( |
|
self._process_economic_indicators( |
|
raw_datasets["economic_indicators"], countries, years |
|
) |
|
) |
|
|
|
|
|
if ( |
|
"news_sentiment" in raw_datasets |
|
and not raw_datasets["news_sentiment"].empty |
|
): |
|
processed["geopolitical_events"] = self._process_sentiment_data( |
|
raw_datasets["news_sentiment"], countries, years |
|
) |
|
|
|
|
|
if ( |
|
"political_events" in raw_datasets |
|
and not raw_datasets["political_events"].empty |
|
): |
|
processed["political_stability"] = self._process_political_events( |
|
raw_datasets["political_events"], countries, years |
|
) |
|
|
|
return processed |
|
|
|
def _process_trade_flows( |
|
self, |
|
df: pd.DataFrame, |
|
countries: List[str], |
|
hs_codes: List[str], |
|
years: List[int], |
|
) -> pd.DataFrame: |
|
"""Process UN Comtrade trade flow data""" |
|
|
|
if df.empty: |
|
return pd.DataFrame() |
|
|
|
try: |
|
|
|
processed = df.copy() |
|
|
|
|
|
column_mapping = { |
|
"reporterCode": "origin_country", |
|
"partnerCode": "destination_country", |
|
"cmdCode": "hs_code", |
|
"period": "year", |
|
"primaryValue": "trade_value", |
|
"netWeight": "trade_volume", |
|
"tradeValue": "trade_value", |
|
} |
|
|
|
for old_col, new_col in column_mapping.items(): |
|
if old_col in processed.columns: |
|
processed[new_col] = processed[old_col] |
|
|
|
|
|
required_cols = [ |
|
"origin_country", |
|
"destination_country", |
|
"hs_code", |
|
"year", |
|
"trade_value", |
|
] |
|
for col in required_cols: |
|
if col not in processed.columns: |
|
processed[col] = 0 |
|
|
|
|
|
processed["transport_cost"] = ( |
|
processed["trade_value"] * 0.05 |
|
) |
|
processed["lead_time"] = 30 |
|
|
|
|
|
if "origin_country" in processed.columns and "hs_code" in processed.columns: |
|
processed = processed[ |
|
( |
|
processed["origin_country"] |
|
.astype(str) |
|
.isin([str(c) for c in countries]) |
|
) |
|
& ( |
|
processed["hs_code"] |
|
.astype(str) |
|
.isin([str(h) for h in hs_codes]) |
|
) |
|
& (processed["year"].isin(years)) |
|
] |
|
|
|
|
|
processed = processed.dropna(subset=["trade_value"]) |
|
processed["trade_value"] = pd.to_numeric( |
|
processed["trade_value"], errors="coerce" |
|
).fillna(0) |
|
|
|
logger.info(f"Processed {len(processed)} trade flow records") |
|
return processed[required_cols + ["transport_cost", "lead_time"]] |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing trade flows: {e}") |
|
return pd.DataFrame() |
|
|
|
def _process_tariff_data( |
|
self, |
|
df: pd.DataFrame, |
|
countries: List[str], |
|
hs_codes: List[str], |
|
years: List[int], |
|
) -> pd.DataFrame: |
|
"""Process tariff rate data from WITS""" |
|
|
|
if df.empty: |
|
return pd.DataFrame() |
|
|
|
try: |
|
processed = df.copy() |
|
|
|
|
|
processed["policy_text"] = processed.apply( |
|
lambda row: f"Tariff rate of {row.get('tariff_rate', 0):.1%} imposed on imports from {row.get('country', 'unknown')} for product {row.get('product_code', 'unknown')}", |
|
axis=1, |
|
) |
|
|
|
|
|
processed["effective_date"] = processed["year"].astype(str) + "-01-01" |
|
|
|
|
|
processed["hs_codes"] = processed.get("product_code", "").astype(str) |
|
processed["tariff_rates"] = processed.get("tariff_rate", 0.15) |
|
processed["countries"] = ( |
|
processed.get("country", "") + ",US" |
|
) |
|
|
|
required_cols = [ |
|
"policy_text", |
|
"effective_date", |
|
"hs_codes", |
|
"tariff_rates", |
|
"countries", |
|
] |
|
|
|
logger.info(f"Processed {len(processed)} tariff policy records") |
|
return processed[required_cols] |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing tariff data: {e}") |
|
return pd.DataFrame() |
|
|
|
def _process_economic_indicators( |
|
self, df: pd.DataFrame, countries: List[str], years: List[int] |
|
) -> Dict[str, pd.DataFrame]: |
|
"""Process World Bank economic indicators""" |
|
|
|
if df.empty: |
|
return {} |
|
|
|
try: |
|
processed_datasets = {} |
|
|
|
|
|
for indicator_code in df["indicator_code"].unique(): |
|
indicator_data = df[df["indicator_code"] == indicator_code].copy() |
|
|
|
if indicator_code == "FP.CPI.TOTL": |
|
consumer_data = [] |
|
for _, row in indicator_data.iterrows(): |
|
consumer_data.append( |
|
{ |
|
"product_category": "general_goods", |
|
"price_change": float(row.get("value", 0)) |
|
/ 100, |
|
"country": row.get("countryiso3code", "unknown"), |
|
"year": row.get("date", 2023), |
|
} |
|
) |
|
|
|
if consumer_data: |
|
processed_datasets["consumer_impacts"] = pd.DataFrame( |
|
consumer_data |
|
) |
|
|
|
elif indicator_code == "SL.UEM.TOTL.ZS": |
|
firm_data = [] |
|
for _, row in indicator_data.iterrows(): |
|
unemployment_rate = float(row.get("value", 5)) / 100 |
|
firm_data.append( |
|
{ |
|
"firm_id": f"aggregate_{row.get('countryiso3code', 'unknown')}", |
|
"response_metric": unemployment_rate, |
|
"country": row.get("countryiso3code", "unknown"), |
|
"year": row.get("date", 2023), |
|
} |
|
) |
|
|
|
if firm_data: |
|
processed_datasets["firm_responses"] = pd.DataFrame(firm_data) |
|
|
|
elif indicator_code == "NY.GDP.MKTP.CD": |
|
industry_data = [] |
|
for _, row in indicator_data.iterrows(): |
|
gdp_growth = float(row.get("value", 0)) / 1e12 |
|
industry_data.append( |
|
{ |
|
"industry_code": "aggregate", |
|
"response_metric": gdp_growth, |
|
"country": row.get("countryiso3code", "unknown"), |
|
"year": row.get("date", 2023), |
|
} |
|
) |
|
|
|
if industry_data: |
|
processed_datasets["industry_responses"] = pd.DataFrame( |
|
industry_data |
|
) |
|
|
|
logger.info( |
|
f"Processed economic indicators into {len(processed_datasets)} datasets" |
|
) |
|
return processed_datasets |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing economic indicators: {e}") |
|
return {} |
|
|
|
def _process_sentiment_data( |
|
self, df: pd.DataFrame, countries: List[str], years: List[int] |
|
) -> pd.DataFrame: |
|
"""Process GDELT news sentiment data""" |
|
|
|
if df.empty: |
|
return pd.DataFrame() |
|
|
|
try: |
|
processed = df.copy() |
|
|
|
|
|
processed["date"] = pd.to_datetime(processed["date"], errors="coerce") |
|
processed["year"] = processed["date"].dt.year |
|
processed["month"] = processed["date"].dt.month |
|
|
|
|
|
sentiment_agg = ( |
|
processed.groupby(["country", "year"]) |
|
.agg( |
|
{ |
|
"tone": "mean", |
|
"social_image_shares": "sum", |
|
"title": lambda x: " ".join(x[:5]), |
|
} |
|
) |
|
.reset_index() |
|
) |
|
|
|
|
|
geo_events = [] |
|
for _, row in sentiment_agg.iterrows(): |
|
geo_events.append( |
|
{ |
|
"event_text": f"News coverage: {row['title'][:100]}...", |
|
"sentiment": row["tone"] / 100, |
|
"country": row["country"], |
|
"year": row["year"], |
|
"social_engagement": row["social_image_shares"], |
|
} |
|
) |
|
|
|
result = pd.DataFrame(geo_events) |
|
logger.info(f"Processed {len(result)} geopolitical sentiment records") |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing sentiment data: {e}") |
|
return pd.DataFrame() |
|
|
|
def _process_political_events( |
|
self, df: pd.DataFrame, countries: List[str], years: List[int] |
|
) -> pd.DataFrame: |
|
"""Process ACLED political events data""" |
|
|
|
if df.empty: |
|
return pd.DataFrame() |
|
|
|
try: |
|
processed = df.copy() |
|
|
|
|
|
if "event_date" in processed.columns: |
|
processed["event_date"] = pd.to_datetime( |
|
processed["event_date"], errors="coerce" |
|
) |
|
processed["year"] = processed["event_date"].dt.year |
|
|
|
|
|
event_counts = ( |
|
processed.groupby(["country", "year"]) |
|
.agg({"event_type": "count", "fatalities": "sum"}) |
|
.reset_index() |
|
) |
|
|
|
|
|
stability_data = [] |
|
for _, row in event_counts.iterrows(): |
|
|
|
stability_score = max(0, 1 - (row["event_type"] / 100)) |
|
|
|
stability_data.append( |
|
{ |
|
"country": row["country"], |
|
"year": row["year"], |
|
"political_stability": stability_score, |
|
"protest_events": row["event_type"], |
|
"fatalities": row.get("fatalities", 0), |
|
} |
|
) |
|
|
|
result = pd.DataFrame(stability_data) |
|
logger.info(f"Processed {len(result)} political stability records") |
|
return result |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing political events: {e}") |
|
return pd.DataFrame() |
|
|
|
def _assess_data_quality(self, datasets: Dict[str, pd.DataFrame]) -> Dict[str, Any]: |
|
"""Assess quality of loaded real datasets""" |
|
|
|
quality_metrics = {} |
|
|
|
for name, df in datasets.items(): |
|
if df.empty: |
|
quality_metrics[name] = { |
|
"quality_score": 0.0, |
|
"record_count": 0, |
|
"completeness": 0.0, |
|
"issues": ["No data available"], |
|
} |
|
continue |
|
|
|
|
|
total_cells = df.shape[0] * df.shape[1] |
|
missing_cells = df.isnull().sum().sum() |
|
completeness = ( |
|
(total_cells - missing_cells) / total_cells if total_cells > 0 else 0 |
|
) |
|
|
|
|
|
temporal_coverage = 1.0 |
|
if "year" in df.columns: |
|
expected_years = set(range(2020, 2025)) |
|
actual_years = set(df["year"].unique()) |
|
temporal_coverage = len( |
|
actual_years.intersection(expected_years) |
|
) / len(expected_years) |
|
|
|
|
|
quality_score = (completeness + temporal_coverage) / 2 |
|
|
|
|
|
issues = [] |
|
if completeness < 0.8: |
|
issues.append(f"High missing data rate: {(1-completeness)*100:.1f}%") |
|
if temporal_coverage < 0.5: |
|
issues.append(f"Poor temporal coverage: {temporal_coverage*100:.1f}%") |
|
if len(df) < self.real_data_config.min_data_points: |
|
issues.append(f"Insufficient data points: {len(df)}") |
|
|
|
quality_metrics[name] = { |
|
"quality_score": quality_score, |
|
"record_count": len(df), |
|
"completeness": completeness, |
|
"temporal_coverage": temporal_coverage, |
|
"issues": issues, |
|
} |
|
|
|
return quality_metrics |
|
|
|
def fit_with_real_data( |
|
self, |
|
countries: List[str], |
|
hs_codes: List[str], |
|
years: List[int], |
|
force_refresh: bool = False, |
|
) -> "RealDataTIPMModel": |
|
""" |
|
Train TIPM model using real datasets |
|
|
|
Args: |
|
countries: Country codes to analyze |
|
hs_codes: Product codes to analyze |
|
years: Years of data to use for training |
|
force_refresh: Force refresh of cached data |
|
""" |
|
|
|
logger.info("Training TIPM with real datasets...") |
|
|
|
|
|
real_datasets = self.load_real_data(countries, hs_codes, years, force_refresh) |
|
|
|
|
|
training_data = {} |
|
|
|
for layer_name, dataset_name in [ |
|
("tariff_shocks", "tariff_shocks"), |
|
("trade_flows", "trade_flows"), |
|
("industry_responses", "industry_responses"), |
|
("firm_responses", "firm_responses"), |
|
("consumer_impacts", "consumer_impacts"), |
|
("geopolitical_events", "geopolitical_events"), |
|
]: |
|
|
|
if ( |
|
dataset_name in real_datasets |
|
and not real_datasets[dataset_name].empty |
|
and self.data_quality_metrics.get(dataset_name, {}).get( |
|
"quality_score", 0 |
|
) |
|
> 0.3 |
|
): |
|
|
|
|
|
training_data[layer_name] = real_datasets[dataset_name] |
|
logger.info( |
|
f"Using real data for {layer_name}: {len(real_datasets[dataset_name])} records" |
|
) |
|
|
|
elif self.real_data_config.fallback_to_synthetic: |
|
|
|
training_data[layer_name] = self._generate_fallback_data( |
|
layer_name, countries, hs_codes, years |
|
) |
|
logger.info(f"Using synthetic fallback for {layer_name}") |
|
|
|
else: |
|
|
|
training_data[layer_name] = pd.DataFrame() |
|
logger.warning(f"No data available for {layer_name}") |
|
|
|
|
|
self.fit(training_data) |
|
|
|
|
|
self.model_metadata.update( |
|
{ |
|
"data_sources": list(real_datasets.keys()), |
|
"data_quality": self.data_quality_metrics, |
|
"training_countries": countries, |
|
"training_hs_codes": hs_codes, |
|
"training_years": years, |
|
"last_trained": datetime.now().isoformat(), |
|
} |
|
) |
|
|
|
logger.info("Real data training completed") |
|
return self |
|
|
|
def _generate_fallback_data( |
|
self, |
|
layer_name: str, |
|
countries: List[str], |
|
hs_codes: List[str], |
|
years: List[int], |
|
) -> pd.DataFrame: |
|
"""Generate synthetic fallback data when real data is unavailable""" |
|
|
|
np.random.seed(42) |
|
|
|
if layer_name == "tariff_shocks": |
|
return pd.DataFrame( |
|
{ |
|
"policy_text": [f"Fallback tariff policy for {hs_codes[0]}"] |
|
* len(countries), |
|
"effective_date": ["2023-01-01"] * len(countries), |
|
"hs_codes": [",".join(hs_codes)] * len(countries), |
|
"tariff_rates": np.random.uniform(0.1, 0.3, len(countries)), |
|
"countries": [f"{c},US" for c in countries], |
|
} |
|
) |
|
|
|
elif layer_name == "trade_flows": |
|
n_records = len(countries) * len(hs_codes) * len(years) |
|
return pd.DataFrame( |
|
{ |
|
"hs_code": np.tile(hs_codes, len(countries) * len(years)), |
|
"origin_country": np.repeat(countries, len(hs_codes) * len(years)), |
|
"destination_country": ["US"] * n_records, |
|
"trade_value": np.random.lognormal( |
|
15, 1, n_records |
|
), |
|
"year": np.tile(np.repeat(years, len(hs_codes)), len(countries)), |
|
"transport_cost": np.random.uniform(10000, 100000, n_records), |
|
"lead_time": np.random.uniform(20, 60, n_records), |
|
} |
|
) |
|
|
|
elif layer_name == "industry_responses": |
|
return pd.DataFrame( |
|
{ |
|
"industry_code": ["electronics", "manufacturing", "services"], |
|
"response_metric": np.random.normal(0.1, 0.05, 3), |
|
} |
|
) |
|
|
|
elif layer_name == "firm_responses": |
|
return pd.DataFrame( |
|
{ |
|
"firm_id": [f"firm_{i}" for i in range(5)], |
|
"response_metric": np.random.normal(0.1, 0.05, 5), |
|
} |
|
) |
|
|
|
elif layer_name == "consumer_impacts": |
|
return pd.DataFrame( |
|
{ |
|
"product_category": ["electronics", "general_goods"], |
|
"price_change": np.random.uniform(0.02, 0.08, 2), |
|
} |
|
) |
|
|
|
elif layer_name == "geopolitical_events": |
|
return pd.DataFrame( |
|
{"event_text": ["Fallback geopolitical event"], "sentiment": [0.0]} |
|
) |
|
|
|
return pd.DataFrame() |
|
|
|
def get_data_provenance(self) -> Dict[str, Any]: |
|
"""Get information about data sources and quality""" |
|
|
|
return { |
|
"last_update": ( |
|
self.last_data_update.isoformat() if self.last_data_update else None |
|
), |
|
"datasets_available": list(self.real_datasets.keys()), |
|
"data_quality_metrics": self.data_quality_metrics, |
|
"model_metadata": self.model_metadata, |
|
"configuration": { |
|
"cache_directory": self.real_data_config.data_cache_dir, |
|
"fallback_enabled": self.real_data_config.fallback_to_synthetic, |
|
"quality_thresholds": { |
|
"min_data_points": self.real_data_config.min_data_points, |
|
"min_trade_coverage": self.real_data_config.min_trade_coverage, |
|
"max_missing_data_ratio": self.real_data_config.max_missing_data_ratio, |
|
}, |
|
}, |
|
} |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
real_model = RealDataTIPMModel() |
|
|
|
|
|
real_model.fit_with_real_data( |
|
countries=["840", "156"], |
|
hs_codes=["8517", "8471"], |
|
years=[2022, 2023, 2024], |
|
) |
|
|
|
|
|
provenance = real_model.get_data_provenance() |
|
print("Data Provenance:") |
|
print(json.dumps(provenance, indent=2)) |
|
|