|
import random |
|
import pandas as pd |
|
import numpy as np |
|
from tqdm import tqdm |
|
from copy import copy,deepcopy |
|
from collections import Counter |
|
import torch |
|
from torch import nn |
|
from torch.utils.data import DataLoader |
|
from transformers import get_cosine_schedule_with_warmup,get_linear_schedule_with_warmup, logging |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from .match_groups import MatchGroups |
|
from .scoring import score_predicted |
|
from .scoring_model import SimilarityScore |
|
from .embeddings import Embeddings |
|
from .embedding_model import EmbeddingModel |
|
from .configuration import SimilarityModelConfig |
|
logging.set_verbosity_error() |
|
|
|
|
|
class ExponentWeights(): |
|
def __init__(self, config,**kwargs): |
|
self.exponent = config.get("weighting_exponent", 0.5) |
|
|
|
def __call__(self,counts): |
|
return counts**self.exponent |
|
|
|
|
|
class SimilarityModel(PreTrainedModel): |
|
config_class = SimilarityModelConfig |
|
""" |
|
A combined embedding/scorer model that produces Embeddings objects |
|
as its primary output. |
|
|
|
- train() jointly optimizes the embedding_model and score_model using |
|
contrastive learning to learn from a training MatchGroups. |
|
""" |
|
def __init__(self, config, **kwargs): |
|
super().__init__(config) |
|
|
|
self.embedding_model = EmbeddingModel(config.embedding_model_config, **kwargs) |
|
self.score_model = SimilarityScore(config.score_model_config, **kwargs) |
|
self.weighting_function = ExponentWeights(config.weighting_function_config, **kwargs) |
|
|
|
self.config = config |
|
self.to(config.device) |
|
|
|
def to(self,device): |
|
super().to(device) |
|
self.embedding_model.to(device) |
|
self.score_model.to(device) |
|
|
|
|
|
def save(self,savefile): |
|
torch.save({'metadata': self.config, 'state_dict': self.state_dict()}, savefile) |
|
|
|
@torch.no_grad() |
|
def embed(self,input,to=None,batch_size=64,progress_bar=True,**kwargs): |
|
""" |
|
Construct an Embeddings object from input strings or a MatchGroups |
|
""" |
|
|
|
if to is None: |
|
to = self.device |
|
|
|
if isinstance(input, MatchGroups): |
|
strings = input.strings() |
|
counts = torch.tensor([input.counts[s] for s in strings],device=self.device).float().to(to) |
|
|
|
else: |
|
strings = list(input) |
|
counts = torch.ones(len(strings),device=self.device).float().to(to) |
|
|
|
input_loader = DataLoader(strings,batch_size=batch_size,num_workers=0) |
|
|
|
self.embedding_model.eval() |
|
|
|
V = None |
|
batch_start = 0 |
|
with tqdm(total=len(strings),delay=1,desc='Embedding strings',disable=not progress_bar) as pbar: |
|
for batch_strings in input_loader: |
|
|
|
v = self.embedding_model(batch_strings).detach().to(to) |
|
|
|
if V is None: |
|
|
|
|
|
V = torch.empty(len(strings),v.shape[1],device=to,dtype=v.dtype) |
|
|
|
V[batch_start:batch_start+len(batch_strings),:] = v |
|
|
|
pbar.update(len(batch_strings)) |
|
batch_start += len(batch_strings) |
|
|
|
score_model = copy(self.score_model) |
|
score_model.load_state_dict(self.score_model.state_dict()) |
|
score_model.to(to) |
|
|
|
weighting_function = deepcopy(self.weighting_function) |
|
|
|
return Embeddings(strings=strings, |
|
V=V.detach(), |
|
counts=counts.detach(), |
|
score_model=score_model, |
|
weighting_function=weighting_function, |
|
device=to) |
|
|
|
def train(self,training_groupings,max_epochs=1,batch_size=8, |
|
score_decay=0,regularization=0, |
|
transformer_lr=1e-5,projection_lr=1e-5,score_lr=10,warmup_frac=0.1, |
|
max_grad_norm=1,dropout=False, |
|
validation_groupings=None,target='F1',restore_best=True,val_seed=None, |
|
validation_interval=1000,early_stopping=True,early_stopping_patience=3, |
|
verbose=False,progress_bar=True, |
|
**kwargs): |
|
|
|
""" |
|
Train the embedding_model and score_model to predict match probabilities |
|
using the training_groupings as a source of "correct" matches. |
|
Training algorithm uses contrastive learning with hard-positive |
|
and hard-negative mining to fine tune the embedding model to place |
|
matched strings near to each other in embedding space, while |
|
simulataneously calibrating the score_model to predict the match |
|
probabilities as a function of cosine distance |
|
""" |
|
|
|
if validation_groupings is None: |
|
early_stopping = False |
|
restore_best = False |
|
|
|
num_training_steps = max_epochs*len(training_groupings)//batch_size |
|
num_warmup_steps = int(warmup_frac*num_training_steps) |
|
|
|
if transformer_lr or projection_lr: |
|
embedding_optimizer = self.embedding_model.config_optimizer(transformer_lr,projection_lr) |
|
embedding_scheduler = get_cosine_schedule_with_warmup( |
|
embedding_optimizer, |
|
num_warmup_steps=num_warmup_steps, |
|
num_training_steps=num_training_steps) |
|
if score_lr: |
|
score_optimizer = self.score_model.config_optimizer(score_lr) |
|
score_scheduler = get_linear_schedule_with_warmup( |
|
score_optimizer, |
|
num_warmup_steps=num_warmup_steps, |
|
num_training_steps=num_training_steps) |
|
|
|
step = 0 |
|
self.history = [] |
|
self.val_scores = [] |
|
for epoch in range(max_epochs): |
|
|
|
global_embeddings = self.embed(training_groupings) |
|
|
|
strings = global_embeddings.strings |
|
V = global_embeddings.V |
|
w = global_embeddings.w |
|
|
|
groups = torch.tensor([global_embeddings.string_map[training_groupings[s]] for s in strings],device=self.device) |
|
|
|
|
|
if w is not None: |
|
w = w/w.mean() |
|
|
|
shuffled_ids = list(range(len(strings))) |
|
random.shuffle(shuffled_ids) |
|
|
|
if dropout: |
|
self.embedding_model.train() |
|
else: |
|
self.embedding_model.eval() |
|
|
|
for batch_start in tqdm(range(0,len(strings),batch_size),desc=f'training epoch {epoch}',disable=not progress_bar): |
|
|
|
h = {'epoch':epoch,'step':step} |
|
|
|
batch_i = shuffled_ids[batch_start:batch_start+batch_size] |
|
|
|
|
|
if len(batch_i) < batch_size: |
|
batch_i = batch_i + shuffled_ids[:(batch_size-len(batch_i))] |
|
|
|
""" |
|
Find highest loss match for each batch string (global search) |
|
|
|
Note: If we compute V_i with dropout enabled, it will add noise |
|
to the embeddings and prevent the same pairs from being selected |
|
every time. |
|
""" |
|
V_i = self.embedding_model(strings[batch_i]) |
|
|
|
|
|
V[batch_i,:] = V_i.detach() |
|
|
|
with torch.no_grad(): |
|
|
|
global_X = [email protected] |
|
global_Y = (groups[batch_i][:,None] == groups[None,:]).float() |
|
|
|
if w is not None: |
|
global_W = torch.outer(w[batch_i],w) |
|
else: |
|
global_W = None |
|
|
|
|
|
if score_lr: |
|
|
|
self.score_model.requires_grad_(True) |
|
|
|
global_loss = self.score_model.loss(global_X,global_Y,weights=global_W,decay=score_decay) |
|
|
|
score_optimizer.zero_grad() |
|
global_loss.nanmean().backward() |
|
torch.nn.utils.clip_grad_norm_(self.score_model.parameters(),max_norm=max_grad_norm) |
|
|
|
score_optimizer.step() |
|
score_scheduler.step() |
|
|
|
h['score_lr'] = score_optimizer.param_groups[0]['lr'] |
|
h['global_mean_cos'] = global_X.mean().item() |
|
try: |
|
h['score_alpha'] = self.score_model.alpha.item() |
|
except: |
|
pass |
|
|
|
else: |
|
with torch.no_grad(): |
|
global_loss = self.score_model.loss(global_X,global_Y) |
|
|
|
h['global_loss'] = global_loss.detach().nanmean().item() |
|
|
|
|
|
if (transformer_lr or projection_lr) and step <= num_warmup_steps + num_training_steps: |
|
|
|
|
|
self.score_model.requires_grad_(False) |
|
|
|
|
|
with torch.no_grad(): |
|
batch_j = global_loss.argmax(dim=1).flatten() |
|
|
|
if w is not None: |
|
batch_W = torch.outer(w[batch_i],w[batch_j]) |
|
else: |
|
batch_W = None |
|
|
|
|
|
V_j = self.embedding_model(strings[batch_j.tolist()]) |
|
|
|
|
|
V[batch_j,:] = V_j.detach() |
|
|
|
batch_X = V_i@V_j.T |
|
batch_Y = (groups[batch_i][:,None] == groups[batch_j][None,:]).float() |
|
h['batch_obs'] = len(batch_i)*len(batch_j) |
|
|
|
batch_loss = self.score_model.loss(batch_X,batch_Y,weights=batch_W) |
|
|
|
if regularization: |
|
|
|
gor_Y = (groups[batch_i][:,None] != groups[batch_i][None,:]).float() |
|
gor_n = gor_Y.sum() |
|
if gor_n > 1: |
|
gor_X = (V_i@V_i.T)*gor_Y |
|
gor_m1 = 0.5*gor_X.sum()/gor_n |
|
gor_m2 = 0.5*(gor_X**2).sum()/gor_n |
|
batch_loss += regularization*(gor_m1 + torch.clamp(gor_m2 - 1/self.embedding_model.d,min=0)) |
|
|
|
h['batch_nan'] = torch.isnan(batch_loss.detach()).sum().item() |
|
|
|
embedding_optimizer.zero_grad() |
|
batch_loss.nanmean().backward() |
|
|
|
torch.nn.utils.clip_grad_norm_(self.parameters(),max_norm=max_grad_norm) |
|
|
|
embedding_optimizer.step() |
|
embedding_scheduler.step() |
|
|
|
h['transformer_lr'] = embedding_optimizer.param_groups[1]['lr'] |
|
h['projection_lr'] = embedding_optimizer.param_groups[-1]['lr'] |
|
|
|
|
|
h['batch_loss'] = batch_loss.detach().mean().item() |
|
h['batch_pos_target'] = batch_Y.detach().mean().item() |
|
|
|
self.history.append(h) |
|
step += 1 |
|
|
|
if (validation_groupings is not None) and not (step % validation_interval): |
|
|
|
validation = len(self.validation_scores) |
|
val_scores = self.test(validation_groupings) |
|
val_scores['step'] = step - 1 |
|
val_scores['epoch'] = epoch |
|
val_scores['validation'] = validation |
|
|
|
self.validation_scores.append(val_scores) |
|
|
|
|
|
if verbose: |
|
print(f'\nValidation results at step {step} (current epoch {epoch})') |
|
for k,v in val_scores.items(): |
|
print(f' {k}: {v:.4f}') |
|
|
|
print(list(self.score_model.named_parameters())) |
|
|
|
|
|
if restore_best: |
|
if val_scores[target] >= max(h[target] for h in self.validation_scores): |
|
best_state = deepcopy({ |
|
'state_dict':self.state_dict(), |
|
'val_scores':val_scores |
|
}) |
|
|
|
if early_stopping and (validation - best_state['val_scores']['validation'] > early_stopping_patience): |
|
print(f'Stopping training ({early_stopping_patience} validation checks since best validation score)') |
|
break |
|
|
|
if restore_best: |
|
print(f"Restoring to best state (step {best_state['val_scores']['step']}):") |
|
for k,v in best_state['val_scores'].items(): |
|
print(f' {k}: {v:.4f}') |
|
|
|
self.to('cpu') |
|
self.load_state_dict(best_state['state_dict']) |
|
self.to(self.device) |
|
|
|
return pd.DataFrame(self.history) |
|
|
|
def unite_similar(self,input,**kwargs): |
|
embeddings = self.embed(input,**kwargs) |
|
return embeddings.unite_similar(**kwargs) |
|
|
|
def test(self,gold_groupings, threshold=0.5, **kwargs): |
|
embeddings = self.embed(gold_groupings, **kwargs) |
|
|
|
if (isinstance(threshold, float)): |
|
predicted = embeddings.unite_similar(threshold=threshold, **kwargs) |
|
scores = score_predicted(predicted, gold_groupings, use_counts=True) |
|
|
|
return scores |
|
|
|
results = [] |
|
for thres in threshold: |
|
predicted = embeddings.unite_similar(threshold=thres, **kwargs) |
|
|
|
scores = score_predicted(predicted, gold_groupings, use_counts=True) |
|
scores["threshold"] = thres |
|
results.append(scores) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
def load_similarity_model(f,map_location='cpu',*args,**kwargs): |
|
checkpoint = torch.load(f, map_location=map_location, **kwargs) |
|
metadata = checkpoint['metadata'] |
|
state_dict = checkpoint['state_dict'] |
|
|
|
model = SimilarityModel(config=metadata) |
|
model.load_state_dict(state_dict) |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|