Alic-Li's picture
Update world RWKV CPU
b2e0455 verified
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import torch.nn.functional as F
import json
import math
import random
import os
import sys
import numpy as np
import torch
import lightning as L
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from lightning_utilities.core.rank_zero import rank_zero_info
from infer.rwkv.utils import PIPELINE
pipeline = PIPELINE('rwkv', "rwkv_vocab_v20230424")
from PIL import Image
import pandas as pd
import librosa
import io
import soundfile as sf
# 读取parquet文件
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor() # 将图像转换为张量
])
def process_conversation_text(conversations):
conversation_text = f"\x16"
for conv in conversations:
role = conv.get('from', '').lower()
content = conv.get('value', '')
if role == 'human':
conversation_text += f"User: {content}\x17"
elif role in ['assistant', 'gpt']:
conversation_text += f"Assistant: {content}\x17"
return conversation_text
def process_tokens(conversations):
# conversation_text = f"\x16"
inputs = []
labels = []
for conv in conversations:
role = conv.get('from', '').lower()
content = conv.get('value', '')
if role in ['human', 'user']:
question = f"\x16User: {content}\x17"
input = torch.tensor(pipeline.encode(question))
label = torch.full_like(input, -100)
elif role in ['assistant', 'gpt']:
answer = f"\x16Assistant: {content}\x17"
input= torch.tensor(pipeline.encode(answer))
label = input
inputs.append(input)
labels.append(label)
inputs =torch.cat(inputs)
labels =torch.cat(labels)
return inputs, labels
def bytes_to_audio(audio_bytes):
with io.BytesIO(audio_bytes) as buf:
# 使用 soundfile 读取音频数据
audio_array, sr = sf.read(buf)
# 确保是单声道
if len(audio_array.shape) > 1:
audio_array = audio_array.mean(axis=1)
# 确保是 float32 类型
audio_array = audio_array.astype(np.float32)
return {
'array': audio_array,
'sampling_rate': sr
}
def get_data_by_l_version(trainer: L.Trainer, args):
if L.__version__[0] == '2':
train_data = MyDataModule(args)
else:
raise ValueError(f"Unsupported PyTorch Lightning version: {L.__version__}")
return train_data
class GlobalIndexManager:
def __init__(self, rank=0, device_num=1, shuffle=True):
self.current_idx = 0
self.rank = rank
self.device_num = device_num
self.shuffle = shuffle
def get_next_idx(self, idx_t):
if self.shuffle:
idx = idx_t
else:
idx = self.current_idx * self.device_num + self.rank
self.current_idx += 1
return idx
class MyDataModule(L.LightningDataModule):
def __init__(self, args):
super().__init__()
self.args = args
self.train_data = None
def setup(self, stage=None):
self.train_data = MyDataset(self.args)
self.args.vocab_size = self.train_data.vocab_size
self.train_data.real_epoch = self.trainer.current_epoch
self.train_data.rank = self.trainer.global_rank
self.train_data.world_size = self.trainer.world_size
self.train_data.setup(self.trainer.global_rank, self.trainer.world_size,
int(self.args.devices), self.args.data_shuffle)
def train_dataloader(self):
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
return DataLoader(
self.train_data,
shuffle=self.args.data_shuffle,
pin_memory=True,
batch_size=self.args.micro_bsz,
num_workers=1,
persistent_workers=False,
drop_last=True
)
class WorldDataset(Dataset):
def __init__(self, args, emb=None):
self.args = args
self.rank = 0
self.real_epoch = 0
self.world_size = 0
self.index_manager = None
self.emb = emb
if args.data_type =='wav':
import jsonlines
# 打开并读取 JSON 文件
#with open(f'{args.data_file}/answer.jsonl', 'r') as file:
with jsonlines.open(f'{args.data_file}/answer.jsonl') as file:
self.data = list(file)
elif args.data_type =='img':
import jsonlines
# 打开并读取 JSON 文件
#with open(f'{args.data_file}/answer.jsonl', 'r') as file:
with jsonlines.open(f'{args.data_file}/answer.jsonl') as file:
self.data = list(file)
elif args.data_type=='hf_img':
import jsonlines
# with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file:
# self.data = json.load(file)
with jsonlines.open(f'{args.data_file}/chat.jsonl') as file:
self.data = list(file)
elif args.data_type=='visual':
import jsonlines
# with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file:
# self.data = json.load(file)
with jsonlines.open(f'{args.data_file}/chat.jsonl') as file:
self.data = list(file)
elif args.data_type == 'visual-r1-cs':
llava_path = os.path.join(args.data_file, 'vision_r1_llava_cot_full.json')
mulberry_path = os.path.join(args.data_file, 'vision_r1_mulberry_sft_full.json')
import json
with open(f'{llava_path}', 'r', encoding='utf-8') as file:
llava_data = json.load(file)
with open(f'{mulberry_path}', 'r', encoding='utf-8') as file:
mulberry_data = json.load(file)
# 合并数据集并添加来源标识
for item in llava_data:
item['_source'] = 'llava_cot'
for item in mulberry_data:
item['_source'] = 'mulberry'
self.data = llava_data + mulberry_data
elif args.data_type =='hf' or args.data_type =='qa' or args.data_type =='cnqa' or args.data_type =='cnasr' or args.data_type =='tts':
from datasets import load_dataset, concatenate_datasets
def list_subdirectories(base_path):
return [
name for name in os.listdir(base_path)
if os.path.isdir(os.path.join(base_path, name)) and not name.startswith('.')
]
datasets = []
files = list_subdirectories(args.data_file)
if not files:
datasets = load_dataset(args.data_file, split="train")
else:
for file in files:
dataset = load_dataset(f'{args.data_file}/{file}', split="train")
datasets.append(dataset)
datasets = concatenate_datasets(datasets)
self.data = datasets
print(len(datasets))
elif args.data_type == "jsonl":
import jsonlines
with jsonlines.open(args.data_file) as file:
self.data = list(file)
else:
self.data = pd.read_parquet(args.data_file)
def setup(self, rank, world_size, devices, shuffle):
self.rank = rank
self.world_size = world_size
self.index_manager = GlobalIndexManager(rank=rank, device_num=devices, shuffle=shuffle)
def __len__(self):
return self.args.epoch_steps * self.args.micro_bsz
def __getitem__(self, idx):
idx = self.index_manager.get_next_idx(idx_t=idx) if self.index_manager else idx
args = self.args
if args.data_type =='wav':
mod_name = self.data[idx]['file_name']
data_answer = self.data[idx]['answer']
mod_path = f'{args.data_file}/{mod_name}'
audio, sample_rate = librosa.load(mod_path, sr=16000) # sr=None 保持原采样率
#sign,_ = self.speech_encoder(audio)
sign = audio
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
elif args.data_type =='hf':
sample = self.data[idx]
audio = sample['audio']
data_answer = sample['text'] #####caption
audio = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
sign = audio
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
elif args.data_type =='tts':
sample = self.data[idx]
audio = sample['audio']
data_answer = sample['text'] #####caption
audio = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
sign = audio
token = torch.tensor(pipeline.encode(f'User: {data_answer}\x17Assistant:'))
elif args.data_type =='qa':
sample = self.data[idx]
# audio = sample['speech_cosy'][0]
# data_answer = sample['answer']
audio = sample['question_audio']
data_answer = sample['answer']
sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
elif args.data_type =='cnqa':
sample = self.data[idx]
audio = sample['audio']
data_answer = sample['answer']
sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
elif args.data_type =='cnasr':
sample = self.data[idx]
audio = sample['audio']
data_answer = sample['transcript']
sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
elif args.data_type == "jsonl":
ctx_len = args.ctx_len
req_len = ctx_len + 1
ctx = self.data[idx]['text']
token = torch.tensor(pipeline.encode(ctx))
token_len = len(token)
pad_len = req_len - token_len
dix = F.pad(token, (0, pad_len), value=0)
x = dix[:-1]
y = dix[1:]
mask = torch.zeros(req_len - 1)
mask[:token_len - 1] = 1
return x, y, mask
elif args.data_type == "img":
mod_name = self.data[idx]['file_name']
data_answer = self.data[idx]['answer']
mod_path = f'{args.data_file}/{mod_name}'
token = torch.tensor(pipeline.encode(f'\n\nAssistant: {data_answer}\x17'))
image = Image.open(mod_path).convert('RGB')
sign = transform(image)
elif args.data_type == 'visual':
img_name = self.data[idx]['image']
conversation_text = self.data[idx]['conversations']
mod_path = f'{args.data_file}/images/{img_name}'
image = Image.open(mod_path).convert('RGB')
sign = image
text_tokens, text_labels = process_tokens(conversation_text)
return sign, text_tokens, text_labels
elif args.data_type== 'hf_img':
img_name = self.data[idx]['image']
conversation_text = self.data[idx]['conversations']
conversation_text = process_conversation_text(conversation_text)
mod_path = f'{args.data_file}/images/{img_name}'
token = torch.tensor(pipeline.encode(conversation_text))
image = Image.open(mod_path).convert('RGB')
sign = image
elif args.data_type == 'visual-r1-cs':
item = self.data[idx]
conversations = item['conversations']
# 根据来源处理图像路径
if item['_source'] == 'llava_cot':
img_name = item['image']
else:
img_name = item['images']
mod_path = f'{self.args.data_file}/{img_name}'
image = Image.open(mod_path).convert('RGB')
sign = image
# 处理文本对话
text_tokens, text_labels = process_tokens(conversations)
return sign, text_tokens, text_labels
else:
data_audio = bytes_to_audio(self.data['question_audio'][idx]['bytes'])
data_answer = self.data['answer'][idx]
audio = librosa.resample(data_audio['array'],orig_sr= 48000,target_sr= 16000)
#sign,_ = self.speech_encoder(audio)
sign = audio
token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17'))
#print(idx, f'Assistant: {data_answer}\x17')
return sign, token