Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Commit 
							
							·
						
						a629fc9
	
1
								Parent(s):
							
							ad51a72
								
Update whisper/inference.py
Browse files- whisper/inference.py +36 -1
    	
        whisper/inference.py
    CHANGED
    
    | @@ -3,6 +3,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| 3 | 
             
            import numpy as np
         | 
| 4 | 
             
            import argparse
         | 
| 5 | 
             
            import torch
         | 
|  | |
|  | |
| 6 |  | 
| 7 | 
             
            from whisper.model import Whisper, ModelDimensions
         | 
| 8 | 
             
            from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram
         | 
| @@ -29,6 +31,37 @@ def load_model(path, device) -> Whisper: | |
| 29 | 
             
                return model
         | 
| 30 |  | 
| 31 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 32 | 
             
            def pred_ppg(whisper: Whisper, wavPath, ppgPath, device):
         | 
| 33 | 
             
                audio = load_audio(wavPath)
         | 
| 34 | 
             
                audln = audio.shape[0]
         | 
| @@ -74,5 +107,7 @@ if __name__ == "__main__": | |
| 74 | 
             
                ppgPath = args.ppg
         | 
| 75 |  | 
| 76 | 
             
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 77 | 
            -
             | 
|  | |
|  | |
| 78 | 
             
                pred_ppg(whisper, wavPath, ppgPath, device)
         | 
|  | |
| 3 | 
             
            import numpy as np
         | 
| 4 | 
             
            import argparse
         | 
| 5 | 
             
            import torch
         | 
| 6 | 
            +
            import requests
         | 
| 7 | 
            +
            from tqdm import tqdm
         | 
| 8 |  | 
| 9 | 
             
            from whisper.model import Whisper, ModelDimensions
         | 
| 10 | 
             
            from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram
         | 
|  | |
| 31 | 
             
                return model
         | 
| 32 |  | 
| 33 |  | 
| 34 | 
            +
            def check_and_download_model():
         | 
| 35 | 
            +
                temp_dir = "/tmp"
         | 
| 36 | 
            +
                model_path = os.path.join(temp_dir, "large-v2.pt")
         | 
| 37 | 
            +
                
         | 
| 38 | 
            +
                if os.path.exists(model_path):
         | 
| 39 | 
            +
                    return f"モデルは既に存在します: {model_path}"
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
                url = "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt"
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                try:
         | 
| 44 | 
            +
                    response = requests.get(url, stream=True)
         | 
| 45 | 
            +
                    response.raise_for_status()
         | 
| 46 | 
            +
                    total_size = int(response.headers.get('content-length', 0))
         | 
| 47 | 
            +
                    
         | 
| 48 | 
            +
                    with open(model_path, 'wb') as f, tqdm(
         | 
| 49 | 
            +
                        desc=model_path,
         | 
| 50 | 
            +
                        total=total_size,
         | 
| 51 | 
            +
                        unit='iB',
         | 
| 52 | 
            +
                        unit_scale=True,
         | 
| 53 | 
            +
                        unit_divisor=1024,
         | 
| 54 | 
            +
                    ) as pbar:
         | 
| 55 | 
            +
                        for data in response.iter_content(chunk_size=1024):
         | 
| 56 | 
            +
                            size = f.write(data)
         | 
| 57 | 
            +
                            pbar.update(size)
         | 
| 58 | 
            +
                            
         | 
| 59 | 
            +
                    return f"モデルのダウンロードが完了しました: {model_path}"
         | 
| 60 | 
            +
                    
         | 
| 61 | 
            +
                except Exception as e:
         | 
| 62 | 
            +
                    return f"エラーが発生しました: {e}"
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
             
            def pred_ppg(whisper: Whisper, wavPath, ppgPath, device):
         | 
| 66 | 
             
                audio = load_audio(wavPath)
         | 
| 67 | 
             
                audln = audio.shape[0]
         | 
|  | |
| 107 | 
             
                ppgPath = args.ppg
         | 
| 108 |  | 
| 109 | 
             
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                _ =check_and_download_model()
         | 
| 112 | 
            +
                whisper = load_model("/tmp/large-v2.pt", device)
         | 
| 113 | 
             
                pred_ppg(whisper, wavPath, ppgPath, device)
         |