Yohai Rosen commited on
Commit
2a0635e
·
1 Parent(s): 0f1045d
Files changed (3) hide show
  1. config.json +1 -0
  2. sagemaker_setup.sh +20 -0
  3. scripts/sagemaker.py +75 -0
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"placeholder": "This is a placeholder config.json"}
sagemaker_setup.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Update package list and install ffmpeg
4
+ apt-get update && apt-get install -y ffmpeg
5
+
6
+ # # Ensure the model directory and config.json file exist
7
+ # MODEL_DIR="/opt/ml/model"
8
+ # CONFIG_FILE="${MODEL_DIR}/config.json"
9
+
10
+ # # Ensure the model directory exists
11
+ # mkdir -p ${MODEL_DIR}
12
+
13
+ # # Create a placeholder config.json if it does not exist
14
+ # if [ ! -f ${CONFIG_FILE} ]; then
15
+ # echo "Creating placeholder config.json in ${MODEL_DIR}"
16
+ # echo '{"placeholder": "This is a placeholder config.json"}' > ${CONFIG_FILE}
17
+ # fi
18
+
19
+ # echo "Initialization completed. Model directory contents:"
20
+ # ls -l ${MODEL_DIR}
scripts/sagemaker.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import boto3
4
+ import torch
5
+ import argparse
6
+ import time
7
+ from omegaconf import OmegaConf
8
+
9
+ from inference import inference_process # Ensure inference.py is in the same directory or update the import path
10
+
11
+ def download_from_s3(s3_path, local_path):
12
+ s3 = boto3.client('s3')
13
+ bucket, key = s3_path.replace("s3://", "").split("/", 1)
14
+ s3.download_file(bucket, key, local_path)
15
+
16
+ def upload_to_s3(local_path, s3_path):
17
+ s3 = boto3.client('s3')
18
+ bucket, key = s3_path.replace("s3://", "").split("/", 1)
19
+ s3.upload_file(local_path, bucket, key)
20
+
21
+ def model_fn(model_dir):
22
+ # config_path = os.path.join(model_dir, 'config.json')
23
+
24
+ # # Create a placeholder config.json if it does not exist
25
+ # if not os.path.exists(config_path):
26
+ # print(f"config.json not found in {model_dir}. Creating a placeholder config.json.")
27
+ # config_content = {
28
+ # "placeholder": "This is a placeholder config.json"
29
+ # }
30
+ # with open(config_path, 'w') as config_file:
31
+ # json.dump(config_content, config_file)
32
+
33
+ return model_dir
34
+
35
+ def input_fn(request_body, content_type='application/json'):
36
+ if content_type == 'application/json':
37
+ input_data = json.loads(request_body)
38
+
39
+ # Download source_image and driving_audio from S3 if necessary
40
+ source_image_path = input_data['source_image']
41
+ driving_audio_path = input_data['driving_audio']
42
+
43
+ local_source_image = "/opt/ml/input/data/source_image.jpg"
44
+ local_driving_audio = "/opt/ml/input/data/driving_audio.wav"
45
+
46
+ if source_image_path.startswith("s3://"):
47
+ download_from_s3(source_image_path, local_source_image)
48
+ input_data['source_image'] = local_source_image
49
+ if driving_audio_path.startswith("s3://"):
50
+ download_from_s3(driving_audio_path, local_driving_audio)
51
+ input_data['driving_audio'] = local_driving_audio
52
+
53
+ args = argparse.Namespace(**input_data.get('config', {}))
54
+ s3_output = input_data.get('output', None)
55
+
56
+ return args, s3_output
57
+ else:
58
+ raise ValueError(f"Unsupported content type: {content_type}")
59
+
60
+ def predict_fn(input_data, model):
61
+ args, s3_output = input_data
62
+
63
+ # Call the inference process
64
+ inference_process(args)
65
+
66
+ return '.cache/output.mp4', s3_output
67
+
68
+ def output_fn(prediction, content_type='application/json'):
69
+ local_output, s3_output = prediction
70
+
71
+ # Wait for the output file to be created and upload it to S3
72
+ while not os.path.exists(local_output):
73
+ time.sleep(1)
74
+
75
+ return json.dumps({'status': 'completed', 's3_output': s3_output})