import gradio as gr import soundfile as sf from PIL import Image import spaces from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig # Define model path model_path = "microsoft/Phi-4-multimodal-instruct" # Load model and processor processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_path, device_map="cuda", torch_dtype="auto", trust_remote_code=True, attn_implementation='flash_attention_2', ).cuda() generation_config = GenerationConfig.from_pretrained(model_path) # Define prompt structure user_prompt = '<|user|>' assistant_prompt = '<|assistant|>' prompt_suffix = '<|end|>' @spaces.GPU def process_multimodal(input_file, query): if input_file is None: return "Please upload an image or an audio file." file_type = input_file.type prompt = f"{user_prompt}<|media_1|>{query}{prompt_suffix}{assistant_prompt}" if "image" in file_type: image = Image.open(input_file) inputs = processor(text=prompt, images=image, return_tensors='pt').to('cuda:0') elif "audio" in file_type: audio, samplerate = sf.read(input_file.name) inputs = processor(text=prompt, audios=[(audio, samplerate)], return_tensors='pt').to('cuda:0') else: return "Unsupported file format. Please upload an image or audio file." generate_ids = model.generate( **inputs, max_new_tokens=1000, generation_config=generation_config, num_logits_to_keep=0, ) generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] response = processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return response with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # Phi-4 Multimodal Chat Upload an image or an audio file and ask questions related to it! """) with gr.Row(): with gr.Column(): input_file = gr.File(label="Upload Image or Audio") query = gr.Textbox(label="Ask a question") submit_btn = gr.Button("Submit") with gr.Column(): output = gr.Textbox(label="Response", interactive=False) submit_btn.click(process_multimodal, inputs=[input_file, query], outputs=output) demo.launch()