File size: 8,429 Bytes
9342c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import logging
import asyncio
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
import deeplabcut as dlc
import os
import requests
from typing import Dict
import threading

# 实例化FastAPI应用
app = FastAPI()

# 允许任何来源的CORS请求
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 允许任何来源
    allow_credentials=True,
    allow_methods=["*"],  # 允许所有方法
    allow_headers=["*"],  # 允许所有头部
)

# 配置文件路径
project_path = "/app/kunin-dlc-240814"
config_path = os.path.join(project_path, "config.yaml")

# 设置日志记录格式
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 创建线程锁
lock = threading.Lock()

# 用于跟踪已下载的文件
downloaded_files = {}

@app.post("/analyze")
async def analyze_video(request: Request):
    logging.info("Received request for video analysis")

    # 获取锁,确保只有一个分析任务在进行
    with lock:
        try:
            # 在开始处理前清空工作目录
            clear_working_directory("/app/working")

            data: Dict = await request.json()
            logging.info(f"Request data: {data}")
        except Exception as e:
            logging.error(f"Failed to parse JSON: {str(e)}")
            raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}")

        if 'videoUrl' not in data or 'videoOssId' not in data:
            logging.error("Missing videoUrl or videoOssId in request")
            raise HTTPException(status_code=400, detail="videoUrl and videoOssId are required")

        video_url = data['videoUrl']
        video_oss_id = data['videoOssId']

        try:
            video_path = download_video(video_url, video_oss_id)
            logging.info(f"Downloaded video to: {video_path}")
        except Exception as e:
            logging.error(f"Error downloading video: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Error downloading video: {str(e)}")

        try:
            # 执行推理
            logging.info(f"Starting analysis for video: {video_path}")
            dlc.analyze_videos(config_path, [video_path], shuffle=0, videotype="mp4", auto_track=True)
            logging.info("Video analysis completed")

            # 过滤预测结果
            logging.info(f"Filtering predictions for video: {video_path}")
            dlc.filterpredictions(config_path, [video_path], shuffle=0, videotype='mp4')
            logging.info("Predictions filtered")

            # 创建标注视频
            logging.info(f"Creating labeled video for: {video_path}")
            dlc.create_labeled_video(
                config_path,
                [video_path],
                videotype='mp4',
                shuffle=0,
                color_by="individual",
                keypoints_only=False,
                draw_skeleton=True,
                filtered=True,
            )
            logging.info("Labeled video created")

            # 查找并重命名输出文件
            labeled_video_path, h5_file_path = find_and_rename_output_files(video_path)
            if not labeled_video_path or not h5_file_path:
                logging.error("Output files missing after analysis")
                raise HTTPException(status_code=500, detail="Analysis completed, but output files are missing.")

            # 初始化文件下载状态
            downloaded_files[os.path.basename(labeled_video_path)] = False
            downloaded_files[os.path.basename(h5_file_path)] = False

        except Exception as e:
            logging.error(f"Error during video analysis: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Error during video analysis: {str(e)}")

        response_data = {
            "videoOssId": video_oss_id,
            "labeled_video": f"/post_download/{os.path.basename(labeled_video_path)}",
            "h5_file": f"/post_download/{os.path.basename(h5_file_path)}"
        }

        logging.info("Returning response data")
        return response_data

@app.post("/post_download/{filename}")
async def post_download_file(filename: str):
    file_path = os.path.join("/app/working", filename)
    if os.path.exists(file_path):
        # 标记文件已被下载
        downloaded_files[filename] = True
        logging.info(f"Serving file: {file_path}")
        return FileResponse(path=file_path, media_type='application/octet-stream', filename=filename)
    else:
        raise HTTPException(status_code=404, detail="File not found")

async def wait_for_files_to_be_downloaded(files: list):
    """等待文件被访问和下载后删除"""
    try:
        while any(not downloaded_files.get(os.path.basename(file), False) for file in files):
            logging.info(f"Waiting for files to be downloaded: {files}")
            await asyncio.sleep(5)
        logging.info("All files have been downloaded, deleting them...")
        for file in files:
            if os.path.exists(file):
                os.remove(file)
                logging.info(f"Deleted file: {file}")
        logging.info("All files have been deleted.")
    except Exception as e:
        logging.error(f"Failed to wait for files: {str(e)}")

def find_and_rename_output_files(video_path: str):
    """查找并重命名生成的标注视频和H5文件"""
    working_directory = "/app/working"
    base_name = os.path.splitext(os.path.basename(video_path))[0]

    labeled_video = None
    h5_file = None

    for file in os.listdir(working_directory):
        if file.endswith("_id_labeled.mp4"):
            labeled_video = os.path.join(working_directory, file)
            new_labeled_video = os.path.join(working_directory, f"{base_name}_labeled.mp4")
            os.rename(labeled_video, new_labeled_video)
            labeled_video = new_labeled_video
            logging.info(f"Renamed labeled video to: {labeled_video}")
        elif file.endswith("_filtered.h5"):
            h5_file = os.path.join(working_directory, file)
            new_h5_file = os.path.join(working_directory, f"{base_name}.h5")
            os.rename(h5_file, new_h5_file)
            h5_file = new_h5_file
            logging.info(f"Renamed H5 file to: {h5_file}")

    logging.info(f"Files in working directory after video processing: {os.listdir(working_directory)}")
    return labeled_video, h5_file

def download_video(url: str, video_oss_id: str) -> str:
    working_directory = "/app/working"
    
    try:
        # 确保目标目录存在
        if not os.path.exists(working_directory):
            os.makedirs(working_directory)

        # 使用video_oss_id作为文件名,避免命名冲突
        local_filename = os.path.join(working_directory, f"{video_oss_id}.mp4")
        
        # 下载视频并处理可能的连接错误
        logging.info(f"Downloading video from URL: {url}")
        with requests.get(url, stream=True, timeout=60) as r:
            r.raise_for_status()
            with open(local_filename, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    if chunk:  # 过滤掉保持连接的空块
                        f.write(chunk)
        logging.info(f"Video downloaded to: {local_filename}")
        return local_filename
    
    except requests.exceptions.RequestException as e:
        logging.error(f"Failed to download video: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Failed to download video: {str(e)}")
    except Exception as e:
        logging.error(f"Unexpected error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")

def clear_working_directory(directory: str):
    """清空工作目录中的所有文件"""
    try:
        for filename in os.listdir(directory):
            file_path = os.path.join(directory, filename)
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                os.rmdir(file_path)
        logging.info(f"Cleared all files in directory: {directory}")
    except Exception as e:
        logging.error(f"Failed to clear directory {directory}: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8080)