|
import copy |
|
import random |
|
import argparse |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
from collections import defaultdict |
|
import torch.distributed as dist |
|
import logging |
|
import re |
|
import pdb |
|
import json |
|
from prompt import sft_prompt, all_prompt |
|
import numpy as np |
|
|
|
|
|
class BaseDataset(Dataset): |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
|
|
self.args = args |
|
self.dataset = args.dataset |
|
self.data_path = os.path.join(args.data_path, self.dataset) |
|
|
|
self.max_his_len = args.max_his_len |
|
self.his_sep = args.his_sep |
|
self.index_file = args.index_file |
|
self.add_prefix = args.add_prefix |
|
|
|
self.new_tokens = None |
|
self.allowed_tokens = None |
|
self.all_items = None |
|
|
|
|
|
def _load_data(self): |
|
|
|
with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: |
|
self.indices = json.load(f) |
|
|
|
def get_new_tokens(self): |
|
|
|
if self.new_tokens is not None: |
|
return self.new_tokens |
|
|
|
self.new_tokens = set() |
|
for index in self.indices.values(): |
|
for token in index: |
|
self.new_tokens.add(token) |
|
self.new_tokens = sorted(list(self.new_tokens)) |
|
|
|
return self.new_tokens |
|
|
|
def get_all_items(self): |
|
|
|
if self.all_items is not None: |
|
return self.all_items |
|
|
|
self.all_items = set() |
|
for index in self.indices.values(): |
|
self.all_items.add("".join(index)) |
|
|
|
return self.all_items |
|
|
|
def get_prefix_allowed_tokens_fn(self, tokenizer): |
|
|
|
|
|
if self.allowed_tokens is None: |
|
self.allowed_tokens = {} |
|
for index in self.indices.values(): |
|
for i, token in enumerate(index): |
|
token_id = tokenizer(token)["input_ids"][1] |
|
if i not in self.allowed_tokens.keys(): |
|
self.allowed_tokens[i] = set() |
|
self.allowed_tokens[i].add(token_id) |
|
self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id]) |
|
sep = tokenizer("Response:")["input_ids"][1:] |
|
|
|
def prefix_allowed_tokens_fn(batch_id, sentence): |
|
sentence = sentence.tolist() |
|
reversed_sent = sentence[::-1] |
|
for i in range(len(reversed_sent)): |
|
if reversed_sent[i:i + len(sep)] == sep[::-1]: |
|
|
|
return list(self.allowed_tokens[i]) |
|
|
|
return prefix_allowed_tokens_fn |
|
|
|
def _process_data(self): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
class SeqRecDataset(BaseDataset): |
|
|
|
def __init__(self, args, mode="train", |
|
prompt_sample_num=1, prompt_id=0, sample_num=-1): |
|
super().__init__(args) |
|
|
|
self.mode = mode |
|
self.prompt_sample_num = prompt_sample_num |
|
self.prompt_id = prompt_id |
|
self.sample_num = sample_num |
|
|
|
self.prompts = all_prompt["seqrec"] |
|
|
|
|
|
|
|
self._load_data() |
|
self._remap_items() |
|
|
|
|
|
if self.mode == 'train': |
|
self.inter_data = self._process_train_data() |
|
elif self.mode == 'valid': |
|
self.sample_valid = args.sample_valid |
|
self.valid_prompt_id = args.valid_prompt_id |
|
self.inter_data = self._process_valid_data() |
|
self._construct_valid_text() |
|
elif self.mode == 'test': |
|
self.inter_data = self._process_test_data() |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
def _load_data(self): |
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: |
|
self.inters = json.load(f) |
|
with open(self.index_file, 'r') as f: |
|
self.indices = json.load(f) |
|
|
|
|
|
def _remap_items(self): |
|
|
|
self.remapped_inters = dict() |
|
for uid, items in self.inters.items(): |
|
new_items = ["".join(self.indices[str(i)]) for i in items] |
|
self.remapped_inters[uid] = new_items |
|
|
|
|
|
def _process_train_data(self): |
|
|
|
inter_data = [] |
|
for uid in self.remapped_inters: |
|
items = self.remapped_inters[uid][:-2] |
|
for i in range(1, len(items)): |
|
one_data = dict() |
|
|
|
one_data["item"] = items[i] |
|
history = items[:i] |
|
if self.max_his_len > 0: |
|
history = history[-self.max_his_len:] |
|
if self.add_prefix: |
|
history = [str(k+1) + ". " + item_idx for k, item_idx in enumerate(history)] |
|
one_data["inters"] = self.his_sep.join(history) |
|
inter_data.append(one_data) |
|
|
|
return inter_data |
|
|
|
def _process_valid_data(self): |
|
|
|
inter_data = [] |
|
for uid in self.remapped_inters: |
|
items = self.remapped_inters[uid] |
|
one_data = dict() |
|
|
|
one_data["item"] = items[-2] |
|
history = items[:-2] |
|
if self.max_his_len > 0: |
|
history = history[-self.max_his_len:] |
|
if self.add_prefix: |
|
history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] |
|
one_data["inters"] = self.his_sep.join(history) |
|
inter_data.append(one_data) |
|
|
|
return inter_data |
|
|
|
def _process_test_data(self): |
|
|
|
inter_data = [] |
|
for uid in self.remapped_inters: |
|
items = self.remapped_inters[uid] |
|
one_data = dict() |
|
|
|
one_data["item"] = items[-1] |
|
history = items[:-1] |
|
if self.max_his_len > 0: |
|
history = history[-self.max_his_len:] |
|
if self.add_prefix: |
|
history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] |
|
one_data["inters"] = self.his_sep.join(history) |
|
inter_data.append(one_data) |
|
|
|
if self.sample_num > 0: |
|
all_inter_idx = range(len(inter_data)) |
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
|
inter_data = np.array(inter_data)[sample_idx].tolist() |
|
|
|
return inter_data |
|
|
|
def set_prompt(self, prompt_id): |
|
|
|
self.prompt_id = prompt_id |
|
|
|
def __len__(self): |
|
if self.mode == 'train': |
|
return len(self.inter_data) * self.prompt_sample_num |
|
elif self.mode == 'valid': |
|
return len(self.valid_text_data) |
|
elif self.mode == 'test': |
|
return len(self.inter_data) |
|
else: |
|
raise NotImplementedError |
|
|
|
def _construct_valid_text(self): |
|
self.valid_text_data = [] |
|
if self.sample_valid: |
|
all_prompt_ids = range(len(self.prompts)) |
|
for i in range(len(self.inter_data)): |
|
d = self.inter_data[i] |
|
prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) |
|
for prompt_id in prompt_ids: |
|
prompt = self.prompts[prompt_id] |
|
input, output = self._get_text_data(d, prompt) |
|
self.valid_text_data.append({"input_ids": input, "labels": output}) |
|
else: |
|
self.prompt_sample_num = 1 |
|
prompt = self.prompts[self.valid_prompt_id] |
|
for i in range(len(self.inter_data)): |
|
d = self.inter_data[i] |
|
input, output = self._get_text_data(d, prompt) |
|
self.valid_text_data.append({"input_ids": input, "labels": output}) |
|
|
|
def _get_text_data(self, data, prompt): |
|
|
|
instruction = prompt["instruction"].format(**data) |
|
response = prompt["response"].format(**data) |
|
|
|
input = sft_prompt.format(instruction = instruction, response = "") |
|
output = sft_prompt.format(instruction = instruction, response = response) |
|
|
|
if self.mode == 'test': |
|
return input, response |
|
|
|
return input, output |
|
|
|
def __getitem__(self, index): |
|
|
|
if self.mode == 'valid': |
|
return self.valid_text_data[index] |
|
|
|
idx = index // self.prompt_sample_num |
|
d = self.inter_data[idx] |
|
|
|
|
|
if self.mode == 'train': |
|
prompt_id = random.randint(0, len(self.prompts) - 1) |
|
elif self.mode == 'test': |
|
prompt_id = self.prompt_id |
|
|
|
prompt = self.prompts[prompt_id] |
|
|
|
input, output = self._get_text_data(d, prompt) |
|
|
|
|
|
|
|
return dict(input_ids=input, labels=output) |
|
|
|
|
|
class FusionSeqRecDataset(BaseDataset): |
|
|
|
def __init__(self, args, mode="train", |
|
prompt_sample_num=1, prompt_id=0, sample_num=-1): |
|
super().__init__(args) |
|
|
|
self.mode = mode |
|
self.prompt_sample_num = prompt_sample_num |
|
self.prompt_id = prompt_id |
|
self.sample_num = sample_num |
|
|
|
self.prompts = all_prompt["fusionseqrec"] |
|
|
|
|
|
self._load_data() |
|
|
|
|
|
|
|
if self.mode == 'train': |
|
self.inter_data = self._process_train_data() |
|
elif self.mode == 'valid': |
|
self.sample_valid = args.sample_valid |
|
self.valid_prompt_id = args.valid_prompt_id |
|
self.inter_data = self._process_valid_data() |
|
self._construct_valid_text() |
|
elif self.mode == 'test': |
|
self.inter_data = self._process_test_data() |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def _load_data(self): |
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: |
|
self.inters = json.load(f) |
|
with open(self.index_file, 'r') as f: |
|
self.indices = json.load(f) |
|
with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f: |
|
self.item_feat = json.load(f) |
|
|
|
def _process_train_data(self): |
|
|
|
inter_data = [] |
|
for uid in self.inters: |
|
items = self.inters[uid][:-2] |
|
for i in range(1, len(items)): |
|
one_data = dict() |
|
|
|
one_data["item"] = "".join(self.indices[str(items[i])]) |
|
one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`") |
|
one_data["description"] = self.item_feat[str(items[i])]["description"] |
|
history = items[:i] |
|
if self.max_his_len > 0: |
|
history = history[-self.max_his_len:] |
|
inters = ["".join(self.indices[str(j)]) for j in history] |
|
inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] |
|
|
|
|
|
if self.add_prefix: |
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
|
inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] |
|
|
|
one_data["inters"] = self.his_sep.join(inters) |
|
one_data["inter_titles"] = self.his_sep.join(inter_titles) |
|
inter_data.append(one_data) |
|
|
|
if self.sample_num > 0: |
|
all_inter_idx = range(len(inter_data)) |
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
|
inter_data = np.array(inter_data)[sample_idx].tolist() |
|
|
|
return inter_data |
|
|
|
def _process_valid_data(self): |
|
|
|
inter_data = [] |
|
for uid in self.inters: |
|
items = self.inters[uid] |
|
one_data = dict() |
|
one_data["item"] = "".join(self.indices[str(items[-2])]) |
|
one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`") |
|
one_data["description"] = self.item_feat[str(items[-2])]["description"] |
|
|
|
|
|
history = items[:-2] |
|
if self.max_his_len > 0: |
|
history = history[-self.max_his_len:] |
|
inters = ["".join(self.indices[str(j)]) for j in history] |
|
inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] |
|
|
|
if self.add_prefix: |
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
|
inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] |
|
|
|
one_data["inters"] = self.his_sep.join(inters) |
|
one_data["inter_titles"] = self.his_sep.join(inter_titles) |
|
inter_data.append(one_data) |
|
|
|
if self.sample_num > 0: |
|
all_inter_idx = range(len(inter_data)) |
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
|
inter_data = np.array(inter_data)[sample_idx].tolist() |
|
|
|
return inter_data |
|
|
|
def _process_test_data(self): |
|
|
|
inter_data = [] |
|
for uid in self.inters: |
|
items = self.inters[uid] |
|
one_data = dict() |
|
one_data["item"] = "".join(self.indices[str(items[-1])]) |
|
one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`") |
|
one_data["description"] = self.item_feat[str(items[-1])]["description"] |
|
|
|
history = items[:-1] |
|
if self.max_his_len > 0: |
|
history = history[-self.max_his_len:] |
|
inters = ["".join(self.indices[str(j)]) for j in history] |
|
inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] |
|
|
|
if self.add_prefix: |
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
|
inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] |
|
|
|
one_data["inters"] = self.his_sep.join(inters) |
|
one_data["inter_titles"] = self.his_sep.join(inter_titles) |
|
inter_data.append(one_data) |
|
|
|
if self.sample_num > 0: |
|
all_inter_idx = range(len(inter_data)) |
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
|
inter_data = np.array(inter_data)[sample_idx].tolist() |
|
|
|
return inter_data |
|
|
|
def set_prompt(self, prompt_id): |
|
|
|
self.prompt_id = prompt_id |
|
|
|
def __len__(self): |
|
if self.mode == 'train': |
|
return len(self.inter_data) * self.prompt_sample_num |
|
elif self.mode == 'valid': |
|
return len(self.valid_text_data) |
|
elif self.mode == 'test': |
|
return len(self.inter_data) |
|
else: |
|
raise NotImplementedError |
|
|
|
def _construct_valid_text(self): |
|
self.valid_text_data = [] |
|
if self.sample_valid: |
|
all_prompt_ids = range(len(self.prompts)) |
|
for i in range(len(self.inter_data)): |
|
d = self.inter_data[i] |
|
prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) |
|
for prompt_id in prompt_ids: |
|
prompt = self.prompts[prompt_id] |
|
input, output = self._get_text_data(d, prompt) |
|
self.valid_text_data.append({"input_ids": input, "labels": output}) |
|
else: |
|
self.prompt_sample_num = 1 |
|
prompt = self.prompts[self.valid_prompt_id] |
|
for i in range(len(self.inter_data)): |
|
d = self.inter_data[i] |
|
input, output = self._get_text_data(d, prompt) |
|
self.valid_text_data.append({"input_ids": input, "labels": output}) |
|
|
|
def _get_text_data(self, data, prompt): |
|
|
|
instruction = prompt["instruction"].format(**data) |
|
response = prompt["response"].format(**data) |
|
|
|
input = sft_prompt.format(instruction=instruction, response="") |
|
output = sft_prompt.format(instruction=instruction, response=response) |
|
|
|
if self.mode == 'test': |
|
return input, response |
|
|
|
return input, output |
|
|
|
def __getitem__(self, index): |
|
|
|
if self.mode == 'valid': |
|
return self.valid_text_data[index] |
|
|
|
idx = index // self.prompt_sample_num |
|
d = self.inter_data[idx] |
|
|
|
if self.mode == 'train': |
|
prompt_id = random.randint(0, len(self.prompts) - 1) |
|
elif self.mode == 'test': |
|
prompt_id = self.prompt_id |
|
|
|
prompt = self.prompts[prompt_id] |
|
|
|
input, output = self._get_text_data(d, prompt) |
|
|
|
|
|
return dict(input_ids=input, labels=output) |
|
|
|
|
|
class ItemFeatDataset(BaseDataset): |
|
|
|
def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1): |
|
super().__init__(args) |
|
|
|
self.task = task.lower() |
|
self.prompt_sample_num = prompt_sample_num |
|
self.sample_num = sample_num |
|
|
|
self.prompts = all_prompt[self.task] |
|
|
|
|
|
self._load_data() |
|
self.feat_data = self._process_data() |
|
|
|
|
|
|
|
def _load_data(self): |
|
|
|
with open(self.index_file, 'r') as f: |
|
self.indices = json.load(f) |
|
with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f: |
|
self.item_feat = json.load(f) |
|
|
|
|
|
def _process_data(self): |
|
|
|
feat_data = [] |
|
for iid in self.item_feat: |
|
feat = self.item_feat[iid] |
|
index = "".join(self.indices[iid]) |
|
feat["item"] = index |
|
feat["title"] = feat["title"].strip().strip(".!?,;:`") |
|
feat_data.append(feat) |
|
|
|
if self.sample_num > 0: |
|
all_idx = range(len(feat_data)) |
|
sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) |
|
|
|
feat_data = np.array(feat_data)[sample_idx].tolist() |
|
|
|
return feat_data |
|
|
|
|
|
def __len__(self): |
|
return len(self.feat_data) * self.prompt_sample_num |
|
|
|
def _get_text_data(self, data, prompt): |
|
|
|
instruction = prompt["instruction"].format(**data) |
|
response = prompt["response"].format(**data) |
|
|
|
input = sft_prompt.format(instruction = instruction, response = "") |
|
output = sft_prompt.format(instruction = instruction, response = response) |
|
|
|
return input, output |
|
|
|
def __getitem__(self, index): |
|
|
|
idx = index // self.prompt_sample_num |
|
d = self.feat_data[idx] |
|
|
|
prompt_id = random.randint(0, len(self.prompts) - 1) |
|
|
|
prompt = self.prompts[prompt_id] |
|
|
|
input, output = self._get_text_data(d, prompt) |
|
|
|
return dict(input_ids=input, labels=output) |
|
|
|
|
|
class ItemSearchDataset(BaseDataset): |
|
|
|
def __init__(self, args, mode="train", |
|
prompt_sample_num=1, prompt_id=0, sample_num=-1): |
|
super().__init__(args) |
|
|
|
self.mode = mode |
|
self.prompt_sample_num = prompt_sample_num |
|
self.prompt_id = prompt_id |
|
self.sample_num = sample_num |
|
|
|
self.prompts = all_prompt["itemsearch"] |
|
|
|
|
|
self._load_data() |
|
self.search_data = self._process_data() |
|
|
|
|
|
|
|
def _load_data(self): |
|
|
|
with open(self.index_file, 'r') as f: |
|
self.indices = json.load(f) |
|
with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: |
|
self.user_info = json.load(f) |
|
|
|
|
|
def _process_data(self): |
|
|
|
search_data = [] |
|
user_explicit_preference = self.user_info["user_explicit_preference"] |
|
user_vague_intention = self.user_info["user_vague_intention"] |
|
if self.mode == 'train': |
|
user_vague_intention = user_vague_intention["train"] |
|
elif self.mode == 'test': |
|
user_vague_intention = user_vague_intention["test"] |
|
else: |
|
raise NotImplementedError |
|
|
|
for uid in user_explicit_preference.keys(): |
|
one_data = {} |
|
user_ep = user_explicit_preference[uid] |
|
user_vi = user_vague_intention[uid]["querys"] |
|
one_data["explicit_preferences"] = user_ep |
|
one_data["user_related_intention"] = user_vi[0] |
|
one_data["item_related_intention"] = user_vi[1] |
|
|
|
iid = user_vague_intention[uid]["item"] |
|
inters = user_vague_intention[uid]["inters"] |
|
|
|
index = "".join(self.indices[str(iid)]) |
|
one_data["item"] = index |
|
|
|
if self.max_his_len > 0: |
|
inters = inters[-self.max_his_len:] |
|
inters = ["".join(self.indices[str(i)]) for i in inters] |
|
if self.add_prefix: |
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
|
|
|
one_data["inters"] = self.his_sep.join(inters) |
|
|
|
search_data.append(one_data) |
|
|
|
if self.sample_num > 0: |
|
all_idx = range(len(search_data)) |
|
sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) |
|
|
|
search_data = np.array(search_data)[sample_idx].tolist() |
|
|
|
return search_data |
|
|
|
def set_prompt(self, prompt_id): |
|
self.prompt_id = prompt_id |
|
|
|
def __len__(self): |
|
if self.mode == 'train': |
|
return len(self.search_data) * self.prompt_sample_num |
|
elif self.mode == 'test': |
|
return len(self.search_data) |
|
else: |
|
return len(self.search_data) |
|
|
|
|
|
def _get_text_data(self, data, prompt): |
|
|
|
instruction = prompt["instruction"].format(**data) |
|
response = prompt["response"].format(**data) |
|
|
|
input = sft_prompt.format(instruction = instruction, response = "") |
|
output = sft_prompt.format(instruction = instruction, response = response) |
|
|
|
if self.mode == 'test': |
|
return input, response |
|
|
|
return input, output |
|
|
|
def __getitem__(self, index): |
|
|
|
idx = index // self.prompt_sample_num |
|
|
|
d = self.search_data[idx] |
|
if self.mode == 'train': |
|
prompt_id = random.randint(0, len(self.prompts) - 1) |
|
elif self.mode == 'test': |
|
prompt_id = self.prompt_id |
|
|
|
prompt = self.prompts[prompt_id] |
|
|
|
d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) |
|
all_querys = [d["user_related_intention"], d["item_related_intention"]] |
|
d["query"] = random.choice(all_querys) |
|
|
|
input, output = self._get_text_data(d, prompt) |
|
|
|
return dict(input_ids=input, labels=output) |
|
|
|
|
|
|
|
class PreferenceObtainDataset(BaseDataset): |
|
|
|
def __init__(self, args, prompt_sample_num=1, sample_num=-1): |
|
super().__init__(args) |
|
|
|
self.prompt_sample_num = prompt_sample_num |
|
self.sample_num = sample_num |
|
|
|
self.prompts = all_prompt["preferenceobtain"] |
|
|
|
|
|
self._load_data() |
|
self._remap_items() |
|
|
|
self.preference_data = self._process_data() |
|
|
|
|
|
|
|
def _load_data(self): |
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: |
|
self.user_info = json.load(f) |
|
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: |
|
self.inters = json.load(f) |
|
with open(self.index_file, 'r') as f: |
|
self.indices = json.load(f) |
|
|
|
|
|
def _remap_items(self): |
|
|
|
self.remapped_inters = dict() |
|
for uid, items in self.inters.items(): |
|
new_items = ["".join(self.indices[str(i)]) for i in items] |
|
self.remapped_inters[uid] = new_items |
|
|
|
def _process_data(self): |
|
|
|
preference_data = [] |
|
user_explicit_preference = self.user_info["user_explicit_preference"] |
|
|
|
for uid in user_explicit_preference.keys(): |
|
one_data = {} |
|
inters = self.remapped_inters[uid][:-3] |
|
user_ep = user_explicit_preference[uid] |
|
|
|
if self.max_his_len > 0: |
|
inters = inters[-self.max_his_len:] |
|
if self.add_prefix: |
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] |
|
|
|
one_data["explicit_preferences"] = user_ep |
|
one_data["inters"] = self.his_sep.join(inters) |
|
|
|
preference_data.append(one_data) |
|
|
|
if self.sample_num > 0: |
|
all_idx = range(len(preference_data)) |
|
sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) |
|
|
|
preference_data = np.array(preference_data)[sample_idx].tolist() |
|
|
|
return preference_data |
|
|
|
def set_prompt(self, prompt_id): |
|
self.prompt_id = prompt_id |
|
|
|
def __len__(self): |
|
return len(self.preference_data) * self.prompt_sample_num |
|
|
|
|
|
def _get_text_data(self, data, prompt): |
|
|
|
instruction = prompt["instruction"].format(**data) |
|
response = prompt["response"].format(**data) |
|
|
|
input = sft_prompt.format(instruction = instruction, response = "") |
|
output = sft_prompt.format(instruction = instruction, response = response) |
|
|
|
return input, output |
|
|
|
def __getitem__(self, index): |
|
|
|
idx = index // self.prompt_sample_num |
|
|
|
d = self.preference_data[idx] |
|
prompt_id = random.randint(0, len(self.prompts) - 1) |
|
|
|
prompt = self.prompts[prompt_id] |
|
|
|
d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) |
|
|
|
input, output = self._get_text_data(d, prompt) |
|
|
|
return dict(input_ids=input, labels=output) |
|
|
|
|
|
|
|
|
|
|
|
class SeqRecTestDataset(BaseDataset): |
|
|
|
def __init__(self, args, prompt_id=0, sample_num=-1): |
|
super().__init__(args) |
|
|
|
self.prompt_id = prompt_id |
|
self.sample_num = sample_num |
|
|
|
self.prompt = all_prompt["seqrec"][self.prompt_id] |
|
|
|
|
|
self._load_data() |
|
self._remap_items() |
|
|
|
self.inter_data = self._process_test_data() |
|
|
|
def _load_data(self): |
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: |
|
self.inters = json.load(f) |
|
with open(self.index_file, 'r') as f: |
|
self.indices = json.load(f) |
|
|
|
|
|
def _remap_items(self): |
|
|
|
self.remapped_inters = dict() |
|
for uid, items in self.inters.items(): |
|
new_items = ["".join(self.indices[str(i)]) for i in items] |
|
self.remapped_inters[uid] = new_items |
|
|
|
def _process_test_data(self): |
|
|
|
inter_data = [] |
|
for uid in self.remapped_inters: |
|
items = self.remapped_inters[uid] |
|
one_data = dict() |
|
|
|
one_data["item"] = items[-1] |
|
history = items[:-1] |
|
if self.max_his_len > 0: |
|
history = history[-self.max_his_len:] |
|
if self.add_prefix: |
|
history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] |
|
one_data["inters"] = self.his_sep.join(history) |
|
inter_data.append(one_data) |
|
|
|
if self.sample_num > 0: |
|
all_inter_idx = range(len(inter_data)) |
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) |
|
|
|
inter_data = np.array(inter_data)[sample_idx].tolist() |
|
|
|
return inter_data |
|
|
|
def set_prompt(self, prompt_id): |
|
self.prompt_id = prompt_id |
|
|
|
self.prompt = all_prompt["seqrec"][self.prompt_id] |
|
|
|
def __len__(self): |
|
|
|
return len(self.inter_data) |
|
|
|
def _get_text_data(self, data, prompt): |
|
|
|
instruction = prompt["instruction"].format(**data) |
|
response = prompt["response"].format(**data) |
|
|
|
input = sft_prompt.format(instruction=instruction, response="") |
|
|
|
return input, response |
|
|
|
def __getitem__(self, index): |
|
|
|
d = self.inter_data[index] |
|
input, target = self._get_text_data(d, self.prompt) |
|
|
|
return dict(input_ids=input, labels=target) |