import copy import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence from torch.nn.functional import cross_entropy, binary_cross_entropy from tqdm.auto import tqdm from utils import Config, extract_spans, generate_targets from representation import TransformerRepresentation from layers import SpanEnumerationLayer DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class SpanNet(nn.Module): def __init__(self, **kwargs): super(SpanNet, self).__init__() self.config = Config() self.config.pos = kwargs.get('pos', None) # pos self.config.dp = kwargs.get('dp', 0.3) # dp self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased') self.config.token_pooling = kwargs.get('token_pooling', 'sum') self.device = kwargs.get('device', DEFAULT_DEVICE) self.config.repr_type = kwargs.get('repr_type', 'token_classification') assert self.config.repr_type in ['token_classification', 'span_enumeration'], 'Invalid representaton type' self.transformer = TransformerRepresentation( model_name=self.config.transformer_model_name, device=self.device).to(self.device) self.transformer_dim = self.transformer.embedding_dim if self.config.pos: self.transformer.add_special_tokens([f'[{p}]' for p in self.config.pos]) self.span_tags = ['B', 'I', 'O'] # , '-'] self.enumeration_layer = SpanEnumerationLayer() output_size = {'token_classification': len(self.span_tags), 'span_enumeration': 1} self.span_output_layer = nn.Sequential( nn.Linear(self.transformer_dim, self.transformer_dim), nn.ReLU(), nn.Dropout(p=self.config.dp), nn.Linear(self.transformer_dim, output_size[self.config.repr_type])) def to_dict(self): return { 'model_config': self.config.__dict__, 'model_state_dict': self.state_dict() } @classmethod def load_model(cls, model_path, device=DEFAULT_DEVICE): res = torch.load(model_path, device) model = cls(**res['model_config']) model.load_state_dict(res['model_state_dict'], strict=False) model.eval() return model @classmethod def preds_to_sequences(self, predictions, enumerations, length): # assumes the function is applied per tensor/sample # sort descendindly enum_preds = {predictions[idx].item(): enumerations[idx] for idx in range(len(enumerations))} sorted_enum_preds = dict(sorted(enum_preds.items(), key=lambda val:val[1], reverse=True)) # look for clashes spans = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()] spans_copy = [sorted_enum_preds[key] for key in sorted_enum_preds.keys()] i=0 while(i!=(len(spans_copy))): filtered_spans = [] s,e = spans_copy[i] for j in range(i+1, len(spans_copy)): sj,ej = spans_copy[j] if((sj0): bounds = (e+1)-(s+1) tagged_seq[s+1:e+1] =['I'] * bounds return tagged_seq def save_model(self, output_path): torch.save(self.to_dict(), output_path) def _extract_sentence_vectors(self, sentences, pos=None): if pos and self.config.pos: sentences = [[f'[{p}] {s}' for s, p in zip(s, p)] for s, p in zip(sentences, pos)] outs = self.transformer(sentences, is_pretokenized=True, token_pooling=self.config.token_pooling) return outs.pooled_tokens def forward(self, sentences, pos=None, tags=None, **kwargs): out_dict = {} embs = self._extract_sentence_vectors(sentences, pos) if kwargs.get('output_word_vecs', False): out_dict['word_vecs'] = embeddings lens = [len(s) for s in embs] if self.config.repr_type == 'span_enumeration': embs, enumerations = self.enumeration_layer(embs, lens) lens = [len(e) for e in enumerations] input_layer = pad_sequence(embs, batch_first=True) span_scores = [torch.unbind(f)[:l] for f, l in zip(self.span_output_layer(input_layer), lens)] if kwargs.get('output_span_scores', False): out_dict['span_scores'] = span_scores if self.config.repr_type == "token_classification": pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores] pred_span_tags = [[self.span_tags[idx] for idx in sequence] for sequence in pred_span_ids] out_dict['pred_tags'] = pred_span_tags else: lens = [len(s) for s in sentences] tagged_seq=[] prev_enum = 0 for idx in range(0, len(enumerations)): enum = enumerations[idx] length =lens[idx] scores = flat_scores[prev_enum :len(enum)+ prev_enum] prev_enum = len(enum) tagged_seq.append(self.preds_to_sequences(scores, enum, length)) out_dict['pred_tags'] = tagged_seq if tags is None: return out_dict if self.config.repr_type == 'span_enumeration': targets = generate_targets(enumerations, tags) targets = torch.Tensor([t for st in targets for t in st]) flat_scores = torch.Tensor([t for score in span_scores for t in score]) print('before: ', flat_scores.shape) if self.config.repr_type == 'token_classification': # limit the targets of each sentence to the words not truncated during tokenization targets = torch.cat( [torch.tensor([self.span_tags.index(t[0]) for t, _ in zip(tg, sc)]) for tg, sc in zip(tags, span_scores)]).to(self.device) flat_scores = torch.stack([s for tg, sc in zip(tags, span_scores) for _, s in zip(tg, sc)]) if self.config.repr_type == 'span_enumeration': span_loss = binary_cross_entropy(flat_scores.sigmoid(), targets) else: span_loss = cross_entropy(flat_scores, targets) out_dict['loss'] = span_loss return out_dict def from_span_scores(self, span_scores): pred_span_ids = [[torch.argmax(s) for s in sc] for sc in span_scores] return [[self.span_tags[idx] for idx in sequence] for sequence in pred_span_ids] class EntNet(nn.Module): def __init__(self, **kwargs): super(EntNet, self).__init__() self.config = Config() self.span_net = kwargs.get('span_net') self.config.tune_span_net = kwargs.get('tune_span_net', False) self.config.use_span_emb = kwargs.get('use_span_emb', False) self.config.use_ent_markers = kwargs.get('use_ent_markers', False) # it is possible to tune span_net without using its embeddings if self.span_net and not self.config.tune_span_net: for p in self.span_net.parameters(): p.requires_grad = False self.config.ent_tags = self.ent_tags = kwargs.get('ent_tags') self.config.pos = kwargs.get('pos', None) self.config.dp = kwargs.get('dp', 0.3) self.config.transformer_model_name = kwargs.get('transformer_model_name', 'bert-base-uncased') self.config.token_pooling = kwargs.get('token_pooling', 'first') self.device = kwargs.get('device', DEFAULT_DEVICE) self.transformer = TransformerRepresentation( model_name=self.config.transformer_model_name, device=self.device).to(self.device) self.transformer_dim = self.transformer.embedding_dim self.transformer.add_special_tokens(['[ENT]', '[/ENT]']) self.transformer.add_special_tokens(['[INFO]', '[/INFO]']) if self.config.pos: self.transformer.add_special_tokens( ['['+p+']' for p in self.config.pos]) self.ent_output_layer = nn.Sequential( nn.Linear(2*self.transformer_dim, 2*self.transformer_dim), nn.ReLU(), nn.Dropout(p=self.config.dp), nn.Linear(2*self.transformer_dim, len(self.config.ent_tags))) def to_dict(self): return { 'model_config': self.config.__dict__, 'span_net_config': self.span_net.config.__dict__ if self.span_net is not None else None, 'model_state_dict': self.state_dict() } @classmethod def load_model(cls, model_path, device=DEFAULT_DEVICE): res = torch.load(model_path, device) span_net = SpanNet(**res['span_net_config']) if res['span_net_config'] is not None else None model = cls(span_net=span_net, **res['model_config']) model.load_state_dict(res['model_state_dict']) model.eval() return model def save_model(self, output_path): torch.save(self.to_dict(), output_path) def _extract_sentence_vectors(self, sentences, pos=None, ent_bounds=None): if pos and self.config.pos: sentences = [[f'[{p}] {s}' for s, p in zip(s, p)] for s, p in zip(sentences, pos)] if ent_bounds and self.config.use_ent_markers: for sent, sent_ents in zip(sentences, ent_bounds): for ent in sent_ents: sent[ent[0]] = f'[ENT] {sent[ent[0]]}' sent[ent[1]] = f'{sent[ent[1]]} [/ENT]' outs = self.transformer(sentences, is_pretokenized=True, token_pooling=self.config.token_pooling) return outs.pooled_tokens def forward(self, sentences, pos=None, tags=None, **kwargs): out_dict = {} pred_span_seqs = kwargs.get('pred_tags', None) if pred_span_seqs is None: span_out = self.span_net(sentences, pos=pos, output_word_vecs=self.config.use_span_emb, tags=tags if self.config.tune_span_net else None) pred_span_seqs = span_out['pred_tags'] bounds = [[e[1] for e in extract_spans(t, tagless=True)[3]] for t in pred_span_seqs] if tags is not None: gold_spans = [[e for e in extract_spans(t, tagless=True)[3]] for t in tags] matches = [[[g[0] for g in golds if p[0] == g[1][0] and p[1] == g[1][1]] for p in preds] for preds, golds in zip(bounds, gold_spans)] targets = [[span_matches[0] if len(span_matches) == 1 else 'O' for span_matches in sent_matches] for sent_matches in matches] sentences = [sent + [t for bd in sent_bounds for t in [self.transformer.tokenizer.sep_token] + sent[bd[0]:bd[1] + 1]] + [self.transformer.tokenizer.sep_token] for sent, sent_bounds in zip(sentences, bounds)] sep_ids = [[i for i, s in enumerate(sent) if s == self.transformer.tokenizer.sep_token] for sent in sentences] embs = self._extract_sentence_vectors(sentences, pos, bounds) if kwargs.get('output_word_vecs', False): out_dict['word_vecs'] = embs span_vecs = [ torch.stack([torch.cat((torch.sum(e[b[0]:b[1] + 1], dim=0), torch.sum(e[spi[i]:spi[i+1]+1], dim=0))) for i, b in enumerate(bd)]) if bd else torch.zeros((0)).to(self.device) for e, bd, spi in zip(embs, bounds, sep_ids)] ent_scores = [self.ent_output_layer(sv) if len(sv) else sv for sv in span_vecs] if kwargs.get('output_ent_scores', False): out_dict['ent_scores'] = ent_scores out_dict['bounds'] = bounds if tags is None: max_tags = [[self.ent_tags[torch.argmax(e)] for e in es] for es in ent_scores] # reconstruct sequences sent_lens = [len(s) for s in sentences] combined_sequences = [] for mt, bnd, lens in zip(max_tags, bounds, sent_lens): x = ['O' for _ in range(lens)] for t, b in zip(mt, bnd): x[b[0]] = 'O' if t == 'O' else f'B-{t}' for i in range(b[0] + 1, b[1] + 1): x[i] = 'O' if t == 'O' else f'I-{t}' combined_sequences.append(x) out_dict['pred_tags'] = combined_sequences return out_dict ent_targs = torch.tensor([self.ent_tags.index(t) for targ in targets for t in targ], dtype=torch.long).to(self.device) ent_preds = torch.cat(ent_scores) if not len(ent_preds): out_dict['loss'] = None return out_dict ent_loss = cross_entropy(ent_preds, ent_targs) out_dict['loss'] = ent_loss if self.config.tune_span_net: out_dict['loss'] += span_out['loss'] return out_dict def from_ent_scores(self, ent_scores, sentences, bounds): max_tags = [[self.ent_tags[torch.argmax(e)] for e in es] for es in ent_scores] # reconstruct sequences sent_lens = [len(s) for s in sentences] combined_sequences = [] for mt, bnd, lens in zip(max_tags, bounds, sent_lens): x = ['O' for _ in range(lens)] for t, b in zip(mt, bnd): x[b[0]] = 'O' if t == 'O' else f'B-{t}' for i in range(b[0] + 1, b[1] + 1): x[i] = 'O' if t == 'O' else f'I-{t}' combined_sequences.append(x) return combined_sequences