import torch from transformers import AutoModelForCausalLM,AutoTokenizer from transformers import LlamaTokenizer from vllm import LLM, SamplingParams def average_two_model(model_path_1,model_path_2,update_num,base_path='/dccstor/obsidian_llm/yiduo/h100_data/llama-3-8b'): # Path to save the averaged model and tokenizer averaged_model_path = "{0}".format(model_path_1+model_path_2.split('/')[-1]).replace('00','').replace('random','').replace('naive_3k','').replace('shuffle','').replace('average','') # Load and average the state dicts for each model models=[] model_paths=[model_path_1,model_path_2] for model_path in model_paths: models.append(AutoModelForCausalLM.from_pretrained(model_path)) avg_state_dict = {} for key in models[0].state_dict().keys(): avg_state_dict[key] = (update_num/(update_num+1))*models[0].state_dict()[key]+(1.0/(update_num+1))*models[1].state_dict()[key] #sum([model.state_dict()[key] for model in models]) / len(models) base_model = AutoModelForCausalLM.from_pretrained(base_path) # Load the base model configuration base_model.load_state_dict(avg_state_dict) base_model.save_pretrained(averaged_model_path) # Save the averaged model # Load the tokenizer (assuming all models used the same tokenizer) # If needed, adjust the tokenizer path to match the base LLaMA tokenizer used tokenizer = AutoTokenizer.from_pretrained(model_path_1) #tokenizer = LlamaTokenizer.from_pretrained(model_path+'_{0}'.format(seeds[0])) tokenizer.save_pretrained(averaged_model_path) return averaged_model_path