File size: 3,522 Bytes
7b1499d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Specify the model directory (SageMaker uses /opt/ml/model by default)
def model_fn(model_dir):
    import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

def model_fn(model_dir):
    """
    Load the model and tokenizer from the specified directory.
    """
    print("Loading model and tokenizer...")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    # Explicitly load the .pth file
    model_path = os.path.join(model_dir, "pytorch_model.pth")
    config_path = os.path.join(model_dir, "config.json")

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model file not found: {model_path}")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found: {config_path}")

    # Load the model using the state dictionary
    config = AutoConfig.from_pretrained(config_path)
    model = AutoModelForSeq2SeqLM(config)
    model.load_state_dict(torch.load(model_path))
    print("Model and tokenizer loaded successfully.")

    return model, tokenizer

def input_fn(serialized_input_data, content_type="application/json"):
    """
    Deserialize the input data from JSON format.
    """
    print("Processing input data...")
    if content_type == "application/json":
        import json
        input_data = json.loads(serialized_input_data)
        if "plsql_code" not in input_data:
            raise ValueError("Missing 'plsql_code' in the input JSON.")
        print("Input data processed successfully.")
        return input_data["plsql_code"]
    else:
        raise ValueError(f"Unsupported content type: {content_type}")

def predict_fn(input_data, model_and_tokenizer):
    """
    Translate PL/SQL code to Hibernate/JPA-based Java code using the trained model.
    """
    print("Starting prediction...")
    model, tokenizer = model_and_tokenizer

    # Construct the tailored prompt
    prompt = f"""
    Translate this PL/SQL function to a Hibernate/JPA-based Java implementation.

    Requirements:
    1. Use @Entity, @Table, and @Column annotations to map the database table structure.
    2. Define Java fields corresponding to the database columns used in the PL/SQL logic.
    3. Replicate the PL/SQL logic as a @Query in the Repository layer or as Java logic in the Service layer.
    4. Use Repository and Service layers, ensuring transactional consistency with @Transactional annotations.
    5. Avoid direct bitwise operations in procedural code; ensure they are part of database entities or queries.

    Input PL/SQL:
    {input_data}
    """
    # Tokenize and generate the translated code
    print("Tokenizing input...")
    inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
    print("Generating output...")
    outputs = model.generate(
        inputs["input_ids"],
        max_length=1024,
        num_beams=4,
        early_stopping=True
    )
    translated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("Prediction completed.")
    return translated_code

def output_fn(prediction, accept="application/json"):
    """
    Serialize the prediction result to JSON format.
    """
    print("Serializing output...")
    if accept == "application/json":
        import json
        return json.dumps({"translated_code": prediction}), "application/json"
    else:
        raise ValueError(f"Unsupported accept type: {accept}")