Spaces:
Build error
Build error
Fix code style with black and isort
Browse files- a3c/eval.py +3 -4
- a3c/net.py +5 -5
- a3c/play.py +6 -8
- a3c/shared_adam.py +8 -10
- a3c/train.py +40 -14
- a3c/utils.py +1 -1
- a3c/worker.py +52 -43
- api_rest/api.py +17 -15
- main.py +41 -41
- rs_wordle_player/firebase_connector.py +20 -19
- rs_wordle_player/rs_wordle_player.py +3 -2
- rs_wordle_player/selenium_player.py +14 -15
- wordle_env/__init__.py +4 -8
- wordle_env/const.py +1 -1
- wordle_env/state.py +16 -25
- wordle_env/test_wordle.py +1 -2
- wordle_env/wordle.py +25 -26
- wordle_env/words.py +9 -5
- wordle_game.py +26 -25
a3c/eval.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
from .net import GreedyNet
|
|
@@ -14,9 +15,7 @@ def evaluate_checkpoints(dir, env):
|
|
| 14 |
wins, guesses = evaluate(env, pretrained_model_path)
|
| 15 |
results[checkpoint] = wins, guesses
|
| 16 |
return dict(
|
| 17 |
-
sorted(results.items(), key=lambda x: (
|
| 18 |
-
x[1][0], -x[1][1]), reverse=True
|
| 19 |
-
)
|
| 20 |
)
|
| 21 |
|
| 22 |
|
|
@@ -39,4 +38,4 @@ def evaluate(env, pretrained_model_path):
|
|
| 39 |
took {n_win_guesses/n_wins} guesses per win, "
|
| 40 |
f"{n_guesses / N} including losses."
|
| 41 |
)
|
| 42 |
-
return n_wins/N*100, n_win_guesses/n_wins
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
|
| 5 |
from .net import GreedyNet
|
|
|
|
| 15 |
wins, guesses = evaluate(env, pretrained_model_path)
|
| 16 |
results[checkpoint] = wins, guesses
|
| 17 |
return dict(
|
| 18 |
+
sorted(results.items(), key=lambda x: (x[1][0], -x[1][1]), reverse=True)
|
|
|
|
|
|
|
| 19 |
)
|
| 20 |
|
| 21 |
|
|
|
|
| 38 |
took {n_win_guesses/n_wins} guesses per win, "
|
| 39 |
f"{n_guesses / N} including losses."
|
| 40 |
)
|
| 41 |
+
return n_wins / N * 100, n_win_guesses / n_wins
|
a3c/net.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
-
import numpy as np
|
| 5 |
|
| 6 |
|
| 7 |
class Net(nn.Module):
|
|
@@ -23,15 +23,15 @@ class Net(nn.Module):
|
|
| 23 |
word_array = np.zeros((word_width, len(word_list)))
|
| 24 |
for i, word in enumerate(word_list):
|
| 25 |
for j, c in enumerate(word):
|
| 26 |
-
word_array[j*26 + (ord(c) - ord(
|
| 27 |
self.words = torch.Tensor(word_array)
|
| 28 |
|
| 29 |
def forward(self, x):
|
| 30 |
values = self.v1(x.float())
|
| 31 |
logits = torch.log_softmax(
|
| 32 |
-
torch.tensordot(self.actor_head(values), self.words,
|
| 33 |
-
|
| 34 |
-
|
| 35 |
values = self.v4(values)
|
| 36 |
return logits, values
|
| 37 |
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class Net(nn.Module):
|
|
|
|
| 23 |
word_array = np.zeros((word_width, len(word_list)))
|
| 24 |
for i, word in enumerate(word_list):
|
| 25 |
for j, c in enumerate(word):
|
| 26 |
+
word_array[j * 26 + (ord(c) - ord("A")), i] = 1
|
| 27 |
self.words = torch.Tensor(word_array)
|
| 28 |
|
| 29 |
def forward(self, x):
|
| 30 |
values = self.v1(x.float())
|
| 31 |
logits = torch.log_softmax(
|
| 32 |
+
torch.tensordot(self.actor_head(values), self.words, dims=((1,), (0,))),
|
| 33 |
+
dim=-1,
|
| 34 |
+
)
|
| 35 |
values = self.v4(values)
|
| 36 |
return logits, values
|
| 37 |
|
a3c/play.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import torch
|
| 3 |
from dotenv import load_dotenv
|
|
|
|
| 4 |
from wordle_env.state import update_from_mask
|
|
|
|
| 5 |
from .net import GreedyNet
|
| 6 |
from .utils import v_wrap
|
| 7 |
|
| 8 |
|
| 9 |
def get_play_model_path():
|
| 10 |
load_dotenv()
|
| 11 |
-
model_name = os.getenv(
|
| 12 |
-
model_checkpoint_dir = os.path.join(
|
| 13 |
return os.path.join(model_checkpoint_dir, model_name)
|
| 14 |
|
| 15 |
|
|
@@ -28,12 +31,7 @@ def get_initial_state(env):
|
|
| 28 |
return state
|
| 29 |
|
| 30 |
|
| 31 |
-
def suggest(
|
| 32 |
-
env,
|
| 33 |
-
words,
|
| 34 |
-
states,
|
| 35 |
-
pretrained_model_path
|
| 36 |
-
) -> str:
|
| 37 |
"""
|
| 38 |
Given a list of words and masks, return the next suggested word
|
| 39 |
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
from wordle_env.state import update_from_mask
|
| 7 |
+
|
| 8 |
from .net import GreedyNet
|
| 9 |
from .utils import v_wrap
|
| 10 |
|
| 11 |
|
| 12 |
def get_play_model_path():
|
| 13 |
load_dotenv()
|
| 14 |
+
model_name = os.getenv("RS_WORDLE_MODEL_NAME")
|
| 15 |
+
model_checkpoint_dir = os.path.join("checkpoints", "best_models")
|
| 16 |
return os.path.join(model_checkpoint_dir, model_name)
|
| 17 |
|
| 18 |
|
|
|
|
| 31 |
return state
|
| 32 |
|
| 33 |
|
| 34 |
+
def suggest(env, words, states, pretrained_model_path) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
"""
|
| 36 |
Given a list of words and masks, return the next suggested word
|
| 37 |
|
a3c/shared_adam.py
CHANGED
|
@@ -6,20 +6,18 @@ import torch
|
|
| 6 |
|
| 7 |
|
| 8 |
class SharedAdam(torch.optim.Adam):
|
| 9 |
-
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
|
| 10 |
-
weight_decay=0):
|
| 11 |
super(SharedAdam, self).__init__(
|
| 12 |
-
params, lr=lr,
|
| 13 |
-
betas=betas, eps=eps, weight_decay=weight_decay
|
| 14 |
)
|
| 15 |
# State initialization
|
| 16 |
for group in self.param_groups:
|
| 17 |
-
for p in group[
|
| 18 |
state = self.state[p]
|
| 19 |
-
state[
|
| 20 |
-
state[
|
| 21 |
-
state[
|
| 22 |
|
| 23 |
# share in memory
|
| 24 |
-
state[
|
| 25 |
-
state[
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class SharedAdam(torch.optim.Adam):
|
| 9 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0):
|
|
|
|
| 10 |
super(SharedAdam, self).__init__(
|
| 11 |
+
params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
|
|
|
|
| 12 |
)
|
| 13 |
# State initialization
|
| 14 |
for group in self.param_groups:
|
| 15 |
+
for p in group["params"]:
|
| 16 |
state = self.state[p]
|
| 17 |
+
state["step"] = 0
|
| 18 |
+
state["exp_avg"] = torch.zeros_like(p.data)
|
| 19 |
+
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
| 20 |
|
| 21 |
# share in memory
|
| 22 |
+
state["exp_avg"].share_memory_()
|
| 23 |
+
state["exp_avg_sq"].share_memory_()
|
a3c/train.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
-
import numpy as np
|
| 3 |
import random
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
import torch.multiprocessing as mp
|
| 6 |
-
|
| 7 |
from .net import Net
|
|
|
|
| 8 |
from .worker import Worker
|
| 9 |
|
| 10 |
|
|
@@ -25,12 +27,12 @@ def train(
|
|
| 25 |
env,
|
| 26 |
max_ep,
|
| 27 |
model_checkpoint_dir,
|
| 28 |
-
gamma=0
|
| 29 |
seed=100,
|
| 30 |
pretrained_model_path=None,
|
| 31 |
save=False,
|
| 32 |
min_reward=9.9,
|
| 33 |
-
every_n_save=100
|
| 34 |
):
|
| 35 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 36 |
if not os.path.exists(model_checkpoint_dir):
|
|
@@ -45,18 +47,40 @@ def train(
|
|
| 45 |
if pretrained_model_path:
|
| 46 |
gnet.load_state_dict(torch.load(pretrained_model_path))
|
| 47 |
gnet.share_memory() # share the global parameters in multiprocessing
|
| 48 |
-
opt = SharedAdam(
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# parallel training
|
| 54 |
workers = [
|
| 55 |
Worker(
|
| 56 |
-
max_ep,
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
]
|
| 61 |
[w.start() for w in workers]
|
| 62 |
res = [] # record episode reward to plot
|
|
@@ -68,6 +92,8 @@ def train(
|
|
| 68 |
break
|
| 69 |
[w.join() for w in workers]
|
| 70 |
if save:
|
| 71 |
-
torch.save(
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
return global_ep, win_ep, gnet, res
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import random
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
import torch
|
| 6 |
import torch.multiprocessing as mp
|
| 7 |
+
|
| 8 |
from .net import Net
|
| 9 |
+
from .shared_adam import SharedAdam
|
| 10 |
from .worker import Worker
|
| 11 |
|
| 12 |
|
|
|
|
| 27 |
env,
|
| 28 |
max_ep,
|
| 29 |
model_checkpoint_dir,
|
| 30 |
+
gamma=0.0,
|
| 31 |
seed=100,
|
| 32 |
pretrained_model_path=None,
|
| 33 |
save=False,
|
| 34 |
min_reward=9.9,
|
| 35 |
+
every_n_save=100,
|
| 36 |
):
|
| 37 |
os.environ["OMP_NUM_THREADS"] = "1"
|
| 38 |
if not os.path.exists(model_checkpoint_dir):
|
|
|
|
| 47 |
if pretrained_model_path:
|
| 48 |
gnet.load_state_dict(torch.load(pretrained_model_path))
|
| 49 |
gnet.share_memory() # share the global parameters in multiprocessing
|
| 50 |
+
opt = SharedAdam(
|
| 51 |
+
gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)
|
| 52 |
+
) # global optimizer
|
| 53 |
+
global_ep, global_ep_r, res_queue, win_ep = (
|
| 54 |
+
mp.Value("i", 0),
|
| 55 |
+
mp.Value("d", 0.0),
|
| 56 |
+
mp.Queue(),
|
| 57 |
+
mp.Value("i", 0),
|
| 58 |
+
)
|
| 59 |
|
| 60 |
# parallel training
|
| 61 |
workers = [
|
| 62 |
Worker(
|
| 63 |
+
max_ep,
|
| 64 |
+
gnet,
|
| 65 |
+
opt,
|
| 66 |
+
global_ep,
|
| 67 |
+
global_ep_r,
|
| 68 |
+
res_queue,
|
| 69 |
+
i,
|
| 70 |
+
env,
|
| 71 |
+
n_s,
|
| 72 |
+
n_a,
|
| 73 |
+
words_list,
|
| 74 |
+
word_width,
|
| 75 |
+
win_ep,
|
| 76 |
+
model_checkpoint_dir,
|
| 77 |
+
gamma,
|
| 78 |
+
pretrained_model_path,
|
| 79 |
+
save,
|
| 80 |
+
min_reward,
|
| 81 |
+
every_n_save,
|
| 82 |
+
)
|
| 83 |
+
for i in range(mp.cpu_count())
|
| 84 |
]
|
| 85 |
[w.start() for w in workers]
|
| 86 |
res = [] # record episode reward to plot
|
|
|
|
| 92 |
break
|
| 93 |
[w.join() for w in workers]
|
| 94 |
if save:
|
| 95 |
+
torch.save(
|
| 96 |
+
gnet.state_dict(),
|
| 97 |
+
os.path.join(model_checkpoint_dir, f"model_{env.unwrapped.spec.id}.pth"),
|
| 98 |
+
)
|
| 99 |
return global_ep, win_ep, gnet, res
|
a3c/utils.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
import torch
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
def v_wrap(np_array, dtype=np.float32):
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
|
| 4 |
|
| 5 |
def v_wrap(np_array, dtype=np.float32):
|
a3c/worker.py
CHANGED
|
@@ -2,40 +2,42 @@
|
|
| 2 |
Worker class implementation of the a3c discrete algorithm
|
| 3 |
"""
|
| 4 |
import os
|
| 5 |
-
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
import torch.multiprocessing as mp
|
| 8 |
from torch import nn
|
|
|
|
| 9 |
from .net import Net
|
| 10 |
from .utils import v_wrap
|
| 11 |
|
| 12 |
|
| 13 |
class Worker(mp.Process):
|
| 14 |
def __init__(
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
):
|
| 36 |
super(Worker, self).__init__()
|
| 37 |
self.max_ep = max_ep
|
| 38 |
-
self.name =
|
| 39 |
self.g_ep = global_ep
|
| 40 |
self.g_ep_r = global_ep_r
|
| 41 |
self.res_queue = res_queue
|
|
@@ -57,7 +59,7 @@ class Worker(mp.Process):
|
|
| 57 |
while self.g_ep.value < self.max_ep:
|
| 58 |
s = self.env.reset()
|
| 59 |
buffer_s, buffer_a, buffer_r = [], [], []
|
| 60 |
-
ep_r = 0.
|
| 61 |
while True:
|
| 62 |
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
| 63 |
s_, r, done, _ = self.env.step(a)
|
|
@@ -68,11 +70,9 @@ class Worker(mp.Process):
|
|
| 68 |
|
| 69 |
if done: # update global and assign to local net
|
| 70 |
# sync
|
| 71 |
-
self.push_and_pull(done, s_, buffer_s,
|
| 72 |
-
buffer_a, buffer_r)
|
| 73 |
goal_word = self.word_list[self.env.goal_word]
|
| 74 |
-
self.record(ep_r, goal_word,
|
| 75 |
-
self.word_list[a], len(buffer_a))
|
| 76 |
self.save_model()
|
| 77 |
buffer_s, buffer_a, buffer_r = [], [], []
|
| 78 |
break
|
|
@@ -81,22 +81,22 @@ class Worker(mp.Process):
|
|
| 81 |
|
| 82 |
def push_and_pull(self, done, s_, bs, ba, br):
|
| 83 |
if done:
|
| 84 |
-
v_s_ = 0.
|
| 85 |
else:
|
| 86 |
-
v_s_ = self.lnet.forward(v_wrap(
|
| 87 |
-
s_[None, :]))[-1].data.numpy()[0, 0]
|
| 88 |
|
| 89 |
buffer_v_target = []
|
| 90 |
-
for r in br[::-1]:
|
| 91 |
v_s_ = r + self.gamma * v_s_
|
| 92 |
buffer_v_target.append(v_s_)
|
| 93 |
buffer_v_target.reverse()
|
| 94 |
|
| 95 |
loss = self.lnet.loss_func(
|
| 96 |
v_wrap(np.vstack(bs)),
|
| 97 |
-
v_wrap(np.array(ba), dtype=np.int64)
|
| 98 |
-
ba[0].dtype == np.int64
|
| 99 |
-
v_wrap(np.
|
|
|
|
| 100 |
)
|
| 101 |
|
| 102 |
# calculate local gradients and push local parameters to global
|
|
@@ -110,16 +110,21 @@ class Worker(mp.Process):
|
|
| 110 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
| 111 |
|
| 112 |
def save_model(self):
|
| 113 |
-
if (
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
def record(self, ep_r, goal_word, action, action_number):
|
| 119 |
with self.g_ep.get_lock():
|
| 120 |
self.g_ep.value += 1
|
| 121 |
with self.g_ep_r.get_lock():
|
| 122 |
-
if self.g_ep_r.value == 0
|
| 123 |
self.g_ep_r.value = ep_r
|
| 124 |
else:
|
| 125 |
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
|
|
@@ -129,9 +134,13 @@ class Worker(mp.Process):
|
|
| 129 |
if self.g_ep.value % 100 == 0:
|
| 130 |
print(
|
| 131 |
self.name,
|
| 132 |
-
"Ep:",
|
|
|
|
| 133 |
"| Ep_r: %.0f" % self.g_ep_r.value,
|
| 134 |
-
"| Goal :",
|
| 135 |
-
|
| 136 |
-
"|
|
|
|
|
|
|
|
|
|
|
| 137 |
)
|
|
|
|
| 2 |
Worker class implementation of the a3c discrete algorithm
|
| 3 |
"""
|
| 4 |
import os
|
| 5 |
+
|
| 6 |
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
import torch.multiprocessing as mp
|
| 9 |
from torch import nn
|
| 10 |
+
|
| 11 |
from .net import Net
|
| 12 |
from .utils import v_wrap
|
| 13 |
|
| 14 |
|
| 15 |
class Worker(mp.Process):
|
| 16 |
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
max_ep,
|
| 19 |
+
gnet,
|
| 20 |
+
opt,
|
| 21 |
+
global_ep,
|
| 22 |
+
global_ep_r,
|
| 23 |
+
res_queue,
|
| 24 |
+
name,
|
| 25 |
+
env,
|
| 26 |
+
N_S,
|
| 27 |
+
N_A,
|
| 28 |
+
words_list,
|
| 29 |
+
word_width,
|
| 30 |
+
winning_ep,
|
| 31 |
+
model_checkpoint_dir,
|
| 32 |
+
gamma=0.0,
|
| 33 |
+
pretrained_model_path=None,
|
| 34 |
+
save=False,
|
| 35 |
+
min_reward=9.9,
|
| 36 |
+
every_n_save=100,
|
| 37 |
):
|
| 38 |
super(Worker, self).__init__()
|
| 39 |
self.max_ep = max_ep
|
| 40 |
+
self.name = "w%02i" % name
|
| 41 |
self.g_ep = global_ep
|
| 42 |
self.g_ep_r = global_ep_r
|
| 43 |
self.res_queue = res_queue
|
|
|
|
| 59 |
while self.g_ep.value < self.max_ep:
|
| 60 |
s = self.env.reset()
|
| 61 |
buffer_s, buffer_a, buffer_r = [], [], []
|
| 62 |
+
ep_r = 0.0
|
| 63 |
while True:
|
| 64 |
a = self.lnet.choose_action(v_wrap(s[None, :]))
|
| 65 |
s_, r, done, _ = self.env.step(a)
|
|
|
|
| 70 |
|
| 71 |
if done: # update global and assign to local net
|
| 72 |
# sync
|
| 73 |
+
self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r)
|
|
|
|
| 74 |
goal_word = self.word_list[self.env.goal_word]
|
| 75 |
+
self.record(ep_r, goal_word, self.word_list[a], len(buffer_a))
|
|
|
|
| 76 |
self.save_model()
|
| 77 |
buffer_s, buffer_a, buffer_r = [], [], []
|
| 78 |
break
|
|
|
|
| 81 |
|
| 82 |
def push_and_pull(self, done, s_, bs, ba, br):
|
| 83 |
if done:
|
| 84 |
+
v_s_ = 0.0 # terminal
|
| 85 |
else:
|
| 86 |
+
v_s_ = self.lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0]
|
|
|
|
| 87 |
|
| 88 |
buffer_v_target = []
|
| 89 |
+
for r in br[::-1]: # reverse buffer r
|
| 90 |
v_s_ = r + self.gamma * v_s_
|
| 91 |
buffer_v_target.append(v_s_)
|
| 92 |
buffer_v_target.reverse()
|
| 93 |
|
| 94 |
loss = self.lnet.loss_func(
|
| 95 |
v_wrap(np.vstack(bs)),
|
| 96 |
+
v_wrap(np.array(ba), dtype=np.int64)
|
| 97 |
+
if ba[0].dtype == np.int64
|
| 98 |
+
else v_wrap(np.vstack(ba)),
|
| 99 |
+
v_wrap(np.array(buffer_v_target)[:, None]),
|
| 100 |
)
|
| 101 |
|
| 102 |
# calculate local gradients and push local parameters to global
|
|
|
|
| 110 |
self.lnet.load_state_dict(self.gnet.state_dict())
|
| 111 |
|
| 112 |
def save_model(self):
|
| 113 |
+
if (
|
| 114 |
+
self.save
|
| 115 |
+
and self.g_ep_r.value >= self.min_reward
|
| 116 |
+
and self.g_ep.value % self.every_n_save == 0
|
| 117 |
+
):
|
| 118 |
+
torch.save(
|
| 119 |
+
self.gnet.state_dict(),
|
| 120 |
+
os.path.join(self.model_checkpoint_dir, f"model_{self.g_ep.value}.pth"),
|
| 121 |
+
)
|
| 122 |
|
| 123 |
def record(self, ep_r, goal_word, action, action_number):
|
| 124 |
with self.g_ep.get_lock():
|
| 125 |
self.g_ep.value += 1
|
| 126 |
with self.g_ep_r.get_lock():
|
| 127 |
+
if self.g_ep_r.value == 0.0:
|
| 128 |
self.g_ep_r.value = ep_r
|
| 129 |
else:
|
| 130 |
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
|
|
|
|
| 134 |
if self.g_ep.value % 100 == 0:
|
| 135 |
print(
|
| 136 |
self.name,
|
| 137 |
+
"Ep:",
|
| 138 |
+
self.g_ep.value,
|
| 139 |
"| Ep_r: %.0f" % self.g_ep_r.value,
|
| 140 |
+
"| Goal :",
|
| 141 |
+
goal_word,
|
| 142 |
+
"| Action: ",
|
| 143 |
+
action,
|
| 144 |
+
"| Actions: ",
|
| 145 |
+
action_number,
|
| 146 |
)
|
api_rest/api.py
CHANGED
|
@@ -1,30 +1,32 @@
|
|
| 1 |
import random
|
| 2 |
-
|
| 3 |
-
from flask import Flask,
|
| 4 |
from flask_cors import cross_origin
|
| 5 |
-
|
|
|
|
| 6 |
from wordle_env.wordle import get_env
|
|
|
|
| 7 |
|
| 8 |
app = Flask(__name__)
|
| 9 |
|
| 10 |
|
| 11 |
def validate_goal_word(word):
|
| 12 |
if not word:
|
| 13 |
-
return True,
|
| 14 |
if word.upper() not in target_vocabulary:
|
| 15 |
-
return True,
|
| 16 |
-
return False,
|
| 17 |
|
| 18 |
|
| 19 |
-
@app.route(
|
| 20 |
-
@cross_origin(origin=
|
| 21 |
def get_play():
|
| 22 |
# Get the goal word from the request
|
| 23 |
-
word = request.args.get(
|
| 24 |
|
| 25 |
error, msge = validate_goal_word(word)
|
| 26 |
if error:
|
| 27 |
-
return jsonify({
|
| 28 |
|
| 29 |
word = word.upper()
|
| 30 |
env = get_env()
|
|
@@ -32,16 +34,16 @@ def get_play():
|
|
| 32 |
# Call the play function with the goal word
|
| 33 |
# and return the attempts and the result
|
| 34 |
won, attempts = play(env, model_path, word)
|
| 35 |
-
return jsonify({
|
| 36 |
|
| 37 |
|
| 38 |
-
@app.route(
|
| 39 |
-
@cross_origin(origin=
|
| 40 |
def get_word():
|
| 41 |
# Get a random word from the target vocabulary used to train the model
|
| 42 |
word = random.choice(target_vocabulary)
|
| 43 |
word = word.upper()
|
| 44 |
-
return jsonify({
|
| 45 |
|
| 46 |
|
| 47 |
def create_app(settings_override=None):
|
|
@@ -58,5 +60,5 @@ def create_app(settings_override=None):
|
|
| 58 |
return app
|
| 59 |
|
| 60 |
|
| 61 |
-
if __name__ ==
|
| 62 |
app.run(debug=True)
|
|
|
|
| 1 |
import random
|
| 2 |
+
|
| 3 |
+
from flask import Flask, jsonify, request
|
| 4 |
from flask_cors import cross_origin
|
| 5 |
+
|
| 6 |
+
from a3c.play import get_play_model_path, play
|
| 7 |
from wordle_env.wordle import get_env
|
| 8 |
+
from wordle_env.words import target_vocabulary
|
| 9 |
|
| 10 |
app = Flask(__name__)
|
| 11 |
|
| 12 |
|
| 13 |
def validate_goal_word(word):
|
| 14 |
if not word:
|
| 15 |
+
return True, "Goal word not provided"
|
| 16 |
if word.upper() not in target_vocabulary:
|
| 17 |
+
return True, "Goal word not in vocabulary"
|
| 18 |
+
return False, ""
|
| 19 |
|
| 20 |
|
| 21 |
+
@app.route("/play_word", methods=["GET"])
|
| 22 |
+
@cross_origin(origin="*", headers=["Content-Type", "Authorization"])
|
| 23 |
def get_play():
|
| 24 |
# Get the goal word from the request
|
| 25 |
+
word = request.args.get("goal_word")
|
| 26 |
|
| 27 |
error, msge = validate_goal_word(word)
|
| 28 |
if error:
|
| 29 |
+
return jsonify({"error": msge}), 400
|
| 30 |
|
| 31 |
word = word.upper()
|
| 32 |
env = get_env()
|
|
|
|
| 34 |
# Call the play function with the goal word
|
| 35 |
# and return the attempts and the result
|
| 36 |
won, attempts = play(env, model_path, word)
|
| 37 |
+
return jsonify({"attempts": attempts, "won": won})
|
| 38 |
|
| 39 |
|
| 40 |
+
@app.route("/word", methods=["GET"])
|
| 41 |
+
@cross_origin(origin="*", headers=["Content-Type", "Authorization"])
|
| 42 |
def get_word():
|
| 43 |
# Get a random word from the target vocabulary used to train the model
|
| 44 |
word = random.choice(target_vocabulary)
|
| 45 |
word = word.upper()
|
| 46 |
+
return jsonify({"word": word})
|
| 47 |
|
| 48 |
|
| 49 |
def create_app(settings_override=None):
|
|
|
|
| 60 |
return app
|
| 61 |
|
| 62 |
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
app.run(debug=True)
|
main.py
CHANGED
|
@@ -3,23 +3,33 @@
|
|
| 3 |
import argparse
|
| 4 |
import os
|
| 5 |
import time
|
|
|
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
-
|
| 8 |
from a3c.eval import evaluate, evaluate_checkpoints
|
| 9 |
from a3c.play import suggest
|
|
|
|
| 10 |
from wordle_env.wordle import get_env
|
| 11 |
|
| 12 |
|
| 13 |
def training_mode(args, env, model_checkpoint_dir):
|
| 14 |
max_ep = args.games
|
| 15 |
start_time = time.time()
|
| 16 |
-
pretrained_model_path =
|
| 17 |
-
model_checkpoint_dir, args.model_name
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
global_ep, win_ep, gnet, res = train(
|
| 20 |
-
env,
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
)
|
| 24 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
| 25 |
print_results(global_ep, win_ep, res)
|
|
@@ -34,8 +44,8 @@ def evaluation_mode(args, env, model_checkpoint_dir):
|
|
| 34 |
|
| 35 |
def play_mode(args, env, model_checkpoint_dir):
|
| 36 |
print("Play mode")
|
| 37 |
-
words = [word.strip() for word in args.words.split(
|
| 38 |
-
states = [state.strip() for state in args.states.split(
|
| 39 |
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
|
| 40 |
word = suggest(env, words, states, pretrained_model_path)
|
| 41 |
print(word)
|
|
@@ -45,8 +55,8 @@ def print_results(global_ep, win_ep, res):
|
|
| 45 |
print("Jugadas:", global_ep.value)
|
| 46 |
print("Ganadas:", win_ep.value)
|
| 47 |
plt.plot(res)
|
| 48 |
-
plt.ylabel(
|
| 49 |
-
plt.xlabel(
|
| 50 |
plt.show()
|
| 51 |
|
| 52 |
|
|
@@ -55,90 +65,80 @@ if __name__ == "__main__":
|
|
| 55 |
parser.add_argument(
|
| 56 |
"enviroment",
|
| 57 |
help="Enviroment (type of wordle game) used for training, \
|
| 58 |
-
example: WordleEnvFull-v0"
|
| 59 |
)
|
| 60 |
parser.add_argument(
|
| 61 |
"--models_dir",
|
| 62 |
help="Directory where models are saved (default=checkpoints)",
|
| 63 |
-
default=
|
| 64 |
)
|
| 65 |
-
subparsers = parser.add_subparsers(help=
|
| 66 |
|
| 67 |
parser_train = subparsers.add_parser(
|
| 68 |
-
|
| 69 |
-
help='Train a model from scratch or train from pretrained model'
|
| 70 |
)
|
| 71 |
parser_train.add_argument(
|
| 72 |
-
"--games",
|
| 73 |
-
"-g",
|
| 74 |
-
help="Number of games to train",
|
| 75 |
-
type=int,
|
| 76 |
-
required=True
|
| 77 |
)
|
| 78 |
parser_train.add_argument(
|
| 79 |
"--model_name",
|
| 80 |
"-m",
|
| 81 |
help="If want to train from a pretrained model, \
|
| 82 |
-
the name of the pretrained model file"
|
| 83 |
)
|
| 84 |
parser_train.add_argument(
|
| 85 |
"--gamma",
|
| 86 |
help="Gamma hyperparameter (discount factor) value",
|
| 87 |
type=float,
|
| 88 |
-
default=0.
|
| 89 |
)
|
| 90 |
parser_train.add_argument(
|
| 91 |
-
"--seed",
|
| 92 |
-
help="Seed used for random numbers generation",
|
| 93 |
-
type=int,
|
| 94 |
-
default=100
|
| 95 |
)
|
| 96 |
parser_train.add_argument(
|
| 97 |
"--save",
|
| 98 |
-
|
| 99 |
help="Save instances of the model while training",
|
| 100 |
-
action=
|
| 101 |
)
|
| 102 |
parser_train.add_argument(
|
| 103 |
"--min_reward",
|
| 104 |
help="The minimun global reward value achieved for saving the model",
|
| 105 |
type=float,
|
| 106 |
-
default=9.9
|
| 107 |
)
|
| 108 |
parser_train.add_argument(
|
| 109 |
"--every_n_save",
|
| 110 |
help="Check every n training steps to save the model",
|
| 111 |
type=int,
|
| 112 |
-
default=100
|
| 113 |
)
|
| 114 |
parser_train.set_defaults(func=training_mode)
|
| 115 |
|
| 116 |
parser_eval = subparsers.add_parser(
|
| 117 |
-
|
|
|
|
| 118 |
parser_eval.set_defaults(func=evaluation_mode)
|
| 119 |
|
| 120 |
parser_play = subparsers.add_parser(
|
| 121 |
-
|
| 122 |
-
help=
|
| 123 |
-
and the model will try to predict the goal word
|
| 124 |
)
|
| 125 |
parser_play.add_argument(
|
| 126 |
-
"--words",
|
| 127 |
-
"-w",
|
| 128 |
-
help="List of words played in the wordle game",
|
| 129 |
-
required=True
|
| 130 |
)
|
| 131 |
parser_play.add_argument(
|
| 132 |
"--states",
|
| 133 |
"-st",
|
| 134 |
help="List of states returned by playing each of the words",
|
| 135 |
-
required=True
|
| 136 |
)
|
| 137 |
parser_play.add_argument(
|
| 138 |
"--model_name",
|
| 139 |
"-m",
|
| 140 |
help="Name of the pretrained model file thich will play the game",
|
| 141 |
-
required=True
|
| 142 |
)
|
| 143 |
parser_play.set_defaults(func=play_mode)
|
| 144 |
|
|
|
|
| 3 |
import argparse
|
| 4 |
import os
|
| 5 |
import time
|
| 6 |
+
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
from a3c.eval import evaluate, evaluate_checkpoints
|
| 10 |
from a3c.play import suggest
|
| 11 |
+
from a3c.train import train
|
| 12 |
from wordle_env.wordle import get_env
|
| 13 |
|
| 14 |
|
| 15 |
def training_mode(args, env, model_checkpoint_dir):
|
| 16 |
max_ep = args.games
|
| 17 |
start_time = time.time()
|
| 18 |
+
pretrained_model_path = (
|
| 19 |
+
os.path.join(model_checkpoint_dir, args.model_name)
|
| 20 |
+
if args.model_name
|
| 21 |
+
else args.model_name
|
| 22 |
+
)
|
| 23 |
global_ep, win_ep, gnet, res = train(
|
| 24 |
+
env,
|
| 25 |
+
max_ep,
|
| 26 |
+
model_checkpoint_dir,
|
| 27 |
+
args.gamma,
|
| 28 |
+
args.seed,
|
| 29 |
+
pretrained_model_path,
|
| 30 |
+
args.save,
|
| 31 |
+
args.min_reward,
|
| 32 |
+
args.every_n_save,
|
| 33 |
)
|
| 34 |
print("--- %.0f seconds ---" % (time.time() - start_time))
|
| 35 |
print_results(global_ep, win_ep, res)
|
|
|
|
| 44 |
|
| 45 |
def play_mode(args, env, model_checkpoint_dir):
|
| 46 |
print("Play mode")
|
| 47 |
+
words = [word.strip() for word in args.words.split(",")]
|
| 48 |
+
states = [state.strip() for state in args.states.split(",")]
|
| 49 |
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
|
| 50 |
word = suggest(env, words, states, pretrained_model_path)
|
| 51 |
print(word)
|
|
|
|
| 55 |
print("Jugadas:", global_ep.value)
|
| 56 |
print("Ganadas:", win_ep.value)
|
| 57 |
plt.plot(res)
|
| 58 |
+
plt.ylabel("Moving average ep reward")
|
| 59 |
+
plt.xlabel("Step")
|
| 60 |
plt.show()
|
| 61 |
|
| 62 |
|
|
|
|
| 65 |
parser.add_argument(
|
| 66 |
"enviroment",
|
| 67 |
help="Enviroment (type of wordle game) used for training, \
|
| 68 |
+
example: WordleEnvFull-v0",
|
| 69 |
)
|
| 70 |
parser.add_argument(
|
| 71 |
"--models_dir",
|
| 72 |
help="Directory where models are saved (default=checkpoints)",
|
| 73 |
+
default="checkpoints",
|
| 74 |
)
|
| 75 |
+
subparsers = parser.add_subparsers(help="sub-command help")
|
| 76 |
|
| 77 |
parser_train = subparsers.add_parser(
|
| 78 |
+
"train", help="Train a model from scratch or train from pretrained model"
|
|
|
|
| 79 |
)
|
| 80 |
parser_train.add_argument(
|
| 81 |
+
"--games", "-g", help="Number of games to train", type=int, required=True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
)
|
| 83 |
parser_train.add_argument(
|
| 84 |
"--model_name",
|
| 85 |
"-m",
|
| 86 |
help="If want to train from a pretrained model, \
|
| 87 |
+
the name of the pretrained model file",
|
| 88 |
)
|
| 89 |
parser_train.add_argument(
|
| 90 |
"--gamma",
|
| 91 |
help="Gamma hyperparameter (discount factor) value",
|
| 92 |
type=float,
|
| 93 |
+
default=0.0,
|
| 94 |
)
|
| 95 |
parser_train.add_argument(
|
| 96 |
+
"--seed", help="Seed used for random numbers generation", type=int, default=100
|
|
|
|
|
|
|
|
|
|
| 97 |
)
|
| 98 |
parser_train.add_argument(
|
| 99 |
"--save",
|
| 100 |
+
"-s",
|
| 101 |
help="Save instances of the model while training",
|
| 102 |
+
action="store_true",
|
| 103 |
)
|
| 104 |
parser_train.add_argument(
|
| 105 |
"--min_reward",
|
| 106 |
help="The minimun global reward value achieved for saving the model",
|
| 107 |
type=float,
|
| 108 |
+
default=9.9,
|
| 109 |
)
|
| 110 |
parser_train.add_argument(
|
| 111 |
"--every_n_save",
|
| 112 |
help="Check every n training steps to save the model",
|
| 113 |
type=int,
|
| 114 |
+
default=100,
|
| 115 |
)
|
| 116 |
parser_train.set_defaults(func=training_mode)
|
| 117 |
|
| 118 |
parser_eval = subparsers.add_parser(
|
| 119 |
+
"eval", help="Evaluate saved models for the enviroment"
|
| 120 |
+
)
|
| 121 |
parser_eval.set_defaults(func=evaluation_mode)
|
| 122 |
|
| 123 |
parser_play = subparsers.add_parser(
|
| 124 |
+
"play",
|
| 125 |
+
help="Give the model a word and the state result \
|
| 126 |
+
and the model will try to predict the goal word",
|
| 127 |
)
|
| 128 |
parser_play.add_argument(
|
| 129 |
+
"--words", "-w", help="List of words played in the wordle game", required=True
|
|
|
|
|
|
|
|
|
|
| 130 |
)
|
| 131 |
parser_play.add_argument(
|
| 132 |
"--states",
|
| 133 |
"-st",
|
| 134 |
help="List of states returned by playing each of the words",
|
| 135 |
+
required=True,
|
| 136 |
)
|
| 137 |
parser_play.add_argument(
|
| 138 |
"--model_name",
|
| 139 |
"-m",
|
| 140 |
help="Name of the pretrained model file thich will play the game",
|
| 141 |
+
required=True,
|
| 142 |
)
|
| 143 |
parser_play.set_defaults(func=play_mode)
|
| 144 |
|
rs_wordle_player/firebase_connector.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
-
import firebase_admin
|
| 3 |
-
from firebase_admin import credentials
|
| 4 |
-
from firebase_admin import firestore
|
| 5 |
from datetime import datetime
|
| 6 |
-
from dotenv import load_dotenv
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
class FirebaseConnector():
|
| 10 |
|
|
|
|
| 11 |
def __init__(self):
|
| 12 |
load_dotenv()
|
| 13 |
cert_path = self.get_credentials_path()
|
|
@@ -20,32 +19,34 @@ class FirebaseConnector():
|
|
| 20 |
return db
|
| 21 |
|
| 22 |
def get_credentials_path(self):
|
| 23 |
-
credentials_path = os.getenv(
|
| 24 |
return credentials_path
|
| 25 |
|
| 26 |
def get_user(self):
|
| 27 |
-
user = os.getenv(
|
| 28 |
return user
|
| 29 |
|
| 30 |
def get_state_from_fb_result(self, firebase_result):
|
| 31 |
-
result_number_map = {
|
| 32 |
-
'misplaced': '1',
|
| 33 |
-
'correct': '2'}
|
| 34 |
char_result_map = map(
|
| 35 |
lambda char_res: result_number_map[char_res], firebase_result
|
| 36 |
)
|
| 37 |
-
return
|
| 38 |
|
| 39 |
def today(self):
|
| 40 |
-
return datetime.today().strftime(
|
| 41 |
|
| 42 |
def today_user_results(self):
|
| 43 |
-
daily_results_col =
|
| 44 |
currentUser = self.get_user()
|
| 45 |
# Execute the query and get the first result
|
| 46 |
-
docs =
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
return docs
|
| 50 |
|
| 51 |
def today_user_attempts(self):
|
|
@@ -53,10 +54,10 @@ class FirebaseConnector():
|
|
| 53 |
attempted_words = []
|
| 54 |
if len(docs) > 0:
|
| 55 |
doc = docs[0]
|
| 56 |
-
attempted_words = doc.to_dict().get(
|
| 57 |
return attempted_words
|
| 58 |
|
| 59 |
def today_word(self):
|
| 60 |
-
words_col =
|
| 61 |
doc = self.db.collection(words_col).document(self.today())
|
| 62 |
-
return doc.get().get(
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
| 2 |
from datetime import datetime
|
|
|
|
| 3 |
|
| 4 |
+
import firebase_admin
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from firebase_admin import credentials, firestore
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
+
class FirebaseConnector:
|
| 10 |
def __init__(self):
|
| 11 |
load_dotenv()
|
| 12 |
cert_path = self.get_credentials_path()
|
|
|
|
| 19 |
return db
|
| 20 |
|
| 21 |
def get_credentials_path(self):
|
| 22 |
+
credentials_path = os.getenv("RS_FIREBASE_CREDENTIALS_PATH")
|
| 23 |
return credentials_path
|
| 24 |
|
| 25 |
def get_user(self):
|
| 26 |
+
user = os.getenv("RS_WORDLE_USER")
|
| 27 |
return user
|
| 28 |
|
| 29 |
def get_state_from_fb_result(self, firebase_result):
|
| 30 |
+
result_number_map = {"incorrect": "0", "misplaced": "1", "correct": "2"}
|
|
|
|
|
|
|
| 31 |
char_result_map = map(
|
| 32 |
lambda char_res: result_number_map[char_res], firebase_result
|
| 33 |
)
|
| 34 |
+
return "".join(char_result_map)
|
| 35 |
|
| 36 |
def today(self):
|
| 37 |
+
return datetime.today().strftime("%Y%m%d")
|
| 38 |
|
| 39 |
def today_user_results(self):
|
| 40 |
+
daily_results_col = "dailyResults"
|
| 41 |
currentUser = self.get_user()
|
| 42 |
# Execute the query and get the first result
|
| 43 |
+
docs = (
|
| 44 |
+
self.db.collection(daily_results_col)
|
| 45 |
+
.where("user.email", "==", currentUser)
|
| 46 |
+
.where("date", "==", self.today())
|
| 47 |
+
.limit(1)
|
| 48 |
+
.get()
|
| 49 |
+
)
|
| 50 |
return docs
|
| 51 |
|
| 52 |
def today_user_attempts(self):
|
|
|
|
| 54 |
attempted_words = []
|
| 55 |
if len(docs) > 0:
|
| 56 |
doc = docs[0]
|
| 57 |
+
attempted_words = doc.to_dict().get("attemptedWords")
|
| 58 |
return attempted_words
|
| 59 |
|
| 60 |
def today_word(self):
|
| 61 |
+
words_col = "words"
|
| 62 |
doc = self.db.collection(words_col).document(self.today())
|
| 63 |
+
return doc.get().get("word")
|
rs_wordle_player/rs_wordle_player.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from a3c.play import get_play_model_path, suggest
|
| 2 |
from wordle_env.wordle import get_env
|
|
|
|
| 3 |
from .firebase_connector import FirebaseConnector
|
| 4 |
from .selenium_player import SeleniumPlayer
|
| 5 |
|
|
@@ -17,7 +18,7 @@ def get_attempts(fb_connector):
|
|
| 17 |
|
| 18 |
def is_game_finished(states):
|
| 19 |
if states:
|
| 20 |
-
return states[-1] ==
|
| 21 |
return False
|
| 22 |
|
| 23 |
|
|
@@ -49,5 +50,5 @@ def play():
|
|
| 49 |
return words, won
|
| 50 |
|
| 51 |
|
| 52 |
-
if __name__ ==
|
| 53 |
print(play())
|
|
|
|
| 1 |
from a3c.play import get_play_model_path, suggest
|
| 2 |
from wordle_env.wordle import get_env
|
| 3 |
+
|
| 4 |
from .firebase_connector import FirebaseConnector
|
| 5 |
from .selenium_player import SeleniumPlayer
|
| 6 |
|
|
|
|
| 18 |
|
| 19 |
def is_game_finished(states):
|
| 20 |
if states:
|
| 21 |
+
return states[-1] == "22222" or len(states) == 6
|
| 22 |
return False
|
| 23 |
|
| 24 |
|
|
|
|
| 50 |
return words, won
|
| 51 |
|
| 52 |
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
print(play())
|
rs_wordle_player/selenium_player.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import time
|
|
|
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
from selenium import webdriver
|
| 5 |
from selenium.common.exceptions import UnexpectedAlertPresentException
|
|
@@ -8,8 +9,7 @@ from selenium.webdriver.common.by import By
|
|
| 8 |
from selenium.webdriver.common.keys import Keys
|
| 9 |
|
| 10 |
|
| 11 |
-
class SeleniumPlayer
|
| 12 |
-
|
| 13 |
def __init__(self):
|
| 14 |
self.wordle_url = self.get_wordle_url()
|
| 15 |
self.driver = self.get_driver()
|
|
@@ -24,22 +24,22 @@ class SeleniumPlayer():
|
|
| 24 |
|
| 25 |
def get_wordle_url(self):
|
| 26 |
load_dotenv()
|
| 27 |
-
return os.getenv(
|
| 28 |
|
| 29 |
def get_credentials(self):
|
| 30 |
load_dotenv()
|
| 31 |
-
username = os.getenv(
|
| 32 |
-
password = os.getenv(
|
| 33 |
return username, password
|
| 34 |
|
| 35 |
def logged_in(self):
|
| 36 |
-
return self.driver.current_url != self.wordle_url +
|
| 37 |
|
| 38 |
def log_in(self):
|
| 39 |
if not self.logged_in():
|
| 40 |
time.sleep(2)
|
| 41 |
-
login_div = self.driver.find_element(By.CLASS_NAME,
|
| 42 |
-
login_btns = login_div.find_elements(By.TAG_NAME,
|
| 43 |
login_btn = login_btns[0]
|
| 44 |
login_btn.click()
|
| 45 |
time.sleep(10)
|
|
@@ -47,32 +47,31 @@ class SeleniumPlayer():
|
|
| 47 |
login_window = self.driver.window_handles[1]
|
| 48 |
self.driver.switch_to.window(login_window)
|
| 49 |
username, password = self.get_credentials()
|
| 50 |
-
element = self.driver.find_element(By.ID,
|
| 51 |
element.send_keys(username)
|
| 52 |
element.send_keys(Keys.ENTER)
|
| 53 |
time.sleep(10)
|
| 54 |
-
element = self.driver.find_element(By.NAME,
|
| 55 |
element.send_keys(password)
|
| 56 |
element.send_keys(Keys.ENTER)
|
| 57 |
self.driver.switch_to.window(wordle_window)
|
| 58 |
time.sleep(5)
|
| 59 |
onboard_div = self.driver.find_element(
|
| 60 |
-
By.CLASS_NAME,
|
| 61 |
-
'onboarding-modal-container'
|
| 62 |
)
|
| 63 |
-
onboard_btn = onboard_div.find_elements(By.TAG_NAME,
|
| 64 |
onboard_btn[-1].click()
|
| 65 |
|
| 66 |
def play_word(self, word):
|
| 67 |
try:
|
| 68 |
-
element = self.driver.find_element(By.TAG_NAME,
|
| 69 |
# simulate typing the letters in the word into the input field
|
| 70 |
element.send_keys(word)
|
| 71 |
# simulate pressing the Enter key
|
| 72 |
element.send_keys(Keys.ENTER)
|
| 73 |
time.sleep(5)
|
| 74 |
except UnexpectedAlertPresentException:
|
| 75 |
-
print(
|
| 76 |
|
| 77 |
def finish(self):
|
| 78 |
self.driver.quit()
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
+
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
from selenium import webdriver
|
| 6 |
from selenium.common.exceptions import UnexpectedAlertPresentException
|
|
|
|
| 9 |
from selenium.webdriver.common.keys import Keys
|
| 10 |
|
| 11 |
|
| 12 |
+
class SeleniumPlayer:
|
|
|
|
| 13 |
def __init__(self):
|
| 14 |
self.wordle_url = self.get_wordle_url()
|
| 15 |
self.driver = self.get_driver()
|
|
|
|
| 24 |
|
| 25 |
def get_wordle_url(self):
|
| 26 |
load_dotenv()
|
| 27 |
+
return os.getenv("RS_WORDLE_URL")
|
| 28 |
|
| 29 |
def get_credentials(self):
|
| 30 |
load_dotenv()
|
| 31 |
+
username = os.getenv("RS_WORDLE_USER")
|
| 32 |
+
password = os.getenv("RS_WORDLE_PASSWORD")
|
| 33 |
return username, password
|
| 34 |
|
| 35 |
def logged_in(self):
|
| 36 |
+
return self.driver.current_url != self.wordle_url + "/login"
|
| 37 |
|
| 38 |
def log_in(self):
|
| 39 |
if not self.logged_in():
|
| 40 |
time.sleep(2)
|
| 41 |
+
login_div = self.driver.find_element(By.CLASS_NAME, "login-button")
|
| 42 |
+
login_btns = login_div.find_elements(By.TAG_NAME, "button")
|
| 43 |
login_btn = login_btns[0]
|
| 44 |
login_btn.click()
|
| 45 |
time.sleep(10)
|
|
|
|
| 47 |
login_window = self.driver.window_handles[1]
|
| 48 |
self.driver.switch_to.window(login_window)
|
| 49 |
username, password = self.get_credentials()
|
| 50 |
+
element = self.driver.find_element(By.ID, "identifierId")
|
| 51 |
element.send_keys(username)
|
| 52 |
element.send_keys(Keys.ENTER)
|
| 53 |
time.sleep(10)
|
| 54 |
+
element = self.driver.find_element(By.NAME, "password")
|
| 55 |
element.send_keys(password)
|
| 56 |
element.send_keys(Keys.ENTER)
|
| 57 |
self.driver.switch_to.window(wordle_window)
|
| 58 |
time.sleep(5)
|
| 59 |
onboard_div = self.driver.find_element(
|
| 60 |
+
By.CLASS_NAME, "onboarding-modal-container"
|
|
|
|
| 61 |
)
|
| 62 |
+
onboard_btn = onboard_div.find_elements(By.TAG_NAME, "button")
|
| 63 |
onboard_btn[-1].click()
|
| 64 |
|
| 65 |
def play_word(self, word):
|
| 66 |
try:
|
| 67 |
+
element = self.driver.find_element(By.TAG_NAME, "html")
|
| 68 |
# simulate typing the letters in the word into the input field
|
| 69 |
element.send_keys(word)
|
| 70 |
# simulate pressing the Enter key
|
| 71 |
element.send_keys(Keys.ENTER)
|
| 72 |
time.sleep(5)
|
| 73 |
except UnexpectedAlertPresentException:
|
| 74 |
+
print("Won game alert on screen")
|
| 75 |
|
| 76 |
def finish(self):
|
| 77 |
self.driver.quit()
|
wordle_env/__init__.py
CHANGED
|
@@ -1,13 +1,9 @@
|
|
| 1 |
-
from gym.envs.registration import (
|
| 2 |
-
registry,
|
| 3 |
-
register,
|
| 4 |
-
make,
|
| 5 |
-
spec,
|
| 6 |
-
load_env_plugins as _load_env_plugins,
|
| 7 |
-
)
|
| 8 |
import os
|
| 9 |
-
from . import wordle
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
register(
|
| 13 |
id="WordleEnv100OneAction-v0",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
|
| 3 |
+
from gym.envs.registration import load_env_plugins as _load_env_plugins
|
| 4 |
+
from gym.envs.registration import make, register, registry, spec
|
| 5 |
+
|
| 6 |
+
from . import wordle
|
| 7 |
|
| 8 |
register(
|
| 9 |
id="WordleEnv100OneAction-v0",
|
wordle_env/const.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
WORDLE_CHARS =
|
| 2 |
WORDLE_N = 5
|
| 3 |
REWARD = 10
|
| 4 |
CHAR_REWARD = 0.1
|
|
|
|
| 1 |
+
WORDLE_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
| 2 |
WORDLE_N = 5
|
| 3 |
REWARD = 10
|
| 4 |
CHAR_REWARD = 0.1
|
wordle_env/state.py
CHANGED
|
@@ -13,11 +13,11 @@ where status has codes
|
|
| 13 |
"""
|
| 14 |
import collections
|
| 15 |
from typing import List, Tuple
|
|
|
|
| 16 |
import numpy as np
|
| 17 |
|
| 18 |
from .const import CHAR_REWARD, WORDLE_CHARS, WORDLE_N
|
| 19 |
|
| 20 |
-
|
| 21 |
WordleState = np.ndarray
|
| 22 |
|
| 23 |
|
|
@@ -27,8 +27,8 @@ def get_nvec(max_turns: int):
|
|
| 27 |
|
| 28 |
def new(max_turns: int) -> WordleState:
|
| 29 |
return np.array(
|
| 30 |
-
[max_turns] + [0, 0, 0] * WORDLE_N * len(WORDLE_CHARS),
|
| 31 |
-
|
| 32 |
|
| 33 |
|
| 34 |
def remaining_steps(state: WordleState) -> int:
|
|
@@ -40,11 +40,7 @@ SOMEWHERE = 1
|
|
| 40 |
YES = 2
|
| 41 |
|
| 42 |
|
| 43 |
-
def update_from_mask(
|
| 44 |
-
state: WordleState,
|
| 45 |
-
word: str,
|
| 46 |
-
mask: List[int]
|
| 47 |
-
) -> WordleState:
|
| 48 |
"""
|
| 49 |
return a copy of state that has been updated to new state
|
| 50 |
|
|
@@ -84,14 +80,14 @@ def update_from_mask(
|
|
| 84 |
# Need to check this first in case there's prior maybe + yes
|
| 85 |
if c in prior_maybe:
|
| 86 |
# Then the maybe could be anywhere except here
|
| 87 |
-
state[offset+3*i:offset+3*i+3] = [1, 0, 0]
|
| 88 |
elif c in prior_yes:
|
| 89 |
# No maybe, definitely a yes,
|
| 90 |
# so it's zero everywhere except the yesses
|
| 91 |
for j in range(WORDLE_N):
|
| 92 |
# Only flip no if previously was maybe
|
| 93 |
-
if state[offset + 3 * j:offset + 3 * j + 3][1] == 1:
|
| 94 |
-
state[offset + 3 * j:offset + 3 * j + 3] = [1, 0, 0]
|
| 95 |
else:
|
| 96 |
# Just straight up no
|
| 97 |
_set_all_no(state, offset)
|
|
@@ -115,7 +111,7 @@ def get_mask(word: str, goal_word: str) -> List[int]:
|
|
| 115 |
mask[i] = 1
|
| 116 |
counts[c] -= 1
|
| 117 |
else:
|
| 118 |
-
for j in range(i+1, len(mask)):
|
| 119 |
if mask[j] == 2:
|
| 120 |
continue
|
| 121 |
mask[j] = 0
|
|
@@ -136,11 +132,7 @@ def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
|
|
| 136 |
return update_from_mask(state, word, mask)
|
| 137 |
|
| 138 |
|
| 139 |
-
def update(
|
| 140 |
-
state: WordleState,
|
| 141 |
-
word: str,
|
| 142 |
-
goal_word: str
|
| 143 |
-
) -> Tuple[WordleState, float]:
|
| 144 |
state = state.copy()
|
| 145 |
reward = 0
|
| 146 |
state[0] -= 1
|
|
@@ -158,8 +150,7 @@ def update(
|
|
| 158 |
cint = ord(c) - ord(WORDLE_CHARS[0])
|
| 159 |
offset = 1 + cint * WORDLE_N * 3
|
| 160 |
if goal_word[i] != c:
|
| 161 |
-
if
|
| 162 |
-
goal_word.count(c) > processed_letters.count(c)):
|
| 163 |
# Char at position i = no,
|
| 164 |
# and in other positions maybe except it had a value before,
|
| 165 |
# other chars stay as they are
|
|
@@ -184,27 +175,27 @@ def _set_if_cero(state, offset, value):
|
|
| 184 |
# but only if it didnt have a value before
|
| 185 |
for char_idx in range(0, WORDLE_N * 3, 3):
|
| 186 |
char_offset = offset + char_idx
|
| 187 |
-
if tuple(state[char_offset: char_offset + 3]) == (0, 0, 0):
|
| 188 |
-
state[char_offset: char_offset + 3] = value
|
| 189 |
|
| 190 |
|
| 191 |
def _set_yes(state, offset, char_int, char_pos):
|
| 192 |
# char at position char_pos = yes,
|
| 193 |
# all other chars at position char_pos == no
|
| 194 |
pos_offset = 3 * char_pos
|
| 195 |
-
state[offset + pos_offset:offset + pos_offset + 3] = [0, 0, 1]
|
| 196 |
for ocint in range(len(WORDLE_CHARS)):
|
| 197 |
if ocint != char_int:
|
| 198 |
oc_offset = 1 + ocint * WORDLE_N * 3
|
| 199 |
yes_index = oc_offset + pos_offset
|
| 200 |
-
state[yes_index:yes_index + 3] = [1, 0, 0]
|
| 201 |
|
| 202 |
|
| 203 |
def _set_no(state, offset, char_pos):
|
| 204 |
# Set offset character = no at char_pos position
|
| 205 |
-
state[offset + 3 * char_pos:offset + 3 * char_pos + 3] = [1, 0, 0]
|
| 206 |
|
| 207 |
|
| 208 |
def _set_all_no(state, offset):
|
| 209 |
# Set offset character = no at all positions
|
| 210 |
-
state[offset:offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
|
|
|
| 13 |
"""
|
| 14 |
import collections
|
| 15 |
from typing import List, Tuple
|
| 16 |
+
|
| 17 |
import numpy as np
|
| 18 |
|
| 19 |
from .const import CHAR_REWARD, WORDLE_CHARS, WORDLE_N
|
| 20 |
|
|
|
|
| 21 |
WordleState = np.ndarray
|
| 22 |
|
| 23 |
|
|
|
|
| 27 |
|
| 28 |
def new(max_turns: int) -> WordleState:
|
| 29 |
return np.array(
|
| 30 |
+
[max_turns] + [0, 0, 0] * WORDLE_N * len(WORDLE_CHARS), dtype=np.int32
|
| 31 |
+
)
|
| 32 |
|
| 33 |
|
| 34 |
def remaining_steps(state: WordleState) -> int:
|
|
|
|
| 40 |
YES = 2
|
| 41 |
|
| 42 |
|
| 43 |
+
def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleState:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
"""
|
| 45 |
return a copy of state that has been updated to new state
|
| 46 |
|
|
|
|
| 80 |
# Need to check this first in case there's prior maybe + yes
|
| 81 |
if c in prior_maybe:
|
| 82 |
# Then the maybe could be anywhere except here
|
| 83 |
+
state[offset + 3 * i : offset + 3 * i + 3] = [1, 0, 0]
|
| 84 |
elif c in prior_yes:
|
| 85 |
# No maybe, definitely a yes,
|
| 86 |
# so it's zero everywhere except the yesses
|
| 87 |
for j in range(WORDLE_N):
|
| 88 |
# Only flip no if previously was maybe
|
| 89 |
+
if state[offset + 3 * j : offset + 3 * j + 3][1] == 1:
|
| 90 |
+
state[offset + 3 * j : offset + 3 * j + 3] = [1, 0, 0]
|
| 91 |
else:
|
| 92 |
# Just straight up no
|
| 93 |
_set_all_no(state, offset)
|
|
|
|
| 111 |
mask[i] = 1
|
| 112 |
counts[c] -= 1
|
| 113 |
else:
|
| 114 |
+
for j in range(i + 1, len(mask)):
|
| 115 |
if mask[j] == 2:
|
| 116 |
continue
|
| 117 |
mask[j] = 0
|
|
|
|
| 132 |
return update_from_mask(state, word, mask)
|
| 133 |
|
| 134 |
|
| 135 |
+
def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState, float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
state = state.copy()
|
| 137 |
reward = 0
|
| 138 |
state[0] -= 1
|
|
|
|
| 150 |
cint = ord(c) - ord(WORDLE_CHARS[0])
|
| 151 |
offset = 1 + cint * WORDLE_N * 3
|
| 152 |
if goal_word[i] != c:
|
| 153 |
+
if c in goal_word and goal_word.count(c) > processed_letters.count(c):
|
|
|
|
| 154 |
# Char at position i = no,
|
| 155 |
# and in other positions maybe except it had a value before,
|
| 156 |
# other chars stay as they are
|
|
|
|
| 175 |
# but only if it didnt have a value before
|
| 176 |
for char_idx in range(0, WORDLE_N * 3, 3):
|
| 177 |
char_offset = offset + char_idx
|
| 178 |
+
if tuple(state[char_offset : char_offset + 3]) == (0, 0, 0):
|
| 179 |
+
state[char_offset : char_offset + 3] = value
|
| 180 |
|
| 181 |
|
| 182 |
def _set_yes(state, offset, char_int, char_pos):
|
| 183 |
# char at position char_pos = yes,
|
| 184 |
# all other chars at position char_pos == no
|
| 185 |
pos_offset = 3 * char_pos
|
| 186 |
+
state[offset + pos_offset : offset + pos_offset + 3] = [0, 0, 1]
|
| 187 |
for ocint in range(len(WORDLE_CHARS)):
|
| 188 |
if ocint != char_int:
|
| 189 |
oc_offset = 1 + ocint * WORDLE_N * 3
|
| 190 |
yes_index = oc_offset + pos_offset
|
| 191 |
+
state[yes_index : yes_index + 3] = [1, 0, 0]
|
| 192 |
|
| 193 |
|
| 194 |
def _set_no(state, offset, char_pos):
|
| 195 |
# Set offset character = no at char_pos position
|
| 196 |
+
state[offset + 3 * char_pos : offset + 3 * char_pos + 3] = [1, 0, 0]
|
| 197 |
|
| 198 |
|
| 199 |
def _set_all_no(state, offset):
|
| 200 |
# Set offset character = no at all positions
|
| 201 |
+
state[offset : offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N
|
wordle_env/test_wordle.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import pytest
|
| 2 |
|
| 3 |
-
from . import wordle
|
| 4 |
-
from . import state
|
| 5 |
|
| 6 |
TESTWORDS = [
|
| 7 |
"APPAA",
|
|
|
|
| 1 |
import pytest
|
| 2 |
|
| 3 |
+
from . import state, wordle
|
|
|
|
| 4 |
|
| 5 |
TESTWORDS = [
|
| 6 |
"APPAA",
|
wordle_env/wordle.py
CHANGED
|
@@ -1,24 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gym
|
| 2 |
from gym import spaces
|
| 3 |
-
from typing import Optional, List
|
| 4 |
|
| 5 |
from . import state
|
| 6 |
-
from .const import
|
| 7 |
-
|
| 8 |
from .words import complete_vocabulary, target_vocabulary
|
| 9 |
|
| 10 |
-
import random
|
| 11 |
-
|
| 12 |
|
| 13 |
def _load_words(
|
| 14 |
-
limit: Optional[int] = None,
|
| 15 |
-
complete: Optional[bool] = False
|
| 16 |
) -> List[str]:
|
| 17 |
words = complete_vocabulary if complete else target_vocabulary
|
| 18 |
return words if not limit else words[:limit]
|
| 19 |
|
| 20 |
|
| 21 |
-
def get_env(env_id=
|
| 22 |
return gym.make(env_id)
|
| 23 |
|
| 24 |
|
|
@@ -42,13 +40,16 @@ class WordleEnvBase(gym.Env):
|
|
| 42 |
Initial state with turn 0, all chars Unvisited
|
| 43 |
"""
|
| 44 |
|
| 45 |
-
def __init__(
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
assert all(
|
| 50 |
len(w) == WORDLE_N for w in words
|
| 51 |
-
), f
|
| 52 |
self.words = words
|
| 53 |
self.max_turns = max_turns
|
| 54 |
self.allowable_words = allowable_words
|
|
@@ -57,8 +58,7 @@ class WordleEnvBase(gym.Env):
|
|
| 57 |
self.allowable_words = len(self.words)
|
| 58 |
|
| 59 |
self.action_space = spaces.Discrete(self.words_as_action_space())
|
| 60 |
-
self.observation_space = spaces.MultiDiscrete(
|
| 61 |
-
state.get_nvec(self.max_turns))
|
| 62 |
|
| 63 |
self.done = True
|
| 64 |
self.goal_word: int = -1
|
|
@@ -79,15 +79,15 @@ class WordleEnvBase(gym.Env):
|
|
| 79 |
word = self.words[action]
|
| 80 |
goal_word = self.words[self.goal_word]
|
| 81 |
# assert word in self.words, f'{word} not in words list'
|
| 82 |
-
self.state, r = self.state_updater(
|
| 83 |
-
|
| 84 |
-
|
| 85 |
|
| 86 |
reward = r
|
| 87 |
if action == self.goal_word:
|
| 88 |
self.done = True
|
| 89 |
# reward = REWARD
|
| 90 |
-
if state.remaining_steps(self.state) == self.max_turns-1:
|
| 91 |
reward = 0 # -10*REWARD # No reward for guessing off the bat
|
| 92 |
else:
|
| 93 |
reward = REWARD
|
|
@@ -100,7 +100,7 @@ class WordleEnvBase(gym.Env):
|
|
| 100 |
def reset(self):
|
| 101 |
self.state = state.new(self.max_turns)
|
| 102 |
self.done = False
|
| 103 |
-
random_word = random.choice(self.words[:self.allowable_words])
|
| 104 |
self.goal_word = self.words.index(random_word)
|
| 105 |
return self.state.copy()
|
| 106 |
|
|
@@ -121,8 +121,7 @@ class WordleEnv100OneAction(WordleEnvBase):
|
|
| 121 |
|
| 122 |
class WordleEnv100WithMask(WordleEnvBase):
|
| 123 |
def __init__(self):
|
| 124 |
-
super().__init__(words=_load_words(100),
|
| 125 |
-
mask_based_state_updates=True)
|
| 126 |
|
| 127 |
|
| 128 |
class WordleEnv100TwoAction(WordleEnvBase):
|
|
@@ -142,8 +141,7 @@ class WordleEnv100FullAction(WordleEnvBase):
|
|
| 142 |
|
| 143 |
class WordleEnv1000WithMask(WordleEnvBase):
|
| 144 |
def __init__(self):
|
| 145 |
-
super().__init__(words=_load_words(1000),
|
| 146 |
-
mask_based_state_updates=True)
|
| 147 |
|
| 148 |
|
| 149 |
class WordleEnv1000FullAction(WordleEnvBase):
|
|
@@ -158,5 +156,6 @@ class WordleEnvFull(WordleEnvBase):
|
|
| 158 |
|
| 159 |
class WordleEnvRealWithMask(WordleEnvBase):
|
| 160 |
def __init__(self):
|
| 161 |
-
super().__init__(
|
| 162 |
-
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
import gym
|
| 5 |
from gym import spaces
|
|
|
|
| 6 |
|
| 7 |
from . import state
|
| 8 |
+
from .const import REWARD, WORDLE_CHARS, WORDLE_N
|
|
|
|
| 9 |
from .words import complete_vocabulary, target_vocabulary
|
| 10 |
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def _load_words(
|
| 13 |
+
limit: Optional[int] = None, complete: Optional[bool] = False
|
|
|
|
| 14 |
) -> List[str]:
|
| 15 |
words = complete_vocabulary if complete else target_vocabulary
|
| 16 |
return words if not limit else words[:limit]
|
| 17 |
|
| 18 |
|
| 19 |
+
def get_env(env_id="WordleEnvFull-v0"):
|
| 20 |
return gym.make(env_id)
|
| 21 |
|
| 22 |
|
|
|
|
| 40 |
Initial state with turn 0, all chars Unvisited
|
| 41 |
"""
|
| 42 |
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
words: List[str],
|
| 46 |
+
max_turns: int = 6,
|
| 47 |
+
allowable_words: Optional[int] = None,
|
| 48 |
+
mask_based_state_updates: bool = False,
|
| 49 |
+
):
|
| 50 |
assert all(
|
| 51 |
len(w) == WORDLE_N for w in words
|
| 52 |
+
), f"Not all words of length {WORDLE_N}, {words}"
|
| 53 |
self.words = words
|
| 54 |
self.max_turns = max_turns
|
| 55 |
self.allowable_words = allowable_words
|
|
|
|
| 58 |
self.allowable_words = len(self.words)
|
| 59 |
|
| 60 |
self.action_space = spaces.Discrete(self.words_as_action_space())
|
| 61 |
+
self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
|
|
|
|
| 62 |
|
| 63 |
self.done = True
|
| 64 |
self.goal_word: int = -1
|
|
|
|
| 79 |
word = self.words[action]
|
| 80 |
goal_word = self.words[self.goal_word]
|
| 81 |
# assert word in self.words, f'{word} not in words list'
|
| 82 |
+
self.state, r = self.state_updater(
|
| 83 |
+
state=self.state, word=word, goal_word=goal_word
|
| 84 |
+
)
|
| 85 |
|
| 86 |
reward = r
|
| 87 |
if action == self.goal_word:
|
| 88 |
self.done = True
|
| 89 |
# reward = REWARD
|
| 90 |
+
if state.remaining_steps(self.state) == self.max_turns - 1:
|
| 91 |
reward = 0 # -10*REWARD # No reward for guessing off the bat
|
| 92 |
else:
|
| 93 |
reward = REWARD
|
|
|
|
| 100 |
def reset(self):
|
| 101 |
self.state = state.new(self.max_turns)
|
| 102 |
self.done = False
|
| 103 |
+
random_word = random.choice(self.words[: self.allowable_words])
|
| 104 |
self.goal_word = self.words.index(random_word)
|
| 105 |
return self.state.copy()
|
| 106 |
|
|
|
|
| 121 |
|
| 122 |
class WordleEnv100WithMask(WordleEnvBase):
|
| 123 |
def __init__(self):
|
| 124 |
+
super().__init__(words=_load_words(100), mask_based_state_updates=True)
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
class WordleEnv100TwoAction(WordleEnvBase):
|
|
|
|
| 141 |
|
| 142 |
class WordleEnv1000WithMask(WordleEnvBase):
|
| 143 |
def __init__(self):
|
| 144 |
+
super().__init__(words=_load_words(1000), mask_based_state_updates=True)
|
|
|
|
| 145 |
|
| 146 |
|
| 147 |
class WordleEnv1000FullAction(WordleEnvBase):
|
|
|
|
| 156 |
|
| 157 |
class WordleEnvRealWithMask(WordleEnvBase):
|
| 158 |
def __init__(self):
|
| 159 |
+
super().__init__(
|
| 160 |
+
words=_load_words(), allowable_words=2315, mask_based_state_updates=True
|
| 161 |
+
)
|
wordle_env/words.py
CHANGED
|
@@ -7,7 +7,7 @@ _COMPLETE_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
|
|
| 7 |
_TARGET_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
|
| 8 |
94f3c0303ba6a7768b47583aff36654d/raw/\
|
| 9 |
d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-La.txt"
|
| 10 |
-
_DOWNLOADS_DIR =
|
| 11 |
_COMPLETE_VOCABULARY_FILENAME = "complete_vocabulary.txt"
|
| 12 |
_TARGET_VOCABULARY_FILENAME = "target_vocabulary.txt"
|
| 13 |
|
|
@@ -24,7 +24,11 @@ def _retrieve_vocabulary(url, filename, dir):
|
|
| 24 |
|
| 25 |
|
| 26 |
target_vocabulary = _retrieve_vocabulary(
|
| 27 |
-
_TARGET_VOCABULARY_URL, _TARGET_VOCABULARY_FILENAME, _DOWNLOADS_DIR
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
_TARGET_VOCABULARY_URL = "https://gist.githubusercontent.com/scholtes/\
|
| 8 |
94f3c0303ba6a7768b47583aff36654d/raw/\
|
| 9 |
d9cddf5e16140df9e14f19c2de76a0ef36fd2748/wordle-La.txt"
|
| 10 |
+
_DOWNLOADS_DIR = "."
|
| 11 |
_COMPLETE_VOCABULARY_FILENAME = "complete_vocabulary.txt"
|
| 12 |
_TARGET_VOCABULARY_FILENAME = "target_vocabulary.txt"
|
| 13 |
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
target_vocabulary = _retrieve_vocabulary(
|
| 27 |
+
_TARGET_VOCABULARY_URL, _TARGET_VOCABULARY_FILENAME, _DOWNLOADS_DIR
|
| 28 |
+
)
|
| 29 |
+
complete_vocabulary = (
|
| 30 |
+
_retrieve_vocabulary(
|
| 31 |
+
_COMPLETE_VOCABULARY_URL, _COMPLETE_VOCABULARY_FILENAME, _DOWNLOADS_DIR
|
| 32 |
+
)
|
| 33 |
+
+ target_vocabulary
|
| 34 |
+
)
|
wordle_game.py
CHANGED
|
@@ -1,30 +1,28 @@
|
|
| 1 |
-
from rich.prompt import Prompt
|
| 2 |
-
from rich.console import Console
|
| 3 |
from random import choice
|
| 4 |
-
from wordle_env.words import target_vocabulary, complete_vocabulary
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
-
WELCOME_MESSAGE = f
|
| 13 |
PLAYER_INSTRUCTIONS = "You may start guessing\n"
|
| 14 |
GUESS_STATEMENT = "\nEnter your guess"
|
| 15 |
ALLOWED_GUESSES = 6
|
| 16 |
|
| 17 |
|
| 18 |
def correct_place(letter):
|
| 19 |
-
return f
|
| 20 |
|
| 21 |
|
| 22 |
def correct_letter(letter):
|
| 23 |
-
return f
|
| 24 |
|
| 25 |
|
| 26 |
def incorrect_letter(letter):
|
| 27 |
-
return f
|
| 28 |
|
| 29 |
|
| 30 |
def check_guess(guess, answer):
|
|
@@ -34,19 +32,20 @@ def check_guess(guess, answer):
|
|
| 34 |
for i, letter in enumerate(guess):
|
| 35 |
if answer[i] == guess[i]:
|
| 36 |
guessed[i] = correct_place(letter)
|
| 37 |
-
wordle_pattern.append(SQUARES[
|
| 38 |
processed_letters.append(letter)
|
| 39 |
for i, letter in enumerate(guess):
|
| 40 |
if answer[i] != guess[i]:
|
| 41 |
-
if
|
| 42 |
-
|
|
|
|
| 43 |
guessed[i] = correct_letter(letter)
|
| 44 |
-
wordle_pattern.append(SQUARES[
|
| 45 |
else:
|
| 46 |
guessed[i] = incorrect_letter(letter)
|
| 47 |
-
wordle_pattern.append(SQUARES[
|
| 48 |
processed_letters.append(letter)
|
| 49 |
-
return
|
| 50 |
|
| 51 |
|
| 52 |
def game(console, chosen_word):
|
|
@@ -57,12 +56,15 @@ def game(console, chosen_word):
|
|
| 57 |
|
| 58 |
while not end_of_game:
|
| 59 |
guess = Prompt.ask(GUESS_STATEMENT).upper()
|
| 60 |
-
while (
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
| 62 |
if guess in already_guessed:
|
| 63 |
console.print("[red]You've already guessed this word!!\n[/]")
|
| 64 |
else:
|
| 65 |
-
console.print(
|
| 66 |
guess = Prompt.ask(GUESS_STATEMENT).upper()
|
| 67 |
already_guessed.append(guess)
|
| 68 |
guessed, pattern = check_guess(guess, chosen_word)
|
|
@@ -74,14 +76,13 @@ def game(console, chosen_word):
|
|
| 74 |
end_of_game = True
|
| 75 |
if len(already_guessed) == ALLOWED_GUESSES and guess != chosen_word:
|
| 76 |
console.print(f"\n[red]WORDLE X/{ALLOWED_GUESSES}[/]")
|
| 77 |
-
console.print(f
|
| 78 |
else:
|
| 79 |
-
console.print(
|
| 80 |
-
f"\n[green]WORDLE {len(already_guessed)}/{ALLOWED_GUESSES}[/]\n")
|
| 81 |
console.print(*full_wordle_pattern, sep="\n")
|
| 82 |
|
| 83 |
|
| 84 |
-
if __name__ ==
|
| 85 |
console = Console()
|
| 86 |
chosen_word = choice(target_vocabulary)
|
| 87 |
console.print(WELCOME_MESSAGE)
|
|
|
|
|
|
|
|
|
|
| 1 |
from random import choice
|
|
|
|
| 2 |
|
| 3 |
+
from rich.console import Console
|
| 4 |
+
from rich.prompt import Prompt
|
| 5 |
+
|
| 6 |
+
from wordle_env.words import complete_vocabulary, target_vocabulary
|
| 7 |
+
|
| 8 |
+
SQUARES = {"correct_place": "🟩", "correct_letter": "🟨", "incorrect_letter": "⬛"}
|
| 9 |
|
| 10 |
+
WELCOME_MESSAGE = f"\n[white on blue] WELCOME TO WORDLE [/]\n"
|
| 11 |
PLAYER_INSTRUCTIONS = "You may start guessing\n"
|
| 12 |
GUESS_STATEMENT = "\nEnter your guess"
|
| 13 |
ALLOWED_GUESSES = 6
|
| 14 |
|
| 15 |
|
| 16 |
def correct_place(letter):
|
| 17 |
+
return f"[black on green]{letter}[/]"
|
| 18 |
|
| 19 |
|
| 20 |
def correct_letter(letter):
|
| 21 |
+
return f"[black on yellow]{letter}[/]"
|
| 22 |
|
| 23 |
|
| 24 |
def incorrect_letter(letter):
|
| 25 |
+
return f"[black on white]{letter}[/]"
|
| 26 |
|
| 27 |
|
| 28 |
def check_guess(guess, answer):
|
|
|
|
| 32 |
for i, letter in enumerate(guess):
|
| 33 |
if answer[i] == guess[i]:
|
| 34 |
guessed[i] = correct_place(letter)
|
| 35 |
+
wordle_pattern.append(SQUARES["correct_place"])
|
| 36 |
processed_letters.append(letter)
|
| 37 |
for i, letter in enumerate(guess):
|
| 38 |
if answer[i] != guess[i]:
|
| 39 |
+
if letter in answer and answer.count(letter) > processed_letters.count(
|
| 40 |
+
letter
|
| 41 |
+
):
|
| 42 |
guessed[i] = correct_letter(letter)
|
| 43 |
+
wordle_pattern.append(SQUARES["correct_letter"])
|
| 44 |
else:
|
| 45 |
guessed[i] = incorrect_letter(letter)
|
| 46 |
+
wordle_pattern.append(SQUARES["incorrect_letter"])
|
| 47 |
processed_letters.append(letter)
|
| 48 |
+
return "".join(guessed), "".join(wordle_pattern)
|
| 49 |
|
| 50 |
|
| 51 |
def game(console, chosen_word):
|
|
|
|
| 56 |
|
| 57 |
while not end_of_game:
|
| 58 |
guess = Prompt.ask(GUESS_STATEMENT).upper()
|
| 59 |
+
while (
|
| 60 |
+
len(guess) != 5
|
| 61 |
+
or guess in already_guessed
|
| 62 |
+
or guess not in complete_vocabulary
|
| 63 |
+
):
|
| 64 |
if guess in already_guessed:
|
| 65 |
console.print("[red]You've already guessed this word!!\n[/]")
|
| 66 |
else:
|
| 67 |
+
console.print("[red]Please enter a valid 5-letter word!!\n[/]")
|
| 68 |
guess = Prompt.ask(GUESS_STATEMENT).upper()
|
| 69 |
already_guessed.append(guess)
|
| 70 |
guessed, pattern = check_guess(guess, chosen_word)
|
|
|
|
| 76 |
end_of_game = True
|
| 77 |
if len(already_guessed) == ALLOWED_GUESSES and guess != chosen_word:
|
| 78 |
console.print(f"\n[red]WORDLE X/{ALLOWED_GUESSES}[/]")
|
| 79 |
+
console.print(f"\n[green]Correct Word: {chosen_word}[/]")
|
| 80 |
else:
|
| 81 |
+
console.print(f"\n[green]WORDLE {len(already_guessed)}/{ALLOWED_GUESSES}[/]\n")
|
|
|
|
| 82 |
console.print(*full_wordle_pattern, sep="\n")
|
| 83 |
|
| 84 |
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
console = Console()
|
| 87 |
chosen_word = choice(target_vocabulary)
|
| 88 |
console.print(WELCOME_MESSAGE)
|