|
|
|
from typing import Tuple, List |
|
from transformers import PretrainedConfig |
|
|
|
class PLPQConfig(PretrainedConfig): |
|
model_type: str = "PLPQ" |
|
def __init__(self, |
|
image_size: List[int] = [512, 512], |
|
patch_size: int = 16, |
|
dropout: float = 0.0, |
|
levels: List[int] = [8,8,8,5,5,5], |
|
num_quantizers: int = 4, |
|
num_in_channels: int = 3, |
|
num_out_channels: int = 3, |
|
use_wavelets: bool = True, |
|
encoder_blocks: List[List] = [], |
|
decoder_blocks: List[List] = [], |
|
**kwargs |
|
): |
|
self.image_size = image_size |
|
self.patch_size = patch_size |
|
self.dropout = dropout |
|
self.levels = levels |
|
self.num_quantizers = num_quantizers |
|
self.num_in_channels = num_in_channels |
|
self.num_out_channels = num_out_channels |
|
self.use_wavelets = use_wavelets |
|
self.encoder_blocks = encoder_blocks |
|
self.decoder_blocks = decoder_blocks |
|
super().__init__(**kwargs) |
|
|