Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the GNU General Public License version 3. | |
| from typing import List | |
| import torch | |
| from llama.tokenizer import Tokenizer | |
| from llama.model import Transformer | |
| class LLaMA: | |
| def __init__(self, model: Transformer, tokenizer: Tokenizer, vision_model = None): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.vision_model = vision_model | |
| def generate( | |
| self, | |
| prompts: List[str], | |
| imgs = None, | |
| max_gen_len: int = 512, | |
| temperature: float = 0.8, | |
| top_p: float = 0.95, | |
| ) -> List[str]: | |
| bsz = len(prompts) | |
| params = self.model.params | |
| assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) | |
| mode = 'instruct' | |
| vision_tokens = None | |
| if imgs is not None and self.vision_model is not None: | |
| vision_tokens = self.vision_model(imgs) | |
| mode = 'caption' | |
| prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] | |
| min_prompt_size = min([len(t) for t in prompt_tokens]) | |
| max_prompt_size = max([len(t) for t in prompt_tokens]) | |
| total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) | |
| tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() | |
| for k, t in enumerate(prompt_tokens): | |
| tokens[k, : len(t)] = torch.tensor(t).long() | |
| input_text_mask = tokens != self.tokenizer.pad_id | |
| start_pos = min_prompt_size | |
| prev_pos = 0 | |
| for cur_pos in range(start_pos, total_len): | |
| logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, vision_tokens, mode) | |
| if temperature > 0: | |
| probs = torch.softmax(logits / temperature, dim=-1) | |
| next_token = sample_top_p(probs, top_p) | |
| else: | |
| next_token = torch.argmax(logits, dim=-1) | |
| next_token = next_token.reshape(-1) | |
| # only replace token if prompt has already been generated | |
| next_token = torch.where( | |
| input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | |
| ) | |
| tokens[:, cur_pos] = next_token | |
| prev_pos = cur_pos | |
| decoded = [] | |
| for i, t in enumerate(tokens.tolist()): | |
| # cut to max gen len | |
| t = t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len] | |
| # cut to eos tok if any | |
| try: | |
| t = t[: t.index(self.tokenizer.eos_id)] | |
| except ValueError: | |
| pass | |
| decoded.append(self.tokenizer.decode(t)) | |
| return decoded | |
| def sample_top_p(probs, p): | |
| probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
| probs_sum = torch.cumsum(probs_sort, dim=-1) | |
| mask = probs_sum - probs_sort > p | |
| probs_sort[mask] = 0.0 | |
| probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
| next_token = torch.multinomial(probs_sort, num_samples=1) | |
| next_token = torch.gather(probs_idx, -1, next_token) | |
| return next_token | |
