import copy import random import argparse import os import torch import torch.nn as nn from torch.utils.data import Dataset from tqdm import tqdm from collections import defaultdict import torch.distributed as dist import logging import re import pdb import json from prompt import sft_prompt, all_prompt import numpy as np class BaseDataset(Dataset): def __init__(self, args): super().__init__() self.args = args self.dataset = args.dataset self.data_path = os.path.join(args.data_path, self.dataset) self.max_his_len = args.max_his_len self.his_sep = args.his_sep self.index_file = args.index_file self.add_prefix = args.add_prefix self.new_tokens = None self.allowed_tokens = None self.all_items = None def _load_data(self): with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: self.indices = json.load(f) def get_new_tokens(self): if self.new_tokens is not None: return self.new_tokens self.new_tokens = set() for index in self.indices.values(): for token in index: self.new_tokens.add(token) self.new_tokens = sorted(list(self.new_tokens)) return self.new_tokens def get_all_items(self): if self.all_items is not None: return self.all_items self.all_items = set() for index in self.indices.values(): self.all_items.add("".join(index)) return self.all_items def get_prefix_allowed_tokens_fn(self, tokenizer): if self.allowed_tokens is None: self.allowed_tokens = {} for index in self.indices.values(): for i, token in enumerate(index): token_id = tokenizer(token)["input_ids"][1] if i not in self.allowed_tokens.keys(): self.allowed_tokens[i] = set() self.allowed_tokens[i].add(token_id) self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id]) sep = tokenizer("Response:")["input_ids"][1:] def prefix_allowed_tokens_fn(batch_id, sentence): sentence = sentence.tolist() reversed_sent = sentence[::-1] for i in range(len(reversed_sent)): if reversed_sent[i:i + len(sep)] == sep[::-1]: # print(list(self.allowed_tokens[i])) return list(self.allowed_tokens[i]) return prefix_allowed_tokens_fn def _process_data(self): raise NotImplementedError class SeqRecDataset(BaseDataset): def __init__(self, args, mode="train", prompt_sample_num=1, prompt_id=0, sample_num=-1): super().__init__(args) self.mode = mode self.prompt_sample_num = prompt_sample_num self.prompt_id = prompt_id self.sample_num = sample_num self.prompts = all_prompt["seqrec"] # load data self._load_data() self._remap_items() # load data if self.mode == 'train': self.inter_data = self._process_train_data() elif self.mode == 'valid': self.sample_valid = args.sample_valid self.valid_prompt_id = args.valid_prompt_id self.inter_data = self._process_valid_data() self._construct_valid_text() elif self.mode == 'test': self.inter_data = self._process_test_data() else: raise NotImplementedError def _load_data(self): with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: self.inters = json.load(f) with open(self.index_file, 'r') as f: self.indices = json.load(f) def _remap_items(self): self.remapped_inters = dict() for uid, items in self.inters.items(): new_items = ["".join(self.indices[str(i)]) for i in items] self.remapped_inters[uid] = new_items def _process_train_data(self): inter_data = [] for uid in self.remapped_inters: items = self.remapped_inters[uid][:-2] for i in range(1, len(items)): one_data = dict() # one_data["user"] = uid one_data["item"] = items[i] history = items[:i] if self.max_his_len > 0: history = history[-self.max_his_len:] if self.add_prefix: history = [str(k+1) + ". " + item_idx for k, item_idx in enumerate(history)] one_data["inters"] = self.his_sep.join(history) inter_data.append(one_data) return inter_data def _process_valid_data(self): inter_data = [] for uid in self.remapped_inters: items = self.remapped_inters[uid] one_data = dict() # one_data["user"] = uid one_data["item"] = items[-2] history = items[:-2] if self.max_his_len > 0: history = history[-self.max_his_len:] if self.add_prefix: history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] one_data["inters"] = self.his_sep.join(history) inter_data.append(one_data) return inter_data def _process_test_data(self): inter_data = [] for uid in self.remapped_inters: items = self.remapped_inters[uid] one_data = dict() # one_data["user"] = uid one_data["item"] = items[-1] history = items[:-1] if self.max_his_len > 0: history = history[-self.max_his_len:] if self.add_prefix: history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] one_data["inters"] = self.his_sep.join(history) inter_data.append(one_data) if self.sample_num > 0: all_inter_idx = range(len(inter_data)) sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) inter_data = np.array(inter_data)[sample_idx].tolist() return inter_data def set_prompt(self, prompt_id): self.prompt_id = prompt_id def __len__(self): if self.mode == 'train': return len(self.inter_data) * self.prompt_sample_num elif self.mode == 'valid': return len(self.valid_text_data) elif self.mode == 'test': return len(self.inter_data) else: raise NotImplementedError def _construct_valid_text(self): self.valid_text_data = [] if self.sample_valid: all_prompt_ids = range(len(self.prompts)) for i in range(len(self.inter_data)): d = self.inter_data[i] prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) for prompt_id in prompt_ids: prompt = self.prompts[prompt_id] input, output = self._get_text_data(d, prompt) self.valid_text_data.append({"input_ids": input, "labels": output}) else: self.prompt_sample_num = 1 prompt = self.prompts[self.valid_prompt_id] for i in range(len(self.inter_data)): d = self.inter_data[i] input, output = self._get_text_data(d, prompt) self.valid_text_data.append({"input_ids": input, "labels": output}) def _get_text_data(self, data, prompt): instruction = prompt["instruction"].format(**data) response = prompt["response"].format(**data) input = sft_prompt.format(instruction = instruction, response = "") output = sft_prompt.format(instruction = instruction, response = response) if self.mode == 'test': return input, response return input, output def __getitem__(self, index): if self.mode == 'valid': return self.valid_text_data[index] idx = index // self.prompt_sample_num d = self.inter_data[idx] # print(index, idx) if self.mode == 'train': prompt_id = random.randint(0, len(self.prompts) - 1) elif self.mode == 'test': prompt_id = self.prompt_id prompt = self.prompts[prompt_id] input, output = self._get_text_data(d, prompt) # print({"input": input, "output": output}) return dict(input_ids=input, labels=output) class FusionSeqRecDataset(BaseDataset): def __init__(self, args, mode="train", prompt_sample_num=1, prompt_id=0, sample_num=-1): super().__init__(args) self.mode = mode self.prompt_sample_num = prompt_sample_num self.prompt_id = prompt_id self.sample_num = sample_num self.prompts = all_prompt["fusionseqrec"] # load data self._load_data() # self._remap_items() # load data if self.mode == 'train': self.inter_data = self._process_train_data() elif self.mode == 'valid': self.sample_valid = args.sample_valid self.valid_prompt_id = args.valid_prompt_id self.inter_data = self._process_valid_data() self._construct_valid_text() elif self.mode == 'test': self.inter_data = self._process_test_data() else: raise NotImplementedError def _load_data(self): with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: self.inters = json.load(f) with open(self.index_file, 'r') as f: self.indices = json.load(f) with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f: self.item_feat = json.load(f) def _process_train_data(self): inter_data = [] for uid in self.inters: items = self.inters[uid][:-2] for i in range(1, len(items)): one_data = dict() # one_data["user"] = uid one_data["item"] = "".join(self.indices[str(items[i])]) one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`") one_data["description"] = self.item_feat[str(items[i])]["description"] history = items[:i] if self.max_his_len > 0: history = history[-self.max_his_len:] inters = ["".join(self.indices[str(j)]) for j in history] inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] if self.add_prefix: inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] one_data["inters"] = self.his_sep.join(inters) one_data["inter_titles"] = self.his_sep.join(inter_titles) inter_data.append(one_data) if self.sample_num > 0: all_inter_idx = range(len(inter_data)) sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) inter_data = np.array(inter_data)[sample_idx].tolist() return inter_data def _process_valid_data(self): inter_data = [] for uid in self.inters: items = self.inters[uid] one_data = dict() one_data["item"] = "".join(self.indices[str(items[-2])]) one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`") one_data["description"] = self.item_feat[str(items[-2])]["description"] history = items[:-2] if self.max_his_len > 0: history = history[-self.max_his_len:] inters = ["".join(self.indices[str(j)]) for j in history] inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] if self.add_prefix: inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] one_data["inters"] = self.his_sep.join(inters) one_data["inter_titles"] = self.his_sep.join(inter_titles) inter_data.append(one_data) if self.sample_num > 0: all_inter_idx = range(len(inter_data)) sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) inter_data = np.array(inter_data)[sample_idx].tolist() return inter_data def _process_test_data(self): inter_data = [] for uid in self.inters: items = self.inters[uid] one_data = dict() one_data["item"] = "".join(self.indices[str(items[-1])]) one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`") one_data["description"] = self.item_feat[str(items[-1])]["description"] history = items[:-1] if self.max_his_len > 0: history = history[-self.max_his_len:] inters = ["".join(self.indices[str(j)]) for j in history] inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] if self.add_prefix: inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] one_data["inters"] = self.his_sep.join(inters) one_data["inter_titles"] = self.his_sep.join(inter_titles) inter_data.append(one_data) if self.sample_num > 0: all_inter_idx = range(len(inter_data)) sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) inter_data = np.array(inter_data)[sample_idx].tolist() return inter_data def set_prompt(self, prompt_id): self.prompt_id = prompt_id def __len__(self): if self.mode == 'train': return len(self.inter_data) * self.prompt_sample_num elif self.mode == 'valid': return len(self.valid_text_data) elif self.mode == 'test': return len(self.inter_data) else: raise NotImplementedError def _construct_valid_text(self): self.valid_text_data = [] if self.sample_valid: all_prompt_ids = range(len(self.prompts)) for i in range(len(self.inter_data)): d = self.inter_data[i] prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) for prompt_id in prompt_ids: prompt = self.prompts[prompt_id] input, output = self._get_text_data(d, prompt) self.valid_text_data.append({"input_ids": input, "labels": output}) else: self.prompt_sample_num = 1 prompt = self.prompts[self.valid_prompt_id] for i in range(len(self.inter_data)): d = self.inter_data[i] input, output = self._get_text_data(d, prompt) self.valid_text_data.append({"input_ids": input, "labels": output}) def _get_text_data(self, data, prompt): instruction = prompt["instruction"].format(**data) response = prompt["response"].format(**data) input = sft_prompt.format(instruction=instruction, response="") output = sft_prompt.format(instruction=instruction, response=response) if self.mode == 'test': return input, response return input, output def __getitem__(self, index): if self.mode == 'valid': return self.valid_text_data[index] idx = index // self.prompt_sample_num d = self.inter_data[idx] if self.mode == 'train': prompt_id = random.randint(0, len(self.prompts) - 1) elif self.mode == 'test': prompt_id = self.prompt_id prompt = self.prompts[prompt_id] input, output = self._get_text_data(d, prompt) return dict(input_ids=input, labels=output) class ItemFeatDataset(BaseDataset): def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1): super().__init__(args) self.task = task.lower() self.prompt_sample_num = prompt_sample_num self.sample_num = sample_num self.prompts = all_prompt[self.task] # load data self._load_data() self.feat_data = self._process_data() def _load_data(self): with open(self.index_file, 'r') as f: self.indices = json.load(f) with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f: self.item_feat = json.load(f) def _process_data(self): feat_data = [] for iid in self.item_feat: feat = self.item_feat[iid] index = "".join(self.indices[iid]) feat["item"] = index feat["title"] = feat["title"].strip().strip(".!?,;:`") feat_data.append(feat) if self.sample_num > 0: all_idx = range(len(feat_data)) sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) feat_data = np.array(feat_data)[sample_idx].tolist() return feat_data def __len__(self): return len(self.feat_data) * self.prompt_sample_num def _get_text_data(self, data, prompt): instruction = prompt["instruction"].format(**data) response = prompt["response"].format(**data) input = sft_prompt.format(instruction = instruction, response = "") output = sft_prompt.format(instruction = instruction, response = response) return input, output def __getitem__(self, index): idx = index // self.prompt_sample_num d = self.feat_data[idx] prompt_id = random.randint(0, len(self.prompts) - 1) prompt = self.prompts[prompt_id] input, output = self._get_text_data(d, prompt) return dict(input_ids=input, labels=output) class ItemSearchDataset(BaseDataset): def __init__(self, args, mode="train", prompt_sample_num=1, prompt_id=0, sample_num=-1): super().__init__(args) self.mode = mode self.prompt_sample_num = prompt_sample_num self.prompt_id = prompt_id self.sample_num = sample_num self.prompts = all_prompt["itemsearch"] # load data self._load_data() self.search_data = self._process_data() def _load_data(self): with open(self.index_file, 'r') as f: self.indices = json.load(f) with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: self.user_info = json.load(f) def _process_data(self): search_data = [] user_explicit_preference = self.user_info["user_explicit_preference"] user_vague_intention = self.user_info["user_vague_intention"] if self.mode == 'train': user_vague_intention = user_vague_intention["train"] elif self.mode == 'test': user_vague_intention = user_vague_intention["test"] else: raise NotImplementedError for uid in user_explicit_preference.keys(): one_data = {} user_ep = user_explicit_preference[uid] user_vi = user_vague_intention[uid]["querys"] one_data["explicit_preferences"] = user_ep one_data["user_related_intention"] = user_vi[0] one_data["item_related_intention"] = user_vi[1] iid = user_vague_intention[uid]["item"] inters = user_vague_intention[uid]["inters"] index = "".join(self.indices[str(iid)]) one_data["item"] = index if self.max_his_len > 0: inters = inters[-self.max_his_len:] inters = ["".join(self.indices[str(i)]) for i in inters] if self.add_prefix: inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] one_data["inters"] = self.his_sep.join(inters) search_data.append(one_data) if self.sample_num > 0: all_idx = range(len(search_data)) sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) search_data = np.array(search_data)[sample_idx].tolist() return search_data def set_prompt(self, prompt_id): self.prompt_id = prompt_id def __len__(self): if self.mode == 'train': return len(self.search_data) * self.prompt_sample_num elif self.mode == 'test': return len(self.search_data) else: return len(self.search_data) def _get_text_data(self, data, prompt): instruction = prompt["instruction"].format(**data) response = prompt["response"].format(**data) input = sft_prompt.format(instruction = instruction, response = "") output = sft_prompt.format(instruction = instruction, response = response) if self.mode == 'test': return input, response return input, output def __getitem__(self, index): idx = index // self.prompt_sample_num d = self.search_data[idx] if self.mode == 'train': prompt_id = random.randint(0, len(self.prompts) - 1) elif self.mode == 'test': prompt_id = self.prompt_id prompt = self.prompts[prompt_id] d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) all_querys = [d["user_related_intention"], d["item_related_intention"]] d["query"] = random.choice(all_querys) input, output = self._get_text_data(d, prompt) return dict(input_ids=input, labels=output) class PreferenceObtainDataset(BaseDataset): def __init__(self, args, prompt_sample_num=1, sample_num=-1): super().__init__(args) self.prompt_sample_num = prompt_sample_num self.sample_num = sample_num self.prompts = all_prompt["preferenceobtain"] # load data self._load_data() self._remap_items() self.preference_data = self._process_data() def _load_data(self): with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: self.user_info = json.load(f) with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: self.inters = json.load(f) with open(self.index_file, 'r') as f: self.indices = json.load(f) def _remap_items(self): self.remapped_inters = dict() for uid, items in self.inters.items(): new_items = ["".join(self.indices[str(i)]) for i in items] self.remapped_inters[uid] = new_items def _process_data(self): preference_data = [] user_explicit_preference = self.user_info["user_explicit_preference"] for uid in user_explicit_preference.keys(): one_data = {} inters = self.remapped_inters[uid][:-3] user_ep = user_explicit_preference[uid] if self.max_his_len > 0: inters = inters[-self.max_his_len:] if self.add_prefix: inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] one_data["explicit_preferences"] = user_ep one_data["inters"] = self.his_sep.join(inters) preference_data.append(one_data) if self.sample_num > 0: all_idx = range(len(preference_data)) sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) preference_data = np.array(preference_data)[sample_idx].tolist() return preference_data def set_prompt(self, prompt_id): self.prompt_id = prompt_id def __len__(self): return len(self.preference_data) * self.prompt_sample_num def _get_text_data(self, data, prompt): instruction = prompt["instruction"].format(**data) response = prompt["response"].format(**data) input = sft_prompt.format(instruction = instruction, response = "") output = sft_prompt.format(instruction = instruction, response = response) return input, output def __getitem__(self, index): idx = index // self.prompt_sample_num d = self.preference_data[idx] prompt_id = random.randint(0, len(self.prompts) - 1) prompt = self.prompts[prompt_id] d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) input, output = self._get_text_data(d, prompt) return dict(input_ids=input, labels=output) class SeqRecTestDataset(BaseDataset): def __init__(self, args, prompt_id=0, sample_num=-1): super().__init__(args) self.prompt_id = prompt_id self.sample_num = sample_num self.prompt = all_prompt["seqrec"][self.prompt_id] # load data self._load_data() self._remap_items() self.inter_data = self._process_test_data() def _load_data(self): with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: self.inters = json.load(f) with open(self.index_file, 'r') as f: self.indices = json.load(f) def _remap_items(self): self.remapped_inters = dict() for uid, items in self.inters.items(): new_items = ["".join(self.indices[str(i)]) for i in items] self.remapped_inters[uid] = new_items def _process_test_data(self): inter_data = [] for uid in self.remapped_inters: items = self.remapped_inters[uid] one_data = dict() # one_data["user"] = uid one_data["item"] = items[-1] history = items[:-1] if self.max_his_len > 0: history = history[-self.max_his_len:] if self.add_prefix: history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] one_data["inters"] = self.his_sep.join(history) inter_data.append(one_data) if self.sample_num > 0: all_inter_idx = range(len(inter_data)) sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) inter_data = np.array(inter_data)[sample_idx].tolist() return inter_data def set_prompt(self, prompt_id): self.prompt_id = prompt_id self.prompt = all_prompt["seqrec"][self.prompt_id] def __len__(self): return len(self.inter_data) def _get_text_data(self, data, prompt): instruction = prompt["instruction"].format(**data) response = prompt["response"].format(**data) input = sft_prompt.format(instruction=instruction, response="") return input, response def __getitem__(self, index): d = self.inter_data[index] input, target = self._get_text_data(d, self.prompt) return dict(input_ids=input, labels=target)