|
import torch |
|
from accelerate import PartialState |
|
import transformers |
|
import torch |
|
import json |
|
from tqdm import tqdm |
|
import os |
|
from distributed_inference.utils import repharse |
|
from distributed_inference.prompt_config import prompt_dict |
|
|
|
def main(num_gpu = 1): |
|
distributed_state = PartialState() |
|
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
|
|
pipeline = transformers.pipeline( |
|
"text-generation", |
|
model=model_id, |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
device=distributed_state.device |
|
) |
|
|
|
|
|
instruction = prompt_dict['rephrase_with_ex'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
file_list = ['./data/count_dataset_pwiseg_0710.json'] |
|
for file in file_list: |
|
with open(file, 'r') as f: |
|
data = json.load(f) |
|
output_name = file.replace('.json', '_rephrased.json') |
|
existing_data = [] |
|
if os.path.exists(output_name): |
|
with open(output_name, 'r') as f: |
|
existing_data = json.load(f) |
|
data[:len(existing_data)] = existing_data |
|
start_index = len(existing_data) |
|
for i in tqdm(range(start_index, len(data), num_gpu)): |
|
indices = list(range(i, min(i+num_gpu, len(data)))) |
|
caption_list = [data[j]['caption'] for j in indices] |
|
|
|
temp_output = {} |
|
with distributed_state.split_between_processes(caption_list, apply_padding=True) as caption: |
|
messages = [ |
|
{"role": "system", "content": instruction}, |
|
{"role": "user", "content": caption}, |
|
] |
|
prompt = pipeline.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
terminators = [ |
|
pipeline.tokenizer.eos_token_id, |
|
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
outputs = pipeline( |
|
prompt, |
|
max_new_tokens=500, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=0.9, |
|
) |
|
results = outputs[0]["generated_text"][len(prompt):] |
|
|
|
output_data = repharse(results) |
|
with open(f'{distributed_state.process_index}.json', 'w') as f: |
|
json.dump(output_data, f, indent=4) |
|
distributed_state.wait_for_everyone() |
|
if distributed_state.is_last_process: |
|
output_data = data[:indices[-1] + 1].copy() |
|
for j in range(len(indices)): |
|
with open(f'{j}.json', 'r') as f: |
|
temp_output = json.load(f) |
|
output_data[j + i]['caption'] = temp_output |
|
with open(output_name, 'w') as f: |
|
json.dump(output_data, f, indent=4) |
|
distributed_state.wait_for_everyone() |
|
for j in range(len(indices)): |
|
os.remove(f'{j}.json') |
|
if __name__ == '__main__': |
|
main(num_gpu = 4) |