|
--- |
|
license: mit |
|
datasets: |
|
- audichandra/bitext_customer_support_llm_dataset_indonesian |
|
language: |
|
- id |
|
--- |
|
|
|
 |
|
|
|
## Quick Intro |
|
|
|
Gajah-7B is the 1st iteration of Indonesian AI chatbot with [Merak-7B](https://huggingface.co/Ichsan2895/Merak-7B-v4) as the base model that is trained with PEFT Qlora method and Indonesian version of [bitext](https://huggingface.co/datasets/audichandra/bitext_customer_support_llm_dataset_indonesian) customer support dataset for LLM. |
|
|
|
Gajah-7B is licensed under [MIT](https://opensource.org/license/mit) license to support the open source initiative and served as another example of how to finetune pre-trained model. |
|
|
|
you can contact me through my [LinkedIn](www.linkedin.com/in/audichandra) or [Github](https://github.com/audichandra/Indonesian_AI_Chatbot_Customer_Support) about this model and its applications. |
|
|
|
## Installation |
|
|
|
We need at least Python 3.10 and PyTorch 2, and do a pip install of the requirements.txt along with some optional pip install features such as flash attention: |
|
|
|
```bash |
|
pip install flash-attn |
|
``` |
|
|
|
## GPU requirements |
|
|
|
**Training**: 8x A40 |
|
**Loading**: 1x RTX A500 |
|
*notes: the author trains and loads the model on Cloud GPU platform such as runpods* |
|
|
|
## Scripts |
|
|
|
**Scripts for loading model using multiple GPU** |
|
|
|
```bash |
|
import torch |
|
import time |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig, LlamaTokenizer, BitsAndBytesConfig |
|
from peft import PeftModel, PeftConfig |
|
|
|
#BNB_CONFIG = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") |
|
model_chat = "audichandra/Gajah-7B" |
|
model1 = AutoModelForCausalLM.from_pretrained(model_chat |
|
, torch_dtype=torch.bfloat16, device_map="auto", pad_token_id=0 |
|
, attn_implementation="flash_attention_2" |
|
, cache_dir="/workspace" |
|
#, quantization_config=BNB_CONFIG |
|
) |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(model_chat) |
|
|
|
def generate_response(question: str) -> str: |
|
chat = [ |
|
{"role": "system", "content": "Ada yang bisa saya bantu?"}, |
|
{"role": "user", "content": question}, |
|
] |
|
|
|
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True) |
|
|
|
with torch.no_grad(): |
|
outputs = model1.generate(input_ids=inputs["input_ids"].to("cuda"), |
|
attention_mask=inputs.attention_mask, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id, |
|
max_new_tokens=512) |
|
response = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0] |
|
|
|
assistant_start = f'''{question} \n assistant\n ''' |
|
response_start = response.find(assistant_start) |
|
return response[response_start + len(assistant_start) :].strip() |
|
|
|
start_time = time.time() |
|
prompt = "bagaimana saya dapat membatalkan pembelian saya?" |
|
print(generate_response(prompt)) |
|
|
|
end_time = time.time() |
|
elapsed_time = end_time - start_time |
|
print(f"Elapsed time: {elapsed_time} seconds") |
|
``` |
|
|
|
*you can uncomment the bnbconfig to do a 4-bit quantization to run it with lower VRAM but the results might suffer in terms of quality and time* |
|
|
|
**Scripts for loading model using single GPU** |
|
|
|
```bash |
|
import torch |
|
import time |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig, LlamaTokenizer, BitsAndBytesConfig |
|
from peft import PeftModel, PeftConfig |
|
|
|
#BNB_CONFIG = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") |
|
#model_save_path1 = "/workspace/axolotl/merged_model" |
|
model_chat = "audichandra/Gajah-7B" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model1 = AutoModelForCausalLM.from_pretrained(model_chat |
|
, torch_dtype=torch.bfloat16 |
|
#, device_map="auto", pad_token_id=0 |
|
#, attn_implementation="flash_attention_2" |
|
, cache_dir="/workspace" |
|
#, quantization_config=BNB_CONFIG |
|
).to(device) |
|
tokenizer = LlamaTokenizer.from_pretrained(model_chat) |
|
|
|
def generate_response(question: str) -> str: |
|
chat = [ |
|
{"role": "system", "content": "Ada yang bisa saya bantu?"}, |
|
{"role": "user", "content": question}, |
|
] |
|
|
|
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True) |
|
|
|
inputs = inputs.to(device) # Ensure inputs are on the same device as the model |
|
|
|
with torch.no_grad(): |
|
outputs = model1.generate(**inputs, max_new_tokens=512) |
|
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
|
assistant_start = f'''{question} \n assistant\n ''' |
|
response_start = response.find(assistant_start) |
|
return response[response_start + len(assistant_start) :].strip() |
|
|
|
|
|
# Use the functions together |
|
start_time = time.time() |
|
prompt = "bagaimana saya dapat membatalkan pembelian saya?" |
|
print(generate_response(prompt)) |
|
|
|
end_time = time.time() |
|
elapsed_time = end_time - start_time |
|
print(f"Elapsed time: {elapsed_time} seconds") |
|
|
|
``` |
|
|
|
*some features such as flash attention might not work on single GPU* |
|
|
|
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl) |
|
|
|
## Citation |
|
|
|
```bash |
|
@article{Merak, |
|
title={Merak-7B: The LLM for Bahasa Indonesia}, |
|
author={Muhammad Ichsan}, |
|
publisher={Hugging Face} |
|
journal={Hugging Face Repository}, |
|
year={2023} |
|
} |
|
|
|
@article{dettmers2023qlora, |
|
title = {QLoRA: Efficient Finetuning of Quantized LLMs}, |
|
author = {Dettmers, Tim and Pagnoni, Artidoro and Holtzman, Ari and Zettlemoyer, Luke}, |
|
journal = {arXiv preprint arXiv:2305.14314}, |
|
year = {2023} |
|
} |
|
|
|
@article{axolotl, |
|
author = {{OpenAccess AI Collective}}, |
|
title = {Axolotl: A Repository for AI Research and Development}, |
|
year = {2023}, |
|
publisher = {GitHub}, |
|
journal = {GitHub repository}, |
|
howpublished = {\url{https://github.com/OpenAccess-AI-Collective/axolotl}} |
|
} |
|
``` |