File size: 1,121 Bytes
67b1c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from src.model import GraphformerModel
from pathlib import Path
from loguru import logger


class ModelTrainer:
    def __init__(self, df, output_path, epochs=100,test_size=0.3):
        self.df = df
        self.output_path = output_path
        self.epochs = epochs
        self.test_size=test_size
        
        # Create output directory
        Path(self.output_path).mkdir(parents=True, exist_ok=True)
        
        # Initialize the HeteroGraphormerModel
    
        self.model = GraphformerModel(df=self.df, output_path=self.output_path, epochs=self.epochs,test_size=self.test_size)

        
        
        logger.info(f"Initialized ModelTrainer with output_path: {self.output_path} and epochs: {self.epochs}")
        

    def train_and_evaluate(self):
     
        try:
            logger.info("Starting model training and evaluation")
            self.model.run_model()
            logger.info("GraphformerModel training and evaluation completed successfully")
        except Exception as e:
            logger.error(f"Error during GraphformerModel training and evaluation: {e}")
            raise