File size: 6,094 Bytes
5dbf895
 
 
c3d8cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dbf895
 
 
 
 
 
 
 
 
 
 
c3d8cc1
 
 
4e2c1e5
c3d8cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dbf895
 
 
 
 
 
4e2c1e5
5dbf895
 
 
c3d8cc1
 
4e2c1e5
c3d8cc1
 
 
4e2c1e5
 
 
 
5dbf895
c3d8cc1
5dbf895
 
 
 
 
 
 
 
 
c3d8cc1
5dbf895
 
 
4e2c1e5
 
 
 
c3d8cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dbf895
c3d8cc1
5dbf895
 
c3d8cc1
 
 
 
 
 
5dbf895
c3d8cc1
5dbf895
 
c3d8cc1
 
 
 
5dbf895
c3d8cc1
 
 
 
 
4e2c1e5
c3d8cc1
 
 
 
 
 
 
 
4e2c1e5
c3d8cc1
 
 
5dbf895
c3d8cc1
 
 
5dbf895
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from PIL import Image
import numpy as np

def create_feather_mask(height, width, feather_size=4):
    """
    2D mask HxW that smoothly transitions from 1.0 in the interior
    to 0.0 at the edges over `feather_size` pixels.
    """
    mask = np.ones((height, width), dtype=np.float32)
    if feather_size <= 0:
        return mask
    ramp = np.linspace(0.0, 1.0, feather_size, dtype=np.float32)
    # Top / Bottom
    mask[:feather_size, :]  *= ramp[:, None]
    mask[-feather_size:, :] *= ramp[::-1, None]
    # Left / Right
    mask[:, :feather_size]  *= ramp[None, :]
    mask[:, -feather_size:] *= ramp[None, ::-1]
    return mask

def _to_divisible_by(img, N):
    """Crop so width and height are divisible by N (top-left anchored)."""
    w, h = img.size
    W = (w // N) * N
    H = (h // N) * N
    if W == 0 or H == 0:
        raise ValueError("N is larger than the image grid (image too small).")
    if W != w or H != h:
        img = img.crop((0, 0, W, H))
    return img, W, H

def _edgelogic(i, j, ph, pw, N, overlap):
    """
    Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
    Extend with overlap, biasing inward. Uses 2*overlap at borders.
    """
    start_h = i * ph
    start_w = j * pw
    end_h   = start_h + ph
    end_w   = start_w + pw

    if overlap <= 0:
        return start_h, end_h, start_w, end_w

    # Vertical
    if i == 0:
        end_h += 2 * overlap
    elif i == N - 1:
        start_h -= 2 * overlap
    else:
        start_h -= overlap
        end_h   += overlap

    # Horizontal
    if j == 0:
        end_w += 2 * overlap
    elif j == N - 1:
        start_w -= 2 * overlap
    else:
        start_w -= overlap
        end_w   += overlap

    return start_h, end_h, start_w, end_w

def spm_augment(
    image,
    num_patches=4,   # N for an N×N grid
    mix_prob=0.5,
    beta_a=2.0,
    beta_b=2.0,
    overlap_pct=0.0,   # percentage of patch size (0..49 typically)
    seed=None
):
    """
    SPM-style augmentation with optional overlap + feathered blending.

    When overlap_pct <= 0:
      - Standard global shuffle over N×N patches;
      - Per-patch mixing with a single alpha ~ Beta(a,b) for the image.

    When overlap_pct > 0:
      - Each base cell (N×N grid) expands by ±overlap_px (derived from percentage),
        clipped to the image. Patches are mixed per location with per-patch alpha.
      - Patches are blended into the canvas with a feather mask of size overlap_px.
    """
    # Normalize to PIL and ensure divisibility
    if isinstance(image, np.ndarray):
        img = Image.fromarray(image).convert("RGB")
    else:
        img = image.convert("RGB")

    N = int(num_patches)
    rng = np.random.default_rng(seed)

    img, W, H = _to_divisible_by(img, N)
    arr_u8 = np.array(img, dtype=np.uint8)
    ph = H // N
    pw = W // N

    # Convert percentage to pixel overlap; clamp to < half patch size
    pct = float(overlap_pct)
    pct = max(0.0, min(pct, 49.0))  # keep below 50% for stability
    overlap_px = int(round((pct / 100.0) * min(ph, pw)))
    max_ov = max(0, min(ph, pw) // 2 - 1)
    ov = int(np.clip(overlap_px, 0, max_ov))

    if ov <= 0:
        # === Non-overlap path ===
        arr = arr_u8
        # Build patches (row-major)
        patches = []
        for i in range(N):
            for j in range(N):
                y0 = i * ph
                x0 = j * pw
                patches.append(arr[y0:y0+ph, x0:x0+pw])

        total = N * N
        perm = rng.permutation(total)

        # One alpha per image
        if beta_a > 0 and beta_b > 0:
            alpha = float(rng.beta(beta_a, beta_b))
        else:
            alpha = 1.0

        out = arr.copy()
        mask = rng.random(total) < float(mix_prob)
        idx = 0
        for i in range(N):
            for j in range(N):
                y0 = i * ph
                x0 = j * pw
                if mask[idx]:
                    src = patches[idx].astype(np.float32)
                    shf = patches[perm[idx]].astype(np.float32)
                    if 0.0 < alpha < 1.0:
                        mixed = alpha * shf + (1.0 - alpha) * src
                        out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
                    else:
                        out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
                else:
                    out[y0:y0+ph, x0:x0+pw] = patches[idx]
                idx += 1

        return Image.fromarray(out)

    # === Overlap path with feather blending ===
    arr = arr_u8.astype(np.float32)
    # Precompute feather mask for max size patch
    feather_full = create_feather_mask(ph + 2*ov, pw + 2*ov, feather_size=ov)

    patches = []
    coords = []
    for i in range(N):
        for j in range(N):
            sh, eh, sw, ew = _edgelogic(i, j, ph, pw, N, ov)
            # Clamp to image bounds
            sh = max(0, sh); sw = max(0, sw)
            eh = min(H, eh); ew = min(W, ew)
            patches.append(arr[sh:eh, sw:ew])
            coords.append((sh, eh, sw, ew))

    total = len(patches)
    perm = rng.permutation(total)

    def sample_alpha():
        if beta_a > 0 and beta_b > 0:
            return float(rng.beta(beta_a, beta_b))
        return 1.0

    canvas = np.zeros_like(arr, dtype=np.float32)
    weight = np.zeros((H, W), dtype=np.float32)

    for k, (sh, eh, sw, ew) in enumerate(coords):
        if rng.random() >= float(mix_prob):
            patch = patches[k]
        else:
            lam = sample_alpha()
            src = patches[k].astype(np.float32)
            shf = patches[int(perm[k])].astype(np.float32)
            patch = lam * shf + (1.0 - lam) * src

        ph_k, pw_k, _ = patch.shape
        mask2d = feather_full[:ph_k, :pw_k]
        mask3d = mask2d[..., None] if arr.shape[2] == 1 else np.repeat(mask2d[..., None], arr.shape[2], axis=2)

        canvas[sh:eh, sw:ew] += patch * mask3d
        weight[sh:eh, sw:ew] += mask2d

    weight = np.clip(weight, 1e-8, None)
    out = (canvas / weight[..., None])
    out = np.clip(out, 0, 255).astype(np.uint8)
    return Image.fromarray(out)