davidberenstein1957 HF staff commited on
Commit
b8a81f2
ยท
1 Parent(s): 187357b

fix distrubution of labels

Browse files
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -126,10 +126,12 @@ def generate_dataset(
126
  inputs = []
127
  for _ in range(batch_size):
128
  if multi_label:
129
- k = int(random.gammavariate(2, 2) * len(labels))
130
  else:
131
  k = 1
132
- sampled_labels = random.sample(labels, k)
 
 
133
  random.shuffle(sampled_labels)
134
  inputs.append(
135
  {
 
126
  inputs = []
127
  for _ in range(batch_size):
128
  if multi_label:
129
+ k = int(random.betavariate(alpha=2, beta=3) * len(labels))
130
  else:
131
  k = 1
132
+
133
+ print(k)
134
+ sampled_labels = random.sample(labels, min(k, len(labels)))
135
  random.shuffle(sampled_labels)
136
  inputs.append(
137
  {