import torch from accelerate import PartialState import transformers import torch import json from tqdm import tqdm import os from distributed_inference.utils import repharse from distributed_inference.prompt_config import prompt_dict def main(num_gpu = 1): distributed_state = PartialState() model_id = "meta-llama/Meta-Llama-3-8B-Instruct" pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device=distributed_state.device ) instruction = prompt_dict['rephrase_with_ex'] # file_list = [ '4dor_count_dataset_0702.json', '4dor_phase_dataset_0702.json'] # file_list = ['4dor_recognition_dataset_0702.json'] file_list = ['./data/count_dataset_pwiseg_0710.json'] for file in file_list: with open(file, 'r') as f: data = json.load(f) output_name = file.replace('.json', '_rephrased.json') existing_data = [] if os.path.exists(output_name): with open(output_name, 'r') as f: existing_data = json.load(f) data[:len(existing_data)] = existing_data start_index = len(existing_data) for i in tqdm(range(start_index, len(data), num_gpu)): indices = list(range(i, min(i+num_gpu, len(data)))) caption_list = [data[j]['caption'] for j in indices] temp_output = {} with distributed_state.split_between_processes(caption_list, apply_padding=True) as caption: messages = [ {"role": "system", "content": instruction}, {"role": "user", "content": caption}, ] prompt = pipeline.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) terminators = [ pipeline.tokenizer.eos_token_id, pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") ] outputs = pipeline( prompt, max_new_tokens=500, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9, ) results = outputs[0]["generated_text"][len(prompt):] output_data = repharse(results) with open(f'{distributed_state.process_index}.json', 'w') as f: json.dump(output_data, f, indent=4) distributed_state.wait_for_everyone() if distributed_state.is_last_process: output_data = data[:indices[-1] + 1].copy() for j in range(len(indices)): with open(f'{j}.json', 'r') as f: temp_output = json.load(f) output_data[j + i]['caption'] = temp_output with open(output_name, 'w') as f: json.dump(output_data, f, indent=4) distributed_state.wait_for_everyone() for j in range(len(indices)): os.remove(f'{j}.json') if __name__ == '__main__': main(num_gpu = 4)