mskov commited on
Commit
3bef3fb
Β·
1 Parent(s): 4afec78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -3
app.py CHANGED
@@ -4,6 +4,7 @@ os.system("pip install transformers==4.27.0")
4
  os.system("pip install numpy==1.23")
5
  from transformers import pipeline, WhisperModel, WhisperTokenizer, WhisperFeatureExtractor, AutoFeatureExtractor, AutoProcessor, WhisperConfig
6
  os.system("pip install jiwer")
 
7
  os.system("pip install datasets[audio]")
8
  from evaluate import evaluator
9
  from datasets import load_dataset, Audio, disable_caching, set_caching_enabled
@@ -15,13 +16,36 @@ huggingface_token = os.environ["huggingface_token"]
15
 
16
  model = WhisperModel.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
17
  feature_extractor = AutoFeatureExtractor.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
 
18
 
 
19
 
20
- ds = load_dataset("mskov/miso_test", split="test").cast_column("audio", Audio(sampling_rate=16000))
21
 
22
- print(ds, "and at 0 ", ds[0])
 
23
 
24
- inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  print("check check")
26
  print(inputs)
27
  input_features = inputs.input_features
@@ -29,3 +53,4 @@ decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
29
  last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
30
  list(last_hidden_state.shape)
31
  print(list(last_hidden_state.shape))
 
 
4
  os.system("pip install numpy==1.23")
5
  from transformers import pipeline, WhisperModel, WhisperTokenizer, WhisperFeatureExtractor, AutoFeatureExtractor, AutoProcessor, WhisperConfig
6
  os.system("pip install jiwer")
7
+ from jiwer import wer
8
  os.system("pip install datasets[audio]")
9
  from evaluate import evaluator
10
  from datasets import load_dataset, Audio, disable_caching, set_caching_enabled
 
16
 
17
  model = WhisperModel.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
18
  feature_extractor = AutoFeatureExtractor.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
19
+ miso_tokenizer = WhisperTokenizer.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
20
 
21
+ dataset = load_dataset("mskov/miso_test", split="test").cast_column("audio", Audio(sampling_rate=16000))
22
 
23
+ print(dataset, "and at 0 ", dataset[0])
24
 
25
+ inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
26
+ tokenized_dataset = miso_tokenizer(dataset) # Tokenize the dataset
27
 
28
+ input_ids = features.input_ids
29
+ attention_mask = features.attention_mask
30
+
31
+ # Evaluate the model
32
+ model.eval()
33
+ with torch.no_grad():
34
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
35
+
36
+ # Convert predicted token IDs back to text
37
+ predicted_text = tokenizer.batch_decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)
38
+
39
+ # Get ground truth labels from the dataset
40
+ labels = dataset["audio"] # Replace "labels" with the appropriate key in your dataset
41
+
42
+ # Compute WER
43
+ wer_score = wer(labels, predicted_text)
44
+
45
+ # Print or return WER score
46
+ print(f"Word Error Rate (WER): {wer_score}")
47
+
48
+ '''
49
  print("check check")
50
  print(inputs)
51
  input_features = inputs.input_features
 
53
  last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
54
  list(last_hidden_state.shape)
55
  print(list(last_hidden_state.shape))
56
+ '''