File size: 6,436 Bytes
5dbf895
 
 
c3d8cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dbf895
 
 
 
 
 
 
 
 
 
 
c3d8cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dbf895
 
 
 
 
 
c3d8cc1
5dbf895
 
 
c3d8cc1
 
 
 
 
 
 
 
 
 
 
5dbf895
c3d8cc1
5dbf895
 
 
 
 
 
 
 
 
c3d8cc1
5dbf895
 
 
c3d8cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dbf895
c3d8cc1
5dbf895
 
c3d8cc1
 
 
 
 
 
5dbf895
c3d8cc1
5dbf895
 
c3d8cc1
 
 
 
 
5dbf895
c3d8cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dbf895
c3d8cc1
 
 
 
5dbf895
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from PIL import Image
import numpy as np

def create_feather_mask(height, width, feather_size=4):
    """
    2D mask HxW that smoothly transitions from 1.0 in the interior
    to 0.0 at the edges over `feather_size` pixels.
    """
    mask = np.ones((height, width), dtype=np.float32)
    if feather_size <= 0:
        return mask
    ramp = np.linspace(0.0, 1.0, feather_size, dtype=np.float32)
    # Top / Bottom
    mask[:feather_size, :]  *= ramp[:, None]
    mask[-feather_size:, :] *= ramp[::-1, None]
    # Left / Right
    mask[:, :feather_size]  *= ramp[None, :]
    mask[:, -feather_size:] *= ramp[None, ::-1]
    return mask

def _to_divisible_by(img, N):
    """Crop so width and height are divisible by N (top-left anchored)."""
    w, h = img.size
    W = (w // N) * N
    H = (h // N) * N
    if W == 0 or H == 0:
        raise ValueError("N is larger than the image grid (image too small).")
    if W != w or H != h:
        img = img.crop((0, 0, W, H))
    return img, W, H

def _edgelogic(i, j, ph, pw, N, overlap):
    """
    Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
    Extend with overlap, biasing inward.
    Uses 2*overlap for edges to keep patch areas roughly comparable.
    Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds.
    """
    start_h = i * ph
    start_w = j * pw
    end_h   = start_h + ph
    end_w   = start_w + pw

    if overlap <= 0:
        return start_h, end_h, start_w, end_w

    # Vertical
    if i == 0:
        end_h += 2 * overlap
    elif i == N - 1:
        start_h -= 2 * overlap
    else:
        start_h -= overlap
        end_h   += overlap

    # Horizontal
    if j == 0:
        end_w += 2 * overlap
    elif j == N - 1:
        start_w -= 2 * overlap
    else:
        start_w -= overlap
        end_w   += overlap

    return start_h, end_h, start_w, end_w

def spm_augment(
    image,
    num_patches=4,   # N for an N×N grid
    mix_prob=0.5,
    beta_a=2.0,
    beta_b=2.0,
    overlap_px=0,
    seed=None
):
    """
    SPM-style augmentation with optional overlap + feathered blending.

    When overlap_px <= 0:
      - Standard global shuffle over N×N patches;
      - Per-patch mixing with a single alpha ~ Beta(a,b) for the image.

    When overlap_px > 0:
      - Each base cell (N×N grid) expands by +/-overlap_px (2*overlap at borders),
        clipped to the image. Patches are mixed per location and alpha sampled per-patch
        for a bit more stochasticity (can be changed to per-image alpha by editing below).
      - Patches are blended into the canvas with a feather mask of size `overlap_px`.
    """
    # Normalize to PIL and ensure divisibility
    if isinstance(image, np.ndarray):
        img = Image.fromarray(image).convert("RGB")
    else:
        img = image.convert("RGB")

    N = int(num_patches)
    rng = np.random.default_rng(seed)

    img, W, H = _to_divisible_by(img, N)
    arr_u8 = np.array(img, dtype=np.uint8)
    ph = H // N
    pw = W // N

    # Clamp overlap to < half patch size
    if overlap_px is None:
        overlap_px = 0
    overlap_px = int(overlap_px)
    max_ov = max(0, min(ph, pw) // 2 - 1)
    ov = int(np.clip(overlap_px, 0, max_ov))

    if ov <= 0:
        # === Non-overlap path ===
        arr = arr_u8
        # Build patches (row-major)
        patches = []
        for i in range(N):
            for j in range(N):
                y0 = i * ph
                x0 = j * pw
                patches.append(arr[y0:y0+ph, x0:x0+pw])

        total = N * N
        perm = rng.permutation(total)

        # One alpha per image
        if beta_a > 0 and beta_b > 0:
            alpha = float(rng.beta(beta_a, beta_b))
        else:
            alpha = 1.0

        out = arr.copy()
        mask = rng.random(total) < float(mix_prob)
        idx = 0
        for i in range(N):
            for j in range(N):
                y0 = i * ph
                x0 = j * pw
                if mask[idx]:
                    src = patches[idx].astype(np.float32)
                    shf = patches[perm[idx]].astype(np.float32)
                    if 0.0 < alpha < 1.0:
                        mixed = alpha * shf + (1.0 - alpha) * src
                        out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
                    else:
                        out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
                else:
                    out[y0:y0+ph, x0:x0+pw] = patches[idx]
                idx += 1

        return Image.fromarray(out)

    # === Overlap path with feather blending ===
    arr = arr_u8.astype(np.float32)
    # Precompute feather mask for max size patch
    feather_full = create_feather_mask(ph + 2*ov, pw + 2*ov, feather_size=ov)

    patches = []
    coords = []
    for i in range(N):
        for j in range(N):
            sh, eh, sw, ew = _edgelogic(i, j, ph, pw, N, ov)
            # Clamp to image bounds
            sh = max(0, sh); sw = max(0, sw)
            eh = min(H, eh); ew = min(W, ew)
            patches.append(arr[sh:eh, sw:ew])
            coords.append((sh, eh, sw, ew))

    total = len(patches)
    perm = rng.permutation(total)

    # We'll sample alpha per-patch to echo your overlap snippet
    def sample_alpha():
        if beta_a > 0 and beta_b > 0:
            return float(rng.beta(beta_a, beta_b))
        return 1.0

    canvas = np.zeros_like(arr, dtype=np.float32)
    weight = np.zeros((H, W), dtype=np.float32)

    for k, (sh, eh, sw, ew) in enumerate(coords):
        if rng.random() >= float(mix_prob):
            # keep original content in that region
            src = patches[k]
            patch = src
        else:
            lam = sample_alpha()
            src = patches[k].astype(np.float32)
            shf = patches[int(perm[k])].astype(np.float32)
            patch = lam * shf + (1.0 - lam) * src

        ph_k, pw_k, _ = patch.shape
        # Slice feather mask down if needed (near borders)
        mask2d = feather_full[:ph_k, :pw_k]
        if arr.shape[2] == 1:
            mask3d = mask2d[..., None]
        else:
            mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2)

        # Accumulate
        canvas[sh:eh, sw:ew] += patch * mask3d
        weight[sh:eh, sw:ew] += mask2d

    # Normalize
    weight = np.clip(weight, 1e-8, None)
    out = (canvas / weight[..., None])
    out = np.clip(out, 0, 255).astype(np.uint8)
    return Image.fromarray(out)