diff options
author | 2018-06-12 14:03:39 -0700 | |
---|---|---|
committer | 2018-06-12 14:08:19 -0700 | |
commit | abfdf45dcdfe366376d859bf29166c0ad16d9993 (patch) | |
tree | f6511da4fb72630f50e4c64b7cc93092c0abbbb7 /tensorflow/python/keras | |
parent | 9c7ba7503402bd02045f2464ef315db69699d6a9 (diff) |
Minor fixes in tf.keras codebase in preparation for Keras 2.2.0 API support.
PiperOrigin-RevId: 200276422
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/activations.py | 71 | ||||
-rw-r--r-- | tensorflow/python/keras/backend.py | 53 | ||||
-rw-r--r-- | tensorflow/python/keras/callbacks.py | 6 | ||||
-rw-r--r-- | tensorflow/python/keras/callbacks_test.py | 22 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_arrays.py | 3 | ||||
-rw-r--r-- | tensorflow/python/keras/layers/convolutional.py | 24 | ||||
-rw-r--r-- | tensorflow/python/keras/layers/merge.py | 4 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/data_utils.py | 4 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/io_utils.py | 5 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/io_utils_test.py | 24 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/multi_gpu_utils.py | 2 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/vis_utils.py | 1 |
12 files changed, 168 insertions, 51 deletions
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py index a62dadb830..e487f583be 100644 --- a/tensorflow/python/keras/activations.py +++ b/tensorflow/python/keras/activations.py @@ -32,7 +32,7 @@ def softmax(x, axis=-1): """Softmax activation function. Arguments: - x : Tensor. + x : Input tensor. axis: Integer, axis along which the softmax normalization is applied. Returns: @@ -49,23 +49,45 @@ def softmax(x, axis=-1): s = math_ops.reduce_sum(e, axis=axis, keepdims=True) return e / s else: - raise ValueError('Cannot apply softmax to a tensor that is 1D') + raise ValueError('Cannot apply softmax to a tensor that is 1D. ' + 'Received input: %s' % (x,)) @tf_export('keras.activations.elu') def elu(x, alpha=1.0): + """Exponential linear unit. + + Arguments: + x: Input tensor. + alpha: A scalar, slope of negative section. + + Returns: + The exponential linear activation: `x` if `x > 0` and + `alpha * (exp(x)-1)` if `x < 0`. + + Reference: + - [Fast and Accurate Deep Network Learning by Exponential + Linear Units (ELUs)](https://arxiv.org/abs/1511.07289) + """ return K.elu(x, alpha) @tf_export('keras.activations.selu') def selu(x): - """Scaled Exponential Linear Unit. (Klambauer et al., 2017). + """Scaled Exponential Linear Unit (SELU). + + SELU is equal to: `scale * elu(x, alpha)`, where alpha and scale + are pre-defined constants. The values of `alpha` and `scale` are + chosen so that the mean and variance of the inputs are preserved + between two consecutive layers as long as the weights are initialized + correctly (see `lecun_normal` initialization) and the number of inputs + is "large enough" (see references for more information). Arguments: x: A tensor or variable to compute the activation function for. Returns: - Tensor with the same shape and dtype as `x`. + The scaled exponential unit activation: `scale * elu(x, alpha)`. # Note - To be used together with the initialization "lecun_normal". @@ -79,16 +101,44 @@ def selu(x): @tf_export('keras.activations.softplus') def softplus(x): + """Softplus activation function. + + Arguments: + x: Input tensor. + + Returns: + The softplus activation: `log(exp(x) + 1)`. + """ return nn.softplus(x) @tf_export('keras.activations.softsign') def softsign(x): + """Softsign activation function. + + Arguments: + x: Input tensor. + + Returns: + The softplus activation: `x / (abs(x) + 1)`. + """ return nn.softsign(x) @tf_export('keras.activations.relu') def relu(x, alpha=0., max_value=None): + """Rectified Linear Unit. + + Arguments: + x: Input tensor. + alpha: Slope of the negative part. Defaults to zero. + max_value: Maximum value for the output. + + Returns: + The (leaky) rectified linear unit activation: `x` if `x > 0`, + `alpha * x` if `x < 0`. If `max_value` is defined, the result + is truncated to this value. + """ return K.relu(x, alpha=alpha, max_value=max_value) @@ -104,6 +154,19 @@ def sigmoid(x): @tf_export('keras.activations.hard_sigmoid') def hard_sigmoid(x): + """Hard sigmoid activation function. + + Faster to compute than sigmoid activation. + + Arguments: + x: Input tensor. + + Returns: + Hard sigmoid activation: + - `0` if `x < -2.5` + - `1` if `x > 2.5` + - `0.2 * x + 0.5` if `-2.5 <= x <= 2.5`. + """ return K.hard_sigmoid(x) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 2a4a1c861c..84821918bf 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -2973,30 +2973,29 @@ def rnn(step_function, Arguments: step_function: RNN step function. - Parameters; - input; tensor with shape `(samples, ...)` (no time dimension), + Args; + input; Tensor with shape `(samples, ...)` (no time dimension), representing input for the batch of samples at a certain time step. - states; list of tensors. + states; List of tensors. Returns; - output; tensor with shape `(samples, output_dim)` + output; Tensor with shape `(samples, output_dim)` (no time dimension). - new_states; list of tensors, same length and shapes + new_states; List of tensors, same length and shapes as 'states'. The first state in the list must be the output tensor at the previous timestep. - inputs: tensor of temporal data of shape `(samples, time, ...)` + inputs: Tensor of temporal data of shape `(samples, time, ...)` (at least 3D). - initial_states: tensor with shape (samples, output_dim) + initial_states: Tensor with shape `(samples, output_dim)` (no time dimension), containing the initial values for the states used in the step function. - go_backwards: boolean. If True, do the iteration over the time + go_backwards: Boolean. If True, do the iteration over the time dimension in reverse order and return the reversed sequence. - mask: binary tensor with shape `(samples, time, 1)`, + mask: Binary tensor with shape `(samples, time, 1)`, with a zero for every element that is masked. - constants: a list of constant values passed at each step. - unroll: whether to unroll the RNN or to use a symbolic loop - (`while_loop` or `scan` depending on backend). + constants: List of constant values passed at each step. + unroll: Whether to unroll the RNN or to use a symbolic `while_loop`. input_length: If specified, assume time dimension is of this length. Returns: @@ -3637,12 +3636,12 @@ def _preprocess_conv1d_input(x, data_format): Returns: A tensor. """ - tf_data_format = 'NHWC' # to pass TF Conv2dNative operations + tf_data_format = 'NWC' # to pass TF Conv2dNative operations if data_format == 'channels_first': if not _has_nchw_support(): x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC else: - tf_data_format = 'NCHW' + tf_data_format = 'NCW' return x, tf_data_format @@ -3741,10 +3740,8 @@ def conv1d(x, x = temporal_padding(x, (left_pad, 0)) padding = 'valid' padding = _preprocess_padding(padding) - if data_format == 'channels_last': - tf_data_format = 'NWC' - else: - tf_data_format = 'NCW' + + x, tf_data_format = _preprocess_conv1d_input(x, data_format) x = nn.convolution( input=x, filter=kernel, @@ -3752,6 +3749,8 @@ def conv1d(x, strides=(strides,), padding=padding, data_format=tf_data_format) + if data_format == 'channels_first' and tf_data_format == 'NWC': + x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW return x @@ -3892,11 +3891,16 @@ def separable_conv1d(x, if data_format not in {'channels_first', 'channels_last'}: raise ValueError('Unknown data_format: ' + str(data_format)) + if isinstance(strides, int): + strides = (strides,) + if isinstance(dilation_rate, int): + dilation_rate = (dilation_rate,) + x, tf_data_format = _preprocess_conv1d_input(x, data_format) padding = _preprocess_padding(padding) if not isinstance(strides, tuple): strides = tuple(strides) - if tf_data_format == 'NHWC': + if tf_data_format == 'NWC': spatial_start_dim = 1 strides = (1,) + strides * 2 + (1,) else: @@ -3918,7 +3922,7 @@ def separable_conv1d(x, x = array_ops.squeeze(x, [spatial_start_dim]) - if data_format == 'channels_first' and tf_data_format == 'NHWC': + if data_format == 'channels_first' and tf_data_format == 'NWC': x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW return x @@ -4717,8 +4721,13 @@ def foldr(fn, elems, initializer=None, name=None): # Load Keras default configuration from config file if present. -_keras_base_dir = os.path.expanduser('~') -_keras_dir = os.path.join(_keras_base_dir, '.keras') +# Set Keras base dir path given KERAS_HOME env variable, if applicable. +# Otherwise either ~/.keras or /tmp. +if 'KERAS_HOME' in os.environ: + _keras_dir = os.environ.get('KERAS_HOME') +else: + _keras_base_dir = os.path.expanduser('~') + _keras_dir = os.path.join(_keras_base_dir, '.keras') _config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json')) if os.path.exists(_config_path): try: diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 8061d47295..70b6a8431a 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -635,7 +635,11 @@ class LearningRateScheduler(Callback): def on_epoch_begin(self, epoch, logs=None): if not hasattr(self.model.optimizer, 'lr'): raise ValueError('Optimizer must have a "lr" attribute.') - lr = self.schedule(epoch) + try: # new API + lr = float(K.get_value(self.model.optimizer.lr)) + lr = self.schedule(epoch, lr) + except TypeError: # Support for old API for backward compatibility + lr = self.schedule(epoch) if not isinstance(lr, (float, np.float32, np.float64)): raise ValueError('The output of the "schedule" function ' 'should be float.') diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index ad5f416b22..b355f4a269 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -321,8 +321,26 @@ class KerasCallbacksTest(test.TestCase): callbacks=cbks, epochs=5, verbose=0) - assert (float(keras.backend.get_value(model.optimizer.lr)) - 0.2 - ) < keras.backend.epsilon() + assert ( + float(keras.backend.get_value( + model.optimizer.lr)) - 0.2) < keras.backend.epsilon() + + cbks = [keras.callbacks.LearningRateScheduler(lambda x, lr: lr / 2)] + model.compile( + loss='categorical_crossentropy', + optimizer='sgd', + metrics=['accuracy']) + model.fit( + x_train, + y_train, + batch_size=BATCH_SIZE, + validation_data=(x_test, y_test), + callbacks=cbks, + epochs=2, + verbose=0) + assert ( + float(keras.backend.get_value( + model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon() def test_ReduceLROnPlateau(self): with self.test_session(): diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py index 93f4f1bd1d..281ad9bd50 100644 --- a/tensorflow/python/keras/engine/training_arrays.py +++ b/tensorflow/python/keras/engine/training_arrays.py @@ -185,6 +185,7 @@ def fit_loop(model, callbacks.on_epoch_begin(epoch) epoch_logs = {} if steps_per_epoch is not None: + # Step-wise fit loop. for step_index in range(steps_per_epoch): batch_logs = {} batch_logs['batch'] = step_index @@ -215,7 +216,6 @@ def fit_loop(model, val_inputs, val_targets, sample_weights=val_sample_weights, - batch_size=batch_size, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): @@ -224,6 +224,7 @@ def fit_loop(model, for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o else: + # Sample-wise fit loop. if shuffle == 'batch': index_array = training_utils.batch_shuffle(index_array, batch_size) elif shuffle: diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 9ea341139e..720b386c4d 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -382,11 +382,11 @@ class Conv2D(Conv): filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. + specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying @@ -613,11 +613,11 @@ class Conv2DTranspose(Conv2D): filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. + specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying @@ -1452,11 +1452,11 @@ class SeparableConv2D(SeparableConv): filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. + specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying @@ -1596,11 +1596,11 @@ class DepthwiseConv2D(Conv2D): Arguments: kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the width and height. + specifying the strides of the convolution along the height and width. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying @@ -2007,7 +2007,7 @@ class ZeroPadding2D(Layer): Arguments: padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. - If int: the same symmetric padding - is applied to width and height. + is applied to height and width. - If tuple of 2 ints: interpreted as two different symmetric padding values for height and width: @@ -2106,7 +2106,7 @@ class ZeroPadding3D(Layer): Arguments: padding: int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints. - If int: the same symmetric padding - is applied to width and height. + is applied to height and width. - If tuple of 3 ints: interpreted as two different symmetric padding values for height and width: @@ -2266,12 +2266,12 @@ class Cropping1D(Layer): class Cropping2D(Layer): """Cropping layer for 2D input (e.g. picture). - It crops along spatial dimensions, i.e. width and height. + It crops along spatial dimensions, i.e. height and width. Arguments: cropping: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. - If int: the same symmetric cropping - is applied to width and height. + is applied to height and width. - If tuple of 2 ints: interpreted as two different symmetric cropping values for height and width: diff --git a/tensorflow/python/keras/layers/merge.py b/tensorflow/python/keras/layers/merge.py index 683e3e0ed1..770665c5fb 100644 --- a/tensorflow/python/keras/layers/merge.py +++ b/tensorflow/python/keras/layers/merge.py @@ -446,8 +446,8 @@ class Concatenate(_Merge): class Dot(_Merge): """Layer that computes a dot product between samples in two tensors. - E.g. if applied to two tensors `a` and `b` of shape `(batch_size, n)`, - the output will be a tensor of shape `(batch_size, 1)` + E.g. if applied to a list of two tensors `a` and `b` of shape + `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)` where each entry `i` will be the dot product between `a[i]` and `b[i]`. diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py index a1f89d9d43..c1ee34ae46 100644 --- a/tensorflow/python/keras/utils/data_utils.py +++ b/tensorflow/python/keras/utils/data_utils.py @@ -324,12 +324,12 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): class Sequence(object): """Base object for fitting to a sequence of data, such as a dataset. - Every `Sequence` must implements the `__getitem__` and the `__len__` methods. + Every `Sequence` must implement the `__getitem__` and the `__len__` methods. If you want to modify your dataset between epochs you may implement `on_epoch_end`. The method `__getitem__` should return a complete batch. - # Notes + Notes: `Sequence` are a safer way to do multiprocessing. This structure guarantees that the network will only train once diff --git a/tensorflow/python/keras/utils/io_utils.py b/tensorflow/python/keras/utils/io_utils.py index f82e3277de..62674a9c77 100644 --- a/tensorflow/python/keras/utils/io_utils.py +++ b/tensorflow/python/keras/utils/io_utils.py @@ -102,13 +102,12 @@ class HDF5Matrix(object): idx = (self.start + key).tolist() else: raise IndexError - elif isinstance(key, list): + else: + # Assume list/iterable if max(key) + self.start < self.end: idx = [x + self.start for x in key] else: raise IndexError - else: - raise IndexError if self.normalizer is not None: return self.normalizer(self.data[idx]) else: diff --git a/tensorflow/python/keras/utils/io_utils_test.py b/tensorflow/python/keras/utils/io_utils_test.py index 3895dca68e..81bb661edd 100644 --- a/tensorflow/python/keras/utils/io_utils_test.py +++ b/tensorflow/python/keras/utils/io_utils_test.py @@ -22,6 +22,7 @@ import os import shutil import numpy as np +import six from tensorflow.python import keras from tensorflow.python.platform import test @@ -95,6 +96,29 @@ class TestIOUtils(test.TestCase): self.assertEqual(out_eval.shape, ()) self.assertGreater(out_eval, 0) + # test slicing for shortened array + self.assertEqual(len(x_train[0:]), len(x_train)) + + # test __getitem__ invalid use cases + with self.assertRaises(IndexError): + _ = x_train[1000] + with self.assertRaises(IndexError): + _ = x_train[1000: 1001] + with self.assertRaises(IndexError): + _ = x_train[[1000, 1001]] + with self.assertRaises(IndexError): + _ = x_train[six.moves.range(1000, 1001)] + with self.assertRaises(IndexError): + _ = x_train[np.array([1000])] + with self.assertRaises(TypeError): + _ = x_train[None] + + # test normalizer + normalizer = lambda x: x + 1 + normalized_x_train = keras.utils.io_utils.HDF5Matrix( + h5_path, 'my_data', start=0, end=150, normalizer=normalizer) + self.assertAllClose(normalized_x_train[0][0], x_train[0][0] + 1) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py index e5442f04e3..e1c49bc852 100644 --- a/tensorflow/python/keras/utils/multi_gpu_utils.py +++ b/tensorflow/python/keras/utils/multi_gpu_utils.py @@ -196,7 +196,7 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False): batch_size = shape[:1] input_shape = shape[1:] step = batch_size // parts - if i == num_gpus - 1: + if i == parts - 1: size = batch_size - step * i else: size = step diff --git a/tensorflow/python/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py index 8007df4622..7a454ac831 100644 --- a/tensorflow/python/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/utils/vis_utils.py @@ -77,7 +77,6 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'): if isinstance(model, Sequential): if not model.built: model.build() - model = model.model layers = model.layers # Create graph nodes. |