import torch from accelerate import PartialState import transformers import torch import json from tqdm import tqdm import os from distributed_inference.utils import caption_repharse import random from distributed_inference.prompt_config import prompt_dict, instruction_dict def main(num_gpu = 1): distributed_state = PartialState() model_id = "meta-llama/Meta-Llama-3-8B-Instruct" pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device=distributed_state.device ) instruction = instruction_dict['general'] user_prompt = prompt_dict['caption_generation'] # file_list = [ '4dor_count_dataset_0702.json', '4dor_phase_dataset_0702.json'] # file_list = ['4dor_recognition_dataset_0702.json'] file_list = ['./pwiseg_info.json'] img_path = '/mnt1/wjl/InternLM-XComposer/data/pwiseg/train/' for file in file_list: with open(file, 'r') as f: data = json.load(f) data_keys = list(data.keys()) data_keys.sort() output_name = '4dor_caption_dataset_pwiseg_llama3.json' output_data = [] if os.path.exists(output_name): with open(output_name, 'r') as f: output_data = json.load(f) start_index = len(output_data) for i in tqdm(range(start_index, len(data_keys), num_gpu)): tempature = random.uniform(0.8, 0.9) top_p = random.uniform(0.7, 0.9) indices = list(range(i, min(i+num_gpu, len(data_keys)))) input_list = [data[data_keys[j]] for j in indices] temp_output = {} with distributed_state.split_between_processes(input_list, apply_padding=True) as s_input: messages = [ {"role": "system", "content": instruction}, {"role": "user", "content": user_prompt.format(position=s_input)}, ] prompt = pipeline.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) terminators = [ pipeline.tokenizer.eos_token_id, pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") ] outputs = pipeline( prompt, max_new_tokens=500, eos_token_id=terminators, do_sample=True, temperature=tempature, top_p=top_p, ) results = outputs[0]["generated_text"][len(prompt):] results = caption_repharse(results) with open(f'{distributed_state.process_index}.json', 'w') as f: json.dump(results, f, indent=4) distributed_state.wait_for_everyone() if distributed_state.is_last_process: for j in range(len(indices)): with open(f'{j}.json', 'r') as f: temp_output = json.load(f) llava_dict = {} llava_dict["id"] = data_keys[indices[j]] llava_dict["image"] = os.path.join(img_path, data_keys[indices[j]]) llava_dict["caption"] = temp_output output_data.append(llava_dict) with open(output_name, 'w') as f: json.dump(output_data, f, indent=4) distributed_state.wait_for_everyone() for j in range(num_gpu): if os.path.exists(f'{j}.json'): os.remove(f'{j}.json') if __name__ == '__main__': main(num_gpu = 4)