phi4-multimodal / app.py
ariG23498's picture
ariG23498 HF staff
flash-attn
2b310a5
raw
history blame
2.4 kB
import gradio as gr
import soundfile as sf
from PIL import Image
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
# Define model path
model_path = "microsoft/Phi-4-multimodal-instruct"
# Load model and processor
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="cuda",
torch_dtype="auto",
trust_remote_code=True,
attn_implementation='flash_attention_2',
).cuda()
generation_config = GenerationConfig.from_pretrained(model_path)
# Define prompt structure
user_prompt = '<|user|>'
assistant_prompt = '<|assistant|>'
prompt_suffix = '<|end|>'
@spaces.GPU
def process_multimodal(input_file, query):
if input_file is None:
return "Please upload an image or an audio file."
file_type = input_file.type
prompt = f"{user_prompt}<|media_1|>{query}{prompt_suffix}{assistant_prompt}"
if "image" in file_type:
image = Image.open(input_file)
inputs = processor(text=prompt, images=image, return_tensors='pt').to('cuda:0')
elif "audio" in file_type:
audio, samplerate = sf.read(input_file.name)
inputs = processor(text=prompt, audios=[(audio, samplerate)], return_tensors='pt').to('cuda:0')
else:
return "Unsupported file format. Please upload an image or audio file."
generate_ids = model.generate(
**inputs,
max_new_tokens=1000,
generation_config=generation_config,
num_logits_to_keep=0,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return response
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Phi-4 Multimodal Chat
Upload an image or an audio file and ask questions related to it!
""")
with gr.Row():
with gr.Column():
input_file = gr.File(label="Upload Image or Audio")
query = gr.Textbox(label="Ask a question")
submit_btn = gr.Button("Submit")
with gr.Column():
output = gr.Textbox(label="Response", interactive=False)
submit_btn.click(process_multimodal, inputs=[input_file, query], outputs=output)
demo.launch()