jnkziaa's picture
Upload inference.py
7b1499d verified
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Specify the model directory (SageMaker uses /opt/ml/model by default)
def model_fn(model_dir):
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
def model_fn(model_dir):
"""
Load the model and tokenizer from the specified directory.
"""
print("Loading model and tokenizer...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Explicitly load the .pth file
model_path = os.path.join(model_dir, "pytorch_model.pth")
config_path = os.path.join(model_dir, "config.json")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
# Load the model using the state dictionary
config = AutoConfig.from_pretrained(config_path)
model = AutoModelForSeq2SeqLM(config)
model.load_state_dict(torch.load(model_path))
print("Model and tokenizer loaded successfully.")
return model, tokenizer
def input_fn(serialized_input_data, content_type="application/json"):
"""
Deserialize the input data from JSON format.
"""
print("Processing input data...")
if content_type == "application/json":
import json
input_data = json.loads(serialized_input_data)
if "plsql_code" not in input_data:
raise ValueError("Missing 'plsql_code' in the input JSON.")
print("Input data processed successfully.")
return input_data["plsql_code"]
else:
raise ValueError(f"Unsupported content type: {content_type}")
def predict_fn(input_data, model_and_tokenizer):
"""
Translate PL/SQL code to Hibernate/JPA-based Java code using the trained model.
"""
print("Starting prediction...")
model, tokenizer = model_and_tokenizer
# Construct the tailored prompt
prompt = f"""
Translate this PL/SQL function to a Hibernate/JPA-based Java implementation.
Requirements:
1. Use @Entity, @Table, and @Column annotations to map the database table structure.
2. Define Java fields corresponding to the database columns used in the PL/SQL logic.
3. Replicate the PL/SQL logic as a @Query in the Repository layer or as Java logic in the Service layer.
4. Use Repository and Service layers, ensuring transactional consistency with @Transactional annotations.
5. Avoid direct bitwise operations in procedural code; ensure they are part of database entities or queries.
Input PL/SQL:
{input_data}
"""
# Tokenize and generate the translated code
print("Tokenizing input...")
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
print("Generating output...")
outputs = model.generate(
inputs["input_ids"],
max_length=1024,
num_beams=4,
early_stopping=True
)
translated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Prediction completed.")
return translated_code
def output_fn(prediction, accept="application/json"):
"""
Serialize the prediction result to JSON format.
"""
print("Serializing output...")
if accept == "application/json":
import json
return json.dumps({"translated_code": prediction}), "application/json"
else:
raise ValueError(f"Unsupported accept type: {accept}")