Spaces:
Sleeping
Sleeping
import random | |
from typing import List, Optional | |
import gym | |
from gym import spaces | |
from . import state | |
from .const import REWARD, WORDLE_CHARS, WORDLE_N | |
from .words import complete_vocabulary, target_vocabulary | |
def _load_words( | |
limit: Optional[int] = None, complete: Optional[bool] = False | |
) -> List[str]: | |
words = complete_vocabulary if complete else target_vocabulary | |
return words if not limit else words[:limit] | |
def get_env(env_id="WordleEnvFull-v0"): | |
return gym.make(env_id) | |
class WordleEnvBase(gym.Env): | |
""" | |
Actions: | |
Can play any 5 letter word in vocabulary | |
* 13k for full vocab | |
State space is defined as: | |
* 6 possibilities for turns (WORDLE_TURNS) | |
* For each in VALID_CHARS [A-Z] | |
can be in one of 3^WORDLE_N states: (No, Maybe, Yes) | |
for full game, this is (3^5)^26 | |
Each state has 1 + 5*26 possibilities | |
Reward: | |
Reward is 10 for guessing the right word, | |
-10 for not guessing the right word after 6 guesses. | |
1 from every letter correctly guessed on each try | |
Starting State: | |
Random goal word | |
Initial state with turn 0, all chars Unvisited | |
""" | |
def __init__( | |
self, | |
words: List[str], | |
max_turns: int = 6, | |
allowable_words: Optional[int] = None, | |
mask_based_state_updates: bool = False, | |
): | |
assert all( | |
len(w) == WORDLE_N for w in words | |
), f"Not all words of length {WORDLE_N}, {words}" | |
self.words = words | |
self.max_turns = max_turns | |
self.allowable_words = allowable_words | |
self.mask_based_state_updates = mask_based_state_updates | |
if not self.allowable_words: | |
self.allowable_words = len(self.words) | |
self.action_space = spaces.Discrete(self.words_as_action_space()) | |
self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns)) | |
self.done = True | |
self.goal_word: int = -1 | |
self.state: state.WordleState = None | |
self.state_updater = state.update | |
if self.mask_based_state_updates: | |
self.state_updater = state.update_mask | |
def step(self, action: int): | |
if self.done: | |
raise ValueError( | |
"You are calling 'step()' even though this " | |
"environment has already returned done = True. You " | |
"should always call 'reset()' once you receive 'done = " | |
"True' -- any further steps are undefined behavior." | |
) | |
word = self.words[action] | |
goal_word = self.words[self.goal_word] | |
# assert word in self.words, f'{word} not in words list' | |
self.state, r = self.state_updater( | |
state=self.state, word=word, goal_word=goal_word | |
) | |
reward = r | |
if action == self.goal_word: | |
self.done = True | |
# reward = REWARD | |
if state.remaining_steps(self.state) == self.max_turns - 1: | |
reward = 0 # -10*REWARD # No reward for guessing off the bat | |
else: | |
reward = REWARD | |
elif state.remaining_steps(self.state) == 0: | |
self.done = True | |
reward = -REWARD | |
goal_dict = {"goal_id": self.goal_word} | |
return self.state.copy(), reward, self.done, goal_dict | |
def reset(self): | |
self.state = state.new(self.max_turns) | |
self.done = False | |
random_word = random.choice(self.words[: self.allowable_words]) | |
self.goal_word = self.words.index(random_word) | |
return self.state.copy() | |
def set_goal_word(self, goal_word: str): | |
self.goal_word = self.words.index(goal_word) | |
def set_goal_encoded(self, goal_encoded: int): | |
self.goal_word = goal_encoded | |
def words_as_action_space(self): | |
return len(self.words) | |
class WordleEnv100OneAction(WordleEnvBase): | |
def __init__(self): | |
super().__init__(words=_load_words(100), allowable_words=1) | |
class WordleEnv100WithMask(WordleEnvBase): | |
def __init__(self): | |
super().__init__(words=_load_words(100), mask_based_state_updates=True) | |
class WordleEnv100TwoAction(WordleEnvBase): | |
def __init__(self): | |
super().__init__(words=_load_words(100), allowable_words=2) | |
class WordleEnv100fiftyAction(WordleEnvBase): | |
def __init__(self): | |
super().__init__(words=_load_words(100), allowable_words=50) | |
class WordleEnv100FullAction(WordleEnvBase): | |
def __init__(self): | |
super().__init__(words=_load_words(100), allowable_words=100) | |
class WordleEnv1000WithMask(WordleEnvBase): | |
def __init__(self): | |
super().__init__(words=_load_words(1000), mask_based_state_updates=True) | |
class WordleEnv1000FullAction(WordleEnvBase): | |
def __init__(self): | |
super().__init__(words=_load_words(1000), allowable_words=1000) | |
class WordleEnvFull(WordleEnvBase): | |
def __init__(self): | |
super().__init__(words=_load_words()) | |
class WordleEnvRealWithMask(WordleEnvBase): | |
def __init__(self): | |
super().__init__( | |
words=_load_words(), allowable_words=2315, mask_based_state_updates=True | |
) | |