Erland commited on
Commit
d2fdb73
·
verified ·
1 Parent(s): 8e0ecfd

Update README.md with weight comparison and hardware info

Browse files
Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -5,12 +5,12 @@ tags:
5
  - flax
6
  - text-generation
7
  - transformers
8
- - meta-llama/Llama-3.2-1B
9
  ---
10
 
11
  # meta-llama/Llama-3.2-1B - JAX/Flax
12
 
13
- This repository contains the JAX/Flax version of the meta-llama/Llama-3.2-1B model, originally a PyTorch model from {original_model_org}. This conversion enables efficient inference and training on TPUs and GPUs using the JAX/Flax framework.
14
 
15
  ## Model Description
16
 
@@ -27,7 +27,7 @@ This model was converted from the original PyTorch implementation to JAX/Flax. T
27
 
28
  ### Important Note about `max_position_embeddings`
29
 
30
- During the conversion process, it was necessary to modify the `max_position_embeddings` parameter in the model's configuration. The original value of {original_max_pos_embed} led to out-of-memory (OOM) errors on the hardware used for conversion. To resolve this, `max_position_embeddings` was adjusted to {new_max_pos_embed}.
31
 
32
  **Implications of this change:**
33
 
@@ -205,7 +205,7 @@ The conversion process was performed on the following hardware configuration:
205
  * **Transformers version:** 4.47.0
206
  * **GPU:** NVIDIA A100-SXM4-40GB
207
 
208
- This conversion took approximately 130.21 seconds to complete.
209
 
210
  ## Usage
211
 
 
5
  - flax
6
  - text-generation
7
  - transformers
8
+ - meta-llama/Llama-3.2-1B # Add the specific model name as a tag
9
  ---
10
 
11
  # meta-llama/Llama-3.2-1B - JAX/Flax
12
 
13
+ This repository contains the JAX/Flax version of the meta-llama/Llama-3.2-1B model, originally a PyTorch model from meta-llama. This conversion enables efficient inference and training on TPUs and GPUs using the JAX/Flax framework.
14
 
15
  ## Model Description
16
 
 
27
 
28
  ### Important Note about `max_position_embeddings`
29
 
30
+ During the conversion process, it was necessary to modify the `max_position_embeddings` parameter in the model's configuration. The original value of 131072 led to out-of-memory (OOM) errors on the hardware used for conversion. To resolve this, `max_position_embeddings` was adjusted to 32768.
31
 
32
  **Implications of this change:**
33
 
 
205
  * **Transformers version:** 4.47.0
206
  * **GPU:** NVIDIA A100-SXM4-40GB
207
 
208
+ This conversion took approximately 52.90 seconds to complete.
209
 
210
  ## Usage
211