File size: 4,805 Bytes
82ecf63 ae5e271 82ecf63 cdc051a 2424766 cdc051a 2424766 82ecf63 2424766 82ecf63 ae5e271 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
---
base_model: unsloth/llama-3-8b-bnb-4bit
language:
- en
license: apache-2.0
tags:
- text-generation-inference
- transformers
- unsloth
- llama
- trl
datasets:
- Studeni/robot-instructions
pipeline_tag: text-generation
---
# Llama 3 8B Robot Instruction Model (4-bit)
## Model description
This model is a fine-tuned version of Llama 3 8B, optimized with Unsloth and quantized into 4-bit.
It is designed to convert casual user input text into function calls for controlling industrial robots.
The aim is to lower the barrier for individuals who do not have programming skills to control robots using simple text instructions.
## Model Details
- **Model ID:** Studeni/llama-3-8b-bnb-4bit-robot-instruct
- **Architecture:** Llama 3 8B
- **Quantization:** 4-bit
- **Framework:** Transformers, Peft, Unsloth
## Usage
### Using Unsloth Library
```python
import json
from datasets import load_dataset
from unsloth import FastLanguageModel
# Dataset
repo_id = "Studeni/robot-instructions"
dataset = load_dataset(repo_id, split="test")
test_input = dataset[0]["input"]
test_output = dataset[0]["output"]
print(f"User input: {test_input}\nGround truth: {test_output}")
# Prompt
robot_instruct_prompt = """
### Instruction:
Transform input into list of function calls for controlling industrial robots.
### Input:
{}
### Response:
{}
"""
# Model Parameters
lora_id = "Studeni/llama-3-8b-bnb-4bit-robot-instruct"
max_seq_length = 2048
dtype = None # Auto-detection. Use Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True
# Load the model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=lora_id,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)
# Tokenize input text
inputs = tokenizer(
[robot_instruct_prompt.format(test_input, "")],
return_tensors="pt",
).to("cuda")
# Run generation
outputs = model.generate(**inputs, max_new_tokens=64, use_cache=True)
text_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# Extracting function call and converting to json
function_call = text_output[0].split("### Response:")[-1].strip()
function_call = json.loads(function_call)
for f in function_call:
print(f"Function to call: {f['function']}")
print(f"Input parameters: {f['kwargs']}")
```
### Using Transformers and Peft
```python
import json
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
# Dataset
repo_id = "Studeni/robot-instructions"
dataset = load_dataset(repo_id, split="test")
test_input = dataset[0]["input"]
test_output = dataset[0]["output"]
print(f"User input: {test_input}\nGround truth: {test_output}")
# Prompt
robot_instruct_prompt = """
### Instruction:
Transform input into list of function calls for controlling industrial robots.
### Input:
{}
### Response:
{}
"""
# Model Parameters
lora_id = "Studeni/llama-3-8b-bnb-4bit-robot-instruct"
load_in_4bit = True
# Load model and tokenizer
model = AutoPeftModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=lora_id,
load_in_4bit=load_in_4bit,
)
tokenizer = AutoTokenizer.from_pretrained(lora_id)
# Tokenize input text
inputs = tokenizer(
[robot_instruct_prompt.format(test_input, "")],
return_tensors="pt",
).to("cuda")
# Run generation
outputs = model.generate(**inputs, max_new_tokens=256, use_cache=True)
text_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# Extracting function call and converting to json
function_call = text_output[0].split("### Response:")[-1].strip()
function_call = json.loads(function_call)
for f in function_call:
print(f"Function to call: {f['function']}")
print(f"Input parameters: {f['kwargs']}")
```
## Limitations and Future Work 🚨
This model is currently a work in progress and supports only three basic functions: `move_tcp`, `move_joint`, and `get_joint_values`.
Future iterations will include a more comprehensive dataset with more complex commands and capabilities, better human-labeled data, and improved performance metrics.
## Contributions and Collaborations 🤝
We welcome contributions and collaborations to help improve and expand the capabilities of this model. Whether you are interested in adding more complex functions, improving the dataset, or enhancing the model's performance, your input is valuable.
You can add and contact me on [LinkedIn](https://www.linkedin.com/in/milutin-studen/).
---
This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth) |