soumyaprabhamaiti's picture
Initial commit
8aed5f0
raw
history blame contribute delete
776 Bytes
from dataclasses import dataclass
import os
from dotenv import load_dotenv
load_dotenv()
@dataclass
class PetSegTrainConfig:
EPOCHS = 5
BATCH_SIZE = 8
FAST_DEV_RUN = False
TOTAL_SAMPLES = 100
LEARNING_RATE = 1e-3
TRAIN_VAL_TEST_DATA_PATH = "./data/train_val_test"
DEPTHWISE_SEP = False
CHANNELS_LIST = [16, 32, 64, 128, 256]
DESCRIPTION_TEXT = None
@dataclass
class PetSegWebappConfig:
MODEL_WEIGHTS_GDRIVE_FILE_ID = os.environ.get("MODEL_WEIGHTS_GDRIVE_FILE_ID")
MODEL_WEIGHTS_LOCAL_PATH = os.environ.get(
"MODEL_WEIGHTS_LOCAL_PATH", "pet-segmentation-pytorch_epoch=4-step=1840.ckpt"
)
DOWNLOAD_MODEL_WEIGTHS_FROM_GDRIVE = (
os.environ.get("DOWNLOAD_MODEL_WEIGTHS_FROM_GDRIVE", "True") == "True"
)