# coding=utf-8 # Copyright 2022 IDEA-CCNL The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Della model. """ import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Tuple from transformers.modeling_outputs import ModelOutput from transformers.modeling_utils import PreTrainedModel from fengshen.models.deepVAE.configuration_della import DellaModelConfig from fengshen.models.deepVAE.latent_connector import GPT2ForDecoderLatentConnector, GPT2ForEncoderLatentConnector from fengshen.models.deepVAE.utils import connect, compute_kl_loss, top_k_top_p_filtering, enforce_repetition_penalty _CHECKPOINT_FOR_DOC = "della-226M-base" _CONFIG_FOR_DOC = "DellaModelConfig" _TOKENIZER_FOR_DOC = "BertTokenizer" Della_model_PRETRAINED_MODEL_ARCHIVE_LIST = [ "della-226M-base" ] @dataclass class DellaModelOutput(ModelOutput): logits: torch.FloatTensor = None posterior_latents: Optional[Tuple[torch.FloatTensor]] = None prior_latent: Optional[Tuple[torch.FloatTensor]] = None class latent_layer(nn.Module): def __init__(self, input_dim) -> None: super().__init__() self.W_hh = nn.Linear(input_dim, input_dim, bias=False) self.W_ih = nn.Linear(input_dim, input_dim, bias=False) self.tanh = nn.Tanh() def forward(self, z_lt_lm1, z_lm1): # inputs are z_