File size: 3,027 Bytes
a9cc853 be53c78 6735ae4 91f2f92 19904de 91f2f92 8e36e52 c6407ad 8e36e52 acbaa45 15b7594 91f2f92 91515a1 91f2f92 91515a1 91f2f92 ada166c 8e36e52 ada166c 8e36e52 ada166c 91f2f92 ada166c 91f2f92 ada166c 91f2f92 be53c78 ada166c be53c78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
#%%
import time
from text_processing import split_into_words, Word
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
from pprint import pprint
def load_model_and_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, tokenizer, device
def process_input_text(input_text, tokenizer, device):
inputs = tokenizer(input_text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]
return inputs, input_ids
def calculate_log_probabilities(model, tokenizer, inputs, input_ids):
with torch.no_grad():
outputs = model(**inputs, labels=input_ids)
logits = outputs.logits[0, :-1, :]
log_probs = torch.log_softmax(logits, dim=-1)
token_log_probs = log_probs[range(log_probs.shape[0]), input_ids[0][1:]]
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
return list(zip(tokens[1:], token_log_probs.tolist()))
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix: str, device: torch.device, num_samples: int = 5) -> list[str]:
input_context = tokenizer(prefix, return_tensors="pt").to(device)
input_ids = input_context["input_ids"]
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
max_length=input_ids.shape[-1] + 5,
num_return_sequences=num_samples,
temperature=1.0,
top_k=50,
top_p=0.95,
do_sample=True
)
new_words = []
for i in range(num_samples):
generated_ids = outputs[i][input_ids.shape[-1]:]
new_word = tokenizer.decode(generated_ids, skip_special_tokens=True).split()[0]
new_words.append(new_word)
return new_words
#%%
model_name = "mistralai/Mistral-7B-v0.1"
model, tokenizer, device = load_model_and_tokenizer(model_name)
input_text = "He asked me to prostrate myself before the king, but I rifused."
inputs, input_ids = process_input_text(input_text, tokenizer, device)
result = calculate_log_probabilities(model, tokenizer, inputs, input_ids)
words = split_into_words([token for token, _ in result], [logprob for _, logprob in result])
log_prob_threshold = -5.0
low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
#%%
start_time = time.time()
for word in low_prob_words:
prefix_index = word.first_token_index
prefix_tokens = [token for token, _ in result][:prefix_index + 1]
prefix = tokenizer.convert_tokens_to_string(prefix_tokens)
replacements = generate_replacements(model, tokenizer, prefix, device)
print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
print(f"Proposed replacements: {replacements}")
print()
end_time = time.time()
print(f"Time taken for the loop: {end_time - start_time:.4f} seconds")
|