File size: 4,269 Bytes
4c346eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "072120f9",
   "metadata": {},
   "source": [
    "If you would like to modify a base model to add our custom reasoning tokens,\n",
    "here's how to do it.\n",
    "\n",
    "Firstly, please install the `add-tokens` extra via\n",
    "`pip install ether0[add-tokens]` for the `transformers` package.\n",
    "\n",
    "Then, configure the following inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2fb6296",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model name/revisions for Hugging Face Hub\n",
    "input_model_name = \"mistralai/Mistral-Small-24B-Instruct-2501\"\n",
    "input_model_revision: str | None = None\n",
    "output_model_name = \"FILL ME IN\"\n",
    "output_model_revision: str | None = None\n",
    "output_model_is_private = True\n",
    "tokenizer_only = False  # Set True to only update the tokenizer\n",
    "push_to_hf = False  # Set True to push to Hugging Face Hub\n",
    "\n",
    "# Chat template file that uses the new tokens\n",
    "chat_template_path = \"updated_mistral_chat_template.jinja\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99927d80",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e15d3fb5e864e1286cf94fc588e504d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n",
      "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "from ether0.model_prompts import ANSWER_END, ANSWER_START, THINK_END, THINK_START\n",
    "\n",
    "REASONING_TOKENS_TO_ADD = [\n",
    "    THINK_START,\n",
    "    THINK_END,\n",
    "    ANSWER_START,\n",
    "    ANSWER_END,\n",
    "]\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    input_model_name, revision=input_model_revision\n",
    ")\n",
    "# NOTE: reasoning tokens are normal (not special) tokens so they aren't\n",
    "# removed when passing skip_special_tokens=True to a tokenizer\n",
    "tokenizer.add_tokens(REASONING_TOKENS_TO_ADD)\n",
    "tokenizer.chat_template = Path(chat_template_path).read_text(encoding=\"utf-8\")\n",
    "if push_to_hf:\n",
    "    tokenizer.push_to_hub(\n",
    "        output_model_name,\n",
    "        revision=output_model_revision,\n",
    "        private=output_model_is_private,\n",
    "    )\n",
    "\n",
    "if not tokenizer_only:\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        input_model_name, revision=input_model_revision\n",
    "    )\n",
    "    # SEE: https://www.thonking.ai/p/what-shapes-do-matrix-multiplications\n",
    "    model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)\n",
    "    if push_to_hf:\n",
    "        model.push_to_hub(\n",
    "            output_model_name,\n",
    "            revision=output_model_revision,\n",
    "            private=output_model_is_private,\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}