import os import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Specify the model directory (SageMaker uses /opt/ml/model by default) def model_fn(model_dir): import os import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM def model_fn(model_dir): """ Load the model and tokenizer from the specified directory. """ print("Loading model and tokenizer...") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_dir) # Explicitly load the .pth file model_path = os.path.join(model_dir, "pytorch_model.pth") config_path = os.path.join(model_dir, "config.json") if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found: {config_path}") # Load the model using the state dictionary config = AutoConfig.from_pretrained(config_path) model = AutoModelForSeq2SeqLM(config) model.load_state_dict(torch.load(model_path)) print("Model and tokenizer loaded successfully.") return model, tokenizer def input_fn(serialized_input_data, content_type="application/json"): """ Deserialize the input data from JSON format. """ print("Processing input data...") if content_type == "application/json": import json input_data = json.loads(serialized_input_data) if "plsql_code" not in input_data: raise ValueError("Missing 'plsql_code' in the input JSON.") print("Input data processed successfully.") return input_data["plsql_code"] else: raise ValueError(f"Unsupported content type: {content_type}") def predict_fn(input_data, model_and_tokenizer): """ Translate PL/SQL code to Hibernate/JPA-based Java code using the trained model. """ print("Starting prediction...") model, tokenizer = model_and_tokenizer # Construct the tailored prompt prompt = f""" Translate this PL/SQL function to a Hibernate/JPA-based Java implementation. Requirements: 1. Use @Entity, @Table, and @Column annotations to map the database table structure. 2. Define Java fields corresponding to the database columns used in the PL/SQL logic. 3. Replicate the PL/SQL logic as a @Query in the Repository layer or as Java logic in the Service layer. 4. Use Repository and Service layers, ensuring transactional consistency with @Transactional annotations. 5. Avoid direct bitwise operations in procedural code; ensure they are part of database entities or queries. Input PL/SQL: {input_data} """ # Tokenize and generate the translated code print("Tokenizing input...") inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) print("Generating output...") outputs = model.generate( inputs["input_ids"], max_length=1024, num_beams=4, early_stopping=True ) translated_code = tokenizer.decode(outputs[0], skip_special_tokens=True) print("Prediction completed.") return translated_code def output_fn(prediction, accept="application/json"): """ Serialize the prediction result to JSON format. """ print("Serializing output...") if accept == "application/json": import json return json.dumps({"translated_code": prediction}), "application/json" else: raise ValueError(f"Unsupported accept type: {accept}")