|
--- |
|
license: mit |
|
language: |
|
- xh |
|
- nr |
|
- zu |
|
- ss |
|
--- |
|
|
|
Usage: |
|
|
|
1. Corrupted span prediction. |
|
|
|
``` |
|
## Example from here: https://huggingface.co/docs/transformers/en/model_doc/byt5 |
|
tokenizer = AutoTokenizer.from_pretrained("francois-meyer/nguni-byt5-large") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("francois-meyer/nguni-byt5-large") |
|
#model = T5ForConditionalGeneration.from_pretrained(model_path) |
|
|
|
input_ids_prompt = "The dog chases a ball in the park." |
|
input_ids = tokenizer(input_ids_prompt).input_ids |
|
|
|
input_ids = torch.tensor([input_ids[:8] + [258] + input_ids[14:21] + [257] + input_ids[28:]]) ## Corruption |
|
|
|
output_ids = model.generate(input_ids, max_length=100)[0].tolist() |
|
|
|
output_ids_list = [] |
|
start_token = 0 |
|
sentinel_token = 258 |
|
while sentinel_token in output_ids: |
|
split_idx = output_ids.index(sentinel_token) |
|
output_ids_list.append(output_ids[start_token:split_idx]) |
|
start_token = split_idx |
|
sentinel_token -= 1 |
|
|
|
output_ids_list.append(output_ids[start_token:]) |
|
output_string = tokenizer.batch_decode(output_ids_list) |
|
print(output_string) |
|
``` |
|
2. For any other task, you will need to fine-tune it like any other T5, mT5, byT5 model. |