Spaces:
Running
on
Zero
Running
on
Zero
| from typing import List, Optional | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from transformers import ( | |
| AutoModelWithLMHead, | |
| AutoTokenizer, | |
| GPT2Model, | |
| GPT2Tokenizer, | |
| LogitsProcessorList, | |
| PreTrainedModel, | |
| PreTrainedTokenizer, | |
| TemperatureLogitsWarper, | |
| TopKLogitsWarper, | |
| ) | |
| from mario_gpt.prompter import Prompter | |
| PRETRAINED_MODEL_PATH = "shyamsn97/Mario-GPT2-700-context-length" | |
| class MarioLM: | |
| def __init__( | |
| self, | |
| lm: Optional[PreTrainedModel] = None, | |
| tokenizer: Optional[PreTrainedTokenizer] = None, | |
| context_len: int = 700, | |
| prompter: Optional[Prompter] = None, | |
| ): | |
| self.context_len = context_len | |
| self.lm = lm | |
| if lm is None: | |
| self.lm = self.load_pretrained_lm() | |
| self.tokenizer = tokenizer | |
| if tokenizer is None: | |
| self.tokenizer = self.load_pretrained_tokenizer() | |
| self.prompter = prompter | |
| if prompter is None: | |
| self.prompter = Prompter(self.tokenizer) | |
| def device(self): | |
| return self.lm.device | |
| def to(self, device: torch.device): | |
| self.lm = self.lm.to(device) | |
| return self | |
| def load_pretrained_lm(self) -> GPT2Model: | |
| print(f"Using {PRETRAINED_MODEL_PATH} model") | |
| return AutoModelWithLMHead.from_pretrained(PRETRAINED_MODEL_PATH) | |
| def load_pretrained_tokenizer(self) -> GPT2Tokenizer: | |
| print(f"Using {PRETRAINED_MODEL_PATH} tokenizer") | |
| return AutoTokenizer.from_pretrained(PRETRAINED_MODEL_PATH) | |
| def sample_step( | |
| self, | |
| seed: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| temperature: float = 2.0, | |
| ): | |
| lm = self.lm | |
| logits_processor = LogitsProcessorList() | |
| logits_warper = LogitsProcessorList( | |
| [ | |
| TopKLogitsWarper(16), # number of characters | |
| TemperatureLogitsWarper(temperature), | |
| ] | |
| ) | |
| with torch.no_grad(): | |
| attention_mask = torch.ones_like(seed).to(seed.device) | |
| input_ids = seed | |
| out = lm( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| token_type_ids=None, | |
| ) | |
| logits = out.logits.detach() | |
| if len(logits.shape) == 2: | |
| logits = logits.view(1, 1, -1) | |
| next_token_logits = logits[:, -1, :] | |
| next_token_scores = logits_processor(input_ids, next_token_logits) | |
| next_token_scores = logits_warper(input_ids, next_token_scores) | |
| probs = torch.nn.functional.softmax(next_token_scores, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| return next_tokens, encoder_hidden_states | |
| def sample( | |
| self, | |
| seed: Optional[torch.Tensor] = None, | |
| prompts: Optional[List[str]] = None, | |
| num_steps: int = 1, | |
| temperature: float = 2.0, | |
| encoder_hidden_states: torch.Tensor = None, | |
| use_tqdm: bool = False, | |
| ): | |
| context_len = self.context_len - 28 | |
| self.lm.eval() | |
| with torch.no_grad(): | |
| if seed is None: | |
| seed = self.tokenizer("X", return_tensors="pt").input_ids.view(1, 1) | |
| out = seed.to(self.device) | |
| if encoder_hidden_states is None: | |
| if prompts is not None: | |
| encoder_hidden_states = torch.stack( | |
| [self.prompter.output_hidden(prompt) for prompt in prompts] | |
| ) | |
| else: | |
| encoder_hidden_states = torch.stack( | |
| [ | |
| self.prompter(sample_prompt=True)[1] | |
| for _ in range(seed.shape[0]) | |
| ] | |
| ) | |
| encoder_hidden_states = encoder_hidden_states.to( | |
| self.device | |
| ) # b x 1 x hidden_dim | |
| encoder_hidden_states = encoder_hidden_states.view(seed.shape[0], 1, -1) | |
| if not use_tqdm: | |
| bar = np.arange(num_steps) | |
| else: | |
| bar = tqdm(np.arange(num_steps)) | |
| with torch.no_grad(): | |
| for i in bar: | |
| inp = out * 1 | |
| if len(out.shape) > 0 and out.shape[-1] > context_len: | |
| diff = inp.shape[-1] % 14 # height of mario level | |
| ctx = context_len + diff | |
| inp = inp[:, -ctx:] * 1 | |
| next_tokens, encoder_hidden_states = self.sample_step( | |
| inp, | |
| encoder_hidden_states=encoder_hidden_states, | |
| temperature=temperature, | |
| ) | |
| out = torch.cat([out, next_tokens.unsqueeze(-1)], dim=-1) | |
| if use_tqdm: | |
| bar.set_description( | |
| f"shape: {inp.shape}, {out.shape} first: {inp[0][0]}, last: {out[0][-1]}" | |
| ) | |
| if use_tqdm: | |
| bar.close() | |
| self.lm.train() | |
| return out | |