Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
d34e3fe
1
Parent(s):
7dbc778
Fix training arguments dataclasses
Browse files- src/shared.py +32 -10
- src/train.py +11 -4
- src/train_classifier.py +23 -18
src/shared.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from transformers.trainer_utils import get_last_checkpoint as glc
|
| 2 |
-
from transformers import TrainingArguments
|
| 3 |
import os
|
| 4 |
from utils import re_findall
|
| 5 |
import logging
|
|
@@ -76,14 +76,15 @@ _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
|
|
| 76 |
SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def extract_sponsor_matches(texts):
|
| 80 |
-
|
| 81 |
-
for text in texts:
|
| 82 |
-
if CustomTokens.NO_SEGMENT.value in text:
|
| 83 |
-
to_return.append([])
|
| 84 |
-
else:
|
| 85 |
-
to_return.append(re_findall(SEGMENT_MATCH_RE, text))
|
| 86 |
-
return to_return
|
| 87 |
|
| 88 |
|
| 89 |
@dataclass
|
|
@@ -134,6 +135,22 @@ class DatasetArguments:
|
|
| 134 |
},
|
| 135 |
)
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def __post_init__(self):
|
| 138 |
if self.train_file is None or self.validation_file is None:
|
| 139 |
raise ValueError(
|
|
@@ -234,7 +251,7 @@ def load_datasets(dataset_args: DatasetArguments):
|
|
| 234 |
|
| 235 |
|
| 236 |
@dataclass
|
| 237 |
-
class
|
| 238 |
seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
|
| 239 |
|
| 240 |
num_train_epochs: float = field(
|
|
@@ -242,7 +259,7 @@ class CustomTrainingArguments(OutputArguments, TrainingArguments):
|
|
| 242 |
|
| 243 |
save_steps: int = field(default=5000, metadata={
|
| 244 |
'help': 'Save checkpoint every X updates steps.'})
|
| 245 |
-
eval_steps: int = field(default=
|
| 246 |
'help': 'Run an evaluation every X steps.'})
|
| 247 |
logging_steps: int = field(default=5000, metadata={
|
| 248 |
'help': 'Log every X updates steps.'})
|
|
@@ -311,6 +328,11 @@ class CustomTrainingArguments(OutputArguments, TrainingArguments):
|
|
| 311 |
)
|
| 312 |
|
| 313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
logging.basicConfig()
|
| 315 |
logger = logging.getLogger(__name__)
|
| 316 |
|
|
|
|
| 1 |
from transformers.trainer_utils import get_last_checkpoint as glc
|
| 2 |
+
from transformers import Seq2SeqTrainingArguments, TrainingArguments
|
| 3 |
import os
|
| 4 |
from utils import re_findall
|
| 5 |
import logging
|
|
|
|
| 76 |
SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
|
| 77 |
|
| 78 |
|
| 79 |
+
def extract_sponsor_matches_from_text(text):
|
| 80 |
+
if CustomTokens.NO_SEGMENT.value in text:
|
| 81 |
+
return []
|
| 82 |
+
else:
|
| 83 |
+
return re_findall(SEGMENT_MATCH_RE, text)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
def extract_sponsor_matches(texts):
|
| 87 |
+
return list(map(extract_sponsor_matches_from_text, texts))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
@dataclass
|
|
|
|
| 135 |
},
|
| 136 |
)
|
| 137 |
|
| 138 |
+
c_train_file: Optional[str] = field(
|
| 139 |
+
default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
| 140 |
+
)
|
| 141 |
+
c_validation_file: Optional[str] = field(
|
| 142 |
+
default='c_valid.json',
|
| 143 |
+
metadata={
|
| 144 |
+
'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
|
| 145 |
+
},
|
| 146 |
+
)
|
| 147 |
+
c_test_file: Optional[str] = field(
|
| 148 |
+
default='c_test.json',
|
| 149 |
+
metadata={
|
| 150 |
+
'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
|
| 151 |
+
},
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
def __post_init__(self):
|
| 155 |
if self.train_file is None or self.validation_file is None:
|
| 156 |
raise ValueError(
|
|
|
|
| 251 |
|
| 252 |
|
| 253 |
@dataclass
|
| 254 |
+
class AdditionalTrainingArguments:
|
| 255 |
seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
|
| 256 |
|
| 257 |
num_train_epochs: float = field(
|
|
|
|
| 259 |
|
| 260 |
save_steps: int = field(default=5000, metadata={
|
| 261 |
'help': 'Save checkpoint every X updates steps.'})
|
| 262 |
+
eval_steps: int = field(default=25000, metadata={
|
| 263 |
'help': 'Run an evaluation every X steps.'})
|
| 264 |
logging_steps: int = field(default=5000, metadata={
|
| 265 |
'help': 'Log every X updates steps.'})
|
|
|
|
| 328 |
)
|
| 329 |
|
| 330 |
|
| 331 |
+
@dataclass
|
| 332 |
+
class CustomTrainingArguments(OutputArguments, AdditionalTrainingArguments):
|
| 333 |
+
pass
|
| 334 |
+
|
| 335 |
+
|
| 336 |
logging.basicConfig()
|
| 337 |
logger = logging.getLogger(__name__)
|
| 338 |
|
src/train.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
from preprocess import PreprocessingDatasetArguments
|
| 2 |
from shared import (
|
| 3 |
CustomTokens,
|
|
|
|
| 4 |
prepare_datasets,
|
| 5 |
load_datasets,
|
| 6 |
CustomTrainingArguments,
|
|
@@ -17,13 +17,15 @@ from transformers import (
|
|
| 17 |
DataCollatorForSeq2Seq,
|
| 18 |
HfArgumentParser,
|
| 19 |
Seq2SeqTrainer,
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
from transformers.utils import check_min_version
|
| 23 |
from transformers.utils.versions import require_version
|
|
|
|
| 24 |
|
| 25 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 26 |
-
check_min_version('4.
|
| 27 |
require_version('datasets>=1.8.0',
|
| 28 |
'To fix: pip install -r requirements.txt')
|
| 29 |
|
|
@@ -40,6 +42,11 @@ logging.basicConfig(
|
|
| 40 |
)
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def main():
|
| 44 |
|
| 45 |
# See all possible arguments in src/transformers/training_args.py
|
|
@@ -48,8 +55,8 @@ def main():
|
|
| 48 |
|
| 49 |
hf_parser = HfArgumentParser((
|
| 50 |
ModelArguments,
|
| 51 |
-
|
| 52 |
-
|
| 53 |
))
|
| 54 |
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
| 55 |
|
|
|
|
|
|
|
| 1 |
from shared import (
|
| 2 |
CustomTokens,
|
| 3 |
+
DatasetArguments,
|
| 4 |
prepare_datasets,
|
| 5 |
load_datasets,
|
| 6 |
CustomTrainingArguments,
|
|
|
|
| 17 |
DataCollatorForSeq2Seq,
|
| 18 |
HfArgumentParser,
|
| 19 |
Seq2SeqTrainer,
|
| 20 |
+
Seq2SeqTrainingArguments,
|
| 21 |
)
|
| 22 |
|
| 23 |
from transformers.utils import check_min_version
|
| 24 |
from transformers.utils.versions import require_version
|
| 25 |
+
from dataclasses import dataclass
|
| 26 |
|
| 27 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 28 |
+
check_min_version('4.17.0')
|
| 29 |
require_version('datasets>=1.8.0',
|
| 30 |
'To fix: pip install -r requirements.txt')
|
| 31 |
|
|
|
|
| 42 |
)
|
| 43 |
|
| 44 |
|
| 45 |
+
@dataclass
|
| 46 |
+
class Seq2SeqTrainingArguments(CustomTrainingArguments, Seq2SeqTrainingArguments):
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
|
| 50 |
def main():
|
| 51 |
|
| 52 |
# See all possible arguments in src/transformers/training_args.py
|
|
|
|
| 55 |
|
| 56 |
hf_parser = HfArgumentParser((
|
| 57 |
ModelArguments,
|
| 58 |
+
DatasetArguments,
|
| 59 |
+
Seq2SeqTrainingArguments
|
| 60 |
))
|
| 61 |
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
| 62 |
|
src/train_classifier.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
| 4 |
import logging
|
| 5 |
import os
|
| 6 |
import sys
|
| 7 |
-
from dataclasses import dataclass
|
| 8 |
from typing import Optional
|
| 9 |
|
| 10 |
import datasets
|
|
@@ -16,11 +16,20 @@ from transformers import (
|
|
| 16 |
EvalPrediction,
|
| 17 |
HfArgumentParser,
|
| 18 |
Trainer,
|
|
|
|
| 19 |
set_seed,
|
| 20 |
)
|
| 21 |
from transformers.utils import check_min_version
|
| 22 |
from transformers.utils.versions import require_version
|
| 23 |
-
from shared import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
from model import get_model_tokenizer, ModelArguments
|
| 25 |
|
| 26 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
|
@@ -32,23 +41,19 @@ os.environ['WANDB_DISABLED'] = 'true'
|
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
@dataclass
|
| 36 |
class ClassifierDatasetArguments(DatasetArguments):
|
| 37 |
-
train_file: Optional[str] =
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
|
| 44 |
-
},
|
| 45 |
-
)
|
| 46 |
-
test_file: Optional[str] = field(
|
| 47 |
-
default='c_test.json',
|
| 48 |
-
metadata={
|
| 49 |
-
'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
|
| 50 |
-
},
|
| 51 |
-
)
|
| 52 |
|
| 53 |
|
| 54 |
def main():
|
|
@@ -59,7 +64,7 @@ def main():
|
|
| 59 |
hf_parser = HfArgumentParser((
|
| 60 |
ModelArguments,
|
| 61 |
ClassifierDatasetArguments,
|
| 62 |
-
|
| 63 |
))
|
| 64 |
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
| 65 |
|
|
|
|
| 4 |
import logging
|
| 5 |
import os
|
| 6 |
import sys
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
from typing import Optional
|
| 9 |
|
| 10 |
import datasets
|
|
|
|
| 16 |
EvalPrediction,
|
| 17 |
HfArgumentParser,
|
| 18 |
Trainer,
|
| 19 |
+
TrainingArguments,
|
| 20 |
set_seed,
|
| 21 |
)
|
| 22 |
from transformers.utils import check_min_version
|
| 23 |
from transformers.utils.versions import require_version
|
| 24 |
+
from shared import (
|
| 25 |
+
CATEGORIES,
|
| 26 |
+
DatasetArguments,
|
| 27 |
+
prepare_datasets,
|
| 28 |
+
load_datasets,
|
| 29 |
+
CustomTrainingArguments,
|
| 30 |
+
train_from_checkpoint,
|
| 31 |
+
get_last_checkpoint
|
| 32 |
+
)
|
| 33 |
from model import get_model_tokenizer, ModelArguments
|
| 34 |
|
| 35 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
|
|
|
| 41 |
logger = logging.getLogger(__name__)
|
| 42 |
|
| 43 |
|
| 44 |
+
@dataclass
|
| 45 |
+
class ClassifierTrainingArguments(CustomTrainingArguments, TrainingArguments):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
@dataclass
|
| 50 |
class ClassifierDatasetArguments(DatasetArguments):
|
| 51 |
+
train_file: Optional[str] = DatasetArguments.__dataclass_fields__[
|
| 52 |
+
'c_train_file']
|
| 53 |
+
validation_file: Optional[str] = DatasetArguments.__dataclass_fields__[
|
| 54 |
+
'c_validation_file']
|
| 55 |
+
test_file: Optional[str] = DatasetArguments.__dataclass_fields__[
|
| 56 |
+
'c_test_file']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
def main():
|
|
|
|
| 64 |
hf_parser = HfArgumentParser((
|
| 65 |
ModelArguments,
|
| 66 |
ClassifierDatasetArguments,
|
| 67 |
+
ClassifierTrainingArguments
|
| 68 |
))
|
| 69 |
model_args, dataset_args, training_args = hf_parser.parse_args_into_dataclasses()
|
| 70 |
|