|
|
|
|
|
replace_dict = {' .': '.', |
|
' ,': ',', |
|
' ;': ';', |
|
' :': ':', |
|
'( ': '(', |
|
' )': ')', |
|
" '": "'"} |
|
|
|
|
|
def get_recipe(ids, vocab): |
|
toks = [] |
|
for id_ in ids: |
|
toks.append(vocab[id_]) |
|
return toks |
|
|
|
|
|
def get_ingrs(ids, ingr_vocab_list): |
|
gen_ingrs = [] |
|
for ingr_idx in ids: |
|
ingr_name = ingr_vocab_list[ingr_idx] |
|
if ingr_name == '<pad>': |
|
break |
|
gen_ingrs.append(ingr_name) |
|
return gen_ingrs |
|
|
|
|
|
def prettify(toks, replace_dict): |
|
toks = ' '.join(toks) |
|
toks = toks.split('<end>')[0] |
|
sentences = toks.split('<eoi>') |
|
|
|
pretty_sentences = [] |
|
for sentence in sentences: |
|
sentence = sentence.strip() |
|
sentence = sentence.capitalize() |
|
for k, v in replace_dict.items(): |
|
sentence = sentence.replace(k, v) |
|
if sentence != '': |
|
pretty_sentences.append(sentence) |
|
return pretty_sentences |
|
|
|
|
|
def colorized_list(ingrs, ingrs_gt, colorize=False): |
|
if colorize: |
|
colorized_list = [] |
|
for word in ingrs: |
|
if word in ingrs_gt: |
|
word = '\033[1;30;42m ' + word + ' \x1b[0m' |
|
else: |
|
word = '\033[1;30;41m ' + word + ' \x1b[0m' |
|
colorized_list.append(word) |
|
return colorized_list |
|
else: |
|
return ingrs |
|
|
|
|
|
def prepare_output(ids, gen_ingrs, ingr_vocab_list, vocab): |
|
|
|
toks = get_recipe(ids, vocab) |
|
is_valid = True |
|
reason = 'All ok.' |
|
try: |
|
cut = toks.index('<end>') |
|
toks_trunc = toks[0:cut] |
|
except: |
|
toks_trunc = toks |
|
is_valid = False |
|
reason = 'no eos found' |
|
|
|
|
|
score = float(len(set(toks_trunc))) / float(len(toks_trunc)) |
|
|
|
prev_word = '' |
|
found_repeat = False |
|
for word in toks_trunc: |
|
if prev_word == word and prev_word != '<eoi>': |
|
found_repeat = True |
|
break |
|
prev_word = word |
|
|
|
toks = prettify(toks, replace_dict) |
|
title = toks[0] |
|
toks = toks[1:] |
|
|
|
if gen_ingrs is not None: |
|
gen_ingrs = get_ingrs(gen_ingrs, ingr_vocab_list) |
|
|
|
if score <= 0.3: |
|
reason = 'Diversity score.' |
|
is_valid = False |
|
elif len(toks) != len(set(toks)): |
|
reason = 'Repeated instructions.' |
|
is_valid = False |
|
elif found_repeat: |
|
reason = 'Found word repeat.' |
|
is_valid = False |
|
|
|
valid = {'is_valid': is_valid, 'reason': reason, 'score': score} |
|
outs = {'title': title, 'recipe': toks, 'ingrs': gen_ingrs} |
|
|
|
return outs, valid |
|
|