|
from src.model import GraphformerModel |
|
from pathlib import Path |
|
from loguru import logger |
|
|
|
|
|
class ModelTrainer: |
|
def __init__(self, df, output_path, epochs=100,test_size=0.3): |
|
self.df = df |
|
self.output_path = output_path |
|
self.epochs = epochs |
|
self.test_size=test_size |
|
|
|
|
|
Path(self.output_path).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
self.model = GraphformerModel(df=self.df, output_path=self.output_path, epochs=self.epochs,test_size=self.test_size) |
|
|
|
|
|
|
|
logger.info(f"Initialized ModelTrainer with output_path: {self.output_path} and epochs: {self.epochs}") |
|
|
|
|
|
def train_and_evaluate(self): |
|
|
|
try: |
|
logger.info("Starting model training and evaluation") |
|
self.model.run_model() |
|
logger.info("GraphformerModel training and evaluation completed successfully") |
|
except Exception as e: |
|
logger.error(f"Error during GraphformerModel training and evaluation: {e}") |
|
raise |
|
|
|
|
|
|