|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- jingyaogong/minimind_dataset |
|
|
language: |
|
|
- zh |
|
|
base_model: |
|
|
- HighCWu/Embformer-MiniMind-RLHF-0.1B |
|
|
pipeline_tag: text-generation |
|
|
library_name: transformers |
|
|
--- |
|
|
|
|
|
# Embformer-MiniMind-R1-0.1B |
|
|
|
|
|
A 0.1B distilled reasoning model of the reasearch note [Embformer: An Embedding-Weight-Only Transformer Architecture](https://doi.org/10.5281/zenodo.15736957), which trained on [jingyaogong/minimind_dataset](https://huggingface.co/datasets/jingyaogong/minimind_dataset) with 512 sequence length. |
|
|
|
|
|
|
|
|
Run commands in the terminal: |
|
|
```sh |
|
|
pip install "transformers @ git+https://github.com/huggingface/transformers.git@cb0f604" |
|
|
``` |
|
|
|
|
|
The following contains a code snippet illustrating how to use the model generate content based on given inputs. |
|
|
|
|
|
```python |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
model_name = "HighCWu/Embformer-MiniMind-R1-0.1B" |
|
|
|
|
|
# load the tokenizer and the model |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
trust_remote_code=True, |
|
|
cache_dir=".cache" |
|
|
) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
cache_dir=".cache" |
|
|
) |
|
|
|
|
|
# prepare the model input |
|
|
prompt = "请为我讲解“大语言模型”这个概念。" |
|
|
messages = [ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
text = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
|
|
# conduct text completion |
|
|
generated_ids = model.generate( |
|
|
input_ids=model_inputs['input_ids'], |
|
|
attention_mask=model_inputs['attention_mask'], |
|
|
max_new_tokens=8192 |
|
|
) |
|
|
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() |
|
|
|
|
|
print(tokenizer.decode(output_ids, skip_special_tokens=True)) |
|
|
``` |
|
|
|