In [1]:
pip install datasets torch transformers peft

[0mNote: you may need to restart the kernel to use updated packages.


In [4]:
from tqdm.notebook import tqdm

from datasets import load_dataset
import torch
from torch.utils.data import DataLoader

from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
)
from transformers import default_data_collator, Trainer, TrainingArguments

from short_hf import ShortHFModel

### Load Data

In [None]:
# data = load_dataset("pg19", split="validation")  # authors sample 10,000 texts to compute block influences
# dataloader = DataLoader(
#     data,
#     batch_size=2,
#     shuffle=True,
# )

In [5]:
data = load_dataset("wikitext", "wikitext-103-raw-v1", split="validation")  # authors sample 10,000 texts to compute block influences
dataloader = DataLoader(
    data,
    batch_size=1,
    shuffle=True,
)

### Load Model

In [3]:
# !huggingface-cli login
# pip install huggingface_hub
!python3 -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')"

In [None]:
#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

In [3]:
!huggingface-cli whoami

asifahmed


In [2]:
# pip install git+https://github.com/tri-ml/linear_open_lm.git


In [6]:
# from open_lm.open_lm_hf import *

MAX_SEQ_LEN = 2048
short_model = ShortHFModel(
    # model_name="tiiuae/falcon-7b",
    model_name="mistralai/Mistral-7B-v0.1",
    layers_path="model.layers",
    n_prune_layers=2
)
# short_model.model



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
short_model.model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm(

In [None]:
# AutoModelForCausalLM.from_pretrained(
#             pretrained_model_name_or_path=model_dir,
#             local_files_only=True,
#             use_safetensors=True,
#             torch_dtype=torch.bfloat16,
#         )

In [8]:
short_model.model.parameters()

<generator object Module.parameters at 0x7f00b3917840>

In [9]:
pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())
pytorch_total_params

7241732096

In [36]:
 # Save the model state to the specified path.
# model_dir='ShortModelSaved/'
# short_model.model.save_pretrained(
#         save_directory=model_dir,
#         safe_serialization=True,
#     )

In [10]:
short_model.layers[0]

MistralDecoderLayer(
  (self_attn): MistralSdpaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MistralRotaryEmbedding()
  )
  (mlp): MistralMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): MistralRMSNorm()
  (post_attention_layernorm): MistralRMSNorm()
)

In [12]:
# sample generationThe evolution of AI has lead to 
gen = short_model.model.generate(
    short_model.tokenizer(["I am an avid fan of "], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=256
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['I am an avid fan of 3D printing. I have been using 3D printers for over 10 years and have been involved in the development of several 3D printers. I have also been involved in the development of several 3D printing software packages.\n\nI have been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing so

In [2]:
# # sample generation
# gen = short_model.model.generate(
#     short_model.tokenizer(["The evolution of AI has lead to  "], return_tensors='pt').input_ids.to("cuda"),
#     max_new_tokens=256
# )
# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

### Compute Importances

In [50]:
# for i, batch in enumerate(tqdm(dataloader)):
#     prompts = batch['text']

#     short_model.eval_importance(
#         prompts=prompts,
#         max_seq_len=MAX_SEQ_LEN,
#         stride=256,
#         max_gen_len=0
#     )

In [51]:
# short_model.importances

### Remove unimportant layers

Layers removed when using subset of pg19 val set: [25, 26, 24, 27, 22, 23, 28, 21, 29]

Authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance.

In [55]:
# short_model.remove_layers()

In [54]:
# short_model.remove_layers()

In [56]:
# short_model.layers

In [48]:
# # reassign layer_idx to attentions for caching
# for layer_idx, module in enumerate(short_model.layers):
#     module.self_attn.layer_idx = layer_idx

In [20]:
# short_model.model.parameters()

<generator object Module.parameters at 0x7f625768a2d0>

In [68]:
# pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())
# pytorch_total_params

7241732096

As the paper states: \
    - "Our experiments reveal that the effect of layer removal is significantly more pronounced on generative
        tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and
        HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance
        drop, with scores approaching zero."

In [53]:
# gen = short_model.model.generate(
#     short_model.tokenizer(["I am an avid fan of  "], return_tensors='pt').input_ids.to("cuda"),
#     max_new_tokens=20,
#     use_cache=True
# )
# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

In [52]:
# gen = short_model.model.generate(I am an avid fan of 
#     short_model.tokenizer(["The evolution of AI has lead to "], return_tensors='pt').input_ids.to("cuda"),
#     max_new_tokens=20,
#     use_cache=True
# )
# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

### Compute Angular Importances

In [16]:
for i, batch in enumerate(tqdm(dataloader)):
    prompts = batch['text']

    short_model.eval_importance(
        prompts=prompts,
        max_seq_len=MAX_SEQ_LEN,
        stride=256,
        max_gen_len=0,
        angular=True
    )

  0%|          | 0/3760 [00:00<?, ?it/s]

In [17]:
short_model.importances

[128390.1328125,
 80922.06787109375,
 61075.2890625,
 nan,
 nan,
 56557.81268310547,
 nan,
 52294.552001953125,
 47928.185302734375,
 42335.215576171875,
 40547.564208984375,
 37178.684326171875,
 34713.912841796875,
 33843.728271484375,
 35384.353271484375,
 35603.388427734375,
 35621.970458984375,
 35356.719482421875,
 35365.243896484375,
 34914.025146484375,
 27854.576904296875,
 24398.073974609375,
 20450.390380859375,
 19501.300537109375,
 18430.427490234375,
 18231.873779296875,
 17917.493896484375,
 17806.815185546875,
 21227.195068359375,
 23928.313018798828,
 22738.702880859375,
 86123.783203125]

### Remove unimportant layers

In [18]:
short_model.remove_layers(angular=True)

[27, 28]

In [20]:
short_model.layers[0]

MistralDecoderLayer(
  (self_attn): MistralSdpaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (rotary_emb): MistralRotaryEmbedding()
  )
  (mlp): MistralMLP(
    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): MistralRMSNorm()
  (post_attention_layernorm): MistralRMSNorm()
)

In [21]:
short_model.layers

ModuleList(
  (0-29): 30 x MistralDecoderLayer(
    (self_attn): MistralSdpaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): MistralRotaryEmbedding()
    )
    (mlp): MistralMLP(
      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): MistralRMSNorm()
    (post_attention_layernorm): MistralRMSNorm()
  )
)

In [22]:
# reassign layer_idx to attentions for caching
for layer_idx, module in enumerate(short_model.layers):
    module.self_attn.layer_idx = layer_idx

In [23]:
short_model.layers

ModuleList(
  (0-29): 30 x MistralDecoderLayer(
    (self_attn): MistralSdpaAttention(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): MistralRotaryEmbedding()
    )
    (mlp): MistralMLP(
      (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
      (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): MistralRMSNorm()
    (post_attention_layernorm): MistralRMSNorm()
  )
)

In [24]:
gen = short_model.model.generate(
    short_model.tokenizer(["I am an avid fan of "], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=256,
    use_cache=True
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['I am an avid fan of 19th century American literature. I have read all of the classics, and I have also read many of the lesser known works. I have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\n\nI have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\n\nI have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\n\nI have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Ja

In [27]:
# gen = short_model.model.generate(I am an avid fan of 
#     short_model.tokenizer(["The evolution of AI has lead to "], return_tensors='pt').input_ids.to("cuda"),
#     max_new_tokens=256,
#     use_cache=True
# )
# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)


gen = short_model.model.generate(
    short_model.tokenizer(["The evolution of AI has lead to "], return_tensors='pt').input_ids.to("cuda"),
    max_new_tokens=256,
    use_cache=True
)
short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


['The evolution of AI has lead to 3 major types of AI:\n\n1. Strong AI\n2. Weak AI\n3. Super AI\n\nStrong AI is the type of AI that is capable of performing any task that a human can perform. This type of AI is still in the development phase and is not yet available in the market.\n\nWeak AI is the type of AI that is capable of performing a specific task. This type of AI is available in the market and is used in a variety of applications.\n\nSuper AI is the type of AI that is capable of performing any task that a human can perform and is also capable of learning and adapting. This type of AI is still in the development phase and is not yet available in the market.\n\n## What is the difference between AI and AI?\n\nThe difference between AI and AI is that AI is a type of artificial intelligence that is capable of performing a specific task, while AI is a type of artificial intelligence that is capable of performing any task.\n\n## What is the difference between AI and AI?\n\nThe differe

In [28]:
pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())
pytorch_total_params

6805508096

In [35]:
 # Save the model state to the specified path.
model_dir='SmallModelSaved/'
short_model.model.save_pretrained(
        save_directory=model_dir,
        safe_serialization=True,
    )

### Model Healing

In [36]:
# tokenizer = short_model.tokenizer
model = short_model.model

In [37]:
from datasets import load_dataset
# Falcon = load_dataset("csv", data_files="FalconData.csv")
Falcon = load_dataset('csv', data_files={"train": 'FalconData2.csv', "validation": 'FalconDataEval2.csv'})

print('Datset Loaded!')


Datset Loaded!


In [38]:
# Falcon = Falcon.train_test_split(test_size=0.10)

"""Then take a look at an example:"""

Falcon['train'][0]


{'Text': 'School Picture Gallery\nFrance Ski School\nChildren from Year 5 & 6 travelled to France from Newcastle airport to take part in a week of Ski School. The children had already spent 3 weeks learning the basics of skiing at Silksworth Ski School in Sunderland. When the children arrived in France they took part in a daily Ski School, during which the children made OUTSTANDING progress. The children also took part in French activities, explored local landmarks and took part in shopping activities in Chamonix. It was an incredible adventure for the children and staff!'}

In [39]:
Falcon['validation'][0]


{'Text': 'Our Annual Garden Party is a fun-filled event with a ton of landscaping and garden supplies; gardening demonstrations, experts, and vendors; activities for kids; live bands; and local food. It’s been so popular that we’re extending it to TWO DAYS this year!\nFestivities at 10am – 4pm Saturday and 11am – 3pm Sunday\nShopping from 9am – 6pm both days\nThroughout the winter, we collect gently-used and surplus lawn & garden supplies as well as outdoor décor and furniture. Then, we put it all out for your shopping pleasure! The sale begins at 9:00 am Saturday, but folks start lining up outside the gates even earlier, eager to dig through piles of flowerpots and shovels. (If you can’t get there in the morning, don’t worry – the staff continues to bring out items throughout the weekend.)\nThe Garden Sale 1st.\nThere will be prizes for people and pets dressed in garden party finery.\nPhoto by Carrie Delesky\nSo find yourself a dapper suit or fancy hat, and check out all the activitie

In [41]:
"""The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:"""

from transformers import AutoTokenizer, GPT2TokenizerFast

# tokenizer = AutoTokenizer.from_pretrained("distilgpt2")


tokenizer = GPT2TokenizerFast.from_pretrained("Xenova/gpt-4")#, cache_dir=cache_dir)
tokenizer.pad_token = tokenizer.eos_token


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [42]:
Falcon = Falcon.flatten()
Falcon["train"][0]

{'Text': 'School Picture Gallery\nFrance Ski School\nChildren from Year 5 & 6 travelled to France from Newcastle airport to take part in a week of Ski School. The children had already spent 3 weeks learning the basics of skiing at Silksworth Ski School in Sunderland. When the children arrived in France they took part in a daily Ski School, during which the children made OUTSTANDING progress. The children also took part in French activities, explored local landmarks and took part in shopping activities in Chamonix. It was an incredible adventure for the children and staff!'}

In [43]:
def preprocess_function(examples):
    return tokenizer([" ".join(x) for x in examples["Text"]])



tokenized_Falcon = Falcon.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=Falcon["train"].column_names,
)

The OrderedVocab you are attempting to save contains holes for indices [100256, 100261, 100262, 100263, 100266, 100267, 100268, 100269, 100270, 100271, 100272, 100273, 100274, 100275], your vocabulary could be corrupted !


Map (num_proc=4):   0%|          | 0/10000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (10412 > 8192). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (10738 > 8192). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (12860 > 8192). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (23091 > 8192). Running this sequence through the model will result in indexing errors


The OrderedVocab you are attempting to save contains holes for indices [100256, 100261, 100262, 100263, 100266, 100267, 100268, 100269, 100270, 100271, 100272, 100273, 100274, 100275], your vocabulary could be corrupted !


Map (num_proc=4):   0%|          | 0/1000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (9078 > 8192). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (15886 > 8192). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (28727 > 8192). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (8257 > 8192). Running this sequence through the model will result in indexing errors


In [44]:
# block_size = tokenizer.model_max_length
block_size = 2048


def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

"""Apply the `group_texts` function over the entire dataset:"""

lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4)


Map (num_proc=4):   0%|          | 0/10000 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [45]:
from transformers import DataCollatorForLanguageModeling

# tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


In [None]:
# from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
# import torch
# model = AutoModelForCausalLM.from_pretrained("tensorplex-labs/pretraining-sn9-7B-5", torch_dtype=torch.bfloat16)

# print('Model Loaded!')


In [46]:
model.to('cuda')

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-29): 30 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm(

In [47]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

6805508096

In [48]:
training_args = TrainingArguments(
    output_dir="Fine-Tuned-S9-2",
    overwrite_output_dir=True,
    bf16=True,
    # evaluation_strategy="epoch",
    evaluation_strategy="steps",
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    lr_scheduler_type = 'cosine',
    push_to_hub=False,
    save_total_limit = 2,
    # save_strategy = “no”
    load_best_model_at_end=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["validation"],
    # eval_dataset=lm_dataset["test"],
    data_collator=data_collator,
)

In [49]:
# trainer.train()
print('Started Training!')
trainer.train()

Started Training!


[34m[1mwandb[0m: Currently logged in as: [33mthatmlguy[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss


OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 

In [None]:
import math

eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")


In [29]:
# # referencing https://github.com/meta-llama/llama-recipes/blob/main/recipes/finetuning/huggingface_trainer/peft_finetuning.ipynb
# eval_prompt = """
# Summarize this dialog:
# A: Hi Tom, are you busy tomorrow's afternoon?
# B: I'm pretty sure I am. What's up?
# A: Can you go with me to the animal shelter?.
# B: What do you want to do?
# A: I want to get a puppy for my son.
# B: That will make him so happy.
# A: Yeah, we've discussed it many times. I think he's ready now.
# B: That's good. Raising a dog is a tough issue. Like having a baby ;-) 
# A: I'll get him one of those little dogs.
# B: One that won't grow up too big;-)
# A: And eat too much;-))
# B: Do you know which one he would like?
# A: Oh, yes, I took him there last Monday. He showed me one that he really liked.
# B: I bet you had to drag him away.
# A: He wanted to take it home right away ;-).
# B: I wonder what he'll name it.
# A: He said he'd name it after his dead hamster - Lemmy  - he's  a great Motorhead fan :-)))
# ---
# Summary:
# """

# model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")

# model.eval()
# with torch.no_grad():
#     print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100, use_cache=True)[0], skip_special_tokens=True))

In [30]:
# def get_preprocessed_samsum():
#     dataset = load_dataset("samsum", split="train")

#     prompt = (
#         f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
#     )

#     def apply_prompt_template(sample):
#         return {
#             "prompt": prompt.format(dialog=sample["dialogue"]),
#             "summary": sample["summary"],
#         }

#     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))

#     def tokenize_add_label(sample):
#         prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
#         summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
#         sample = {
#             "input_ids": prompt + summary,
#             "attention_mask" : [1] * (len(prompt) + len(summary)),
#             "labels": [-100] * len(prompt) + summary,
#             }

#         return sample

#     dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))

#     return dataset

In [31]:
# model.train()

# def create_peft_config(model):
#     peft_config = LoraConfig(
#         task_type=TaskType.CAUSAL_LM,
#         inference_mode=False,
#         r=8,
#         lora_alpha=32,
#         lora_dropout=0.05,
#         target_modules = ["q_proj", "v_proj"]
#     )

#     model = get_peft_model(model, peft_config)
#     model.print_trainable_parameters()
#     return model, peft_config

# # create peft config
# model, lora_config = create_peft_config(model)

In [32]:
# output_dir = "tmp/"

# config = {
#     'lora_config': lora_config,
#     'learning_rate': 1e-6,
#     'num_train_epochs': 1,
#     'per_device_train_batch_size': 1,
#     'gradient_checkpointing': False,
# }


In [33]:
# training_args = TrainingArguments(
#     output_dir=output_dir,
#     overwrite_output_dir=True,
#     # logging strategies
#     logging_strategy="steps",
#     logging_steps=10,
#     save_strategy="no",
#     optim="adamw_torch_fused",
#     **{k:v for k,v in config.items() if k != 'lora_config'}
# )

# # Create Trainer instance
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=get_preprocessed_samsum(),
#     data_collator=default_data_collator,
#     callbacks=[],
# )

# # Start training
# trainer.train()

In [34]:
# model.eval()
# with torch.no_grad():
#     print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))