Source code for diluvian.training

# -*- coding: utf-8 -*-
"""Functions for generating training data and training networks."""


from __future__ import division
from __future__ import print_function

import collections
import copy
import itertools
import logging
import random

import matplotlib as mpl
# Use the 'Agg' backend to allow the generation of plots even if no X server
# is available. The matplotlib backend must be set before importing pyplot.
mpl.use('Agg')  # noqa
import matplotlib.pyplot as plt
import numpy as np
import six
from six.moves import range as xrange
import tensorflow as tf
from tqdm import tqdm

import keras.backend as K
from keras.callbacks import (
        Callback,
        EarlyStopping,
        ModelCheckpoint,
        TensorBoard,
        )

from .config import CONFIG
from .network import compile_network, load_model, make_parallel
from .util import (
        get_color_shader,
        get_function,
        pad_dims,
        Roundrobin,
        WrappedViewer,
        write_keras_history_to_csv,
        )
from .volumes import (
        ClipSubvolumeImageGenerator,
        ContrastAugmentGenerator,
        ErodedMaskGenerator,
        GaussianNoiseAugmentGenerator,
        MaskedArtifactAugmentGenerator,
        MirrorAugmentGenerator,
        MissingDataAugmentGenerator,
        partition_volumes,
        PermuteAxesAugmentGenerator,
        RelabelSeedComponentGenerator,
        )
from .regions import (
        Region,
        )


[docs]def plot_history(history): fig = plt.figure() ax = fig.add_subplot(111) ax.plot(history.history['loss']) ax.plot(history.history['val_loss']) ax.plot(history.history['val_subv_metric']) fig.suptitle('model loss') ax.set_ylabel('loss') ax.set_xlabel('epoch') ax.legend(['train', 'validation', 'val subvolumes'], loc='upper right') return fig
[docs]def patch_prediction_copy(model): """Patch a Keras model to copy outputs to a kludge during training. This is necessary for mask updates to a region during training. Parameters ---------- model : keras.engine.Model """ model.train_function = None model.test_function = None model._orig_train_on_batch = model.train_on_batch def train_on_batch(self, x, y, **kwargs): kludge = x.pop('kludge', None) outputs = self._orig_train_on_batch(x, y, **kwargs) kludge['outputs'] = outputs.pop() if len(outputs) == 1: return outputs[0] return outputs model.train_on_batch = six.create_bound_method(train_on_batch, model) model._orig_test_on_batch = model.test_on_batch def test_on_batch(self, x, y, **kwargs): kludge = x.pop('kludge', None) outputs = self._orig_test_on_batch(x, y, **kwargs) kludge['outputs'] = outputs.pop() if len(outputs) == 1: return outputs[0] return outputs model.test_on_batch = six.create_bound_method(test_on_batch, model) # Below is copied and modified from Keras Model._make_train_function. # The only change is the addition of `self.outputs` to the train function. def _make_train_function(self): if not hasattr(self, 'train_function'): raise RuntimeError('You must compile your model before using it.') if self.train_function is None: inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights if self.uses_learning_phase and not isinstance(K.learning_phase(), int): inputs += [K.learning_phase()] with K.name_scope('training'): with K.name_scope(self.optimizer.__class__.__name__): training_updates = self.optimizer.get_updates( params=self._collected_trainable_weights, loss=self.total_loss) updates = self.updates + training_updates # Gets loss and metrics. Updates weights at each call. self.train_function = K.function(inputs, [self.total_loss] + self.metrics_tensors + self.outputs, updates=updates, name='train_function', **self._function_kwargs) model._make_train_function = six.create_bound_method(_make_train_function, model) def _make_test_function(self): if not hasattr(self, 'test_function'): raise RuntimeError('You must compile your model before using it.') if self.test_function is None: inputs = self._feed_inputs + self._feed_targets + self._feed_sample_weights if self.uses_learning_phase and not isinstance(K.learning_phase(), int): inputs += [K.learning_phase()] # Return loss and metrics, no gradient updates. # Does update the network states. self.test_function = K.function(inputs, [self.total_loss] + self.metrics_tensors + self.outputs, updates=self.state_updates, name='test_function', **self._function_kwargs) model._make_test_function = six.create_bound_method(_make_test_function, model)
[docs]class GeneratorReset(Callback): """Keras epoch end callback to reset prediction copy kludges. """ def __init__(self, gens): self.gens = gens
[docs] def on_epoch_end(self, epoch, logs=None): for gen in self.gens: gen.reset()
[docs]class GeneratorSubvolumeMetric(Callback): """Add a data generator's subvolume metric to Keras' metric logs. Parameters ---------- gens : iterable of diluvian.training.MovingTrainingGenerator metric_name : string """ def __init__(self, gens, metric_name): self.gens = gens self.metric_name = metric_name
[docs] def on_epoch_end(self, epoch, logs=None): if self.metric_name not in self.params['metrics']: self.params['metrics'].append(self.metric_name) if logs: metric = np.mean([np.mean(gen.get_epoch_metric()) for gen in self.gens]) logs[self.metric_name] = metric
[docs]class EarlyAbortException(Exception): pass
[docs]class EarlyAbort(Callback): """Keras epoch end callback that aborts if a metric is above a threshold. This is useful when convergence is sensitive to initial conditions and models are obviously not useful to continue training after only a few epochs. Unlike the early stopping callback, this is considered an abnormal termination and throws an exception so that behaviors like restarting with a new random seed are possible. """ def __init__(self, monitor='val_loss', threshold_epoch=None, threshold_value=None): if threshold_epoch is None or threshold_value is None: raise ValueError('Epoch and value to enforce threshold must be provided.') self.monitor = monitor self.threshold_epoch = threshold_epoch - 1 self.threshold_value = threshold_value
[docs] def on_epoch_end(self, epoch, logs=None): if epoch == self.threshold_epoch: current = logs.get(self.monitor) if current >= self.threshold_value: raise EarlyAbortException('Aborted after epoch {} because {} was {} >= {}'.format( self.threshold_epoch, self.monitor, current, self.threshold_value))
[docs]def preprocess_subvolume_generator(subvolume_generator): """Apply non-augmentation preprocessing to a subvolume generator. Parameters ---------- subvolume_generator : diluvian.volumes.SubvolumeGenerator Returns ------- diluvian.volumes.SubvolumeGenerator """ gen = subvolume_generator if np.any(CONFIG.training.label_erosion): gen = ErodedMaskGenerator(gen, CONFIG.training.label_erosion) if CONFIG.training.relabel_seed_component: gen = RelabelSeedComponentGenerator(gen) return gen
[docs]def augment_subvolume_generator(subvolume_generator): """Apply data augmentations to a subvolume generator. Parameters ---------- subvolume_generator : diluvian.volumes.SubvolumeGenerator Returns ------- diluvian.volumes.SubvolumeGenerator """ gen = subvolume_generator for axes in CONFIG.training.augment_permute_axes: gen = PermuteAxesAugmentGenerator(gen, CONFIG.training.augment_use_both, axes) for axis in CONFIG.training.augment_mirrors: gen = MirrorAugmentGenerator(gen, CONFIG.training.augment_use_both, axis) for v in CONFIG.training.augment_noise: gen = GaussianNoiseAugmentGenerator(gen, CONFIG.training.augment_use_both, v['axis'], v['mul'], v['add']) for v in CONFIG.training.augment_artifacts: if 'cache' not in v: v['cache'] = {} gen = MaskedArtifactAugmentGenerator(gen, CONFIG.training.augment_use_both, v['axis'], v['prob'], v['volume_file'], v['cache']) for v in CONFIG.training.augment_missing_data: gen = MissingDataAugmentGenerator(gen, CONFIG.training.augment_use_both, v['axis'], v['prob']) for v in CONFIG.training.augment_contrast: gen = ContrastAugmentGenerator(gen, CONFIG.training.augment_use_both, v['axis'], v['prob'], v['scaling_mean'], v['scaling_std'], v['center_mean'], v['center_std']) gen = ClipSubvolumeImageGenerator(gen) return gen
[docs]class MovingTrainingGenerator(six.Iterator): """Generate Keras moving FOV training tuples from a subvolume generator. This generator expects a subvolume generator that will provide subvolumes larger than the network FOV, and will allow the output of training at one batch to generate moves within these subvolumes to produce training data for the subsequent batch. Parameters ---------- subvolumes : generator of Subvolume batch_size : int kludge : dict A kludge object to allow this generator to provide inputs and receive outputs from the network. See ``diluvian.training.patch_prediction_copy``. f_a_bins : sequence of float, optional Bin boundaries for filling fractions. If provided, sample loss will be weighted to increase loss contribution from less-frequent f_a bins. Otherwise all samples are weighted equally. reset_generators : bool Whether to reset subvolume generators when this generator is reset. If true subvolumes will be sampled in the same order each epoch. subv_per_epoch : int, optional If specified, the generator will only return moves from this many subvolumes before being reset. Once this number of subvolumes is exceeded, the generator will yield garbage batches (this is necessary because Keras currently uses a fixed number of batches per epoch). If specified, once each subvolume is complete its total loss will be calculated. subv_metric_fn : function, option Metric function to run on subvolumes when `subv_per_epoch` is set. subv_metric_threshold : bool, optional Whether to threshold subvolume masks for metrics. subv_metric_args : dict, optional Keyword arguments that will be passed to the subvolume metric. """ def __init__(self, subvolumes, batch_size, kludge, f_a_bins=None, reset_generators=True, subv_per_epoch=None, subv_metric_fn=None, subv_metric_threshold=False, subv_metric_args=None): self.subvolumes = subvolumes self.batch_size = batch_size self.kludge = kludge self.reset_generators = reset_generators self.subv_per_epoch = subv_per_epoch self.subv_metric_fn = subv_metric_fn self.subv_metric_threshold = subv_metric_threshold self.subv_metric_args = subv_metric_args if self.subv_metric_args is None: self.subv_metric_args = {} self.regions = [None] * batch_size self.region_pos = [None] * batch_size self.move_counts = [0] * batch_size self.epoch_move_counts = [] self.epoch_subv_metrics = [] self.epoch_subvolumes = 0 self.batch_image_input = [None] * batch_size self.f_a_bins = f_a_bins self.f_a_init = False if f_a_bins is not None: self.f_a_init = True self.f_a_counts = np.ones_like(f_a_bins, dtype=np.int64) self.f_as = np.zeros(batch_size) self.fake_block = None self.fake_mask = [False] * batch_size def __iter__(self): return self
[docs] def reset(self): self.f_a_init = False if self.reset_generators: self.subvolumes.reset() self.regions = [None] * self.batch_size self.kludge['inputs'] = None self.kludge['outputs'] = None if len(self.epoch_move_counts): logging.info(' Average moves (%s): %s', self.subvolumes.name, sum(self.epoch_move_counts)/float(len(self.epoch_move_counts))) self.epoch_move_counts = [] self.epoch_subvolumes = 0 self.epoch_subv_metrics = [] self.fake_mask = [False] * self.batch_size
[docs] def get_epoch_metric(self): assert len(self.epoch_subv_metrics) == self.subv_per_epoch, \ 'Not all validation subvs completed: {}/{} (Finished moves: {}, ongoing: {})'.format( len(self.epoch_subv_metrics), self.subv_per_epoch, self.epoch_move_counts, self.move_counts) return self.epoch_subv_metrics
def __next__(self): # If in the fixed-subvolumes-per-epoch mode and completed, yield fake # data quickly. if all(self.fake_mask): inputs = collections.OrderedDict({ 'image_input': np.repeat(pad_dims(self.fake_block['image']), CONFIG.training.num_gpus, axis=0), 'mask_input': np.repeat(pad_dims(self.fake_block['mask']), CONFIG.training.num_gpus, axis=0) }) inputs['kludge'] = self.kludge outputs = np.repeat(pad_dims(self.fake_block['target']), CONFIG.training.num_gpus, axis=0) return (inputs, outputs) # Before clearing last batches, reuse them to predict mask outputs # for move training. Add mask outputs to regions. active_regions = [n for n, region in enumerate(self.regions) if region is not None] if active_regions and self.kludge['outputs'] is not None and self.kludge['inputs'] is not None: for n in active_regions: assert np.array_equal(self.kludge['inputs'][n, :], self.batch_image_input[n, 0, 0, :, 0]) self.regions[n].add_mask(self.kludge['outputs'][n, :, :, :, 0], self.region_pos[n]) self.batch_image_input = [None] * self.batch_size batch_mask_input = [None] * self.batch_size batch_mask_target = [None] * self.batch_size for r, region in enumerate(self.regions): block_data = region.get_next_block() if region is not None else None if block_data is None: if self.subv_per_epoch: if region is not None: metric = region.prediction_metric( self.subv_metric_fn, threshold=self.subv_metric_threshold, **self.subv_metric_args) self.epoch_subv_metrics.append(metric) self.regions[r] = None if self.epoch_subvolumes >= self.subv_per_epoch: block_data = self.fake_block self.fake_mask[r] = True while block_data is None: subvolume = six.next(self.subvolumes) self.epoch_subvolumes += 1 self.f_as[r] = subvolume.f_a() self.regions[r] = Region.from_subvolume(subvolume) if region is not None: self.epoch_move_counts.append(self.move_counts[r]) region = self.regions[r] self.move_counts[r] = 0 block_data = region.get_next_block() else: self.move_counts[r] += 1 if self.subv_per_epoch and self.fake_block is None: assert block_data is not None self.fake_block = copy.deepcopy(block_data) self.batch_image_input[r] = pad_dims(block_data['image']) batch_mask_input[r] = pad_dims(block_data['mask']) batch_mask_target[r] = pad_dims(block_data['target']) self.region_pos[r] = block_data['position'] self.batch_image_input = np.concatenate(self.batch_image_input) batch_mask_input = np.concatenate(batch_mask_input) batch_mask_target = np.concatenate(batch_mask_target) inputs = collections.OrderedDict({'image_input': self.batch_image_input, 'mask_input': batch_mask_input}) inputs['kludge'] = self.kludge # These inputs are only necessary for assurance the correct FOV is updated. self.kludge['inputs'] = self.batch_image_input[:, 0, 0, :, 0].copy() self.kludge['outputs'] = None if self.f_a_bins is None: return (inputs, [batch_mask_target]) else: f_a_inds = np.digitize(self.f_as, self.f_a_bins) - 1 inds, counts = np.unique(f_a_inds, return_counts=True) if self.f_a_init: self.f_a_counts[inds] += counts.astype(np.int64) sample_weights = np.ones(self.f_as.size, dtype=np.float64) else: sample_weights = np.reciprocal(self.f_a_counts[f_a_inds], dtype=np.float64) * float(self.f_as.size) return (inputs, [batch_mask_target], sample_weights)
DataGenerator = collections.namedtuple('DataGenerator', ['data', 'gens', 'callbacks', 'steps_per_epoch'])
[docs]def get_output_margin(model_config): return np.floor_divide(model_config.input_fov_shape - model_config.output_fov_shape, 2)
[docs]def build_validation_gen(validation_volumes): output_margin = get_output_margin(CONFIG.model) # If there is only one volume, duplicate since more than one is needed # for Keras queuing. if len(validation_volumes) == 1: single_vol = six.next(six.itervalues(validation_volumes)) validation_volumes = {'dupe {}'.format(n): single_vol for n in range(CONFIG.training.num_workers)} validation_gens = [ preprocess_subvolume_generator( v.subvolume_generator(shape=CONFIG.model.validation_subv_shape, label_margin=output_margin)) for v in six.itervalues(validation_volumes)] if CONFIG.training.augment_validation: validation_gens = list(map(augment_subvolume_generator, validation_gens)) # Divide training generators up for workers. validation_worker_gens = [ validation_gens[i::CONFIG.training.num_workers] for i in xrange(CONFIG.training.num_workers)] # Some workers may not receive any generators. validation_worker_gens = [g for g in validation_worker_gens if len(g) > 0] subv_per_worker = CONFIG.training.validation_size // len(validation_worker_gens) logging.debug('# of validation workers: %s', len(validation_worker_gens)) validation_metric = get_function(CONFIG.training.validation_metric['metric']) validation_kludges = [{'inputs': None, 'outputs': None} for _ in range(CONFIG.training.num_workers)] validation_data = [MovingTrainingGenerator( Roundrobin(*gen, name='validation {}'.format(i)), CONFIG.training.batch_size, kludge, f_a_bins=CONFIG.training.fill_factor_bins, reset_generators=True, subv_per_epoch=subv_per_worker, subv_metric_fn=validation_metric, subv_metric_threshold=CONFIG.training.validation_metric['threshold'], subv_metric_args=CONFIG.training.validation_metric['args']) for i, (gen, kludge) in enumerate(zip(validation_worker_gens, validation_kludges))] callbacks = [] callbacks.append(GeneratorSubvolumeMetric(validation_data, 'val_subv_metric')) callbacks.append(GeneratorReset(validation_data)) VALIDATION_STEPS = np.ceil(CONFIG.training.validation_size / CONFIG.training.batch_size) # Number of all-move sequences must be a multiple of number of worker gens. VALIDATION_STEPS = np.ceil(VALIDATION_STEPS / len(validation_worker_gens)) * len(validation_worker_gens) VALIDATION_STEPS = VALIDATION_STEPS * CONFIG.model.validation_subv_moves + len(validation_worker_gens) VALIDATION_STEPS = VALIDATION_STEPS.astype(np.int64) return DataGenerator( data=validation_data, gens=validation_worker_gens, callbacks=callbacks, steps_per_epoch=VALIDATION_STEPS)
[docs]def build_training_gen(training_volumes): output_margin = get_output_margin(CONFIG.model) # If there is only one volume, duplicate since more than one is needed # for Keras queuing. if len(training_volumes) == 1: single_vol = six.next(six.itervalues(training_volumes)) training_volumes = {'dupe {}'.format(n): single_vol for n in range(CONFIG.training.num_workers)} training_gens = [ augment_subvolume_generator( preprocess_subvolume_generator( v.subvolume_generator(shape=CONFIG.model.training_subv_shape, label_margin=output_margin))) for v in six.itervalues(training_volumes)] random.shuffle(training_gens) # Divide training generators up for workers. worker_gens = [ training_gens[i::CONFIG.training.num_workers] for i in xrange(CONFIG.training.num_workers)] # Some workers may not receive any generators. worker_gens = [g for g in worker_gens if len(g) > 0] logging.debug('# of training workers: %s', len(worker_gens)) kludges = [{'inputs': None, 'outputs': None} for _ in range(CONFIG.training.num_workers)] # Create a training data generator for each worker. training_data = [MovingTrainingGenerator( Roundrobin(*gen, name='training {}'.format(i)), CONFIG.training.batch_size, kludge, f_a_bins=CONFIG.training.fill_factor_bins, reset_generators=CONFIG.training.reset_generators) for i, (gen, kludge) in enumerate(zip(worker_gens, kludges))] training_reset_callback = GeneratorReset(training_data) callbacks = [training_reset_callback] TRAINING_STEPS_PER_EPOCH = CONFIG.training.training_size // CONFIG.training.batch_size return DataGenerator( data=training_data, gens=worker_gens, callbacks=callbacks, steps_per_epoch=TRAINING_STEPS_PER_EPOCH)
[docs]def train_network( model_file=None, volumes=None, model_output_filebase=None, model_checkpoint_file=None, tensorboard=False, viewer=False, metric_plot=False): random.seed(CONFIG.random_seed) tf_device = 'cpu:0' if CONFIG.training.num_gpus > 1 else 'gpu:0' if model_file is None: factory = get_function(CONFIG.network.factory) with tf.device(tf_device): ffn = factory(CONFIG.model.input_fov_shape, CONFIG.model.output_fov_shape, CONFIG.network) else: with tf.device(tf_device): ffn = load_model(model_file, CONFIG.network) # Multi-GPU models are saved as a single-GPU model prior to compilation, # so if loading from such a model file it will need to be recompiled. if not hasattr(ffn, 'optimizer'): if CONFIG.training.num_gpus > 1: ffn = make_parallel(ffn, CONFIG.training.num_gpus) compile_network(ffn, CONFIG.optimizer) patch_prediction_copy(ffn) if model_output_filebase is None: model_output_filebase = 'model_output' if volumes is None: raise ValueError('Volumes must be provided.') CONFIG.to_toml(model_output_filebase + '.toml') training_volumes, validation_volumes = partition_volumes(volumes) num_training = len(training_volumes) num_validation = len(validation_volumes) logging.info('Using {} volumes for training, {} for validation.'.format(num_training, num_validation)) validation = build_validation_gen(validation_volumes) training = build_training_gen(training_volumes) callbacks = [] callbacks.extend(validation.callbacks) callbacks.extend(training.callbacks) validation_mode = CONFIG.training.validation_metric['mode'] if CONFIG.training.early_abort_epoch is not None and \ CONFIG.training.early_abort_loss is not None: callbacks.append(EarlyAbort(threshold_epoch=CONFIG.training.early_abort_epoch, threshold_value=CONFIG.training.early_abort_loss)) callbacks.append(ModelCheckpoint(model_output_filebase + '.hdf5', monitor='val_subv_metric', save_best_only=True, mode=validation_mode)) if model_checkpoint_file: callbacks.append(ModelCheckpoint(model_checkpoint_file)) callbacks.append(EarlyStopping(monitor='val_subv_metric', patience=CONFIG.training.patience, mode=validation_mode)) # Activation histograms and weight images for TensorBoard will not work # because the Keras callback does not currently support validation data # generators. if tensorboard: callbacks.append(TensorBoard()) history = ffn.fit_generator( Roundrobin(*training.data, name='training outer'), steps_per_epoch=training.steps_per_epoch, epochs=CONFIG.training.total_epochs, max_queue_size=len(training.gens) - 1, workers=1, callbacks=callbacks, validation_data=Roundrobin(*validation.data, name='validation outer'), validation_steps=validation.steps_per_epoch) write_keras_history_to_csv(history, model_output_filebase + '.csv') if viewer: viz_ex = itertools.islice(validation.data[0], 1) for inputs, targets in viz_ex: viewer = WrappedViewer(voxel_size=list(np.flipud(CONFIG.volume.resolution))) output_offset = np.array(inputs['image_input'].shape[1:4]) - np.array(targets[0].shape[1:4]) output_offset = np.flipud(output_offset // 2) viewer.add(inputs['image_input'][0, :, :, :, 0], name='Image') viewer.add(inputs['mask_input'][0, :, :, :, 0], name='Mask Input', shader=get_color_shader(2)) viewer.add(targets[0][0, :, :, :, 0], name='Mask Target', shader=get_color_shader(0), voxel_offset=output_offset) output = ffn.predict_on_batch(inputs) viewer.add(output[0, :, :, :, 0], name='Mask Output', shader=get_color_shader(1), voxel_offset=output_offset) viewer.print_view_prompt() if metric_plot: fig = plot_history(history) fig.savefig(model_output_filebase + '.png') return history
[docs]def validate_model(model_file, volumes): from .network import load_model _, volumes = partition_volumes(volumes) validation = build_validation_gen(volumes) tf_device = 'cpu:0' if CONFIG.training.num_gpus > 1 else 'gpu:0' with tf.device(tf_device): model = load_model(model_file, CONFIG.network) # Multi-GPU models are saved as a single-GPU model prior to compilation, # so if loading from such a model file it will need to be recompiled. if not hasattr(model, 'optimizer'): if CONFIG.training.num_gpus > 1: model = make_parallel(model, CONFIG.training.num_gpus) compile_network(model, CONFIG.optimizer) patch_prediction_copy(model) pbar = tqdm(desc='Validation batches', total=validation.steps_per_epoch) finished = [False] * len(validation.gens) for n, data in itertools.cycle(enumerate(validation.data)): if all(finished): break pbar.update(1) if all(data.fake_mask): finished[n] = True continue batch = six.next(data) model.test_on_batch(*batch) pbar.close() metrics = [] for gen in validation.data: metrics.extend(gen.get_epoch_metric()) print('Metric: ', np.mean(metrics)) print('All: ', metrics)