|
|
""" |
|
|
Copyright $today.year LY Corporation |
|
|
LY Corporation licenses this file to you under the Apache License, |
|
|
version 2.0 (the "License"); you may not use this file except in compliance |
|
|
with the License. You may obtain a copy of the License at: |
|
|
https://www.apache.org/licenses/LICENSE-2.0 |
|
|
Unless required by applicable law or agreed to in writing, software |
|
|
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
|
|
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
|
|
License for the specific language governing permissions and limitations |
|
|
under the License. |
|
|
Moment-DETR (https://github.com/jayleicn/moment_detr) |
|
|
Copyright (c) 2021 Jie Lei |
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
|
of this software and associated documentation files (the "Software"), to deal |
|
|
in the Software without restriction, including without limitation the rights |
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
|
copies of the Software, and to permit persons to whom the Software is |
|
|
furnished to do so, subject to the following conditions: |
|
|
The above copyright notice and this permission notice shall be included in all |
|
|
copies or substantial portions of the Software. |
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
|
SOFTWARE. |
|
|
""" |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import torch |
|
|
from transformers import PreTrainedModel |
|
|
from lighthouse.feature_extractor.audio_encoder import AudioEncoder |
|
|
from lighthouse.feature_extractor.text_encoder import TextEncoder |
|
|
from lighthouse.models import BasePredictor |
|
|
|
|
|
from .configuration_amdetr import AMDETRConfig |
|
|
|
|
|
|
|
|
class AMDETRPredictorWrapper(BasePredictor, PreTrainedModel): |
|
|
config_class = AMDETRConfig |
|
|
|
|
|
def __init__(self, config: AMDETRConfig, feature_name: str="clap") -> None: |
|
|
PreTrainedModel.__init__(self, config) |
|
|
args = config |
|
|
self._clip_len: float = args.clip_length |
|
|
self._device: str = args.device |
|
|
self._size = 224 |
|
|
self._moment_num = 10 |
|
|
|
|
|
self._model: torch.nn.Module = self._initialize_model(args, args.model_name) |
|
|
self._model.eval() |
|
|
|
|
|
self._feature_name: str = feature_name |
|
|
self._model_name: str = args.model_name |
|
|
|
|
|
def load_encoders(self) -> None: |
|
|
self._vision_encoder = None |
|
|
self._audio_encoder: AudioEncoder = self._initialize_audio_encoder(self._feature_name, pann_path=None) |
|
|
self._text_encoder: TextEncoder = self._initialize_text_encoder(self._feature_name) |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_audio(self, audio_path: str) -> Dict[str, torch.Tensor]: |
|
|
if not hasattr(self, "_audio_encoder") or not hasattr(self, "_text_encoder"): |
|
|
self.load_encoders() |
|
|
return super().encode_audio(audio_path) |
|
|
|