# -*- coding: utf-8 -*- # @Time    : 2023/3/20 8:02 p.m. # @Author  : Jianing Wang # @File    : calibrate.py import os import numpy as np import torch """ Use LM to classify label words for calibrating CLS """ class CLSCalibrator: pass """ Use Causal LM to generate label words for calibrating CLS e.g., use gpt2 to generate a label word with in-context prompts, and calibrate for the prediction. Paper: http://proceedings.mlr.press/v139/zhao21c.html """ class CausalCLSCalibrator: def __init__(self, model, tokenizer) -> None: self.model = model self.tokenizer = tokenizer def calibrate(self, all_label_probs, content_free_examples, label2id, mode="diagonal_W"): """Perform calibration for de-biasing and obtain calibrated probability""" p_cf = self.get_content_free_prediction(content_free_examples, label2id) num_classes = all_label_probs.shape[1] if p_cf is None: # do not calibrate W = np.identity(num_classes) b = np.zeros([num_classes, 1]) else: # calibrate if mode == "diagonal_W": W = np.linalg.inv(np.identity(num_classes) * p_cf) b = np.zeros([num_classes, 1]) elif mode == "identity_W": W = np.identity(num_classes) b = -1 * np.expand_dims(p_cf, axis=-1) else: assert False all_calibrate_label_probs = list() for label_probs in all_label_probs: label_probs = label_probs / np.sum(label_probs) # normalize to 1 calibrate_label_probs = np.matmul(W, np.expand_dims(label_probs, axis=-1)) + b all_calibrate_label_probs.append(calibrate_label_probs.squeeze().tolist()) return np.array(all_calibrate_label_probs) def get_content_free_prediction(self, content_free_examples, label2id: dict): """Query model with content free input, return its prediction probability for each label""" all_p_y = [] for content_free_example in content_free_examples: content_free_prompt = content_free_example["content_free_prompt"] p_y = [0] * len(label2id) for answers, i in label2id.items(): prob = 0 for a in answers: prob += np.exp(self.get_causal_cls_prediction(content_free_prompt + " " + a, 0, echo=True, num_log_probs=1)['choices'][0]['logprobs']['token_logprobs'][-1]) p_y[i] = prob all_p_y.append(p_y) p_y = np.mean(np.array(all_p_y), axis=0) p_y = p_y / np.sum(p_y) # normalize return p_y def get_causal_cls_prediction(self, prompt, l=10, num_log_probs=None, echo=False): ''' This function runs GPT-2 locally but places the outputs into an json that looks just like the one provided by the OpenAI API. ''' if isinstance(prompt, str): prompt = [prompt] # the code below assumes a list input_ids = self.tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True) if l + len(input_ids['input_ids'][0]) > 1020: m = l + len(input_ids['input_ids'][0]) - 1024 input_ids['input_ids'] = torch.Tensor([input_ids['input_ids'][0][m:].numpy()]).long() input_ids['attention_mask'] = torch.Tensor([input_ids['attention_mask'][0][m:].numpy()]).long() # greedily generate l tokens # print("l=", l) if l > 0: # the generate function can handle left padded inputs automatically in HF # total_sequences is now the input + possible generated output # print("l + len(input_ids[input_ids][0]=", l + len(input_ids['input_ids'][0])) total_sequences = self.model.generate( input_ids=input_ids['input_ids'].to(self.model.device), attention_mask=input_ids['attention_mask'].to(self.model.device), max_length=l + len(input_ids['input_ids'][0]), do_sample=False ) else: assert echo == True and l == 0 total_sequences = input_ids['input_ids'].to(self.model.device) # print("="*50) # print("total_sequences=", total_sequences) [batch, len+l] # print("total_sequences.shape=", total_sequences.shape) # they want the probs of the top tokens if num_log_probs is not None: # we are left padding, so we need to adjust the position IDs attention_mask = (total_sequences != 50256).float() position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) # get the logits for the context and the next l tokens logits = self.model.forward(input_ids=total_sequences, attention_mask=attention_mask, position_ids=position_ids, return_dict=True).logits.detach().cpu() if not echo: # get the top tokens and probs for the generated l tokens probs = torch.softmax(logits[:,-l-1:], dim=2).cpu() else: # get the top tokens and probs for the context and the generated l tokens probs = torch.softmax(logits, dim=2).cpu() top_probs, top_tokens = torch.topk(probs, k=num_log_probs) logprobs = torch.log(probs) top_log_probs = torch.log(top_probs) # print("top_log_probs=", top_log_probs) # print("top_log_probs.shape=", top_log_probs.shape) # [1, 2, 100] [batch, 2, api_num_log_prob] # create the return value to resemble OpenAI return_json = {} choices = [] # print("="*50) for batch_id in range(len(prompt)): curr_json = {} # text is just the optional context and next l tokens if not echo: curr_json['text'] = self.tokenizer.decode(total_sequences[batch_id][-l:], skip_special_tokens=True) else: curr_json['text'] = self.tokenizer.decode(total_sequences[batch_id], skip_special_tokens=True) # fill the return json with the top tokens and probs to match the OpenAI return value. if num_log_probs is not None: curr_json['logprobs'] = {} curr_json['logprobs']['top_logprobs'] = [] curr_json['logprobs']['token_logprobs'] = [] curr_json['logprobs']['tokens'] = [] if not echo: # cutoff the -1 here because the probs are shifted one over for LMs for current_element_top_log_probs, current_element_top_tokens in zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1]): # tokens is a list of the top token at each position curr_json['logprobs']['tokens'].append(self.tokenizer.decode([current_element_top_tokens[0]])) # token_logprobs is a list of the logprob of the top token at each position curr_json['logprobs']['token_logprobs'].append(current_element_top_log_probs[0].item()) # top_logprobs is a list of dicts for the top K tokens. with each entry being {'token_name': log_prob} temp = {} for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens): temp[self.tokenizer.decode(token.item())] = log_prob.item() curr_json['logprobs']['top_logprobs'].append(temp) else: # same as not above but small tweaks # we add null to the front because for the GPT models, they have null probability for the first token # (for some reason they don't have an beginning of sentence token) curr_json['logprobs']['top_logprobs'].append('null') # cutoff the -1 here because the probs are shifted one over for LMs for index, (current_element_top_log_probs, current_element_top_tokens) in enumerate(zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1])): # skip padding tokens if total_sequences[batch_id][index].item() == 50256: continue temp = {} for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens): temp[self.tokenizer.decode(token.item())] = log_prob.item() curr_json['logprobs']['top_logprobs'].append(temp) for index in range(len(probs[batch_id])): curr_json['logprobs']['tokens'].append(self.tokenizer.decode([total_sequences[batch_id][index]])) curr_json['logprobs']['token_logprobs'].append('null') for index, log_probs_token_position_j in enumerate(logprobs[batch_id][:-1]): # probs are left shifted for LMs curr_json['logprobs']['token_logprobs'].append(log_probs_token_position_j[total_sequences[batch_id][index+1]]) choices.append(curr_json) # print("curr_json=", curr_json) ''' e.g., num_tokens_to_predict=1 curr_json= { 'text': ' I', # 当前生成的top词 'logprobs': {'top_logprobs': [{' I': -3.4267239570617676, '\n': -3.5073862075805664, ...], # top100词及其socre 'token_logprobs': [-3.4267239570617676], # 当前top词的score 'tokens': [' I']} } num_tokens_to_predict=2 curr_json= { 'text': '\nThe', # 如果指定生成两个词,则为两个词 'logprobs': {'top_logprobs': [ # 两个位置对应的预测的score {'\n': -3.186706304550171, '\xa0': -3.222092390060425, ' We': -6.781067848205566, ...}, {'The': -2.5251243114471436, '"': -2.857935667037964, ...], 'token_logprobs': [-3.186706304550171, -2.5251243114471436], # 生成的词的score 'tokens': ['\n', 'The']} } ''' return_json['choices'] = choices # print("="*50) # print("return_json=", return_json) return return_json