File size: 3,522 Bytes
7b1499d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Specify the model directory (SageMaker uses /opt/ml/model by default)
def model_fn(model_dir):
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
def model_fn(model_dir):
"""
Load the model and tokenizer from the specified directory.
"""
print("Loading model and tokenizer...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Explicitly load the .pth file
model_path = os.path.join(model_dir, "pytorch_model.pth")
config_path = os.path.join(model_dir, "config.json")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
# Load the model using the state dictionary
config = AutoConfig.from_pretrained(config_path)
model = AutoModelForSeq2SeqLM(config)
model.load_state_dict(torch.load(model_path))
print("Model and tokenizer loaded successfully.")
return model, tokenizer
def input_fn(serialized_input_data, content_type="application/json"):
"""
Deserialize the input data from JSON format.
"""
print("Processing input data...")
if content_type == "application/json":
import json
input_data = json.loads(serialized_input_data)
if "plsql_code" not in input_data:
raise ValueError("Missing 'plsql_code' in the input JSON.")
print("Input data processed successfully.")
return input_data["plsql_code"]
else:
raise ValueError(f"Unsupported content type: {content_type}")
def predict_fn(input_data, model_and_tokenizer):
"""
Translate PL/SQL code to Hibernate/JPA-based Java code using the trained model.
"""
print("Starting prediction...")
model, tokenizer = model_and_tokenizer
# Construct the tailored prompt
prompt = f"""
Translate this PL/SQL function to a Hibernate/JPA-based Java implementation.
Requirements:
1. Use @Entity, @Table, and @Column annotations to map the database table structure.
2. Define Java fields corresponding to the database columns used in the PL/SQL logic.
3. Replicate the PL/SQL logic as a @Query in the Repository layer or as Java logic in the Service layer.
4. Use Repository and Service layers, ensuring transactional consistency with @Transactional annotations.
5. Avoid direct bitwise operations in procedural code; ensure they are part of database entities or queries.
Input PL/SQL:
{input_data}
"""
# Tokenize and generate the translated code
print("Tokenizing input...")
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
print("Generating output...")
outputs = model.generate(
inputs["input_ids"],
max_length=1024,
num_beams=4,
early_stopping=True
)
translated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Prediction completed.")
return translated_code
def output_fn(prediction, accept="application/json"):
"""
Serialize the prediction result to JSON format.
"""
print("Serializing output...")
if accept == "application/json":
import json
return json.dumps({"translated_code": prediction}), "application/json"
else:
raise ValueError(f"Unsupported accept type: {accept}")
|