Spaces:
Running
Running
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from shared import CustomTokens, device | |
| from functools import lru_cache | |
| import pickle | |
| import os | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| class ModelArguments: | |
| """ | |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. | |
| """ | |
| model_name_or_path: str = field( | |
| default=None, | |
| # default='google/t5-v1_1-small', # t5-small | |
| metadata={ | |
| 'help': 'Path to pretrained model or model identifier from huggingface.co/models' | |
| } | |
| ) | |
| # config_name: Optional[str] = field( # TODO remove? | |
| # default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'} | |
| # ) | |
| # tokenizer_name: Optional[str] = field( | |
| # default=None, metadata={ | |
| # 'help': 'Pretrained tokenizer name or path if not the same as model_name' | |
| # } | |
| # ) | |
| cache_dir: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| 'help': 'Where to store the pretrained models downloaded from huggingface.co' | |
| }, | |
| ) | |
| use_fast_tokenizer: bool = field( # TODO remove? | |
| default=True, | |
| metadata={ | |
| 'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.' | |
| }, | |
| ) | |
| model_revision: str = field( # TODO remove? | |
| default='main', | |
| metadata={ | |
| 'help': 'The specific model version to use (can be a branch name, tag name or commit id).' | |
| }, | |
| ) | |
| use_auth_token: bool = field( | |
| default=False, | |
| metadata={ | |
| 'help': 'Will use the token generated when running `transformers-cli login` (necessary to use this script ' | |
| 'with private models).' | |
| }, | |
| ) | |
| resize_position_embeddings: Optional[bool] = field( | |
| default=None, | |
| metadata={ | |
| 'help': "Whether to automatically resize the position embeddings if `max_source_length` exceeds the model's position embeddings." | |
| }, | |
| ) | |
| def get_classifier_vectorizer(classifier_args): | |
| classifier_path = os.path.join( | |
| classifier_args.classifier_dir, classifier_args.classifier_file) | |
| with open(classifier_path, 'rb') as fp: | |
| classifier = pickle.load(fp) | |
| vectorizer_path = os.path.join( | |
| classifier_args.classifier_dir, classifier_args.vectorizer_file) | |
| with open(vectorizer_path, 'rb') as fp: | |
| vectorizer = pickle.load(fp) | |
| return classifier, vectorizer | |
| def get_model_tokenizer(model_name_or_path, cache_dir=None): | |
| if model_name_or_path is None: | |
| raise ValueError('Invalid model_name_or_path.') | |
| # Load pretrained model and tokenizer | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, cache_dir=cache_dir) | |
| model.to(device()) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name_or_path, max_length=model.config.d_model, cache_dir=cache_dir) | |
| # Ensure model and tokenizer contain the custom tokens | |
| CustomTokens.add_custom_tokens(tokenizer) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| # TODO add this back: means that different models will have different training data | |
| # Currently we only send 512 tokens to the model each time... | |
| # Adjust based on dimensions of model | |
| # tokenizer.model_max_length = model.config.d_model | |
| return model, tokenizer | |