|
from typing import List, Literal |
|
import numpy as np |
|
|
|
|
|
def generate_parameters_with_timesteps( |
|
start: int, |
|
num: int, |
|
stop: int = None, |
|
method: Literal["linear", "two_stage", "three_stage", "fix_two_stage"] = "linear", |
|
n_fix_start: int = 3, |
|
) -> List[float]: |
|
if stop is None or start == stop: |
|
params = [start] * num |
|
else: |
|
if method == "linear": |
|
params = generate_linear_parameters(start, stop, num) |
|
elif method == "two_stage": |
|
params = generate_two_stages_parameters(start, stop, num) |
|
elif method == "three_stage": |
|
params = generate_three_stages_parameters(start, stop, num) |
|
elif method == "fix_two_stage": |
|
params = generate_fix_two_stages_parameters(start, stop, num, n_fix_start) |
|
else: |
|
raise ValueError( |
|
f"now only support linear, two_stage, three_stage, but given{method}" |
|
) |
|
return params |
|
|
|
|
|
def generate_linear_parameters(start, stop, num): |
|
parames = list( |
|
np.linspace( |
|
start=start, |
|
stop=stop, |
|
num=num, |
|
) |
|
) |
|
return parames |
|
|
|
|
|
def generate_two_stages_parameters(start, stop, num): |
|
num_start = num // 2 |
|
num_end = num - num_start |
|
parames = [start] * num_start + [stop] * num_end |
|
return parames |
|
|
|
|
|
def generate_fix_two_stages_parameters(start, stop, num, n_fix_start: int) -> List: |
|
num_start = n_fix_start |
|
num_end = num - num_start |
|
parames = [start] * num_start + [stop] * num_end |
|
return parames |
|
|
|
|
|
def generate_three_stages_parameters(start, stop, num): |
|
middle = (start + stop) // 2 |
|
num_start = num // 3 |
|
num_middle = num_start |
|
num_end = num - num_start - num_middle |
|
parames = [start] * num_start + [middle] * num_middle + [stop] * num_end |
|
return parames |
|
|