File size: 8,429 Bytes
9342c6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import logging
import asyncio
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
import deeplabcut as dlc
import os
import requests
from typing import Dict
import threading
# 实例化FastAPI应用
app = FastAPI()
# 允许任何来源的CORS请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许任何来源
allow_credentials=True,
allow_methods=["*"], # 允许所有方法
allow_headers=["*"], # 允许所有头部
)
# 配置文件路径
project_path = "/app/kunin-dlc-240814"
config_path = os.path.join(project_path, "config.yaml")
# 设置日志记录格式
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 创建线程锁
lock = threading.Lock()
# 用于跟踪已下载的文件
downloaded_files = {}
@app.post("/analyze")
async def analyze_video(request: Request):
logging.info("Received request for video analysis")
# 获取锁,确保只有一个分析任务在进行
with lock:
try:
# 在开始处理前清空工作目录
clear_working_directory("/app/working")
data: Dict = await request.json()
logging.info(f"Request data: {data}")
except Exception as e:
logging.error(f"Failed to parse JSON: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")
if 'videoUrl' not in data or 'videoOssId' not in data:
logging.error("Missing videoUrl or videoOssId in request")
raise HTTPException(status_code=400, detail="videoUrl and videoOssId are required")
video_url = data['videoUrl']
video_oss_id = data['videoOssId']
try:
video_path = download_video(video_url, video_oss_id)
logging.info(f"Downloaded video to: {video_path}")
except Exception as e:
logging.error(f"Error downloading video: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error downloading video: {str(e)}")
try:
# 执行推理
logging.info(f"Starting analysis for video: {video_path}")
dlc.analyze_videos(config_path, [video_path], shuffle=0, videotype="mp4", auto_track=True)
logging.info("Video analysis completed")
# 过滤预测结果
logging.info(f"Filtering predictions for video: {video_path}")
dlc.filterpredictions(config_path, [video_path], shuffle=0, videotype='mp4')
logging.info("Predictions filtered")
# 创建标注视频
logging.info(f"Creating labeled video for: {video_path}")
dlc.create_labeled_video(
config_path,
[video_path],
videotype='mp4',
shuffle=0,
color_by="individual",
keypoints_only=False,
draw_skeleton=True,
filtered=True,
)
logging.info("Labeled video created")
# 查找并重命名输出文件
labeled_video_path, h5_file_path = find_and_rename_output_files(video_path)
if not labeled_video_path or not h5_file_path:
logging.error("Output files missing after analysis")
raise HTTPException(status_code=500, detail="Analysis completed, but output files are missing.")
# 初始化文件下载状态
downloaded_files[os.path.basename(labeled_video_path)] = False
downloaded_files[os.path.basename(h5_file_path)] = False
except Exception as e:
logging.error(f"Error during video analysis: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error during video analysis: {str(e)}")
response_data = {
"videoOssId": video_oss_id,
"labeled_video": f"/post_download/{os.path.basename(labeled_video_path)}",
"h5_file": f"/post_download/{os.path.basename(h5_file_path)}"
}
logging.info("Returning response data")
return response_data
@app.post("/post_download/{filename}")
async def post_download_file(filename: str):
file_path = os.path.join("/app/working", filename)
if os.path.exists(file_path):
# 标记文件已被下载
downloaded_files[filename] = True
logging.info(f"Serving file: {file_path}")
return FileResponse(path=file_path, media_type='application/octet-stream', filename=filename)
else:
raise HTTPException(status_code=404, detail="File not found")
async def wait_for_files_to_be_downloaded(files: list):
"""等待文件被访问和下载后删除"""
try:
while any(not downloaded_files.get(os.path.basename(file), False) for file in files):
logging.info(f"Waiting for files to be downloaded: {files}")
await asyncio.sleep(5)
logging.info("All files have been downloaded, deleting them...")
for file in files:
if os.path.exists(file):
os.remove(file)
logging.info(f"Deleted file: {file}")
logging.info("All files have been deleted.")
except Exception as e:
logging.error(f"Failed to wait for files: {str(e)}")
def find_and_rename_output_files(video_path: str):
"""查找并重命名生成的标注视频和H5文件"""
working_directory = "/app/working"
base_name = os.path.splitext(os.path.basename(video_path))[0]
labeled_video = None
h5_file = None
for file in os.listdir(working_directory):
if file.endswith("_id_labeled.mp4"):
labeled_video = os.path.join(working_directory, file)
new_labeled_video = os.path.join(working_directory, f"{base_name}_labeled.mp4")
os.rename(labeled_video, new_labeled_video)
labeled_video = new_labeled_video
logging.info(f"Renamed labeled video to: {labeled_video}")
elif file.endswith("_filtered.h5"):
h5_file = os.path.join(working_directory, file)
new_h5_file = os.path.join(working_directory, f"{base_name}.h5")
os.rename(h5_file, new_h5_file)
h5_file = new_h5_file
logging.info(f"Renamed H5 file to: {h5_file}")
logging.info(f"Files in working directory after video processing: {os.listdir(working_directory)}")
return labeled_video, h5_file
def download_video(url: str, video_oss_id: str) -> str:
working_directory = "/app/working"
try:
# 确保目标目录存在
if not os.path.exists(working_directory):
os.makedirs(working_directory)
# 使用video_oss_id作为文件名,避免命名冲突
local_filename = os.path.join(working_directory, f"{video_oss_id}.mp4")
# 下载视频并处理可能的连接错误
logging.info(f"Downloading video from URL: {url}")
with requests.get(url, stream=True, timeout=60) as r:
r.raise_for_status()
with open(local_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk: # 过滤掉保持连接的空块
f.write(chunk)
logging.info(f"Video downloaded to: {local_filename}")
return local_filename
except requests.exceptions.RequestException as e:
logging.error(f"Failed to download video: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to download video: {str(e)}")
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
def clear_working_directory(directory: str):
"""清空工作目录中的所有文件"""
try:
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
os.rmdir(file_path)
logging.info(f"Cleared all files in directory: {directory}")
except Exception as e:
logging.error(f"Failed to clear directory {directory}: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)
|