Source code for diluvian.network

# -*- coding: utf-8 -*-
"""Flood-fill network creation and compilation using Keras."""


from __future__ import division

import inspect

import numpy as np
import six

from keras.layers import (
        BatchNormalization,
        Conv3D,
        Conv3DTranspose,
        Cropping3D,
        Dropout,
        Input,
        Lambda,
        Permute,
        )
from keras.layers.merge import (
        add,
        concatenate,
        )
from keras.layers.core import Activation
from keras.models import load_model as keras_load_model, Model
from keras.utils import multi_gpu_model
import keras.optimizers


[docs]def make_flood_fill_network(input_fov_shape, output_fov_shape, network_config): """Construct a stacked convolution module flood filling network. """ if network_config.convolution_padding != 'same': raise ValueError('ResNet implementation only supports same padding.') image_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='image_input') if network_config.rescale_image: ffn = Lambda(lambda x: (x - 0.5) * 2.0)(image_input) else: ffn = image_input mask_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='mask_input') ffn = concatenate([ffn, mask_input]) # Convolve and activate before beginning the skip connection modules, # as discussed in the Appendix of He et al 2016. ffn = Conv3D( network_config.convolution_filters, tuple(network_config.convolution_dim), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding='same')(ffn) if network_config.batch_normalization: ffn = BatchNormalization()(ffn) contraction = (input_fov_shape - output_fov_shape) // 2 if np.any(np.less(contraction, 0)): raise ValueError('Output FOV shape can not be larger than input FOV shape.') contraction_cumu = np.zeros(3, dtype=np.int32) contraction_step = np.divide(contraction, float(network_config.num_modules)) for i in range(0, network_config.num_modules): ffn = add_convolution_module(ffn, network_config) contraction_dims = np.floor(i * contraction_step - contraction_cumu).astype(np.int32) if np.count_nonzero(contraction_dims): ffn = Cropping3D(zip(list(contraction_dims), list(contraction_dims)))(ffn) contraction_cumu += contraction_dims if np.any(np.less(contraction_cumu, contraction)): remainder = contraction - contraction_cumu ffn = Cropping3D(zip(list(remainder), list(remainder)))(ffn) mask_output = Conv3D( 1, tuple(network_config.convolution_dim), kernel_initializer=network_config.initialization, padding='same', name='mask_output', activation=network_config.output_activation)(ffn) ffn = Model(inputs=[image_input, mask_input], outputs=[mask_output]) return ffn
[docs]def add_convolution_module(model, network_config): model2 = model for _ in range(network_config.num_layers_per_module): model2 = Conv3D( network_config.convolution_filters, tuple(network_config.convolution_dim), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding='same')(model2) if network_config.batch_normalization: model2 = BatchNormalization()(model2) model = add([model, model2]) # Note that the activation here differs from He et al 2016, as that # activation is not on the skip connection path. However, this is not # likely to be important, see: # http://torch.ch/blog/2016/02/04/resnets.html # https://github.com/gcr/torch-residual-networks model = Activation(network_config.convolution_activation)(model) if network_config.batch_normalization: model = BatchNormalization()(model) if network_config.dropout_probability > 0.0: model = Dropout(network_config.dropout_probability)(model) return model
[docs]def make_flood_fill_unet(input_fov_shape, output_fov_shape, network_config): """Construct a U-net flood filling network. """ image_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='image_input') if network_config.rescale_image: ffn = Lambda(lambda x: (x - 0.5) * 2.0)(image_input) else: ffn = image_input mask_input = Input(shape=tuple(input_fov_shape) + (1,), dtype='float32', name='mask_input') ffn = concatenate([ffn, mask_input]) # Note that since the Keras 2 upgrade strangely models with depth > 3 are # rejected by TF. ffn = add_unet_layer(ffn, network_config, network_config.unet_depth - 1, output_fov_shape, n_channels=network_config.convolution_filters) mask_output = Conv3D( 1, (1, 1, 1), kernel_initializer=network_config.initialization, padding=network_config.convolution_padding, name='mask_output', activation=network_config.output_activation)(ffn) ffn = Model(inputs=[image_input, mask_input], outputs=[mask_output]) return ffn
[docs]def add_unet_layer(model, network_config, remaining_layers, output_shape, n_channels=None, resolution=None): if n_channels is None: n_channels = model.get_shape().as_list()[-1] if network_config.unet_downsample_mode == "fixed_rate": downsample = np.array([x != 0 and remaining_layers % x == 0 for x in network_config.unet_downsample_rate]) else: resolution = resolution if resolution is not None else network_config.resolution min_res = np.min(resolution) # x < min_res * sqrt(2) because: # if a > sqrt(2)b, then a/b > sqrt(2) and 2b/a < sqrt(2) # if sqrt(2)b > a > b, then 2a/2b < sqrt(2) and 2b/a > sqrt(2) downsample = np.array([x < min_res * (2 ** .5) for x in resolution]) if network_config.convolution_padding == 'same': conv_contract = np.zeros(3, dtype=np.int32) else: conv_contract = network_config.convolution_dim - 1 # First U convolution module. for i in range(network_config.num_layers_per_module): if i == network_config.num_layers_per_module - 1: # Increase the number of channels before downsampling to avoid # bottleneck (identical to 3D U-Net paper). n_channels = 2 * n_channels model = Conv3D( n_channels, tuple(network_config.convolution_dim), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding=network_config.convolution_padding)(model) if network_config.batch_normalization: model = BatchNormalization()(model) # Crop and pass forward to upsampling. if remaining_layers > 0: forward_link_shape = output_shape + network_config.num_layers_per_module * conv_contract else: forward_link_shape = output_shape contraction = (np.array(model.get_shape().as_list()[1:4]) - forward_link_shape) // 2 forward = Cropping3D(list(zip(list(contraction), list(contraction))))(model) if network_config.dropout_probability > 0.0: forward = Dropout(network_config.dropout_probability)(forward) # Terminal layer of the U. if remaining_layers <= 0: return forward # Downsample and recurse. model = Conv3D( n_channels, tuple(network_config.convolution_dim), strides=list(downsample + 1), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding='same')(model) if network_config.batch_normalization: model = BatchNormalization()(model) next_output_shape = np.ceil(np.divide(forward_link_shape, downsample.astype(np.float32) + 1.0)).astype(np.int32) if network_config.unet_downsample_mode == "fixed_rate": model = add_unet_layer(model, network_config, remaining_layers - 1, next_output_shape.astype(np.int32)) else: model = add_unet_layer(model, network_config, remaining_layers - 1, next_output_shape.astype(np.int32), None, resolution * (downsample + 1)) # Upsample output of previous layer and merge with forward link. model = Conv3DTranspose( n_channels * 2, tuple(network_config.convolution_dim), strides=list(downsample + 1), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding='same')(model) if network_config.batch_normalization: model = BatchNormalization()(model) # Must crop output because Keras wrongly pads the output shape for odd array sizes. stride_pad = (network_config.convolution_dim // 2) * np.array(downsample) + (1 - np.mod(forward_link_shape, 2)) tf_pad_start = stride_pad // 2 # Tensorflow puts odd padding at end. model = Cropping3D(list(zip(list(tf_pad_start), list(stride_pad - tf_pad_start))))(model) model = concatenate([forward, model]) # Second U convolution module. for _ in range(network_config.num_layers_per_module): model = Conv3D( n_channels, tuple(network_config.convolution_dim), kernel_initializer=network_config.initialization, activation=network_config.convolution_activation, padding=network_config.convolution_padding)(model) if network_config.batch_normalization: model = BatchNormalization()(model) return model
[docs]def compile_network(model, optimizer_config): optimizer_klass = getattr(keras.optimizers, optimizer_config.klass) optimizer_kwargs = inspect.getargspec(optimizer_klass.__init__)[0] optimizer_kwargs = {k: v for k, v in six.iteritems(optimizer_config.__dict__) if k in optimizer_kwargs} optimizer = optimizer_klass(**optimizer_kwargs) model.compile(loss=optimizer_config.loss, optimizer=optimizer)
[docs]def load_model(model_file, network_config): model = keras_load_model(model_file) # If necessary, wrap the loaded model to transpose the axes for both # inputs and outputs. if network_config.transpose: inputs = [] perms = [] for old_input in model.input_layers: input_shape = np.asarray(old_input.input_shape)[[3, 2, 1, 4]] new_input = Input(shape=tuple(input_shape), dtype=old_input.input_dtype, name=old_input.name) perm = Permute((3, 2, 1, 4), input_shape=tuple(input_shape))(new_input) inputs.append(new_input) perms.append(perm) old_outputs = model(perms) if not isinstance(old_outputs, list): old_outputs = [old_outputs] outputs = [] for old_output in old_outputs: new_output = Permute((3, 2, 1, 4))(old_output) outputs.append(new_output) new_model = Model(input=inputs, output=outputs) # Monkeypatch the save to save just the underlying model. func_type = type(model.save) old_model = model def new_save(_, *args, **kwargs): old_model.save(*args, **kwargs) new_model.save = func_type(new_save, new_model) model = new_model return model
[docs]def make_parallel(model, gpus=None): new_model = multi_gpu_model(model, gpus) func_type = type(model.save) # monkeypatch the save to save just the underlying model def new_save(_, *args, **kwargs): model.save(*args, **kwargs) new_model.save = func_type(new_save, new_model) return new_model