Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,269 Bytes
4c346eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
{
"cells": [
{
"cell_type": "markdown",
"id": "072120f9",
"metadata": {},
"source": [
"If you would like to modify a base model to add our custom reasoning tokens,\n",
"here's how to do it.\n",
"\n",
"Firstly, please install the `add-tokens` extra via\n",
"`pip install ether0[add-tokens]` for the `transformers` package.\n",
"\n",
"Then, configure the following inputs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a2fb6296",
"metadata": {},
"outputs": [],
"source": [
"# Model name/revisions for Hugging Face Hub\n",
"input_model_name = \"mistralai/Mistral-Small-24B-Instruct-2501\"\n",
"input_model_revision: str | None = None\n",
"output_model_name = \"FILL ME IN\"\n",
"output_model_revision: str | None = None\n",
"output_model_is_private = True\n",
"tokenizer_only = False # Set True to only update the tokenizer\n",
"push_to_hf = False # Set True to push to Hugging Face Hub\n",
"\n",
"# Chat template file that uses the new tokens\n",
"chat_template_path = \"updated_mistral_chat_template.jinja\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99927d80",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8e15d3fb5e864e1286cf94fc588e504d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/10 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n",
"The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"from ether0.model_prompts import ANSWER_END, ANSWER_START, THINK_END, THINK_START\n",
"\n",
"REASONING_TOKENS_TO_ADD = [\n",
" THINK_START,\n",
" THINK_END,\n",
" ANSWER_START,\n",
" ANSWER_END,\n",
"]\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\n",
" input_model_name, revision=input_model_revision\n",
")\n",
"# NOTE: reasoning tokens are normal (not special) tokens so they aren't\n",
"# removed when passing skip_special_tokens=True to a tokenizer\n",
"tokenizer.add_tokens(REASONING_TOKENS_TO_ADD)\n",
"tokenizer.chat_template = Path(chat_template_path).read_text(encoding=\"utf-8\")\n",
"if push_to_hf:\n",
" tokenizer.push_to_hub(\n",
" output_model_name,\n",
" revision=output_model_revision,\n",
" private=output_model_is_private,\n",
" )\n",
"\n",
"if not tokenizer_only:\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" input_model_name, revision=input_model_revision\n",
" )\n",
" # SEE: https://www.thonking.ai/p/what-shapes-do-matrix-multiplications\n",
" model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)\n",
" if push_to_hf:\n",
" model.push_to_hub(\n",
" output_model_name,\n",
" revision=output_model_revision,\n",
" private=output_model_is_private,\n",
" )"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|