#!/usr/bin/env python3 | |
from hubert_for_sequence_classification import FlaxHubertForSequenceClassification, FlaxHubertModel | |
import numpy as np | |
# need to do some ugly save/reload because of a bug: https://github.com/huggingface/transformers/issues/12532 | |
model = FlaxHubertModel.from_pretrained("facebook/hubert-large-ll60k", from_pt=True) | |
model.save_pretrained("./") | |
model = FlaxHubertForSequenceClassification.from_pretrained("./") | |
dummy_input = np.array(2 * [1024 * [1.0]], dtype=np.float32) | |
logits = model(dummy_input).logits | |
# output shape is (batch_size, 2) | |
print("output shape", logits.shape) | |