Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from .net import GreedyNet | |
from .play import play | |
from .utils import v_wrap | |
def evaluate_checkpoints(dir, env): | |
results = {} | |
for checkpoint in os.listdir(dir): | |
pretrained_model_path = os.path.join(dir, checkpoint) | |
if os.path.isfile(pretrained_model_path): | |
wins, guesses = evaluate(env, pretrained_model_path) | |
results[checkpoint] = wins, guesses | |
return dict( | |
sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True) | |
) | |
def evaluate(env, pretrained_model_path): | |
n_wins = 0 | |
n_guesses = 0 | |
n_win_guesses = 0 | |
env = env.unwrapped | |
N = env.allowable_words | |
for goal_word in env.words[:N]: | |
win, outcomes = play(env, pretrained_model_path, goal_word) | |
if win: | |
n_wins += 1 | |
n_win_guesses += len(outcomes) | |
# else: | |
# print("Lost!", goal_word, outcomes) | |
n_guesses += len(outcomes) | |
print( | |
f"Evaluation complete, won {n_wins/N*100}% and \ | |
took {n_win_guesses/n_wins} guesses per win, " | |
f"{n_guesses / N} including losses." | |
) | |
return n_wins / N * 100, n_win_guesses / n_wins | |