|
import os |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
def model_fn(model_dir): |
|
import os |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
def model_fn(model_dir): |
|
""" |
|
Load the model and tokenizer from the specified directory. |
|
""" |
|
print("Loading model and tokenizer...") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
|
|
|
model_path = os.path.join(model_dir, "pytorch_model.pth") |
|
config_path = os.path.join(model_dir, "config.json") |
|
|
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"Model file not found: {model_path}") |
|
if not os.path.exists(config_path): |
|
raise FileNotFoundError(f"Config file not found: {config_path}") |
|
|
|
|
|
config = AutoConfig.from_pretrained(config_path) |
|
model = AutoModelForSeq2SeqLM(config) |
|
model.load_state_dict(torch.load(model_path)) |
|
print("Model and tokenizer loaded successfully.") |
|
|
|
return model, tokenizer |
|
|
|
def input_fn(serialized_input_data, content_type="application/json"): |
|
""" |
|
Deserialize the input data from JSON format. |
|
""" |
|
print("Processing input data...") |
|
if content_type == "application/json": |
|
import json |
|
input_data = json.loads(serialized_input_data) |
|
if "plsql_code" not in input_data: |
|
raise ValueError("Missing 'plsql_code' in the input JSON.") |
|
print("Input data processed successfully.") |
|
return input_data["plsql_code"] |
|
else: |
|
raise ValueError(f"Unsupported content type: {content_type}") |
|
|
|
def predict_fn(input_data, model_and_tokenizer): |
|
""" |
|
Translate PL/SQL code to Hibernate/JPA-based Java code using the trained model. |
|
""" |
|
print("Starting prediction...") |
|
model, tokenizer = model_and_tokenizer |
|
|
|
|
|
prompt = f""" |
|
Translate this PL/SQL function to a Hibernate/JPA-based Java implementation. |
|
|
|
Requirements: |
|
1. Use @Entity, @Table, and @Column annotations to map the database table structure. |
|
2. Define Java fields corresponding to the database columns used in the PL/SQL logic. |
|
3. Replicate the PL/SQL logic as a @Query in the Repository layer or as Java logic in the Service layer. |
|
4. Use Repository and Service layers, ensuring transactional consistency with @Transactional annotations. |
|
5. Avoid direct bitwise operations in procedural code; ensure they are part of database entities or queries. |
|
|
|
Input PL/SQL: |
|
{input_data} |
|
""" |
|
|
|
print("Tokenizing input...") |
|
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) |
|
print("Generating output...") |
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=1024, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
translated_code = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print("Prediction completed.") |
|
return translated_code |
|
|
|
def output_fn(prediction, accept="application/json"): |
|
""" |
|
Serialize the prediction result to JSON format. |
|
""" |
|
print("Serializing output...") |
|
if accept == "application/json": |
|
import json |
|
return json.dumps({"translated_code": prediction}), "application/json" |
|
else: |
|
raise ValueError(f"Unsupported accept type: {accept}") |
|
|
|
|