llava_finetune / distributed_inference /distribute_llama_question_generation.py
lyclyc52's picture
Update: add distributed inference code
f168581
raw
history blame
3.75 kB
import torch
from accelerate import PartialState
import transformers
import torch
import json
from tqdm import tqdm
import os
from distributed_inference.utils import repharse
from distributed_inference.prompt_config import prompt_dict
def main(num_gpu = 1):
distributed_state = PartialState()
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device=distributed_state.device
)
instruction = ' You are an AI visual assistant, and you are looking at a picture of many surgical tools.'
user_prompt = prompt_dict['question_generation_with_ex']
# file_list = [ '4dor_count_dataset_0702.json', '4dor_phase_dataset_0702.json']
# file_list = ['4dor_recognition_dataset_0702.json']
file_list = ['./pwiseg_info.json']
img_path = '/mnt1/wjl/InternLM-XComposer/data/pwiseg/train/'
for file in file_list:
with open(file, 'r') as f:
data = json.load(f)
data_keys = list(data.keys())
data_keys.sort()
output_name = '4dor_count_dataset_pwiseg_llama3.json'
output_data = []
if os.path.exists(output_name):
with open(output_name, 'r') as f:
output_data = json.load(f)
start_index = len(output_data)
print(start_index)
for i in tqdm(range(start_index, len(data_keys), num_gpu)):
indices = list(range(i, min(i+num_gpu, len(data_keys))))
input_list = [data[data_keys[j]] for j in indices]
temp_output = {}
with distributed_state.split_between_processes(input_list, apply_padding=True) as s_input:
messages = [
{"role": "system", "content": instruction},
{"role": "user", "content": user_prompt.format(position=s_input)},
]
prompt = pipeline.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
terminators = [
pipeline.tokenizer.eos_token_id,
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
outputs = pipeline(
prompt,
max_new_tokens=500,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
results = outputs[0]["generated_text"][len(prompt):]
results = repharse(results)
with open(f'{distributed_state.process_index}.json', 'w') as f:
json.dump(results, f, indent=4)
distributed_state.wait_for_everyone()
if distributed_state.is_last_process:
for j in range(len(indices)):
with open(f'{j}.json', 'r') as f:
temp_output = json.load(f)
llava_dict = {}
llava_dict["id"] = data_keys[indices[j]]
llava_dict["image"] = os.path.join(img_path, data_keys[indices[j]])
llava_dict["caption"] = temp_output
output_data.append(llava_dict)
with open(output_name, 'w') as f:
json.dump(output_data, f, indent=4)
distributed_state.wait_for_everyone()
for j in range(len(indices)):
if os.path.exists(f'{j}.json'):
os.remove(f'{j}.json')
if __name__ == '__main__':
main(num_gpu = 4)