udpate readme
Browse files
README.md
CHANGED
|
@@ -38,9 +38,19 @@ dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", spl
|
|
| 38 |
dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
|
| 39 |
sample = dataset[0]["audio"]
|
| 40 |
|
| 41 |
-
inputs = processor(
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
print(processor.decode(generated_ids[0], skip_special_tokens=True))
|
| 45 |
```
|
| 46 |
|
|
|
|
| 38 |
dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
|
| 39 |
sample = dataset[0]["audio"]
|
| 40 |
|
| 41 |
+
inputs = processor(
|
| 42 |
+
sample["array"],
|
| 43 |
+
return_tensors="pt",
|
| 44 |
+
sampling_rate=processor.feature_extractor.sampling_rate
|
| 45 |
+
)
|
| 46 |
+
inputs = inputs.to(device, torch_dtype)
|
| 47 |
+
|
| 48 |
+
# to avoid hallucination loops, we limit the maximum length of the generated text based expected number of tokens per second
|
| 49 |
+
token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate # Maximum of 6.5 tokens per second
|
| 50 |
+
seq_lens = inputs.attention_mask.sum(dim=-1)
|
| 51 |
+
max_length = int((seq_lens * token_limit_factor).max().item())
|
| 52 |
+
|
| 53 |
+
generated_ids = model.generate(**inputs, max_length=max_length)
|
| 54 |
print(processor.decode(generated_ids[0], skip_special_tokens=True))
|
| 55 |
```
|
| 56 |
|