santit96's picture
Fix code style with black and isort
c412087
import random
from typing import List, Optional
import gym
from gym import spaces
from . import state
from .const import REWARD, WORDLE_CHARS, WORDLE_N
from .words import complete_vocabulary, target_vocabulary
def _load_words(
limit: Optional[int] = None, complete: Optional[bool] = False
) -> List[str]:
words = complete_vocabulary if complete else target_vocabulary
return words if not limit else words[:limit]
def get_env(env_id="WordleEnvFull-v0"):
return gym.make(env_id)
class WordleEnvBase(gym.Env):
"""
Actions:
Can play any 5 letter word in vocabulary
* 13k for full vocab
State space is defined as:
* 6 possibilities for turns (WORDLE_TURNS)
* For each in VALID_CHARS [A-Z]
can be in one of 3^WORDLE_N states: (No, Maybe, Yes)
for full game, this is (3^5)^26
Each state has 1 + 5*26 possibilities
Reward:
Reward is 10 for guessing the right word,
-10 for not guessing the right word after 6 guesses.
1 from every letter correctly guessed on each try
Starting State:
Random goal word
Initial state with turn 0, all chars Unvisited
"""
def __init__(
self,
words: List[str],
max_turns: int = 6,
allowable_words: Optional[int] = None,
mask_based_state_updates: bool = False,
):
assert all(
len(w) == WORDLE_N for w in words
), f"Not all words of length {WORDLE_N}, {words}"
self.words = words
self.max_turns = max_turns
self.allowable_words = allowable_words
self.mask_based_state_updates = mask_based_state_updates
if not self.allowable_words:
self.allowable_words = len(self.words)
self.action_space = spaces.Discrete(self.words_as_action_space())
self.observation_space = spaces.MultiDiscrete(state.get_nvec(self.max_turns))
self.done = True
self.goal_word: int = -1
self.state: state.WordleState = None
self.state_updater = state.update
if self.mask_based_state_updates:
self.state_updater = state.update_mask
def step(self, action: int):
if self.done:
raise ValueError(
"You are calling 'step()' even though this "
"environment has already returned done = True. You "
"should always call 'reset()' once you receive 'done = "
"True' -- any further steps are undefined behavior."
)
word = self.words[action]
goal_word = self.words[self.goal_word]
# assert word in self.words, f'{word} not in words list'
self.state, r = self.state_updater(
state=self.state, word=word, goal_word=goal_word
)
reward = r
if action == self.goal_word:
self.done = True
# reward = REWARD
if state.remaining_steps(self.state) == self.max_turns - 1:
reward = 0 # -10*REWARD # No reward for guessing off the bat
else:
reward = REWARD
elif state.remaining_steps(self.state) == 0:
self.done = True
reward = -REWARD
goal_dict = {"goal_id": self.goal_word}
return self.state.copy(), reward, self.done, goal_dict
def reset(self):
self.state = state.new(self.max_turns)
self.done = False
random_word = random.choice(self.words[: self.allowable_words])
self.goal_word = self.words.index(random_word)
return self.state.copy()
def set_goal_word(self, goal_word: str):
self.goal_word = self.words.index(goal_word)
def set_goal_encoded(self, goal_encoded: int):
self.goal_word = goal_encoded
def words_as_action_space(self):
return len(self.words)
class WordleEnv100OneAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100), allowable_words=1)
class WordleEnv100WithMask(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100), mask_based_state_updates=True)
class WordleEnv100TwoAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100), allowable_words=2)
class WordleEnv100fiftyAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100), allowable_words=50)
class WordleEnv100FullAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(100), allowable_words=100)
class WordleEnv1000WithMask(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(1000), mask_based_state_updates=True)
class WordleEnv1000FullAction(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words(1000), allowable_words=1000)
class WordleEnvFull(WordleEnvBase):
def __init__(self):
super().__init__(words=_load_words())
class WordleEnvRealWithMask(WordleEnvBase):
def __init__(self):
super().__init__(
words=_load_words(), allowable_words=2315, mask_based_state_updates=True
)