Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -2,6 +2,22 @@ import gradio as gr | |
| 2 | 
             
            from gradio_client import Client
         | 
| 3 | 
             
            import json
         | 
| 4 | 
             
            import re
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
             
            def get_caption_from_kosmos(image_in):
         | 
| 7 | 
             
                kosmos2_client = Client("https://ydshieh-kosmos-2.hf.space/")
         | 
| @@ -75,7 +91,7 @@ def get_magnet(prompt): | |
| 75 | 
             
                    api_name="/predict_full"
         | 
| 76 | 
             
                )
         | 
| 77 | 
             
                print(result)
         | 
| 78 | 
            -
                return result[ | 
| 79 |  | 
| 80 | 
             
            def get_audioldm(prompt):
         | 
| 81 | 
             
                client = Client("https://haoheliu-audioldm2-text2audio-text2music.hf.space/")
         | 
| @@ -89,13 +105,29 @@ def get_audioldm(prompt): | |
| 89 | 
             
                    fn_index=1
         | 
| 90 | 
             
                )
         | 
| 91 | 
             
                print(result)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 92 | 
             
                return result
         | 
| 93 |  | 
| 94 | 
            -
            def infer(image_in):
         | 
| 95 | 
             
                caption = get_caption(image_in)
         | 
| 96 | 
            -
                 | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 99 |  | 
| 100 | 
             
            css="""
         | 
| 101 | 
             
            #col-container{
         | 
| @@ -117,13 +149,15 @@ with gr.Blocks(css=css) as demo: | |
| 117 |  | 
| 118 | 
             
                    with gr.Column():
         | 
| 119 | 
             
                        image_in = gr.Image(sources=["upload"], type="filepath", label="Image input", value="oiseau.png")
         | 
|  | |
| 120 | 
             
                        submit_btn = gr.Button("Submit")
         | 
| 121 | 
             
                    with gr.Row():
         | 
| 122 | 
            -
                         | 
| 123 | 
            -
             | 
| 124 | 
             
                submit_btn.click(
         | 
| 125 | 
             
                    fn=infer,
         | 
| 126 | 
            -
                    inputs=[image_in],
         | 
| 127 | 
            -
                    outputs=[ | 
| 128 | 
             
                )
         | 
|  | |
| 129 | 
             
            demo.queue(max_size=10).launch(debug=True)
         | 
|  | |
| 2 | 
             
            from gradio_client import Client
         | 
| 3 | 
             
            import json
         | 
| 4 | 
             
            import re
         | 
| 5 | 
            +
            from moviepy.editor import VideoFileClip
         | 
| 6 | 
            +
            from moviepy.audio.AudioClip import AudioClip
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            def extract_audio(video_in):
         | 
| 9 | 
            +
                input_video = video_in
         | 
| 10 | 
            +
                output_audio = 'audio.wav'
         | 
| 11 | 
            +
                
         | 
| 12 | 
            +
                # Open the video file and extract the audio
         | 
| 13 | 
            +
                video_clip = VideoFileClip(input_video)
         | 
| 14 | 
            +
                audio_clip = video_clip.audio
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
                # Save the audio as a .wav file
         | 
| 17 | 
            +
                audio_clip.write_audiofile(output_audio, fps=44100)  # Use 44100 Hz as the sample rate for .wav files  
         | 
| 18 | 
            +
                print("Audio extraction complete.")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                return 'audio.wav'
         | 
| 21 |  | 
| 22 | 
             
            def get_caption_from_kosmos(image_in):
         | 
| 23 | 
             
                kosmos2_client = Client("https://ydshieh-kosmos-2.hf.space/")
         | 
|  | |
| 91 | 
             
                    api_name="/predict_full"
         | 
| 92 | 
             
                )
         | 
| 93 | 
             
                print(result)
         | 
| 94 | 
            +
                return result[1]
         | 
| 95 |  | 
| 96 | 
             
            def get_audioldm(prompt):
         | 
| 97 | 
             
                client = Client("https://haoheliu-audioldm2-text2audio-text2music.hf.space/")
         | 
|  | |
| 105 | 
             
                    fn_index=1
         | 
| 106 | 
             
                )
         | 
| 107 | 
             
                print(result)
         | 
| 108 | 
            +
                audio_result = extract_audio(result)
         | 
| 109 | 
            +
                return audio_result
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            def get_audiogen(prompt):
         | 
| 112 | 
            +
                client = Client("https://fffiloni-audiogen.hf.space/")
         | 
| 113 | 
            +
                result = client.predict(
         | 
| 114 | 
            +
                    prompt,
         | 
| 115 | 
            +
                    10,
         | 
| 116 | 
            +
                    api_name="/infer"
         | 
| 117 | 
            +
                )
         | 
| 118 | 
             
                return result
         | 
| 119 |  | 
| 120 | 
            +
            def infer(image_in, chosen_model):
         | 
| 121 | 
             
                caption = get_caption(image_in)
         | 
| 122 | 
            +
                if chosen_model == "MAGNet" :
         | 
| 123 | 
            +
                    magnet_result = get_magnet(caption)
         | 
| 124 | 
            +
                    return magnet_result
         | 
| 125 | 
            +
                elif chosen_model == "AudioLDM-2" : 
         | 
| 126 | 
            +
                    audioldm_result = get_audioldm(caption)
         | 
| 127 | 
            +
                    return audioldm_result
         | 
| 128 | 
            +
                elif chosen_model == "AudioGen" :
         | 
| 129 | 
            +
                    audiogen_result = get_audiogen(caption)
         | 
| 130 | 
            +
                    return audiogen_result
         | 
| 131 |  | 
| 132 | 
             
            css="""
         | 
| 133 | 
             
            #col-container{
         | 
|  | |
| 149 |  | 
| 150 | 
             
                    with gr.Column():
         | 
| 151 | 
             
                        image_in = gr.Image(sources=["upload"], type="filepath", label="Image input", value="oiseau.png")
         | 
| 152 | 
            +
                        chosen_model = gr.Radio(label="Choose a model", choices=["MAGNet", "AudioLDM-2", "AudioGen"], value="AudioLDM-2")
         | 
| 153 | 
             
                        submit_btn = gr.Button("Submit")
         | 
| 154 | 
             
                    with gr.Row():
         | 
| 155 | 
            +
                        audio_o = gr.Audio(label="Audio output")
         | 
| 156 | 
            +
                
         | 
| 157 | 
             
                submit_btn.click(
         | 
| 158 | 
             
                    fn=infer,
         | 
| 159 | 
            +
                    inputs=[image_in, chosen_model],
         | 
| 160 | 
            +
                    outputs=[audio_o]
         | 
| 161 | 
             
                )
         | 
| 162 | 
            +
             | 
| 163 | 
             
            demo.queue(max_size=10).launch(debug=True)
         | 
