|
from transformers import LlamaForCausalLM, LlamaConfig, AutoTokenizer |
|
import torch |
|
import os |
|
|
|
|
|
|
|
|
|
|
|
question = "A $y$-intercept is a point on the graph that lies on the $y$-axis, so $x = 0$. Hence, the number $y$-intercepts corresponds to the number of real solutions of the quadratic equation $y^2 - 4y - 1 = 0$. The discriminant of this quadratic equation is $(-4)^2 + 4 \cdot 1 \cdot (-1) = 20$, which is positive, so the quadratic has two distinct real roots. Therefore, the number of $y$-intercepts is $\boxed{2}$. \n \n [asy] \n size(150); \n real ticklen=3; \n real tickspace=2; \n \n real ticklength=0.1cm; \n real axisarrowsize=0.14cm; \n pen axispen=black+1.3bp; \n real vectorarrowsize=0.2cm; \n real tickdown=-0.5; \n real tickdownlength=-0.15inch; \n real tickdownbase=0.3; \n real wholetickdown=tickdown; \n void rr_cartesian_axes(real xleft, real xright, real ybottom, real ytop, real xstep=1, real ystep=1, bool \n \n useticks=false, bool complexplane=false, bool usegrid=true) { \n \n import graph; \n \n real i; \n \n if(complexplane) { \n \n label('$\textnormal{Re}$',(xright,0),SE); \n \n label('$\textnormal{Im}$',(0,ytop),NW); \n \n } else { \n \n label('$x$',(xright+0.4,-0.5)); \n \n label('$y$',(-0.5,ytop+0.2)); \n \n } \n \n ylimits(ybottom,ytop); \n \n xlimits( xleft, xright); \n \n real[] TicksArrx,TicksArry; \n \n for(i=xleft+xstep; i<xright; i+=xstep) { \n \n if(abs(i) >0.1) { \n \n TicksArrx.push(i); \n \n } \n \n } \n \n for(i=ybottom+ystep; i<ytop; i+=ystep) { \n \n if(abs(i) >0.1) { \n \n TicksArry.push(i); \n \n } \n \n } \n \n if(usegrid) {" |
|
predictor_load_path = "/home/ya255/projects/TokenButler/expt_model/TrainTokenButler_42_finetune_None_None_500_llama_meta-llama_Llama-3.2-3B_L3_3B_2k.csv_L3_3B_2k_False_False_2000_False_redpajama_1024_1_1_20_0.001_1024/16_False_4_1000_ExpPred_fixed_40pc_True_False_0_None_False_False_4_8_2_16_1024_False_False_True_28_0.38571428571428584__best.pt" |
|
base_model_name = "meta-llama/Llama-3.2-3B" |
|
|
|
def get_producer_layers(model): |
|
""" |
|
Traverses the model to find the producer layer (layer_idx=0).cc |
|
""" |
|
producer_modules = [] |
|
for module in model.modules(): |
|
if module.__class__.__name__.endswith("AttentionExperimental") and module.layer_idx == 0: |
|
producer_modules.append(module) |
|
return producer_modules |
|
|
|
|
|
base_model = LlamaForCausalLM.from_pretrained(base_model_name, device_map="auto") |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name) |
|
inputs = tokenizer(question, return_tensors="pt") |
|
inputs = {k: v.to(base_model.device) for k, v in inputs.items()} |
|
question_length = inputs['attention_mask'].shape[1] |
|
|
|
with torch.no_grad(): |
|
base_output_ids = base_model.generate( |
|
**inputs, |
|
max_new_tokens=200, |
|
do_sample=True, |
|
top_p=0.95, |
|
temperature=0.7, |
|
) |
|
base_output_text = tokenizer.decode(base_output_ids[0][question_length:], skip_special_tokens=True) |
|
|
|
|
|
base_model_device = base_model.device |
|
base_model.to("cpu") |
|
base_state_dict = base_model.state_dict() |
|
del base_model |
|
torch.cuda.empty_cache() |
|
|
|
from modeling_llama_butler import LlamaButlerConfig, LlamaButlerForCausalLM |
|
butler_config = LlamaButlerConfig.from_pretrained('config.json') |
|
|
|
butler_model = LlamaButlerForCausalLM(butler_config) |
|
butler_model.load_state_dict(base_state_dict, strict=False) |
|
|
|
model_producer_layers = get_producer_layers(butler_model) |
|
producer_layer_weights = torch.load(predictor_load_path) |
|
for idx, producer_layer_weight in enumerate(producer_layer_weights): |
|
try: |
|
model_producer_layers[idx].load_state_dict(producer_layer_weight, strict=False) |
|
except Exception as e: |
|
print(f"Error loading producer layer {idx}: {e}") |
|
print("\n\nContinuing... !! Bad Perf If Unintentional !!\n\n") |
|
|
|
|
|
butler_model.to(base_model_device) |
|
butler_model.eval() |
|
|
|
with torch.no_grad(): |
|
butler_output_ids = butler_model.generate( |
|
**inputs, |
|
max_new_tokens=200, |
|
do_sample=True, |
|
top_p=0.95, |
|
temperature=0.7, |
|
) |
|
|
|
butler_output_text = tokenizer.decode(butler_output_ids[0][question_length:], skip_special_tokens=True) |
|
|
|
print("\n=== Base Model Output (Newlines Removed For Brevity) ===\n") |
|
print(base_output_text.replace("\n", "")) |
|
print("\n") |
|
print("=== Butler Model Output (Newlines Removed For Brevity) ===\n") |
|
print(butler_output_text.replace("\n", "")) |
|
print("\n") |
|
|
|
OUTPUT_DIR = "." |
|
print(f"\nSaving final merged model to: {OUTPUT_DIR}") |
|
butler_model.save_pretrained(OUTPUT_DIR, safe_serialization=False) |
|
|
|
|
|
print("\nAll done! The folder should now have `pytorch_model.bin` and the updated `config.json`.\n") |
|
|