aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-06-12 14:03:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 14:08:19 -0700
commitabfdf45dcdfe366376d859bf29166c0ad16d9993 (patch)
treef6511da4fb72630f50e4c64b7cc93092c0abbbb7 /tensorflow/python/keras
parent9c7ba7503402bd02045f2464ef315db69699d6a9 (diff)
Minor fixes in tf.keras codebase in preparation for Keras 2.2.0 API support.
PiperOrigin-RevId: 200276422
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/activations.py71
-rw-r--r--tensorflow/python/keras/backend.py53
-rw-r--r--tensorflow/python/keras/callbacks.py6
-rw-r--r--tensorflow/python/keras/callbacks_test.py22
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py3
-rw-r--r--tensorflow/python/keras/layers/convolutional.py24
-rw-r--r--tensorflow/python/keras/layers/merge.py4
-rw-r--r--tensorflow/python/keras/utils/data_utils.py4
-rw-r--r--tensorflow/python/keras/utils/io_utils.py5
-rw-r--r--tensorflow/python/keras/utils/io_utils_test.py24
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils.py2
-rw-r--r--tensorflow/python/keras/utils/vis_utils.py1
12 files changed, 168 insertions, 51 deletions
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index a62dadb830..e487f583be 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -32,7 +32,7 @@ def softmax(x, axis=-1):
"""Softmax activation function.
Arguments:
- x : Tensor.
+ x : Input tensor.
axis: Integer, axis along which the softmax normalization is applied.
Returns:
@@ -49,23 +49,45 @@ def softmax(x, axis=-1):
s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
return e / s
else:
- raise ValueError('Cannot apply softmax to a tensor that is 1D')
+ raise ValueError('Cannot apply softmax to a tensor that is 1D. '
+ 'Received input: %s' % (x,))
@tf_export('keras.activations.elu')
def elu(x, alpha=1.0):
+ """Exponential linear unit.
+
+ Arguments:
+ x: Input tensor.
+ alpha: A scalar, slope of negative section.
+
+ Returns:
+ The exponential linear activation: `x` if `x > 0` and
+ `alpha * (exp(x)-1)` if `x < 0`.
+
+ Reference:
+ - [Fast and Accurate Deep Network Learning by Exponential
+ Linear Units (ELUs)](https://arxiv.org/abs/1511.07289)
+ """
return K.elu(x, alpha)
@tf_export('keras.activations.selu')
def selu(x):
- """Scaled Exponential Linear Unit. (Klambauer et al., 2017).
+ """Scaled Exponential Linear Unit (SELU).
+
+ SELU is equal to: `scale * elu(x, alpha)`, where alpha and scale
+ are pre-defined constants. The values of `alpha` and `scale` are
+ chosen so that the mean and variance of the inputs are preserved
+ between two consecutive layers as long as the weights are initialized
+ correctly (see `lecun_normal` initialization) and the number of inputs
+ is "large enough" (see references for more information).
Arguments:
x: A tensor or variable to compute the activation function for.
Returns:
- Tensor with the same shape and dtype as `x`.
+ The scaled exponential unit activation: `scale * elu(x, alpha)`.
# Note
- To be used together with the initialization "lecun_normal".
@@ -79,16 +101,44 @@ def selu(x):
@tf_export('keras.activations.softplus')
def softplus(x):
+ """Softplus activation function.
+
+ Arguments:
+ x: Input tensor.
+
+ Returns:
+ The softplus activation: `log(exp(x) + 1)`.
+ """
return nn.softplus(x)
@tf_export('keras.activations.softsign')
def softsign(x):
+ """Softsign activation function.
+
+ Arguments:
+ x: Input tensor.
+
+ Returns:
+ The softplus activation: `x / (abs(x) + 1)`.
+ """
return nn.softsign(x)
@tf_export('keras.activations.relu')
def relu(x, alpha=0., max_value=None):
+ """Rectified Linear Unit.
+
+ Arguments:
+ x: Input tensor.
+ alpha: Slope of the negative part. Defaults to zero.
+ max_value: Maximum value for the output.
+
+ Returns:
+ The (leaky) rectified linear unit activation: `x` if `x > 0`,
+ `alpha * x` if `x < 0`. If `max_value` is defined, the result
+ is truncated to this value.
+ """
return K.relu(x, alpha=alpha, max_value=max_value)
@@ -104,6 +154,19 @@ def sigmoid(x):
@tf_export('keras.activations.hard_sigmoid')
def hard_sigmoid(x):
+ """Hard sigmoid activation function.
+
+ Faster to compute than sigmoid activation.
+
+ Arguments:
+ x: Input tensor.
+
+ Returns:
+ Hard sigmoid activation:
+ - `0` if `x < -2.5`
+ - `1` if `x > 2.5`
+ - `0.2 * x + 0.5` if `-2.5 <= x <= 2.5`.
+ """
return K.hard_sigmoid(x)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 2a4a1c861c..84821918bf 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -2973,30 +2973,29 @@ def rnn(step_function,
Arguments:
step_function: RNN step function.
- Parameters;
- input; tensor with shape `(samples, ...)` (no time dimension),
+ Args;
+ input; Tensor with shape `(samples, ...)` (no time dimension),
representing input for the batch of samples at a certain
time step.
- states; list of tensors.
+ states; List of tensors.
Returns;
- output; tensor with shape `(samples, output_dim)`
+ output; Tensor with shape `(samples, output_dim)`
(no time dimension).
- new_states; list of tensors, same length and shapes
+ new_states; List of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
- inputs: tensor of temporal data of shape `(samples, time, ...)`
+ inputs: Tensor of temporal data of shape `(samples, time, ...)`
(at least 3D).
- initial_states: tensor with shape (samples, output_dim)
+ initial_states: Tensor with shape `(samples, output_dim)`
(no time dimension),
containing the initial values for the states used in
the step function.
- go_backwards: boolean. If True, do the iteration over the time
+ go_backwards: Boolean. If True, do the iteration over the time
dimension in reverse order and return the reversed sequence.
- mask: binary tensor with shape `(samples, time, 1)`,
+ mask: Binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
- constants: a list of constant values passed at each step.
- unroll: whether to unroll the RNN or to use a symbolic loop
- (`while_loop` or `scan` depending on backend).
+ constants: List of constant values passed at each step.
+ unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: If specified, assume time dimension is of this length.
Returns:
@@ -3637,12 +3636,12 @@ def _preprocess_conv1d_input(x, data_format):
Returns:
A tensor.
"""
- tf_data_format = 'NHWC' # to pass TF Conv2dNative operations
+ tf_data_format = 'NWC' # to pass TF Conv2dNative operations
if data_format == 'channels_first':
if not _has_nchw_support():
x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC
else:
- tf_data_format = 'NCHW'
+ tf_data_format = 'NCW'
return x, tf_data_format
@@ -3741,10 +3740,8 @@ def conv1d(x,
x = temporal_padding(x, (left_pad, 0))
padding = 'valid'
padding = _preprocess_padding(padding)
- if data_format == 'channels_last':
- tf_data_format = 'NWC'
- else:
- tf_data_format = 'NCW'
+
+ x, tf_data_format = _preprocess_conv1d_input(x, data_format)
x = nn.convolution(
input=x,
filter=kernel,
@@ -3752,6 +3749,8 @@ def conv1d(x,
strides=(strides,),
padding=padding,
data_format=tf_data_format)
+ if data_format == 'channels_first' and tf_data_format == 'NWC':
+ x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
@@ -3892,11 +3891,16 @@ def separable_conv1d(x,
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
+ if isinstance(strides, int):
+ strides = (strides,)
+ if isinstance(dilation_rate, int):
+ dilation_rate = (dilation_rate,)
+
x, tf_data_format = _preprocess_conv1d_input(x, data_format)
padding = _preprocess_padding(padding)
if not isinstance(strides, tuple):
strides = tuple(strides)
- if tf_data_format == 'NHWC':
+ if tf_data_format == 'NWC':
spatial_start_dim = 1
strides = (1,) + strides * 2 + (1,)
else:
@@ -3918,7 +3922,7 @@ def separable_conv1d(x,
x = array_ops.squeeze(x, [spatial_start_dim])
- if data_format == 'channels_first' and tf_data_format == 'NHWC':
+ if data_format == 'channels_first' and tf_data_format == 'NWC':
x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
@@ -4717,8 +4721,13 @@ def foldr(fn, elems, initializer=None, name=None):
# Load Keras default configuration from config file if present.
-_keras_base_dir = os.path.expanduser('~')
-_keras_dir = os.path.join(_keras_base_dir, '.keras')
+# Set Keras base dir path given KERAS_HOME env variable, if applicable.
+# Otherwise either ~/.keras or /tmp.
+if 'KERAS_HOME' in os.environ:
+ _keras_dir = os.environ.get('KERAS_HOME')
+else:
+ _keras_base_dir = os.path.expanduser('~')
+ _keras_dir = os.path.join(_keras_base_dir, '.keras')
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
try:
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 8061d47295..70b6a8431a 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -635,7 +635,11 @@ class LearningRateScheduler(Callback):
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
- lr = self.schedule(epoch)
+ try: # new API
+ lr = float(K.get_value(self.model.optimizer.lr))
+ lr = self.schedule(epoch, lr)
+ except TypeError: # Support for old API for backward compatibility
+ lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index ad5f416b22..b355f4a269 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -321,8 +321,26 @@ class KerasCallbacksTest(test.TestCase):
callbacks=cbks,
epochs=5,
verbose=0)
- assert (float(keras.backend.get_value(model.optimizer.lr)) - 0.2
- ) < keras.backend.epsilon()
+ assert (
+ float(keras.backend.get_value(
+ model.optimizer.lr)) - 0.2) < keras.backend.epsilon()
+
+ cbks = [keras.callbacks.LearningRateScheduler(lambda x, lr: lr / 2)]
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=2,
+ verbose=0)
+ assert (
+ float(keras.backend.get_value(
+ model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon()
def test_ReduceLROnPlateau(self):
with self.test_session():
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index 93f4f1bd1d..281ad9bd50 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -185,6 +185,7 @@ def fit_loop(model,
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
if steps_per_epoch is not None:
+ # Step-wise fit loop.
for step_index in range(steps_per_epoch):
batch_logs = {}
batch_logs['batch'] = step_index
@@ -215,7 +216,6 @@ def fit_loop(model,
val_inputs,
val_targets,
sample_weights=val_sample_weights,
- batch_size=batch_size,
steps=validation_steps,
verbose=0)
if not isinstance(val_outs, list):
@@ -224,6 +224,7 @@ def fit_loop(model,
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
else:
+ # Sample-wise fit loop.
if shuffle == 'batch':
index_array = training_utils.batch_shuffle(index_array, batch_size)
elif shuffle:
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index 9ea341139e..720b386c4d 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -382,11 +382,11 @@ class Conv2D(Conv):
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
+ height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
+ specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@@ -613,11 +613,11 @@ class Conv2DTranspose(Conv2D):
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
+ height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
+ specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@@ -1452,11 +1452,11 @@ class SeparableConv2D(SeparableConv):
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
+ height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
+ specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@@ -1596,11 +1596,11 @@ class DepthwiseConv2D(Conv2D):
Arguments:
kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
+ height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
+ specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@@ -2007,7 +2007,7 @@ class ZeroPadding2D(Layer):
Arguments:
padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
- If int: the same symmetric padding
- is applied to width and height.
+ is applied to height and width.
- If tuple of 2 ints:
interpreted as two different
symmetric padding values for height and width:
@@ -2106,7 +2106,7 @@ class ZeroPadding3D(Layer):
Arguments:
padding: int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
- If int: the same symmetric padding
- is applied to width and height.
+ is applied to height and width.
- If tuple of 3 ints:
interpreted as two different
symmetric padding values for height and width:
@@ -2266,12 +2266,12 @@ class Cropping1D(Layer):
class Cropping2D(Layer):
"""Cropping layer for 2D input (e.g. picture).
- It crops along spatial dimensions, i.e. width and height.
+ It crops along spatial dimensions, i.e. height and width.
Arguments:
cropping: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
- If int: the same symmetric cropping
- is applied to width and height.
+ is applied to height and width.
- If tuple of 2 ints:
interpreted as two different
symmetric cropping values for height and width:
diff --git a/tensorflow/python/keras/layers/merge.py b/tensorflow/python/keras/layers/merge.py
index 683e3e0ed1..770665c5fb 100644
--- a/tensorflow/python/keras/layers/merge.py
+++ b/tensorflow/python/keras/layers/merge.py
@@ -446,8 +446,8 @@ class Concatenate(_Merge):
class Dot(_Merge):
"""Layer that computes a dot product between samples in two tensors.
- E.g. if applied to two tensors `a` and `b` of shape `(batch_size, n)`,
- the output will be a tensor of shape `(batch_size, 1)`
+ E.g. if applied to a list of two tensors `a` and `b` of shape
+ `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)`
where each entry `i` will be the dot product between
`a[i]` and `b[i]`.
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py
index a1f89d9d43..c1ee34ae46 100644
--- a/tensorflow/python/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/utils/data_utils.py
@@ -324,12 +324,12 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
class Sequence(object):
"""Base object for fitting to a sequence of data, such as a dataset.
- Every `Sequence` must implements the `__getitem__` and the `__len__` methods.
+ Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
If you want to modify your dataset between epochs you may implement
`on_epoch_end`.
The method `__getitem__` should return a complete batch.
- # Notes
+ Notes:
`Sequence` are a safer way to do multiprocessing. This structure guarantees
that the network will only train once
diff --git a/tensorflow/python/keras/utils/io_utils.py b/tensorflow/python/keras/utils/io_utils.py
index f82e3277de..62674a9c77 100644
--- a/tensorflow/python/keras/utils/io_utils.py
+++ b/tensorflow/python/keras/utils/io_utils.py
@@ -102,13 +102,12 @@ class HDF5Matrix(object):
idx = (self.start + key).tolist()
else:
raise IndexError
- elif isinstance(key, list):
+ else:
+ # Assume list/iterable
if max(key) + self.start < self.end:
idx = [x + self.start for x in key]
else:
raise IndexError
- else:
- raise IndexError
if self.normalizer is not None:
return self.normalizer(self.data[idx])
else:
diff --git a/tensorflow/python/keras/utils/io_utils_test.py b/tensorflow/python/keras/utils/io_utils_test.py
index 3895dca68e..81bb661edd 100644
--- a/tensorflow/python/keras/utils/io_utils_test.py
+++ b/tensorflow/python/keras/utils/io_utils_test.py
@@ -22,6 +22,7 @@ import os
import shutil
import numpy as np
+import six
from tensorflow.python import keras
from tensorflow.python.platform import test
@@ -95,6 +96,29 @@ class TestIOUtils(test.TestCase):
self.assertEqual(out_eval.shape, ())
self.assertGreater(out_eval, 0)
+ # test slicing for shortened array
+ self.assertEqual(len(x_train[0:]), len(x_train))
+
+ # test __getitem__ invalid use cases
+ with self.assertRaises(IndexError):
+ _ = x_train[1000]
+ with self.assertRaises(IndexError):
+ _ = x_train[1000: 1001]
+ with self.assertRaises(IndexError):
+ _ = x_train[[1000, 1001]]
+ with self.assertRaises(IndexError):
+ _ = x_train[six.moves.range(1000, 1001)]
+ with self.assertRaises(IndexError):
+ _ = x_train[np.array([1000])]
+ with self.assertRaises(TypeError):
+ _ = x_train[None]
+
+ # test normalizer
+ normalizer = lambda x: x + 1
+ normalized_x_train = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_data', start=0, end=150, normalizer=normalizer)
+ self.assertAllClose(normalized_x_train[0][0], x_train[0][0] + 1)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py
index e5442f04e3..e1c49bc852 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils.py
@@ -196,7 +196,7 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
batch_size = shape[:1]
input_shape = shape[1:]
step = batch_size // parts
- if i == num_gpus - 1:
+ if i == parts - 1:
size = batch_size - step * i
else:
size = step
diff --git a/tensorflow/python/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py
index 8007df4622..7a454ac831 100644
--- a/tensorflow/python/keras/utils/vis_utils.py
+++ b/tensorflow/python/keras/utils/vis_utils.py
@@ -77,7 +77,6 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
if isinstance(model, Sequential):
if not model.built:
model.build()
- model = model.model
layers = model.layers
# Create graph nodes.