ldapca / script.py
annaferrari02's picture
Upload 8 files
5958fad verified
"""
Inference script
Version combining baseline structure with enhanced features
"""
import os
import pickle
import cv2
import pandas as pd
import numpy as np
from utils.utils import extract_features_from_image, apply_pca_lda_transform
def run_inference(TEST_IMAGE_PATH, svm_model, pca_params, SUBMISSION_CSV_SAVE_PATH):
"""
Run inference on test images
Args:
TEST_IMAGE_PATH: Path to test images (/tmp/data/test_images)
svm_model: Trained SVM model
pca_params: Dictionary containing PCA transformation parameters
SUBMISSION_CSV_SAVE_PATH: Path to save submission.csv
"""
# Load test images
test_images = os.listdir(TEST_IMAGE_PATH)
test_images.sort()
# Extract features from all test images
image_feature_list = []
for test_image in test_images:
path_to_image = os.path.join(TEST_IMAGE_PATH, test_image)
image = cv2.imread(path_to_image)
# Extract features (using enhanced features by default)
image_features = extract_features_from_image(image)
image_feature_list.append(image_features)
features_array = np.array(image_feature_list)
# Apply PCA transformation using saved parameters
features_reduced = apply_pca_lda_transform(features_array, pca_params)
# Run predictions
predictions = svm_model.predict(features_reduced)
# Create submission CSV
df_predictions = pd.DataFrame({
"file_name": test_images,
"category_id": predictions
})
df_predictions.to_csv(SUBMISSION_CSV_SAVE_PATH, index=False)
if __name__ == "__main__":
# Paths
current_directory = os.path.dirname(os.path.abspath(__file__))
TEST_IMAGE_PATH = "/tmp/data/test_images"
MODEL_NAME = "multiclass_model.pkl"
MODEL_PATH = os.path.join(current_directory, MODEL_NAME)
PCA_LDA_PARAMS_NAME = "pca_lda_params.pkl"
PCA_LDA_PARAMS_PATH = os.path.join(current_directory, PCA_LDA_PARAMS_NAME)
SUBMISSION_CSV_SAVE_PATH = os.path.join(current_directory, "submission.csv")
# Load trained SVM model
with open(MODEL_PATH, 'rb') as file:
svm_model = pickle.load(file)
# Load PCA parameters
with open(PCA_LDA_PARAMS_PATH, 'rb') as file:
pca_params = pickle.load(file)
# Run inference
run_inference(TEST_IMAGE_PATH, svm_model, pca_params, SUBMISSION_CSV_SAVE_PATH)