File size: 3,680 Bytes
6e43012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_nSGBgG98qRC"
},
"outputs": [],
"source": [
"# Trainer\n",
"import torch\n",
"from tqdm import tqdm\n",
"\n",
"iterator = tqdm(dataloader, desc=\"Training\", postfix={\"train_loss\":0.0})\n",
"\n",
"for item in iterator:\n",
" item = tokenizer.bos_token + \" \" + item[0] + \" \" + tokenizer.eos_token\n",
" encoded_inp = tokenizer(item, return_tensors='pt').input_ids.to(\"cuda\")\n",
" logits = mamba_model(encoded_inp)\n",
"\n",
" labels = encoded_inp.to(logits.device)\n",
" shift_logits = logits[:, :-1, :].contiguous()\n",
" labels = labels[:, 1:].contiguous()\n",
" loss_fct = torch.nn.CrossEntropyLoss()\n",
" loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))\n",
"\n",
" optimizer.zero_grad(set_to_none=True)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # moving data's from gpu to cpu\n",
" loss = loss.detach().cpu().numpy()\n",
" logits = logits.detach().cpu().numpy()\n",
" labels = labels.detach().cpu().numpy()\n",
" encoded_inp = encoded_inp.detach().cpu().numpy()\n",
" shift_logits = shift_logits.detach().cpu().numpy()\n",
"\n",
" iterator.set_postfix({\"train_loss\": loss.item()}, refresh=False)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "feaR0XKtOGug"
},
"outputs": [],
"source": [
"# Inference\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"def generate(model,\n",
" tokenizer,\n",
" prompt: str,\n",
" n_tokens_to_gen: int = 200,\n",
" sample: bool = True,\n",
" top_k: int = 40):\n",
" model.eval()\n",
"\n",
" input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(\"cuda\")\n",
"\n",
" for token_n in range(n_tokens_to_gen):\n",
" with torch.no_grad():\n",
" indices_to_input = input_ids\n",
" next_token_logits = mamba_model(indices_to_input)[:, -1]\n",
"\n",
" probs = F.softmax(next_token_logits, dim=-1)\n",
" (batch, vocab_size) = probs.shape\n",
"\n",
" if top_k is not None:\n",
" (values, indices) = torch.topk(probs, k=top_k)\n",
" probs[probs < values[:, -1, None]] = 0\n",
" probs = probs / probs.sum(axis=1, keepdims=True)\n",
"\n",
" if sample:\n",
" next_indices = torch.multinomial(probs, num_samples=1)\n",
" else:\n",
" next_indices = torch.argmax(probs, dim=-1)[:, None]\n",
"\n",
" input_ids = torch.cat([input_ids, next_indices], dim=1)\n",
"\n",
" output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n",
"\n",
" return output_completions"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
} |