redPajama-3b-zAgile-base / hp_validation.py
dtorres-zAgile's picture
Upload 25 files
0be3778
import logging
from typing import Any
from typing import Dict
from typing import Union
logging.basicConfig(level=logging.INFO)
# Possible model parameters
TEXT_INPUTS = "text_inputs"
MAX_LENGTH = "max_length"
NUM_RETURN_SEQUENCES = "num_return_sequences"
NUM_BEAMS = "num_beams"
TOP_P = "top_p"
EARLY_STOPPING = "early_stopping"
DO_SAMPLE = "do_sample"
NO_REPEAT_NGRAM_SIZE = "no_repeat_ngram_size"
TOP_K = "top_k"
TEMPERATURE = "temperature"
SEED = "seed"
MIN_LENGTH = "min_length"
MIN_NEW_TOKENS = "min_new_tokens"
MAX_NEW_TOKENS = "max_new_tokens"
LENGTH_PENALTY = "length_penalty"
MAX_TIME = "max_time"
RETURN_FULL_TEXT = "return_full_text"
STOPPING_CRITERIA = "stopping_criteria"
ALL_PARAM_NAMES = [
TEXT_INPUTS,
MAX_LENGTH,
NUM_RETURN_SEQUENCES,
NUM_BEAMS,
TOP_P,
EARLY_STOPPING,
DO_SAMPLE,
NO_REPEAT_NGRAM_SIZE,
TOP_K,
TEMPERATURE,
SEED,
MIN_LENGTH,
MAX_NEW_TOKENS,
MIN_NEW_TOKENS,
LENGTH_PENALTY,
MAX_TIME,
RETURN_FULL_TEXT,
STOPPING_CRITERIA,
]
# Model parameter ranges
LENGTH_MIN = 1
NUM_RETURN_SEQUENCE_MIN = 1
NUM_BEAMS_MIN = 1
TOP_P_MIN = 0
TOP_P_MAX = 1
NO_REPEAT_NGRAM_SIZE_MIN = 1
TOP_K_MIN = 0
TEMPERATURE_MIN = 0
NEW_TOKENS_MIN = 0
def is_list_of_strings(parameter: Any) -> bool:
"""Return True if the parameter is a list of strings."""
if parameter and isinstance(parameter, list):
return all(isinstance(elem, str) for elem in parameter)
else:
return False
def _validate_payload(payload: Dict[str, Any]) -> None:
"""Validate the parameters in the input loads.
Checks if max_length, num_return_sequences, num_beams, top_p and temprature are in bounds.
Checks if do_sample is boolean.
Checks max_length, num_return_sequences, num_beams and seed are integers.
Args:
payload: a decoded input payload (dictionary of input parameter and values)
Raises: ValueError is any of the check fails.
"""
# For all parameters used in text2text generation task, please see
# https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
for param_name in payload:
if param_name not in ALL_PARAM_NAMES:
raise ValueError(f"Input payload contains an invalid key '{param_name}'. Valid keys are {ALL_PARAM_NAMES}.")
if TEXT_INPUTS not in payload:
raise ValueError(f"Input payload must contain {TEXT_INPUTS} key.")
for param_name in [MAX_LENGTH, NUM_RETURN_SEQUENCES, NUM_BEAMS, SEED]:
if param_name in payload:
if type(payload[param_name]) != int:
raise ValueError(f"{param_name} must be an integer, got {payload[param_name]}.")
if MAX_LENGTH in payload:
if payload[MAX_LENGTH] < LENGTH_MIN:
raise ValueError(f"{MAX_LENGTH} must be at least {LENGTH_MIN}, got {payload[MAX_LENGTH]}.")
if MIN_LENGTH in payload:
if payload[MIN_LENGTH] < LENGTH_MIN:
raise ValueError(f"{MIN_LENGTH} must be at least {LENGTH_MIN}, got {payload[MIN_LENGTH]}.")
if MAX_NEW_TOKENS in payload:
if payload[MAX_NEW_TOKENS] < NEW_TOKENS_MIN:
raise ValueError(f"{MAX_NEW_TOKENS} must be at least {NEW_TOKENS_MIN}, got {payload[MAX_NEW_TOKENS]}.")
if MIN_NEW_TOKENS in payload:
if payload[MIN_NEW_TOKENS] < NEW_TOKENS_MIN:
raise ValueError(f"{MIN_NEW_TOKENS} must be at least {NEW_TOKENS_MIN}, got {payload[MIN_NEW_TOKENS]}.")
if NUM_RETURN_SEQUENCES in payload:
if payload[NUM_RETURN_SEQUENCES] < NUM_RETURN_SEQUENCE_MIN:
raise ValueError(
f"{NUM_RETURN_SEQUENCES} must be at least {NUM_RETURN_SEQUENCE_MIN}, "
f"got {payload[NUM_RETURN_SEQUENCES]}."
)
if NUM_BEAMS in payload:
if payload[NUM_BEAMS] < NUM_BEAMS_MIN:
raise ValueError(f"{NUM_BEAMS} must be at least {NUM_BEAMS_MIN}, got {payload[NUM_BEAMS]}.")
if NUM_RETURN_SEQUENCES in payload and NUM_BEAMS in payload:
if payload[NUM_RETURN_SEQUENCES] > payload[NUM_BEAMS]:
raise ValueError(
f"{NUM_BEAMS} must be at least {NUM_RETURN_SEQUENCES}. Instead got "
f"{NUM_BEAMS}={payload[NUM_BEAMS]} and {NUM_RETURN_SEQUENCES}="
f"{payload[NUM_RETURN_SEQUENCES]}."
)
if TOP_P in payload:
if payload[TOP_P] < TOP_P_MIN or payload[TOP_P] > TOP_P_MAX:
raise ValueError(f"{TOP_K} must be in range [{TOP_P_MIN},{TOP_P_MAX}], got " f"{payload[TOP_P]}")
if TEMPERATURE in payload:
if payload[TEMPERATURE] < TEMPERATURE_MIN:
raise ValueError(
f"{TEMPERATURE} must be a float with value at least {TEMPERATURE_MIN}, got " f"{payload[TEMPERATURE]}."
)
if DO_SAMPLE in payload:
if type(payload[DO_SAMPLE]) != bool:
raise ValueError(f"{DO_SAMPLE} must be a boolean, got {payload[DO_SAMPLE]}.")
if STOPPING_CRITERIA in payload and not is_list_of_strings(payload[STOPPING_CRITERIA]):
raise ValueError(f"{payload[STOPPING_CRITERIA]} must be a list of strings, got {payload[STOPPING_CRITERIA]}")
def _update_num_beams(payload: Dict[str, Union[str, float, int]]) -> Dict[str, Union[str, float, int]]:
"""Add num_beans to the payload if missing and num_return_sequences is present.
Args:
payload (Dict): dictionary of input text and parameters
Returns:
payload (Dict): payload with number of beams updated
"""
if NUM_RETURN_SEQUENCES in payload and NUM_BEAMS not in payload:
payload[NUM_BEAMS] = payload[NUM_RETURN_SEQUENCES]
return payload