Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
from torchvision import models, transforms | |
import torch.nn as nn | |
import os | |
import json | |
import cv2 | |
from PIL import Image | |
import gradio as gr | |
class MultimodalRiskBehaviorModel(nn.Module): | |
def __init__(self, text_model_name="bert-base-uncased", hidden_dim=512, dropout=0.3): | |
super(MultimodalRiskBehaviorModel, self).__init__() | |
# Text model using AutoModelForSequenceClassification | |
self.text_model_name = text_model_name | |
self.text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name, num_labels=2) | |
# Visual model (ResNet50) | |
self.visual_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
visual_feature_dim = self.visual_model.fc.in_features | |
self.visual_model.fc = nn.Identity() | |
# Fusion and classification layer setup | |
text_feature_dim = self.text_model.config.hidden_size | |
self.fc1 = nn.Linear(text_feature_dim + visual_feature_dim, hidden_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.fc2 = nn.Linear(hidden_dim, 1) | |
def forward(self, encoding, frames): | |
input_ids = encoding['input_ids'].squeeze(1).to(device) | |
attention_mask = encoding['attention_mask'].squeeze(1).to(device) | |
# Extract text and visual features | |
text_features = self.text_model(input_ids=input_ids, attention_mask=attention_mask).logits | |
frames = frames.to(device) | |
batch_size, num_frames, channels, height, width = frames.size() | |
frames = frames.view(batch_size * num_frames, channels, height, width) | |
visual_features = self.visual_model(frames) | |
visual_features = visual_features.view(batch_size, num_frames, -1).mean(dim=1) | |
# Combine and classify | |
combined_features = torch.cat((text_features, visual_features), dim=1) | |
x = self.dropout(torch.relu(self.fc1(combined_features))) | |
output = torch.sigmoid(self.fc2(x)) | |
return output | |
def save_pretrained(self, save_directory): | |
os.makedirs(save_directory, exist_ok=True) | |
torch.save(self.state_dict(), os.path.join(save_directory, 'pytorch_model.bin')) | |
config = { | |
"text_model_name": self.text_model_name, | |
"hidden_dim": self.fc1.out_features | |
} | |
with open(os.path.join(save_directory, 'config.json'), 'w') as f: | |
json.dump(config, f) | |
def from_pretrained(cls, load_directory, map_location=None): | |
if os.path.exists(load_directory): | |
config_path = os.path.join(load_directory, 'config.json') | |
state_dict_path = os.path.join(load_directory, 'pytorch_model.bin') | |
with open(config_path, 'r') as f: | |
config_dict = json.load(f) | |
model = cls(text_model_name=config_dict["text_model_name"], hidden_dim=config_dict["hidden_dim"]) | |
state_dict = torch.load(state_dict_path, map_location=map_location) | |
model.load_state_dict(state_dict) | |
else: | |
hf_model = AutoModelForSequenceClassification.from_pretrained(load_directory, num_labels=2) | |
model = cls(text_model_name=hf_model.config.name_or_path, hidden_dim=hf_model.config.hidden_size) | |
model.text_model = hf_model | |
return model | |
tokenizer = AutoTokenizer.from_pretrained('Souha-BH/BERT_Resnet50') | |
model = MultimodalRiskBehaviorModel.from_pretrained('Souha-BH/BERT_Resnet50') # if cpu add arg map_location='cpu' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Function to load frames from a video | |
def load_frames_from_video(video_path, transform, num_frames=10): | |
cap = cv2.VideoCapture(video_path) | |
frames = [] | |
frame_count = 0 | |
while frame_count < num_frames: # Limit to a number of frames for efficiency | |
success, frame = cap.read() | |
if not success: | |
break | |
# Convert frame (NumPy array) to PIL image and apply transformations | |
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
frame = transform(frame) | |
frames.append(frame) | |
frame_count += 1 | |
cap.release() | |
# Stack frames and add batch dimension (1, num_frames, channels, height, width) | |
frames = torch.stack(frames) | |
frames = frames.unsqueeze(0) # Add batch dimension | |
return frames | |
def predict_video(model, video_path, text_input, tokenizer, transform): | |
try: | |
# Set model to evaluation mode | |
model.eval() | |
# Tokenize the text input | |
encoding = tokenizer( | |
text_input, padding='max_length', truncation=True, max_length=128, return_tensors='pt' | |
) | |
encoding = {key: val.to(device) for key, val in encoding.items()} | |
# Load frames from the video | |
frames = load_frames_from_video(video_path, transform) | |
frames = frames.to(device) | |
# Log input shapes and devices | |
print(f"Encoding device: {next(iter(encoding.values())).device}, Frames shape: {frames.shape}") | |
# Perform forward pass through the model | |
with torch.no_grad(): | |
output = model(encoding, frames) | |
# Apply sigmoid to get probability, then threshold to get prediction | |
prediction = (output.squeeze(-1) > 0.5).float() | |
return prediction.item() | |
except Exception as e: | |
print(f"Prediction error: {e}") | |
return "Error during prediction" | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Define your video paths and captions | |
video_paths = [ | |
'https://drive.google.com/uc?export=download&id=1iWq1q1LM-jmf4iZxOqZTw4FaIBekJowM', | |
'https://drive.google.com/uc?export=download&id=1_egBaC1HD2kIZgRRKsnCtsWG94vg1c7n', | |
'https://drive.google.com/uc?export=download&id=12cGxBEkfU5Q1Ezg2jRk6zGyn2hoR3JLj' | |
] | |
video_captions = [ | |
"Everytime i start a diet ูู ู ุฑุฉ ุฃุญุงูู ุฃุจุฏุฃ ุฑูุฌูู ๐ #dietmemes #funnyvideos #animetiktok", | |
"New sandwich from burger king ๐๐ #mukbang #asmr #asmrmukbang #asmrsounds #eat #food #Foodie moe eats #yummy #cheese #chicken #burger #fries #burgerking @Burger King", | |
"all workout guides l!nked in bi0 // honestly huge moment ๐ Iโve been so focused on growing my upper body that this feels like it finally shows! shorts from @KEEPTHATPUMP #upperbody #upperbodyworkout #glutegains #glutegrowth #gluteexercise #workout #strengthtraining #gym #trending #fyp" | |
] | |
def predict_risk(video_index): | |
video_path = video_paths[video_index] | |
text_input = video_captions[video_index] | |
# Make prediction | |
prediction = predict_video(model, video_path, text_input, tokenizer, transform) | |
# Return the corresponding label | |
if prediction == "Error during prediction": | |
return "Error during prediction" | |
return "Risky Health Behavior" if prediction == 1 else "Not Risky Health Behavior" | |
# Interface setup | |
with gr.Blocks() as interface: | |
gr.Markdown("# Risk Behavior Prediction") | |
gr.Markdown("Select a video to classify its behavior as risky or not.") | |
# Input option selector | |
video_selector = gr.Radio(["Video 1", "Video 2", "Video 3"], label="Choose a Video") | |
# Use function to return URLs which are handled by the Gradio `gr.Video` component | |
def show_selected_video(choice): | |
idx = int(choice.split()[-1]) - 1 | |
return video_paths[idx], f"**Caption:** {video_captions[idx]}" | |
video_player = gr.Video(width=320, height=240) | |
caption_box = gr.Markdown() | |
video_selector.change( | |
fn=show_selected_video, | |
inputs=video_selector, | |
outputs=[video_player, caption_box] | |
) | |
# Prediction button and output | |
predict_button = gr.Button("Predict Risk") | |
output_text = gr.Textbox(label="Prediction") | |
predict_button.click( | |
fn=lambda idx: predict_risk(int(idx.split()[-1]) - 1), | |
inputs=video_selector, | |
outputs=output_text | |
) | |
# Launch the app | |
interface.launch() |