Yohai Rosen
commited on
Commit
·
2a0635e
1
Parent(s):
0f1045d
test
Browse files- config.json +1 -0
- sagemaker_setup.sh +20 -0
- scripts/sagemaker.py +75 -0
config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"placeholder": "This is a placeholder config.json"}
|
sagemaker_setup.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Update package list and install ffmpeg
|
4 |
+
apt-get update && apt-get install -y ffmpeg
|
5 |
+
|
6 |
+
# # Ensure the model directory and config.json file exist
|
7 |
+
# MODEL_DIR="/opt/ml/model"
|
8 |
+
# CONFIG_FILE="${MODEL_DIR}/config.json"
|
9 |
+
|
10 |
+
# # Ensure the model directory exists
|
11 |
+
# mkdir -p ${MODEL_DIR}
|
12 |
+
|
13 |
+
# # Create a placeholder config.json if it does not exist
|
14 |
+
# if [ ! -f ${CONFIG_FILE} ]; then
|
15 |
+
# echo "Creating placeholder config.json in ${MODEL_DIR}"
|
16 |
+
# echo '{"placeholder": "This is a placeholder config.json"}' > ${CONFIG_FILE}
|
17 |
+
# fi
|
18 |
+
|
19 |
+
# echo "Initialization completed. Model directory contents:"
|
20 |
+
# ls -l ${MODEL_DIR}
|
scripts/sagemaker.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import boto3
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import time
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
|
9 |
+
from inference import inference_process # Ensure inference.py is in the same directory or update the import path
|
10 |
+
|
11 |
+
def download_from_s3(s3_path, local_path):
|
12 |
+
s3 = boto3.client('s3')
|
13 |
+
bucket, key = s3_path.replace("s3://", "").split("/", 1)
|
14 |
+
s3.download_file(bucket, key, local_path)
|
15 |
+
|
16 |
+
def upload_to_s3(local_path, s3_path):
|
17 |
+
s3 = boto3.client('s3')
|
18 |
+
bucket, key = s3_path.replace("s3://", "").split("/", 1)
|
19 |
+
s3.upload_file(local_path, bucket, key)
|
20 |
+
|
21 |
+
def model_fn(model_dir):
|
22 |
+
# config_path = os.path.join(model_dir, 'config.json')
|
23 |
+
|
24 |
+
# # Create a placeholder config.json if it does not exist
|
25 |
+
# if not os.path.exists(config_path):
|
26 |
+
# print(f"config.json not found in {model_dir}. Creating a placeholder config.json.")
|
27 |
+
# config_content = {
|
28 |
+
# "placeholder": "This is a placeholder config.json"
|
29 |
+
# }
|
30 |
+
# with open(config_path, 'w') as config_file:
|
31 |
+
# json.dump(config_content, config_file)
|
32 |
+
|
33 |
+
return model_dir
|
34 |
+
|
35 |
+
def input_fn(request_body, content_type='application/json'):
|
36 |
+
if content_type == 'application/json':
|
37 |
+
input_data = json.loads(request_body)
|
38 |
+
|
39 |
+
# Download source_image and driving_audio from S3 if necessary
|
40 |
+
source_image_path = input_data['source_image']
|
41 |
+
driving_audio_path = input_data['driving_audio']
|
42 |
+
|
43 |
+
local_source_image = "/opt/ml/input/data/source_image.jpg"
|
44 |
+
local_driving_audio = "/opt/ml/input/data/driving_audio.wav"
|
45 |
+
|
46 |
+
if source_image_path.startswith("s3://"):
|
47 |
+
download_from_s3(source_image_path, local_source_image)
|
48 |
+
input_data['source_image'] = local_source_image
|
49 |
+
if driving_audio_path.startswith("s3://"):
|
50 |
+
download_from_s3(driving_audio_path, local_driving_audio)
|
51 |
+
input_data['driving_audio'] = local_driving_audio
|
52 |
+
|
53 |
+
args = argparse.Namespace(**input_data.get('config', {}))
|
54 |
+
s3_output = input_data.get('output', None)
|
55 |
+
|
56 |
+
return args, s3_output
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Unsupported content type: {content_type}")
|
59 |
+
|
60 |
+
def predict_fn(input_data, model):
|
61 |
+
args, s3_output = input_data
|
62 |
+
|
63 |
+
# Call the inference process
|
64 |
+
inference_process(args)
|
65 |
+
|
66 |
+
return '.cache/output.mp4', s3_output
|
67 |
+
|
68 |
+
def output_fn(prediction, content_type='application/json'):
|
69 |
+
local_output, s3_output = prediction
|
70 |
+
|
71 |
+
# Wait for the output file to be created and upload it to S3
|
72 |
+
while not os.path.exists(local_output):
|
73 |
+
time.sleep(1)
|
74 |
+
|
75 |
+
return json.dumps({'status': 'completed', 's3_output': s3_output})
|