|
import multiprocessing |
|
import pickle |
|
import time |
|
import traceback |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
from core import mplib |
|
from core.joblib import SubprocessGenerator, ThisThreadGenerator |
|
from facelib import LandmarksProcessor |
|
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, |
|
SampleType) |
|
|
|
|
|
''' |
|
arg |
|
output_sample_types = [ |
|
[SampleProcessor.TypeFlags, size, (optional) {} opts ] , |
|
... |
|
] |
|
''' |
|
class SampleGeneratorFaceDebug(SampleGeneratorBase): |
|
def __init__ (self, samples_path, debug=False, batch_size=1, |
|
random_ct_samples_path=None, |
|
sample_process_options=SampleProcessor.Options(), |
|
output_sample_types=[], |
|
add_sample_idx=False, |
|
generators_count=4, |
|
rnd_seed=None, |
|
**kwargs): |
|
|
|
super().__init__(debug, batch_size) |
|
self.sample_process_options = sample_process_options |
|
self.output_sample_types = output_sample_types |
|
self.add_sample_idx = add_sample_idx |
|
|
|
if rnd_seed is None: |
|
rnd_seed = np.random.randint(0x80000000) |
|
|
|
if self.debug: |
|
self.generators_count = 1 |
|
else: |
|
self.generators_count = max(1, generators_count) |
|
|
|
samples = SampleLoader.load (SampleType.FACE, samples_path) |
|
self.samples_len = len(samples) |
|
|
|
if self.samples_len == 0: |
|
raise ValueError('No training data provided.') |
|
|
|
if random_ct_samples_path is not None: |
|
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) |
|
else: |
|
ct_samples = None |
|
|
|
pickled_samples = pickle.dumps(samples, 4) |
|
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None |
|
|
|
if self.debug: |
|
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, ct_pickled_samples, rnd_seed) )] |
|
else: |
|
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, ct_pickled_samples, rnd_seed+i), start_now=False ) \ |
|
for i in range(self.generators_count) ] |
|
|
|
SubprocessGenerator.start_in_parallel( self.generators ) |
|
|
|
self.generator_counter = -1 |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
self.generator_counter += 1 |
|
generator = self.generators[self.generator_counter % len(self.generators) ] |
|
return next(generator) |
|
|
|
def batch_func(self, param ): |
|
pickled_samples, ct_pickled_samples, rnd_seed = param |
|
|
|
rnd_state = np.random.RandomState(rnd_seed) |
|
|
|
samples = pickle.loads(pickled_samples) |
|
idxs = [*range(len(samples))] |
|
shuffle_idxs = [] |
|
|
|
if ct_pickled_samples is not None: |
|
ct_samples = pickle.loads(ct_pickled_samples) |
|
ct_idxs = [*range(len(ct_samples))] |
|
ct_shuffle_idxs = [] |
|
else: |
|
ct_samples = None |
|
|
|
|
|
bs = self.batch_size |
|
while True: |
|
batches = None |
|
|
|
for n_batch in range(bs): |
|
|
|
if len(shuffle_idxs) == 0: |
|
shuffle_idxs = idxs.copy() |
|
rnd_state.shuffle(shuffle_idxs) |
|
|
|
sample_idx = shuffle_idxs.pop() |
|
sample = samples[sample_idx] |
|
|
|
ct_sample = None |
|
if ct_samples is not None: |
|
if len(ct_shuffle_idxs) == 0: |
|
ct_shuffle_idxs = ct_idxs.copy() |
|
rnd_state.shuffle(ct_shuffle_idxs) |
|
ct_sample_idx = ct_shuffle_idxs.pop() |
|
ct_sample = ct_samples[ct_sample_idx] |
|
|
|
try: |
|
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, rnd_state=rnd_state) |
|
except: |
|
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) |
|
|
|
if batches is None: |
|
batches = [ [] for _ in range(len(x)) ] |
|
if self.add_sample_idx: |
|
batches += [ [] ] |
|
i_sample_idx = len(batches)-1 |
|
|
|
for i in range(len(x)): |
|
batches[i].append ( x[i] ) |
|
|
|
if self.add_sample_idx: |
|
batches[i_sample_idx].append (sample_idx) |
|
|
|
yield [ np.array(batch) for batch in batches] |
|
|