File size: 2,456 Bytes
5dbf895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from PIL import Image
import numpy as np

def _to_divisible_by(img, N):
    """Crop so width and height are divisible by N (top-left anchored)."""
    w, h = img.size
    W = (w // N) * N
    H = (h // N) * N
    if W == 0 or H == 0:
        raise ValueError("N is larger than the image grid (image too small).")
    if W != w or H != h:
        img = img.crop((0, 0, W, H))
    return img, W, H

def spm_augment(
    image,
    num_patches=4,   # N for an N×N grid
    mix_prob=0.5,
    beta_a=2.0,
    beta_b=2.0,
    seed=None
):
    """
    SPM-style augmentation using a global shuffle over an N×N patch grid.
      1) Divide image into N×N patches (cropping to be divisible by N if needed).
      2) Globally permute patch indices.
      3) Per patch, with probability `mix_prob`, replace by a convex blend of
         original and a shuffled patch using alpha~Beta(beta_a,beta_b) (one alpha per image).
    """
    # Normalize input
    if isinstance(image, np.ndarray):
        img = Image.fromarray(image).convert("RGB")
    else:
        img = image.convert("RGB")

    N = int(num_patches)
    rng = np.random.default_rng(seed)

    # Ensure divisibility and compute patch size
    img, W, H = _to_divisible_by(img, N)
    arr = np.array(img, dtype=np.uint8)
    ph = H // N
    pw = W // N

    # Build patch list (row-major)
    patches = []
    for i in range(N):
        for j in range(N):
            y0 = i * ph
            x0 = j * pw
            patches.append(arr[y0:y0+ph, x0:x0+pw])

    total = N * N
    perm = rng.permutation(total)

    # Sample one alpha for the whole image
    if beta_a > 0 and beta_b > 0:
        alpha = float(rng.beta(beta_a, beta_b))
    else:
        alpha = 1.0

    # Patchwise mix
    out = arr.copy()
    mask = rng.random(total) < float(mix_prob)
    idx = 0
    for i in range(N):
        for j in range(N):
            y0 = i * ph
            x0 = j * pw
            if mask[idx]:
                src = patches[idx].astype(np.float32)
                shf = patches[perm[idx]].astype(np.float32)
                if 0.0 < alpha < 1.0:
                    mixed = alpha * shf + (1.0 - alpha) * src
                    out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
                else:
                    out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
            else:
                out[y0:y0+ph, x0:x0+pw] = patches[idx]
            idx += 1

    return Image.fromarray(out)