Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
import cv2 | |
import numpy as np | |
import gradio as gr | |
import math | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
class ShopliftingPrediction: | |
def __init__(self, model_path, frame_width, frame_height, sequence_length): | |
self.frame_width = frame_width | |
self.frame_height = frame_height | |
self.sequence_length = sequence_length | |
self.model_path = model_path | |
self.message = '' | |
def load_model(self): | |
# Define custom objects for loading the model | |
custom_objects = { | |
'Conv2D': tf.keras.layers.Conv2D, | |
'MaxPooling2D': tf.keras.layers.MaxPooling2D, | |
'TimeDistributed': tf.keras.layers.TimeDistributed, | |
'LSTM': tf.keras.layers.LSTM, | |
'Dense': tf.keras.layers.Dense, | |
'Flatten': tf.keras.layers.Flatten, | |
'Dropout': tf.keras.layers.Dropout, | |
'Orthogonal': tf.keras.initializers.Orthogonal, | |
} | |
# Load the model with custom objects | |
self.model = tf.keras.models.load_model(self.model_path, custom_objects=custom_objects) | |
logging.info("Model loaded successfully.") | |
def generate_message_content(self, probability, label): | |
if label == 0: | |
if probability <=50: | |
self.message = "No theft" | |
elif probability <= 75: | |
self.message = "There is little chance of theft" | |
elif probability <= 85: | |
self.message = "High probability of theft" | |
else: | |
self.message = "Very high probability of theft" | |
elif label == 1: | |
if probability <=50: | |
self.message = "No theft" | |
elif probability <= 75: | |
self.message = "The movement is confusing, watch" | |
elif probability <= 85: | |
self.message = "I think it's normal, but it's better to watch" | |
else: | |
self.message = "Movement is normal" | |
def Pre_Process_Video(self, current_frame, previous_frame): | |
diff = cv2.absdiff(current_frame, previous_frame) | |
diff = cv2.GaussianBlur(diff, (3, 3), 0) | |
resized_frame = cv2.resize(diff, (self.frame_height, self.frame_width)) | |
gray_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY) | |
normalized_frame = gray_frame / 255 | |
return normalized_frame | |
def Read_Video(self, filePath): | |
self.video_reader = cv2.VideoCapture(filePath) | |
self.original_video_width = int(self.video_reader.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
self.original_video_height = int(self.video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
self.fps = self.video_reader.get(cv2.CAP_PROP_FPS) | |
def Single_Frame_Predict(self, frames_queue): | |
probabilities = self.model.predict(np.expand_dims(frames_queue, axis=0))[0] | |
predicted_label = np.argmax(probabilities) | |
probability = math.floor(max(probabilities[0], probabilities[1]) * 100) | |
return [probability, predicted_label] | |
def Predict_Video(self, video_file_path, output_file_path): | |
self.load_model() | |
self.Read_Video(video_file_path) | |
video_writer = cv2.VideoWriter(output_file_path, cv2.VideoWriter_fourcc('M', 'P', '4', 'V'), | |
self.fps, (self.original_video_width, self.original_video_height)) | |
success, frame = self.video_reader.read() | |
previous = frame.copy() | |
frames_queue = [] | |
while self.video_reader.isOpened(): | |
ok, frame = self.video_reader.read() | |
if not ok: | |
break | |
normalized_frame = self.Pre_Process_Video(frame, previous) | |
previous = frame.copy() | |
frames_queue.append(normalized_frame) | |
if len(frames_queue) == self.sequence_length: | |
[probability, predicted_label] = self.Single_Frame_Predict(frames_queue) | |
self.generate_message_content(probability, predicted_label) | |
message = "{}:{}%".format(self.message, probability) | |
frames_queue = [] | |
logging.info(message) | |
cv2.rectangle(frame, (0, 0), (640, 40), (255, 255, 255), -1) | |
cv2.putText(frame, self.message, (1, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) | |
video_writer.write(frame) | |
self.video_reader.release() | |
video_writer.release() | |
return output_file_path | |
def inference(model_path): | |
shoplifting_prediction = ShopliftingPrediction(model_path, 90, 90, sequence_length=160) | |
def process_video(video_path): | |
output_file_path = '/tmp/output.mp4' | |
return shoplifting_prediction.Predict_Video(video_path, output_file_path) | |
return process_video | |
model_path = 'lrcn_160S_90_90Q.h5' | |
process_video = inference(model_path) | |
iface = gr.Interface( | |
fn=process_video, | |
inputs=gr.Video(), | |
outputs="video", | |
live=True, | |
) | |
iface.launch() | |