|
import torch |
|
from accelerate import PartialState |
|
import transformers |
|
import torch |
|
import json |
|
from tqdm import tqdm |
|
import os |
|
from distributed_inference.utils import repharse |
|
from distributed_inference.prompt_config import prompt_dict |
|
|
|
def main(num_gpu = 1): |
|
distributed_state = PartialState() |
|
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
|
|
pipeline = transformers.pipeline( |
|
"text-generation", |
|
model=model_id, |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
device=distributed_state.device |
|
) |
|
|
|
|
|
instruction = ' You are an AI visual assistant, and you are looking at a picture of many surgical tools.' |
|
user_prompt = prompt_dict['question_generation_with_ex'] |
|
|
|
|
|
|
|
|
|
file_list = ['./pwiseg_info.json'] |
|
img_path = '/mnt1/wjl/InternLM-XComposer/data/pwiseg/train/' |
|
|
|
for file in file_list: |
|
with open(file, 'r') as f: |
|
data = json.load(f) |
|
data_keys = list(data.keys()) |
|
data_keys.sort() |
|
output_name = '4dor_count_dataset_pwiseg_llama3.json' |
|
output_data = [] |
|
if os.path.exists(output_name): |
|
with open(output_name, 'r') as f: |
|
output_data = json.load(f) |
|
start_index = len(output_data) |
|
print(start_index) |
|
for i in tqdm(range(start_index, len(data_keys), num_gpu)): |
|
indices = list(range(i, min(i+num_gpu, len(data_keys)))) |
|
input_list = [data[data_keys[j]] for j in indices] |
|
temp_output = {} |
|
with distributed_state.split_between_processes(input_list, apply_padding=True) as s_input: |
|
messages = [ |
|
{"role": "system", "content": instruction}, |
|
{"role": "user", "content": user_prompt.format(position=s_input)}, |
|
] |
|
prompt = pipeline.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
terminators = [ |
|
pipeline.tokenizer.eos_token_id, |
|
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
outputs = pipeline( |
|
prompt, |
|
max_new_tokens=500, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=0.9, |
|
) |
|
results = outputs[0]["generated_text"][len(prompt):] |
|
results = repharse(results) |
|
with open(f'{distributed_state.process_index}.json', 'w') as f: |
|
json.dump(results, f, indent=4) |
|
distributed_state.wait_for_everyone() |
|
if distributed_state.is_last_process: |
|
for j in range(len(indices)): |
|
with open(f'{j}.json', 'r') as f: |
|
temp_output = json.load(f) |
|
llava_dict = {} |
|
llava_dict["id"] = data_keys[indices[j]] |
|
llava_dict["image"] = os.path.join(img_path, data_keys[indices[j]]) |
|
llava_dict["caption"] = temp_output |
|
output_data.append(llava_dict) |
|
with open(output_name, 'w') as f: |
|
json.dump(output_data, f, indent=4) |
|
distributed_state.wait_for_everyone() |
|
for j in range(len(indices)): |
|
if os.path.exists(f'{j}.json'): |
|
os.remove(f'{j}.json') |
|
if __name__ == '__main__': |
|
main(num_gpu = 4) |