LLaMA-8x265M-MoE

💻 Code

👋 Very nice to meet you here~

❤️ This repo contains the model LLaMA-8x265M-MoE(970M totally), which activates 2 out of 8 experts (332M parameters). This model is trained from scratch with FP32 precision. We firstly train the model through wikipedia dataset with 1 epoch and then through 10% of C4 dataset (10 data shards among 1024 data shards) with 1 epoch. This is NOT fine-tuned by instruction pairs, so it may not be good enough to act like a chatbot.

📢 This series also includes a dense version (without MoE structure), see 🤗this repo.

1. 🚀QuickStart

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_dir = "JuncaiL/llama-8x265m-moe"
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
model.eval()
model.to("cuda:0")

input_text = "Beijing is a famous city"
inputs = tokenizer(input_text, return_tensors="pt",return_token_type_ids=False)
inputs = inputs.to("cuda:0")

pred = model.generate(**inputs, max_length=50, temperature=0.0)
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
# Beijing is a famous city in China. It is the capital of the Beijing Province and the largest city in China. It is also the home of the world’s largest city, Beijing.
#The city is the

2. 📑Checkpoint Details and Evaluation

Model Parameter

Model #Experts #Activated Experts #Params # Activated Params Flops(T) per sample (se q=2048) Model Weights
265M - - 265M 265M 0.48 🤗 llama-265m
8 $\times$ 265M MoE 8 2 970M 332M 0.76 🤗 llama-8x265m-moe
llama-7b - - 7B 7B 25.29

Model Evaluation

We use the "Average number of tokens verified" $N$ ( see reference link ) as the metric to evaluate these models. This metric demonstrates that giving the same input to the small speculative model and llama-7b, counting from the first predicted tokens, how many successive tokens in the output sentence of the small speculative model are the same as the output sentence of the llama-7b.

  • Average number of tokens verified
Dataset 8 $\times$ 265M MoE GPT without MoE
tatsu-lab/alpaca 3.2362 3.0334
alespalla/chatbot_instruction_prompts 3.2031 3.0823
web_questions 2.7201 2.5541
MohamedRashad/ChatGPT-prompts 3.0954 2.9768

Supposed that the small speculative model can have a hit rate $p$ for the next token when giving the same input. Then we have

1p+2p2+3p3+...=N 1p + 2p^2 + 3p^3 + ... = N

We can get the hit rate as follow.

p=1+11+4N2N p = 1 + \frac{1-\sqrt{1+4N}}{2N}

  • Hit Rate
Dataset 8 $\times$ 265M MoE GPT without MoE
tatsu-lab/alpaca 0.578 0.567
alespalla/chatbot_instruction_prompts 0.576 0.570
web_questions 0.550 0.540
MohamedRashad/ChatGPT-prompts 0.571 0.565

3. 🚧Limitation and Future Plans

For the MoE model, we only show the accuracy of how this small speculative model approximates the performance of llama-7b. In practice, to achieve physically low latency, the implementation of our MoE needs to be improved. In this version, we calculate the result of MoE expert by expert (sequentially) , and we need to fuse the calculation of these experts.

Acknowledgment

  1. My implementation of MoE structure is based on the repo https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8
  2. My inspiration for Speculative Inference comes from the paper "SpecInfer: Accelerating Generative Large Language Model Serving with Tree-based Speculative Inference and Verification" (link) . I am very appreciative of the help and suggestions from the SpecInfer group. ❤️

Citation

@misc{specmoe-2024,
  title={SpecMoE: Building A Speculative MoE Model To Accelerate Inference},
  author={Juncai Liu},
  year={2024},
  month={March},
  url={https://github.com/JuncaiL/SpecMoE/}
}

Contact

If you have any interest or question about this project, please feel free to contact me.

[email protected] (before June 30, 2024) or [email protected] (After June 30, 2024)

Downloads last month
33
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The HF Inference API does not support model that require custom code execution.

Datasets used to train JuncaiL/llama-8x265m-moe