######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import torch.nn.functional as F import json import math import random import os import sys import numpy as np import torch import lightning as L from torch.utils.data import Dataset from torch.utils.data import DataLoader from lightning_utilities.core.rank_zero import rank_zero_info from infer.rwkv.utils import PIPELINE pipeline = PIPELINE('rwkv', "rwkv_vocab_v20230424") from PIL import Image import pandas as pd import librosa import io import soundfile as sf # 读取parquet文件 from torchvision import transforms transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor() # 将图像转换为张量 ]) def process_conversation_text(conversations): conversation_text = f"\x16" for conv in conversations: role = conv.get('from', '').lower() content = conv.get('value', '') if role == 'human': conversation_text += f"User: {content}\x17" elif role in ['assistant', 'gpt']: conversation_text += f"Assistant: {content}\x17" return conversation_text def process_tokens(conversations): # conversation_text = f"\x16" inputs = [] labels = [] for conv in conversations: role = conv.get('from', '').lower() content = conv.get('value', '') if role in ['human', 'user']: question = f"\x16User: {content}\x17" input = torch.tensor(pipeline.encode(question)) label = torch.full_like(input, -100) elif role in ['assistant', 'gpt']: answer = f"\x16Assistant: {content}\x17" input= torch.tensor(pipeline.encode(answer)) label = input inputs.append(input) labels.append(label) inputs =torch.cat(inputs) labels =torch.cat(labels) return inputs, labels def bytes_to_audio(audio_bytes): with io.BytesIO(audio_bytes) as buf: # 使用 soundfile 读取音频数据 audio_array, sr = sf.read(buf) # 确保是单声道 if len(audio_array.shape) > 1: audio_array = audio_array.mean(axis=1) # 确保是 float32 类型 audio_array = audio_array.astype(np.float32) return { 'array': audio_array, 'sampling_rate': sr } def get_data_by_l_version(trainer: L.Trainer, args): if L.__version__[0] == '2': train_data = MyDataModule(args) else: raise ValueError(f"Unsupported PyTorch Lightning version: {L.__version__}") return train_data class GlobalIndexManager: def __init__(self, rank=0, device_num=1, shuffle=True): self.current_idx = 0 self.rank = rank self.device_num = device_num self.shuffle = shuffle def get_next_idx(self, idx_t): if self.shuffle: idx = idx_t else: idx = self.current_idx * self.device_num + self.rank self.current_idx += 1 return idx class MyDataModule(L.LightningDataModule): def __init__(self, args): super().__init__() self.args = args self.train_data = None def setup(self, stage=None): self.train_data = MyDataset(self.args) self.args.vocab_size = self.train_data.vocab_size self.train_data.real_epoch = self.trainer.current_epoch self.train_data.rank = self.trainer.global_rank self.train_data.world_size = self.trainer.world_size self.train_data.setup(self.trainer.global_rank, self.trainer.world_size, int(self.args.devices), self.args.data_shuffle) def train_dataloader(self): # must set shuffle=False, persistent_workers=False (because worker is in another thread) return DataLoader( self.train_data, shuffle=self.args.data_shuffle, pin_memory=True, batch_size=self.args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True ) class WorldDataset(Dataset): def __init__(self, args, emb=None): self.args = args self.rank = 0 self.real_epoch = 0 self.world_size = 0 self.index_manager = None self.emb = emb if args.data_type =='wav': import jsonlines # 打开并读取 JSON 文件 #with open(f'{args.data_file}/answer.jsonl', 'r') as file: with jsonlines.open(f'{args.data_file}/answer.jsonl') as file: self.data = list(file) elif args.data_type =='img': import jsonlines # 打开并读取 JSON 文件 #with open(f'{args.data_file}/answer.jsonl', 'r') as file: with jsonlines.open(f'{args.data_file}/answer.jsonl') as file: self.data = list(file) elif args.data_type=='hf_img': import jsonlines # with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file: # self.data = json.load(file) with jsonlines.open(f'{args.data_file}/chat.jsonl') as file: self.data = list(file) elif args.data_type=='visual': import jsonlines # with open(f'{args.data_file}/chat.json', 'r', encoding='utf-8') as file: # self.data = json.load(file) with jsonlines.open(f'{args.data_file}/chat.jsonl') as file: self.data = list(file) elif args.data_type == 'visual-r1-cs': llava_path = os.path.join(args.data_file, 'vision_r1_llava_cot_full.json') mulberry_path = os.path.join(args.data_file, 'vision_r1_mulberry_sft_full.json') import json with open(f'{llava_path}', 'r', encoding='utf-8') as file: llava_data = json.load(file) with open(f'{mulberry_path}', 'r', encoding='utf-8') as file: mulberry_data = json.load(file) # 合并数据集并添加来源标识 for item in llava_data: item['_source'] = 'llava_cot' for item in mulberry_data: item['_source'] = 'mulberry' self.data = llava_data + mulberry_data elif args.data_type =='hf' or args.data_type =='qa' or args.data_type =='cnqa' or args.data_type =='cnasr' or args.data_type =='tts': from datasets import load_dataset, concatenate_datasets def list_subdirectories(base_path): return [ name for name in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, name)) and not name.startswith('.') ] datasets = [] files = list_subdirectories(args.data_file) if not files: datasets = load_dataset(args.data_file, split="train") else: for file in files: dataset = load_dataset(f'{args.data_file}/{file}', split="train") datasets.append(dataset) datasets = concatenate_datasets(datasets) self.data = datasets print(len(datasets)) elif args.data_type == "jsonl": import jsonlines with jsonlines.open(args.data_file) as file: self.data = list(file) else: self.data = pd.read_parquet(args.data_file) def setup(self, rank, world_size, devices, shuffle): self.rank = rank self.world_size = world_size self.index_manager = GlobalIndexManager(rank=rank, device_num=devices, shuffle=shuffle) def __len__(self): return self.args.epoch_steps * self.args.micro_bsz def __getitem__(self, idx): idx = self.index_manager.get_next_idx(idx_t=idx) if self.index_manager else idx args = self.args if args.data_type =='wav': mod_name = self.data[idx]['file_name'] data_answer = self.data[idx]['answer'] mod_path = f'{args.data_file}/{mod_name}' audio, sample_rate = librosa.load(mod_path, sr=16000) # sr=None 保持原采样率 #sign,_ = self.speech_encoder(audio) sign = audio token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17')) elif args.data_type =='hf': sample = self.data[idx] audio = sample['audio'] data_answer = sample['text'] #####caption audio = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率 sign = audio token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17')) elif args.data_type =='tts': sample = self.data[idx] audio = sample['audio'] data_answer = sample['text'] #####caption audio = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率 sign = audio token = torch.tensor(pipeline.encode(f'User: {data_answer}\x17Assistant:')) elif args.data_type =='qa': sample = self.data[idx] # audio = sample['speech_cosy'][0] # data_answer = sample['answer'] audio = sample['question_audio'] data_answer = sample['answer'] sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率 token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17')) elif args.data_type =='cnqa': sample = self.data[idx] audio = sample['audio'] data_answer = sample['answer'] sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率 token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17')) elif args.data_type =='cnasr': sample = self.data[idx] audio = sample['audio'] data_answer = sample['transcript'] sign = librosa.resample(audio['array'],orig_sr= audio['sampling_rate'],target_sr= 16000) # sr=None 保持原采样率 token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17')) elif args.data_type == "jsonl": ctx_len = args.ctx_len req_len = ctx_len + 1 ctx = self.data[idx]['text'] token = torch.tensor(pipeline.encode(ctx)) token_len = len(token) pad_len = req_len - token_len dix = F.pad(token, (0, pad_len), value=0) x = dix[:-1] y = dix[1:] mask = torch.zeros(req_len - 1) mask[:token_len - 1] = 1 return x, y, mask elif args.data_type == "img": mod_name = self.data[idx]['file_name'] data_answer = self.data[idx]['answer'] mod_path = f'{args.data_file}/{mod_name}' token = torch.tensor(pipeline.encode(f'\n\nAssistant: {data_answer}\x17')) image = Image.open(mod_path).convert('RGB') sign = transform(image) elif args.data_type == 'visual': img_name = self.data[idx]['image'] conversation_text = self.data[idx]['conversations'] mod_path = f'{args.data_file}/images/{img_name}' image = Image.open(mod_path).convert('RGB') sign = image text_tokens, text_labels = process_tokens(conversation_text) return sign, text_tokens, text_labels elif args.data_type== 'hf_img': img_name = self.data[idx]['image'] conversation_text = self.data[idx]['conversations'] conversation_text = process_conversation_text(conversation_text) mod_path = f'{args.data_file}/images/{img_name}' token = torch.tensor(pipeline.encode(conversation_text)) image = Image.open(mod_path).convert('RGB') sign = image elif args.data_type == 'visual-r1-cs': item = self.data[idx] conversations = item['conversations'] # 根据来源处理图像路径 if item['_source'] == 'llava_cot': img_name = item['image'] else: img_name = item['images'] mod_path = f'{self.args.data_file}/{img_name}' image = Image.open(mod_path).convert('RGB') sign = image # 处理文本对话 text_tokens, text_labels = process_tokens(conversations) return sign, text_tokens, text_labels else: data_audio = bytes_to_audio(self.data['question_audio'][idx]['bytes']) data_answer = self.data['answer'][idx] audio = librosa.resample(data_audio['array'],orig_sr= 48000,target_sr= 16000) #sign,_ = self.speech_encoder(audio) sign = audio token = torch.tensor(pipeline.encode(f'\x16Assistant: {data_answer}\x17')) #print(idx, f'Assistant: {data_answer}\x17') return sign, token