{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyarrow.parquet as pq \n",
    "from transformers import AutoTokenizer, PreTrainedTokenizerFast\n",
    "from rich import progress"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 训练集数据训练tokenizer,小于16G内存的机器容易OOM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pq_file = '../data/my_dataset.shuffle.parquet'\n",
    "pf = pq.read_table(pq_file)\n",
    "\n",
    "def get_training_corpus():\n",
    "    buffer = []\n",
    "    for prompt, response in progress.track(zip(pf['prompt'], pf['response']), total=pf.num_rows):\n",
    "\n",
    "        buffer.append(\n",
    "            f\"{prompt.as_py()}\\n{response.as_py()}\"\n",
    "        )\n",
    "\n",
    "        if len(buffer) >= 1000:\n",
    "             yield buffer\n",
    "             buffer = []\n",
    "\n",
    "    if buffer: yield buffer\n",
    "iter_training_corpus = get_training_corpus()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## step 1: 加载T5模型自带的tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "old_tokenizer = AutoTokenizer.from_pretrained('t5-base')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  step 2: 加载Wiki中文语料,1.6GB\n",
    "备注: 全量预训练语料文本大小约7GB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lines = []\n",
    "with open('../data/raw_data/wiki.simple.txt', 'r', encoding='utf-8') as f:\n",
    "    lines = f.readlines()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(lines)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## step 3 定义一个语料的迭代生成器\n",
    "一个文本块(段落)的最小长度为2048,迭代一次返回1000个文本块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_training_corpus():\n",
    "    buffer = []\n",
    "    i = 0 \n",
    "    txt = []\n",
    "    len_cnt = 0\n",
    "    for line in progress.track(lines):\n",
    "        \n",
    "        len_cnt += len(line)\n",
    "        txt.append(line)\n",
    "        if len_cnt >= 2048:\n",
    "            buffer.append(\n",
    "                ''.join(txt)\n",
    "            )\n",
    "            txt = []\n",
    "            len_cnt = 0\n",
    "      \n",
    "        if len(buffer) >= 1000:\n",
    "             yield buffer\n",
    "             buffer = []\n",
    "             i += 1\n",
    "\n",
    "    #  yield last  buffer\n",
    "    if len(buffer) > 0:\n",
    "        yield buffer\n",
    "\n",
    "iter_training_corpus = get_training_corpus()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in get_training_corpus():\n",
    "    print(len(i))\n",
    "    print([len(t) for t in i][0:20])\n",
    "    break\n",
    "## 1000\n",
    "## [2104, 2053, 2176, 2224, 2172, 2068, 2054, 2258, 2058, 2085, 2142, 2274, 2184, 2246, 2144, 2223, 2075, 2058, 2164,  2178]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## step 4: 训练tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = old_tokenizer.train_new_from_iterator(iter_training_corpus, vocab_size=40960)\n",
    "\n",
    "# cpu计算密集型任务 13600K大概需要1个小时,最大内存占用20G"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## step 5: 保存训练好的tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.save_pretrained('../model_save/my_tokenizer_wiki')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 补充内容: 自定义模型、及特殊字符训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import PreTrainedTokenizerFast\n",
    "from tokenizers.pre_tokenizers import Whitespace, Punctuation, Digits, ByteLevel, Metaspace\n",
    "from tokenizers.normalizers import NFKC\n",
    "from tokenizers import Tokenizer, decoders\n",
    "from tokenizers.models import BPE\n",
    "import tokenizers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 字符级别的 BPE toeknizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BPE(unk_token=\"[UNK]\")\n",
    "tokenizer = Tokenizer(model)\n",
    "\n",
    "# 用兼容等价分解合并对utf编码进行等价组合,比如全角A转换为半角A\n",
    "tokenizer.normalizer = tokenizers.normalizers.Sequence([NFKC()])\n",
    "\n",
    "# 标点符号,数字,及Metaspace预分割(否则decode出来没有空格)\n",
    "tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Sequence(\n",
    "    [Punctuation(), Digits(individual_digits=True), Metaspace()])\n",
    "\n",
    "tokenizer.add_special_tokens([\"[PAD]\",\"[EOS]\",\"[SEP]\",\"[BOS]\", \"[CLS]\", \"[MASK]\", \"[UNK]\"])\n",
    "tokenizer.decoder = decoders.Metaspace()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 字节级别(ByteLevel) BPE toeknizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# byte BPE n不需要unk_token\n",
    "model = BPE() \n",
    "tokenizer = Tokenizer(model)\n",
    "\n",
    "tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=False)\n",
    "\n",
    "tokenizer.add_special_tokens([\"[PAD]\",\"[EOS]\",\"[SEP]\",\"[BOS]\", \"[CLS]\", \"[MASK]\", \"[UNK]\"])\n",
    "tokenizer.decoder = decoders.ByteLevel(add_prefix_space=True, use_regex=True)\n",
    "tokenizer.post_processor = tokenizers.processors.ByteLevel(trim_offsets=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PreTrainedTokenizerFast类无法从 tokenizer 对象推断出哪个标记是掩码标记、[CLS] 标记等,需要手动指定\n",
    "# 上文的通过from_pretrained('t5-base')定义的old_tokenizer,自带了特殊标记,不用指定\n",
    "# 到这一步和上文 step 4 一致了\n",
    "old_tokenizer = PreTrainedTokenizerFast(\n",
    "    tokenizer_object=tokenizer,\n",
    "    unk_token=\"[UNK]\",\n",
    "    pad_token=\"[PAD]\",\n",
    "    cls_token=\"[CLS]\",\n",
    "    sep_token=\"[SEP]\",\n",
    "    mask_token=\"[MASK]\",\n",
    "    bos_token='[BOS]',\n",
    "    eos_token='[EOS]',                  \n",
    ")\n",
    "tokenizer = old_tokenizer.train_new_from_iterator(iter_training_corpus, vocab_size=40960)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add \\t \\n if char level tokenizer\n",
    "# if '\\t' not in tokenizer.vcoab:\n",
    "#     tokenizer.add_tokens(['\\t'])\n",
    "# if '\\n' not in tokenizer.vcoab:\n",
    "#     tokenizer.add_tokens(['\\n'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.save_pretrained('../model_save/my_tokenizer_wiki')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "txt = '这是一段中英混输的句子, (chinese and English, here are words.)'\n",
    "# toeknize\n",
    "tokens = tokenizer.tokenize(txt)\n",
    "print(tokens)\n",
    "# 字级别输出:\n",
    "# ['▁这是', '一段', '中英', '混', '输', '的', '句子', '▁,', '▁(', '▁ch', 'inese', '▁and', '▁Eng', 'lish', '▁,', '▁h', 'ere', '▁', 'are', '▁w', 'ord', 's', '▁.', '▁)']\n",
    "\n",
    "# Byte级别输出\n",
    "# ['Ġè¿Ļæĺ¯', 'ä¸Ģ段', 'ä¸Ńèĭ±', 'æ··', 'è¾ĵ', 'çļĦ', 'åı¥åŃIJ', 'Ġ,', 'Ġ(', 'Ġch', 'inese', 'Ġand', 'ĠEng', 'lish', 'Ġ,', 'Ġh', 'ere', 'Ġare', 'Ġw', 'ord', 's', 'Ġ.', 'Ġ)']\n",
    "\n",
    "# decode\n",
    "ids = tokenizer.encode(txt)\n",
    "tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}