|
""" |
|
Core TIPM Model Implementation |
|
============================= |
|
|
|
Main orchestrator for the Tariff Impact Propagation Model |
|
""" |
|
|
|
from typing import Dict, List, Any, Optional |
|
import logging |
|
from dataclasses import dataclass |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional, Set |
|
from datetime import datetime |
|
|
|
|
|
@dataclass |
|
class EnhancedCountryData: |
|
"""Enhanced country data structure for v1.5 with authoritative classifications""" |
|
|
|
|
|
name: str |
|
iso_alpha_2: str |
|
iso_alpha_3: str |
|
un_code: str |
|
|
|
tariff_rate: float |
|
bilateral_trade_usd: float |
|
gdp_usd: float |
|
gdp_per_capita: float |
|
|
|
continent: str |
|
region: str |
|
income_group: str |
|
|
|
global_groups: Set[str] |
|
trade_agreements: Set[str] |
|
currency_bloc: Optional[str] |
|
|
|
emerging_market_status: bool |
|
tech_manufacturing_rank: Optional[int] |
|
supply_chain_critical: bool |
|
|
|
mining_resource_exports: List[str] |
|
agricultural_exports: List[str] |
|
strategic_commodities: List[str] |
|
|
|
data_sources: Dict[str, str] |
|
last_updated: datetime |
|
confidence_level: str |
|
|
|
def validate(self) -> bool: |
|
"""Validate country data fields and bounds.""" |
|
if not 0 <= self.tariff_rate <= 100: |
|
raise ValueError( |
|
f"Invalid tariff_rate for {self.name}: {self.tariff_rate}%. Must be 0-100%" |
|
) |
|
if self.gdp_usd < 0: |
|
raise ValueError(f"Negative GDP for {self.name}: {self.gdp_usd}") |
|
if self.gdp_per_capita < 0: |
|
raise ValueError( |
|
f"Negative GDP per capita for {self.name}: {self.gdp_per_capita}" |
|
) |
|
return True |
|
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
|
from .layers.policy_trigger import PolicyTriggerLayer |
|
from .layers.trade_flow import TradeFlowLayer |
|
from .layers.industry_response import IndustryResponseLayer |
|
from .layers.firm_impact import FirmImpactLayer |
|
from .layers.consumer_impact import ConsumerImpactLayer |
|
from .layers.geopolitical import GeopoliticalLayer |
|
from .config.settings import ( |
|
TIPMConfig, |
|
PolicyLayerConfig, |
|
TradeFlowConfig, |
|
IndustryConfig, |
|
FirmConfig, |
|
ConsumerConfig, |
|
GeopoliticalConfig, |
|
) |
|
|
|
|
|
@dataclass |
|
@dataclass |
|
class TariffShock: |
|
"""Represents a tariff policy change""" |
|
|
|
tariff_id: str |
|
hs_codes: List[str] |
|
rate_change: float |
|
origin_country: str |
|
destination_country: str |
|
effective_date: str |
|
policy_text: str |
|
|
|
|
|
@dataclass |
|
class TIPMPrediction: |
|
"""Model prediction output structure""" |
|
|
|
tariff_shock: TariffShock |
|
trade_flow_impact: Dict[str, Any] |
|
industry_response: Dict[str, Any] |
|
firm_impact: Dict[str, Any] |
|
consumer_impact: Dict[str, Any] |
|
geopolitical_impact: Dict[str, Any] |
|
confidence_scores: Dict[str, float] |
|
|
|
|
|
class TIPMModel: |
|
""" |
|
Main TIPM Model Class |
|
|
|
Orchestrates the 6-layer architecture for tariff impact prediction: |
|
1. Policy Trigger Input (tariff shock processing) |
|
2. Upstream Trade Flow Modeling |
|
3. Industry-Level Economic Response |
|
4. Firm-Level & Employment Impact |
|
5. Consumer Market Impact |
|
6. Geopolitical & Social Feedback |
|
""" |
|
|
|
def __init__(self, config: Optional[TIPMConfig] = None): |
|
"""Initialize TIPM model with configuration""" |
|
self.config = config or TIPMConfig() |
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
self._initialize_layers() |
|
|
|
|
|
self.is_trained = False |
|
self.model_metadata = {} |
|
|
|
def _initialize_layers(self): |
|
"""Initialize all 6 model layers""" |
|
|
|
policy_config = self.config.policy_config or PolicyLayerConfig() |
|
trade_flow_config = self.config.trade_flow_config or TradeFlowConfig() |
|
industry_config = self.config.industry_config or IndustryConfig() |
|
firm_config = self.config.firm_config or FirmConfig() |
|
consumer_config = self.config.consumer_config or ConsumerConfig() |
|
geopolitical_config = self.config.geopolitical_config or GeopoliticalConfig() |
|
|
|
self.policy_layer = PolicyTriggerLayer(policy_config) |
|
self.trade_flow_layer = TradeFlowLayer(trade_flow_config) |
|
self.industry_layer = IndustryResponseLayer(industry_config) |
|
self.firm_layer = FirmImpactLayer(firm_config) |
|
self.consumer_layer = ConsumerImpactLayer(consumer_config) |
|
self.geopolitical_layer = GeopoliticalLayer(geopolitical_config) |
|
|
|
self.logger.info("TIPM model layers initialized successfully") |
|
|
|
def fit(self, training_data: Dict[str, pd.DataFrame]) -> "TIPMModel": |
|
""" |
|
Train the TIPM model on historical data |
|
|
|
Args: |
|
training_data: Dictionary containing training datasets for each layer |
|
- 'tariff_shocks': Historical tariff policy changes |
|
- 'trade_flows': Historical trade flow data |
|
- 'industry_responses': Industry-level response data |
|
- 'firm_responses': Firm-level response data |
|
- 'consumer_impacts': Consumer market impact data |
|
- 'geopolitical_events': Social/political response data |
|
""" |
|
self.logger.info("Starting TIPM model training...") |
|
|
|
|
|
self.policy_layer.fit(training_data.get("tariff_shocks")) |
|
self.trade_flow_layer.fit(training_data.get("trade_flows")) |
|
self.industry_layer.fit(training_data.get("industry_responses")) |
|
self.firm_layer.fit(training_data.get("firm_responses")) |
|
self.consumer_layer.fit(training_data.get("consumer_impacts")) |
|
self.geopolitical_layer.fit(training_data.get("geopolitical_events")) |
|
|
|
self.is_trained = True |
|
self.logger.info("TIPM model training completed") |
|
|
|
return self |
|
|
|
def predict(self, tariff_shock: TariffShock) -> TIPMPrediction: |
|
""" |
|
Predict tariff impact propagation through all layers |
|
|
|
Args: |
|
tariff_shock: Tariff policy change to analyze |
|
|
|
Returns: |
|
TIPMPrediction: Comprehensive impact prediction |
|
""" |
|
if not self.is_trained: |
|
raise ValueError("Model must be trained before making predictions") |
|
|
|
self.logger.info( |
|
f"Predicting impact for tariff shock: {tariff_shock.tariff_id}" |
|
) |
|
|
|
|
|
policy_features = self.policy_layer.transform(tariff_shock) |
|
|
|
|
|
trade_impact = self.trade_flow_layer.predict(policy_features) |
|
|
|
|
|
industry_impact = self.industry_layer.predict(trade_impact) |
|
|
|
|
|
firm_impact = self.firm_layer.predict(industry_impact) |
|
|
|
|
|
consumer_impact = self.consumer_layer.predict(firm_impact) |
|
|
|
|
|
geopolitical_impact = self.geopolitical_layer.predict(consumer_impact) |
|
|
|
|
|
confidence_scores = self._calculate_confidence_scores( |
|
policy_features, |
|
trade_impact, |
|
industry_impact, |
|
firm_impact, |
|
consumer_impact, |
|
geopolitical_impact, |
|
) |
|
|
|
return TIPMPrediction( |
|
tariff_shock=tariff_shock, |
|
trade_flow_impact=trade_impact, |
|
industry_response=industry_impact, |
|
firm_impact=firm_impact, |
|
consumer_impact=consumer_impact, |
|
geopolitical_impact=geopolitical_impact, |
|
confidence_scores=confidence_scores, |
|
) |
|
|
|
def _calculate_confidence_scores(self, *layer_outputs) -> Dict[str, float]: |
|
"""Calculate prediction confidence scores for each layer""" |
|
|
|
return { |
|
"policy_confidence": 0.85, |
|
"trade_flow_confidence": 0.78, |
|
"industry_confidence": 0.82, |
|
"firm_confidence": 0.73, |
|
"consumer_confidence": 0.79, |
|
"geopolitical_confidence": 0.65, |
|
"overall_confidence": 0.77, |
|
} |
|
|
|
def simulate_scenario( |
|
self, tariff_shocks: List[TariffShock], time_horizon: int = 12 |
|
) -> Dict[str, Any]: |
|
""" |
|
Run scenario simulation with multiple tariff shocks over time |
|
|
|
Args: |
|
tariff_shocks: List of tariff policy changes |
|
time_horizon: Simulation time horizon in months |
|
|
|
Returns: |
|
Dictionary containing simulation results |
|
""" |
|
self.logger.info( |
|
f"Running scenario simulation with {len(tariff_shocks)} shocks" |
|
) |
|
|
|
scenario_results = { |
|
"timeline": [], |
|
"cumulative_impacts": {}, |
|
"country_impacts": {}, |
|
"sector_impacts": {}, |
|
} |
|
|
|
for shock in tariff_shocks: |
|
prediction = self.predict(shock) |
|
scenario_results["timeline"].append( |
|
{"shock": shock, "prediction": prediction} |
|
) |
|
|
|
return scenario_results |
|
|
|
def get_country_exposure(self, country_code: str) -> Dict[str, Any]: |
|
""" |
|
Get tariff exposure profile for a specific country |
|
|
|
Args: |
|
country_code: ISO country code (e.g., 'SG', 'US', 'CN') |
|
|
|
Returns: |
|
Dictionary containing country-specific exposure metrics |
|
""" |
|
|
|
|
|
return { |
|
"import_dependency": {}, |
|
"export_exposure": {}, |
|
"supply_chain_vulnerability": {}, |
|
"economic_resilience": {}, |
|
} |
|
|