|
|
|
from completions import * |
|
from expand_llm import * |
|
from expand import * |
|
|
|
|
|
model, tokenizer, device = load_model() |
|
|
|
|
|
|
|
|
|
input_text = "Здравствуйте, я хочу предвыполнить заказ" |
|
inputs: BatchEncoding = tokenize(input_text, tokenizer, device) |
|
|
|
|
|
token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs) |
|
|
|
|
|
words = split_into_words(token_probs, tokenizer) |
|
log_prob_threshold = -5.0 |
|
low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold] |
|
|
|
|
|
contexts = [word.context for _, word in low_prob_words] |
|
|
|
|
|
expander = LLMBatchExpander(model, tokenizer) |
|
|
|
|
|
series = [] |
|
for i, x in enumerate(contexts): |
|
series.append(Series(id=i, tokens=x, budget=5.0)) |
|
|
|
|
|
batch = Batch(items=series) |
|
|
|
|
|
stopping_criterion = create_stopping_criterion_llm(tokenizer) |
|
|
|
|
|
expanded = expand(batch, expander, stopping_criterion) |
|
|
|
|
|
def print_expansions(expansions: CompletedBatch): |
|
for result in expansions.items: |
|
for expansion in result.expansions: |
|
|
|
tokens = [e.token for e in expansion] |
|
s = tokenizer.decode(tokens) |
|
print(f"{result.series.id}: {expansion} {s}") |
|
|
|
print_expansions(expanded) |
|
|
|
|