import torch import platform class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def __call__(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def getPlatform(): plt = platform.system() if plt=='Darwin': return 'mac' return plt def hasGPU(plt:str): if plt == 'mac': return torch.backends.mps.is_available() return torch.cuda.is_available() def getDevice(plt:str): if plt == 'mac': return torch.device('mps') return torch.device('cuda') def disableWarnings(): import warnings warnings.filterwarnings("ignore", category=UserWarning, module="transformers.utils.generic") warnings.filterwarnings("ignore", category=UserWarning, module="trl.trainer.ppo_config") warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly")