aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2017-05-31 10:33:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-31 10:38:01 -0700
commitd21bf7d7502f447e5f967a479282b32b5845ba8b (patch)
tree455de8814f480728c77d919a8d2a170e2a47b300 /tensorflow/contrib/keras
parent43bfc138c9676fb54945fbede977b90a0c0aea79 (diff)
Backport changes from Github master.
PiperOrigin-RevId: 157603238
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r--tensorflow/contrib/keras/BUILD1
-rw-r--r--tensorflow/contrib/keras/python/keras/__init__.py4
-rw-r--r--tensorflow/contrib/keras/python/keras/activations.py8
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/inception_v3.py15
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/resnet50.py25
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/vgg16.py14
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/vgg19.py14
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/xception.py5
-rw-r--r--tensorflow/contrib/keras/python/keras/backend.py229
-rw-r--r--tensorflow/contrib/keras/python/keras/callbacks.py198
-rw-r--r--tensorflow/contrib/keras/python/keras/callbacks_test.py33
-rw-r--r--tensorflow/contrib/keras/python/keras/datasets/imdb.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/datasets/reuters.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology.py207
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/training.py119
-rw-r--r--tensorflow/contrib/keras/python/keras/integration_test.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py30
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/lstm_test.py103
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/merge.py3
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/recurrent.py187
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/wrappers.py16
-rw-r--r--tensorflow/contrib/keras/python/keras/losses.py14
-rw-r--r--tensorflow/contrib/keras/python/keras/losses_test.py12
-rw-r--r--tensorflow/contrib/keras/python/keras/metrics.py1
-rw-r--r--tensorflow/contrib/keras/python/keras/models.py48
-rw-r--r--tensorflow/contrib/keras/python/keras/preprocessing/image.py59
-rw-r--r--tensorflow/contrib/keras/python/keras/preprocessing/sequence.py3
-rw-r--r--tensorflow/contrib/keras/python/keras/preprocessing/text.py22
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/conv_utils.py5
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/data_utils.py28
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/generic_utils.py65
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/io_utils.py9
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/layer_utils.py58
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/vis_utils.py36
34 files changed, 1027 insertions, 550 deletions
diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD
index b38bcd1e8f..f7f56f6fcf 100644
--- a/tensorflow/contrib/keras/BUILD
+++ b/tensorflow/contrib/keras/BUILD
@@ -123,6 +123,7 @@ py_library(
"//tensorflow/python:logging_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
+ "//tensorflow/python:platform",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:summary",
diff --git a/tensorflow/contrib/keras/python/keras/__init__.py b/tensorflow/contrib/keras/python/keras/__init__.py
index ec316253db..1c1485c0cd 100644
--- a/tensorflow/contrib/keras/python/keras/__init__.py
+++ b/tensorflow/contrib/keras/python/keras/__init__.py
@@ -35,6 +35,6 @@ from tensorflow.contrib.keras.python.keras import preprocessing
from tensorflow.contrib.keras.python.keras import regularizers
from tensorflow.contrib.keras.python.keras import utils
from tensorflow.contrib.keras.python.keras import wrappers
+from tensorflow.contrib.keras.python.keras.layers import Input
-
-__version__ = '2.0.2-tf'
+__version__ = '2.0.4-tf'
diff --git a/tensorflow/contrib/keras/python/keras/activations.py b/tensorflow/contrib/keras/python/keras/activations.py
index 67762c83ba..35d15e74c2 100644
--- a/tensorflow/contrib/keras/python/keras/activations.py
+++ b/tensorflow/contrib/keras/python/keras/activations.py
@@ -21,7 +21,9 @@ from __future__ import print_function
import six
from tensorflow.contrib.keras.python.keras import backend as K
+from tensorflow.contrib.keras.python.keras.engine import Layer
from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.platform import tf_logging as logging
def softmax(x, axis=-1):
@@ -99,6 +101,12 @@ def get(identifier):
identifier = str(identifier)
return deserialize(identifier)
elif callable(identifier):
+ if isinstance(identifier, Layer):
+ logging.warning(
+ 'Do not pass a layer instance (such as {identifier}) as the '
+ 'activation argument of another layer. Instead, advanced '
+ 'activation layers should be used just like any other '
+ 'layer in a model.'.format(identifier=identifier.__class__.__name__))
return identifier
else:
raise ValueError('Could not interpret '
diff --git a/tensorflow/contrib/keras/python/keras/applications/inception_v3.py b/tensorflow/contrib/keras/python/keras/applications/inception_v3.py
index 3fc16c88ca..f77e4a8341 100644
--- a/tensorflow/contrib/keras/python/keras/applications/inception_v3.py
+++ b/tensorflow/contrib/keras/python/keras/applications/inception_v3.py
@@ -29,8 +29,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
-
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras import layers
from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
@@ -47,7 +45,6 @@ from tensorflow.contrib.keras.python.keras.layers import Input
from tensorflow.contrib.keras.python.keras.layers import MaxPooling2D
from tensorflow.contrib.keras.python.keras.models import Model
from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
-from tensorflow.contrib.keras.python.keras.utils.layer_utils import convert_all_kernels_in_model
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.5/inception_v3_weights_tf_dim_ordering_tf_kernels.h5'
@@ -371,16 +368,6 @@ def InceptionV3(include_top=True,
# load weights
if weights == 'imagenet':
- if K.image_data_format() == 'channels_first':
- if K.backend() == 'tensorflow':
- warnings.warn('You are using the TensorFlow backend, yet you '
- 'are using the Theano '
- 'image data format convention '
- '(`image_data_format="channels_first"`). '
- 'For best performance, set '
- '`image_data_format="channels_last"` in '
- 'your Keras config '
- 'at ~/.keras/keras.json.')
if include_top:
weights_path = get_file(
'inception_v3_weights_tf_dim_ordering_tf_kernels.h5',
@@ -394,8 +381,6 @@ def InceptionV3(include_top=True,
cache_subdir='models',
md5_hash='bcbd6486424b2319ff4ef7d526e38f63')
model.load_weights(weights_path)
- if K.backend() == 'theano':
- convert_all_kernels_in_model(model)
return model
diff --git a/tensorflow/contrib/keras/python/keras/applications/resnet50.py b/tensorflow/contrib/keras/python/keras/applications/resnet50.py
index 12f7ca424e..ce7d0bb046 100644
--- a/tensorflow/contrib/keras/python/keras/applications/resnet50.py
+++ b/tensorflow/contrib/keras/python/keras/applications/resnet50.py
@@ -26,8 +26,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
-
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras import layers
from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
@@ -46,7 +44,6 @@ from tensorflow.contrib.keras.python.keras.layers import Input
from tensorflow.contrib.keras.python.keras.layers import MaxPooling2D
from tensorflow.contrib.keras.python.keras.layers import ZeroPadding2D
from tensorflow.contrib.keras.python.keras.models import Model
-from tensorflow.contrib.keras.python.keras.utils import layer_utils
from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
@@ -172,7 +169,7 @@ def ResNet50(include_top=True,
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
- or `(3, 224, 244)` (with `channels_first` data format).
+ or `(3, 224, 224)` (with `channels_first` data format).
It should have exactly 3 inputs channels,
and width and height should be no smaller than 197.
E.g. `(200, 200, 3)` would be one valid value.
@@ -286,24 +283,4 @@ def ResNet50(include_top=True,
cache_subdir='models',
md5_hash='a268eb855778b3df3c7506639542a6af')
model.load_weights(weights_path)
- if K.backend() == 'theano':
- layer_utils.convert_all_kernels_in_model(model)
-
- if K.image_data_format() == 'channels_first':
- if include_top:
- maxpool = model.get_layer(name='avg_pool')
- shape = maxpool.output_shape[1:]
- dense = model.get_layer(name='fc1000')
- layer_utils.convert_dense_weights_data_format(dense, shape,
- 'channels_first')
-
- if K.backend() == 'tensorflow':
- warnings.warn('You are using the TensorFlow backend, yet you '
- 'are using the Theano '
- 'image data format convention '
- '(`image_data_format="channels_first"`). '
- 'For best performance, set '
- '`image_data_format="channels_last"` in '
- 'your Keras config '
- 'at ~/.keras/keras.json.')
return model
diff --git a/tensorflow/contrib/keras/python/keras/applications/vgg16.py b/tensorflow/contrib/keras/python/keras/applications/vgg16.py
index 7fc393055f..89bbb040e6 100644
--- a/tensorflow/contrib/keras/python/keras/applications/vgg16.py
+++ b/tensorflow/contrib/keras/python/keras/applications/vgg16.py
@@ -25,8 +25,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
-
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
@@ -77,7 +75,7 @@ def VGG16(include_top=True,
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
- or `(3, 224, 244)` (with `channels_first` data format).
+ or `(3, 224, 224)` (with `channels_first` data format).
It should have exactly 3 inputs channels,
and width and height should be no smaller than 48.
E.g. `(200, 200, 3)` would be one valid value.
@@ -210,14 +208,4 @@ def VGG16(include_top=True,
dense = model.get_layer(name='fc1')
layer_utils.convert_dense_weights_data_format(dense, shape,
'channels_first')
-
- if K.backend() == 'tensorflow':
- warnings.warn('You are using the TensorFlow backend, yet you '
- 'are using the Theano '
- 'image data format convention '
- '(`image_data_format="channels_first"`). '
- 'For best performance, set '
- '`image_data_format="channels_last"` in '
- 'your Keras config '
- 'at ~/.keras/keras.json.')
return model
diff --git a/tensorflow/contrib/keras/python/keras/applications/vgg19.py b/tensorflow/contrib/keras/python/keras/applications/vgg19.py
index f7c2921b5c..522a516ecf 100644
--- a/tensorflow/contrib/keras/python/keras/applications/vgg19.py
+++ b/tensorflow/contrib/keras/python/keras/applications/vgg19.py
@@ -25,8 +25,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
-
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
@@ -77,7 +75,7 @@ def VGG19(include_top=True,
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
- or `(3, 224, 244)` (with `channels_first` data format).
+ or `(3, 224, 224)` (with `channels_first` data format).
It should have exactly 3 inputs channels,
and width and height should be no smaller than 48.
E.g. `(200, 200, 3)` would be one valid value.
@@ -216,14 +214,4 @@ def VGG19(include_top=True,
dense = model.get_layer(name='fc1')
layer_utils.convert_dense_weights_data_format(dense, shape,
'channels_first')
-
- if K.backend() == 'tensorflow':
- warnings.warn('You are using the TensorFlow backend, yet you '
- 'are using the Theano '
- 'image data format convention '
- '(`image_data_format="channels_first"`). '
- 'For best performance, set '
- '`image_data_format="channels_last"` in '
- 'your Keras config '
- 'at ~/.keras/keras.json.')
return model
diff --git a/tensorflow/contrib/keras/python/keras/applications/xception.py b/tensorflow/contrib/keras/python/keras/applications/xception.py
index 3b08e73514..49fb6008f6 100644
--- a/tensorflow/contrib/keras/python/keras/applications/xception.py
+++ b/tensorflow/contrib/keras/python/keras/applications/xception.py
@@ -36,8 +36,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
-
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras import layers
from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
@@ -54,6 +52,7 @@ from tensorflow.contrib.keras.python.keras.layers import MaxPooling2D
from tensorflow.contrib.keras.python.keras.layers import SeparableConv2D
from tensorflow.contrib.keras.python.keras.models import Model
from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels.h5'
@@ -127,7 +126,7 @@ def Xception(include_top=True,
raise RuntimeError('The Xception model is only available with '
'the TensorFlow backend.')
if K.image_data_format() != 'channels_last':
- warnings.warn(
+ logging.warning(
'The Xception model is only available for the '
'input data format "channels_last" '
'(width, height, channels). '
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py
index ed2b251b31..84d0dacce9 100644
--- a/tensorflow/contrib/keras/python/keras/backend.py
+++ b/tensorflow/contrib/keras/python/keras/backend.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import json
import os
-import warnings
import numpy as np
@@ -52,6 +51,7 @@ from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-im
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.training import moving_averages
+from tensorflow.python.util import tf_inspect
py_all = all
@@ -386,6 +386,17 @@ def set_session(session):
def _convert_string_dtype(dtype):
+ """Get the type from a string.
+
+ Arguments:
+ dtype: A string representation of a type.
+
+ Returns:
+ The type requested.
+
+ Raises:
+ ValueError: if `dtype` is not supported.
+ """
if dtype == 'float16':
return dtypes_module.float16
if dtype == 'float32':
@@ -407,6 +418,15 @@ def _convert_string_dtype(dtype):
def _to_tensor(x, dtype):
+ """Convert the input `x` to a tensor of type `dtype`.
+
+ Arguments:
+ x: An object to be converted (numpy array, list, tensors).
+ dtype: The destination type.
+
+ Returns:
+ A tensor.
+ """
x = ops.convert_to_tensor(x)
if x.dtype != dtype:
x = math_ops.cast(x, dtype)
@@ -521,6 +541,17 @@ def _initialize_variables():
def constant(value, dtype=None, shape=None, name=None):
+ """Creates a constant tensor.
+
+ Arguments:
+ value: A constant value (or list)
+ dtype: The type of the elements of the resulting tensor.
+ shape: Optional dimensions of resulting tensor.
+ name: Optional name for the tensor.
+
+ Returns:
+ A Constant Tensor.
+ """
if dtype is None:
dtype = floatx()
return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
@@ -833,6 +864,18 @@ def ones_like(x, dtype=None, name=None):
return array_ops.ones_like(x, dtype=dtype, name=name)
+def identity(x):
+ """Returns a tensor with the same content as the input tensor.
+
+ Arguments:
+ x: The input tensor.
+
+ Returns:
+ A tensor of the same shape, type and content.
+ """
+ return array_ops.identity(x)
+
+
def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
"""Instantiates a variable with values drawn from a uniform distribution.
@@ -971,14 +1014,42 @@ def update(x, new_x):
def update_add(x, increment):
+ """Update the value of `x` by adding `increment`.
+
+ Arguments:
+ x: A Variable.
+ increment: A tensor of same shape as `x`.
+
+ Returns:
+ The variable `x` updated.
+ """
return state_ops.assign_add(x, increment)
def update_sub(x, decrement):
+ """Update the value of `x` by subtracting `decrement`.
+
+ Arguments:
+ x: A Variable.
+ decrement: A tensor of same shape as `x`.
+
+ Returns:
+ The variable `x` updated.
+ """
return state_ops.assign_sub(x, decrement)
def moving_average_update(x, value, momentum):
+ """Compute the moving average of a variable.
+
+ Arguments:
+ x: A Variable.
+ value: A tensor with the same shape as `variable`.
+ momentum: The moving average momentum.
+
+ Returns:
+ An Operation to update the variable.
+ """
return moving_averages.assign_moving_average(
x, value, momentum, zero_debias=False)
@@ -1110,6 +1181,20 @@ def batch_dot(x, y, axes=None):
"""
if isinstance(axes, int):
axes = (axes, axes)
+ x_ndim = ndim(x)
+ y_ndim = ndim(y)
+ if x_ndim > y_ndim:
+ diff = x_ndim - y_ndim
+ y = array_ops.reshape(y,
+ array_ops.concat(
+ [array_ops.shape(y), [1] * (diff)], axis=0))
+ elif y_ndim > x_ndim:
+ diff = y_ndim - x_ndim
+ x = array_ops.reshape(x,
+ array_ops.concat(
+ [array_ops.shape(x), [1] * (diff)], axis=0))
+ else:
+ diff = 0
if ndim(x) == 2 and ndim(y) == 2:
if axes[0] == axes[1]:
out = math_ops.reduce_sum(math_ops.multiply(x, y), axes[0])
@@ -1124,6 +1209,12 @@ def batch_dot(x, y, axes=None):
adj_x = None
adj_y = None
out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
+ if diff:
+ if x_ndim > y_ndim:
+ idx = x_ndim + y_ndim - 3
+ else:
+ idx = x_ndim - 1
+ out = array_ops.squeeze(out, list(range(idx, idx + diff)))
if ndim(out) == 1:
out = expand_dims(out, 1)
return out
@@ -1485,6 +1576,28 @@ def log(x):
return math_ops.log(x)
+def logsumexp(x, axis=None, keepdims=False):
+ """Computes log(sum(exp(elements across dimensions of a tensor))).
+
+ This function is more numerically stable than log(sum(exp(x))).
+ It avoids overflows caused by taking the exp of large inputs and
+ underflows caused by taking the log of small inputs.
+
+ Arguments:
+ x: A tensor or variable.
+ axis: An integer, the axis to reduce over.
+ keepdims: A boolean, whether to keep the dimensions or not.
+ If `keepdims` is `False`, the rank of the tensor is reduced
+ by 1. If `keepdims` is `True`, the reduced dimension is
+ retained with length 1.
+
+ Returns:
+ The reduced tensor.
+ """
+ axis = _normalize_axis(axis, ndim(x))
+ return math_ops.reduce_logsumexp(x, axis=axis, keep_dims=keepdims)
+
+
def round(x):
"""Element-wise rounding to the closest integer.
@@ -1986,14 +2099,14 @@ def batch_flatten(x):
def expand_dims(x, axis=-1):
- """Adds a 1-sized dimension at index "dim".
+ """Adds a 1-sized dimension at index "axis".
Arguments:
x: A tensor or variable.
axis: Position where to add a new axis.
Returns:
- A tensor with expended dimensions.
+ A tensor with expanded dimensions.
"""
return array_ops.expand_dims(x, axis)
@@ -2247,9 +2360,11 @@ class Function(object):
inputs: Feed placeholders to the computation graph.
outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
+ name: a name to help users identify what this function does.
"""
- def __init__(self, inputs, outputs, updates=None):
+ def __init__(self, inputs, outputs, updates=None, name=None,
+ **session_kwargs):
updates = updates or []
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` to a TensorFlow backend function '
@@ -2272,6 +2387,8 @@ class Function(object):
# assumed already an op
updates_ops.append(update)
self.updates_op = control_flow_ops.group(*updates_ops)
+ self.name = name
+ self.session_kwargs = session_kwargs
def __call__(self, inputs):
if not isinstance(inputs, (list, tuple)):
@@ -2285,7 +2402,10 @@ class Function(object):
value = (indices, sparse_coo.data, sparse_coo.shape)
feed_dict[tensor] = value
session = get_session()
- updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict)
+ updated = session.run(
+ self.outputs + [self.updates_op],
+ feed_dict=feed_dict,
+ **self.session_kwargs)
return updated[:len(self.outputs)]
@@ -2296,18 +2416,22 @@ def function(inputs, outputs, updates=None, **kwargs):
inputs: List of placeholder tensors.
outputs: List of output tensors.
updates: List of update ops.
- **kwargs: Not used with TensorFlow.
+ **kwargs: Passed to `tf.Session.run`.
Returns:
Output values as Numpy arrays.
+
+ Raises:
+ ValueError: if invalid kwargs are passed in.
"""
if kwargs:
- msg = [
- 'Expected no kwargs, you passed %s' % len(kwargs),
- 'kwargs passed to function are ignored with Tensorflow backend'
- ]
- warnings.warn('\n'.join(msg))
- return Function(inputs, outputs, updates=updates)
+ for key in kwargs:
+ if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and
+ key not in tf_inspect.getargspec(Function.__init__)[0]):
+ msg = ('Invalid argument "%s" passed to K.function with Tensorflow '
+ 'backend') % key
+ raise ValueError(msg)
+ return Function(inputs, outputs, updates=updates, **kwargs)
def gradients(loss, variables):
@@ -2452,7 +2576,8 @@ def rnn(step_function,
# (see earlier comment for tile explanation)
tiled_mask_t = array_ops.tile(mask_t,
array_ops.stack(
- [1, array_ops.shape(new_state)[1]]))
+ [1,
+ array_ops.shape(new_state)[1]]))
return_states.append(array_ops.where(tiled_mask_t, new_state, state))
states = return_states
successive_outputs.append(output)
@@ -2931,6 +3056,16 @@ def in_top_k(predictions, targets, k):
def _preprocess_deconv_output_shape(x, shape, data_format):
+ """Get the output_shape for the deconvolution.
+
+ Arguments:
+ x: input tensor.
+ shape: output shape.
+ data_format: string, one of 'channels_last', 'channels_first'.
+
+ Returns:
+ The output shape.
+ """
if data_format == 'channels_first':
shape = (shape[0], shape[2], shape[3], shape[1])
@@ -2941,6 +3076,15 @@ def _preprocess_deconv_output_shape(x, shape, data_format):
def _preprocess_conv2d_input(x, data_format):
+ """Transpose and cast the input before the conv2d.
+
+ Arguments:
+ x: input tensor.
+ data_format: string, one of 'channels_last', 'channels_first'.
+
+ Returns:
+ A tensor.
+ """
if dtype(x) == 'float64':
x = math_ops.cast(x, 'float32')
if data_format == 'channels_first':
@@ -2953,6 +3097,15 @@ def _preprocess_conv2d_input(x, data_format):
def _preprocess_conv3d_input(x, data_format):
+ """Transpose and cast the input before the conv3d.
+
+ Arguments:
+ x: input tensor.
+ data_format: string, one of 'channels_last', 'channels_first'.
+
+ Returns:
+ A tensor.
+ """
if dtype(x) == 'float64':
x = math_ops.cast(x, 'float32')
if data_format == 'channels_first':
@@ -2961,6 +3114,15 @@ def _preprocess_conv3d_input(x, data_format):
def _preprocess_conv2d_kernel(kernel, data_format):
+ """Transpose and cast the kernel before the conv2d.
+
+ Arguments:
+ kernel: kernel tensor.
+ data_format: string, one of 'channels_last', 'channels_first'.
+
+ Returns:
+ A tensor.
+ """
if dtype(kernel) == 'float64':
kernel = math_ops.cast(kernel, 'float32')
if data_format == 'channels_first':
@@ -2969,6 +3131,15 @@ def _preprocess_conv2d_kernel(kernel, data_format):
def _preprocess_conv3d_kernel(kernel, data_format):
+ """Transpose and cast the kernel before the conv3d.
+
+ Arguments:
+ kernel: kernel tensor.
+ data_format: string, one of 'channels_last', 'channels_first'.
+
+ Returns:
+ A tensor.
+ """
if dtype(kernel) == 'float64':
kernel = math_ops.cast(kernel, 'float32')
if data_format == 'channels_first':
@@ -2977,16 +3148,37 @@ def _preprocess_conv3d_kernel(kernel, data_format):
def _preprocess_padding(padding):
+ """Convert keras' padding to tensorflow's padding.
+
+ Arguments:
+ padding: string, one of 'same' , 'valid'
+
+ Returns:
+ a string, one of 'SAME', 'VALID'.
+
+ Raises:
+ ValueError: if invalid `padding'`
+ """
if padding == 'same':
padding = 'SAME'
elif padding == 'valid':
padding = 'VALID'
else:
- raise ValueError('Invalid border mode:', padding)
+ raise ValueError('Invalid padding:', padding)
return padding
def _postprocess_conv2d_output(x, data_format):
+ """Transpose and cast the output from conv2d if needed.
+
+ Arguments:
+ x: A tensor.
+ data_format: string, one of "channels_last", "channels_first".
+
+ Returns:
+ A tensor.
+ """
+
if data_format == 'channels_first':
x = array_ops.transpose(x, (0, 3, 1, 2))
@@ -2996,6 +3188,15 @@ def _postprocess_conv2d_output(x, data_format):
def _postprocess_conv3d_output(x, data_format):
+ """Transpose and cast the output from conv3d if needed.
+
+ Arguments:
+ x: A tensor.
+ data_format: string, one of "channels_last", "channels_first".
+
+ Returns:
+ A tensor.
+ """
if data_format == 'channels_first':
x = array_ops.transpose(x, (0, 4, 1, 2, 3))
diff --git a/tensorflow/contrib/keras/python/keras/callbacks.py b/tensorflow/contrib/keras/python/keras/callbacks.py
index a533e0fbda..d0587a549b 100644
--- a/tensorflow/contrib/keras/python/keras/callbacks.py
+++ b/tensorflow/contrib/keras/python/keras/callbacks.py
@@ -25,14 +25,15 @@ import csv
import json
import os
import time
-import warnings
import numpy as np
+import six
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.utils.generic_utils import Progbar
from tensorflow.contrib.tensorboard.plugins import projector
from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
from tensorflow.python.training import saver as saver_lib
@@ -110,7 +111,7 @@ class CallbackList(object):
delta_t_median = np.median(self._delta_ts_batch_begin)
if (self._delta_t_batch > 0. and
delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1):
- warnings.warn(
+ logging.warning(
'Method on_batch_begin() is slow compared '
'to the batch update (%f). Check your callbacks.' % delta_t_median)
self._t_enter_batch = time.time()
@@ -133,7 +134,7 @@ class CallbackList(object):
delta_t_median = np.median(self._delta_ts_batch_end)
if (self._delta_t_batch > 0. and
(delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
- warnings.warn(
+ logging.warning(
'Method on_batch_end() is slow compared '
'to the batch update (%f). Check your callbacks.' % delta_t_median)
@@ -245,6 +246,21 @@ class BaseLogger(Callback):
logs[k] = self.totals[k] / self.seen
+class TerminateOnNaN(Callback):
+ """Callback that terminates training when a NaN loss is encountered."""
+
+ def __init__(self):
+ super(TerminateOnNaN, self).__init__()
+
+ def on_batch_end(self, batch, logs=None):
+ logs = logs or {}
+ loss = logs.get('loss')
+ if loss is not None:
+ if np.isnan(loss) or np.isinf(loss):
+ print('Batch %d: Invalid loss, terminating training' % (batch))
+ self.model.stop_training = True
+
+
class ProgbarLogger(Callback):
"""Callback that prints metrics to stdout.
@@ -380,8 +396,8 @@ class ModelCheckpoint(Callback):
self.epochs_since_last_save = 0
if mode not in ['auto', 'min', 'max']:
- warnings.warn('ModelCheckpoint mode %s is unknown, '
- 'fallback to auto mode.' % (mode), RuntimeWarning)
+ logging.warning('ModelCheckpoint mode %s is unknown, '
+ 'fallback to auto mode.' % (mode))
mode = 'auto'
if mode == 'min':
@@ -407,8 +423,8 @@ class ModelCheckpoint(Callback):
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
- warnings.warn('Can save best model only with %s available, '
- 'skipping.' % (self.monitor), RuntimeWarning)
+ logging.warning('Can save best model only with %s available, '
+ 'skipping.' % (self.monitor))
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
@@ -469,8 +485,8 @@ class EarlyStopping(Callback):
self.stopped_epoch = 0
if mode not in ['auto', 'min', 'max']:
- warnings.warn('EarlyStopping mode %s is unknown, '
- 'fallback to auto mode.' % (self.mode), RuntimeWarning)
+ logging.warning('EarlyStopping mode %s is unknown, '
+ 'fallback to auto mode.' % (self.mode))
mode = 'auto'
if mode == 'min':
@@ -489,14 +505,15 @@ class EarlyStopping(Callback):
self.min_delta *= -1
def on_train_begin(self, logs=None):
- self.wait = 0 # Allow instances to be re-used
+ # Allow instances to be re-used
+ self.wait = 0
+ self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
if current is None:
- warnings.warn('Early stopping requires %s available!' % (self.monitor),
- RuntimeWarning)
+ logging.warning('Early stopping requires %s available!' % (self.monitor))
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
@@ -526,8 +543,7 @@ class RemoteMonitor(Callback):
field: String; JSON field under which the data will be stored.
headers: Dictionary; optional custom HTTP headers.
Defaults to:
- `{'Accept': 'application/json',
- 'Content-Type': 'application/json'}`
+ `{'Accept': 'application/json', 'Content-Type': 'application/json'}`
"""
def __init__(self,
@@ -559,8 +575,8 @@ class RemoteMonitor(Callback):
self.root + self.path, {self.field: json.dumps(send)},
headers=self.headers)
except requests.exceptions.RequestException:
- warnings.warn('Warning: could not reach RemoteMonitor '
- 'root server at ' + str(self.root))
+ logging.warning('Warning: could not reach RemoteMonitor '
+ 'root server at ' + str(self.root))
class LearningRateScheduler(Callback):
@@ -595,17 +611,34 @@ class TensorBoard(Callback):
metrics, as well as activation histograms for the different
layers in your model.
+ TensorBoard is a visualization tool provided with TensorFlow.
+
+ If you have installed TensorFlow with pip, you should be able
+ to launch TensorBoard from the command line:
+
+ ```
+ tensorboard --logdir=/full_path_to_your_logs
+ ```
+
+ You can find more information about TensorBoard
+ [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
+
Arguments:
log_dir: the path of the directory where to save the log
- files to be parsed by Tensorboard.
+ files to be parsed by TensorBoard.
histogram_freq: frequency (in epochs) at which to compute activation
- histograms for the layers of the model. If set to 0,
- histograms won't be computed.
- write_graph: whether to visualize the graph in Tensorboard.
+ and weight histograms for the layers of the model. If set to 0,
+ histograms won't be computed. Validation data (or split) must be
+ specified for histogram visualizations.
+ write_graph: whether to visualize the graph in TensorBoard.
The log file can become quite large when
write_graph is set to True.
+ write_grads: whether to visualize gradient histograms in TensorBoard.
+ `histogram_freq` must be greater than 0.
+ batch_size: size of batch of inputs to feed to the network
+ for histograms computation.
write_images: whether to write model weights to visualize as
- image in Tensorboard.
+ image in TensorBoard.
embeddings_freq: frequency (in epochs) at which selected embedding
layers will be saved.
embeddings_layer_names: a list of names of layers to keep eye on. If
@@ -622,7 +655,9 @@ class TensorBoard(Callback):
def __init__(self,
log_dir='./logs',
histogram_freq=0,
+ batch_size=32,
write_graph=True,
+ write_grads=False,
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
@@ -632,27 +667,47 @@ class TensorBoard(Callback):
self.histogram_freq = histogram_freq
self.merged = None
self.write_graph = write_graph
+ self.write_grads = write_grads
self.write_images = write_images
self.embeddings_freq = embeddings_freq
self.embeddings_layer_names = embeddings_layer_names
self.embeddings_metadata = embeddings_metadata or {}
+ self.batch_size = batch_size
def set_model(self, model):
self.model = model
self.sess = K.get_session()
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
-
for weight in layer.weights:
tf_summary.histogram(weight.name, weight)
+ if self.write_grads:
+ grads = model.optimizer.get_gradients(model.total_loss, weight)
+ tf_summary.histogram('{}_grad'.format(weight.name), grads)
if self.write_images:
w_img = array_ops.squeeze(weight)
- shape = w_img.get_shape()
- if len(shape) > 1 and shape[0] > shape[1]:
- w_img = array_ops.transpose(w_img)
- if len(shape) == 1:
- w_img = array_ops.expand_dims(w_img, 0)
- w_img = array_ops.expand_dims(array_ops.expand_dims(w_img, 0), -1)
+ shape = K.int_shape(w_img)
+ if len(shape) == 2: # dense layer kernel case
+ if shape[0] > shape[1]:
+ w_img = array_ops.transpose(w_img)
+ shape = K.int_shape(w_img)
+ w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
+ elif len(shape) == 3: # convnet case
+ if K.image_data_format() == 'channels_last':
+ # switch to channels_first to display
+ # every kernel as a separate image
+ w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
+ shape = K.int_shape(w_img)
+ w_img = array_ops.reshape(w_img,
+ [shape[0], shape[1], shape[2], 1])
+ elif len(shape) == 1: # bias case
+ w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
+ else:
+ # not possible to handle 3D convnets etc.
+ continue
+
+ shape = K.int_shape(w_img)
+ assert len(shape) == 4 and shape[-1] in [1, 3, 4]
tf_summary.image(weight.name, w_img)
if hasattr(layer, 'output'):
@@ -665,8 +720,6 @@ class TensorBoard(Callback):
self.writer = tf_summary.FileWriter(self.log_dir)
if self.embeddings_freq:
- self.saver = saver_lib.Saver()
-
embeddings_layer_names = self.embeddings_layer_names
if not embeddings_layer_names:
@@ -680,6 +733,8 @@ class TensorBoard(Callback):
for layer in self.model.layers if layer.name in embeddings_layer_names
}
+ self.saver = saver_lib.Saver(list(embeddings.values()))
+
embeddings_metadata = {}
if not isinstance(self.embeddings_metadata, str):
@@ -691,15 +746,13 @@ class TensorBoard(Callback):
}
config = projector.ProjectorConfig()
- self.embeddings_logs = []
+ self.embeddings_ckpt_path = os.path.join(self.log_dir,
+ 'keras_embedding.ckpt')
for layer_name, tensor in embeddings.items():
embedding = config.embeddings.add()
embedding.tensor_name = tensor.name
- self.embeddings_logs.append(
- os.path.join(self.log_dir, layer_name + '.ckpt'))
-
if layer_name in embeddings_metadata:
embedding.metadata_path = embeddings_metadata[layer_name]
@@ -710,24 +763,34 @@ class TensorBoard(Callback):
if self.validation_data and self.histogram_freq:
if epoch % self.histogram_freq == 0:
- # TODO(fchollet): implement batched calls to sess.run
- # (current call will likely go OOM on GPU)
+
+ val_data = self.validation_data
+ tensors = (
+ self.model.inputs + self.model.targets + self.model.sample_weights)
+
if self.model.uses_learning_phase:
- cut_v_data = len(self.model.inputs)
- val_data = self.validation_data[:cut_v_data] + [0]
- tensors = self.model.inputs + [K.learning_phase()]
- else:
- val_data = self.validation_data
- tensors = self.model.inputs
- feed_dict = dict(zip(tensors, val_data))
- result = self.sess.run([self.merged], feed_dict=feed_dict)
- summary_str = result[0]
- self.writer.add_summary(summary_str, epoch)
-
- if self.embeddings_freq and self.embeddings_logs:
+ tensors += [K.learning_phase()]
+
+ assert len(val_data) == len(tensors)
+ val_size = val_data[0].shape[0]
+ i = 0
+ while i < val_size:
+ step = min(self.batch_size, val_size - i)
+ batch_val = []
+ batch_val.append(val_data[0][i:i + step])
+ batch_val.append(val_data[1][i:i + step])
+ batch_val.append(val_data[2][i:i + step])
+ if self.model.uses_learning_phase:
+ batch_val.append(val_data[3])
+ feed_dict = dict(zip(tensors, batch_val))
+ result = self.sess.run([self.merged], feed_dict=feed_dict)
+ summary_str = result[0]
+ self.writer.add_summary(summary_str, epoch)
+ i += self.batch_size
+
+ if self.embeddings_freq and self.embeddings_ckpt_path:
if epoch % self.embeddings_freq == 0:
- for log in self.embeddings_logs:
- self.saver.save(self.sess, log, epoch)
+ self.saver.save(self.sess, self.embeddings_ckpt_path, epoch)
for name, value in logs.items():
if name in ['batch', 'size']:
@@ -752,11 +815,12 @@ class ReduceLROnPlateau(Callback):
of epochs, the learning rate is reduced.
Example:
- ```python
- reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
- patience=5, min_lr=0.001)
- model.fit(X_train, Y_train, callbacks=[reduce_lr])
- ```
+
+ ```python
+ reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
+ patience=5, min_lr=0.001)
+ model.fit(X_train, Y_train, callbacks=[reduce_lr])
+ ```
Arguments:
monitor: quantity to be monitored.
@@ -810,8 +874,8 @@ class ReduceLROnPlateau(Callback):
"""Resets wait counter and cooldown counter.
"""
if self.mode not in ['auto', 'min', 'max']:
- warnings.warn('Learning Rate Plateau Reducing mode %s is unknown, '
- 'fallback to auto mode.' % (self.mode), RuntimeWarning)
+ logging.warning('Learning Rate Plateau Reducing mode %s is unknown, '
+ 'fallback to auto mode.' % (self.mode))
self.mode = 'auto'
if (self.mode == 'min' or
(self.mode == 'auto' and 'acc' not in self.monitor)):
@@ -832,8 +896,8 @@ class ReduceLROnPlateau(Callback):
logs['lr'] = K.get_value(self.model.optimizer.lr)
current = logs.get(self.monitor)
if current is None:
- warnings.warn('Learning Rate Plateau Reducing requires %s available!' %
- self.monitor, RuntimeWarning)
+ logging.warning('Learning Rate Plateau Reducing requires %s available!' %
+ self.monitor)
else:
if self.in_cooldown():
self.cooldown_counter -= 1
@@ -868,8 +932,8 @@ class CSVLogger(Callback):
Example:
```python
- csv_logger = CSVLogger('training.log')
- model.fit(X_train, Y_train, callbacks=[csv_logger])
+ csv_logger = CSVLogger('training.log')
+ model.fit(X_train, Y_train, callbacks=[csv_logger])
```
Arguments:
@@ -886,23 +950,26 @@ class CSVLogger(Callback):
self.writer = None
self.keys = None
self.append_header = True
+ self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''
super(CSVLogger, self).__init__()
def on_train_begin(self, logs=None):
if self.append:
if os.path.exists(self.filename):
- with open(self.filename) as f:
+ with open(self.filename, 'r' + self.file_flags) as f:
self.append_header = not bool(len(f.readline()))
- self.csv_file = open(self.filename, 'a')
+ self.csv_file = open(self.filename, 'a' + self.file_flags)
else:
- self.csv_file = open(self.filename, 'w')
+ self.csv_file = open(self.filename, 'w' + self.file_flags)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
def handle_value(k):
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
- if isinstance(k, Iterable) and not is_zero_dim_ndarray:
+ if isinstance(k, six.string_types):
+ return k
+ elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
return '"[%s]"' % (', '.join(map(str, k)))
else:
return k
@@ -936,6 +1003,7 @@ class LambdaCallback(Callback):
This callback is constructed with anonymous functions that will be called
at the appropriate time. Note that the callbacks expects positional
arguments, as:
+
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
`epoch`, `logs`
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
diff --git a/tensorflow/contrib/keras/python/keras/callbacks_test.py b/tensorflow/contrib/keras/python/keras/callbacks_test.py
index 412f736e16..15a7304b60 100644
--- a/tensorflow/contrib/keras/python/keras/callbacks_test.py
+++ b/tensorflow/contrib/keras/python/keras/callbacks_test.py
@@ -436,6 +436,35 @@ class KerasCallbacksTest(test.TestCase):
os.remove(filepath)
+ def test_TerminateOnNaN(self):
+ np.random.seed(1337)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
+ cbks = [keras.callbacks.TerminateOnNaN()]
+ model = keras.models.Sequential()
+ initializer = keras.initializers.Constant(value=1e5)
+ for _ in range(5):
+ model.add(keras.layers.Dense(2,
+ input_dim=INPUT_DIM,
+ activation='relu',
+ kernel_initializer=initializer))
+ model.add(keras.layers.Dense(NUM_CLASSES))
+ model.compile(loss='mean_squared_error',
+ optimizer='rmsprop')
+
+ history = model.fit(x_train, y_train, batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks, epochs=20)
+ loss = history.history['loss']
+ assert len(loss) == 1
+ assert loss[0] == np.inf
+
def test_TensorBoard(self):
np.random.seed(1337)
@@ -479,7 +508,9 @@ class KerasCallbacksTest(test.TestCase):
metrics=['accuracy'])
tsb = keras.callbacks.TensorBoard(
- log_dir=temp_dir, histogram_freq=1, write_images=True)
+ log_dir=temp_dir, histogram_freq=1, write_images=True,
+ write_grads=True, embeddings_freq=1,
+ embeddings_layer_names=['dense_1'], batch_size=5)
cbks = [tsb]
# fit with validation data
diff --git a/tensorflow/contrib/keras/python/keras/datasets/imdb.py b/tensorflow/contrib/keras/python/keras/datasets/imdb.py
index 5c087fe63f..04ab154f9f 100644
--- a/tensorflow/contrib/keras/python/keras/datasets/imdb.py
+++ b/tensorflow/contrib/keras/python/keras/datasets/imdb.py
@@ -118,7 +118,7 @@ def load_data(path='imdb.npz',
for x in xs:
nx = []
for w in x:
- if w >= num_words or w < skip_top:
+ if skip_top <= w < num_words:
nx.append(w)
new_xs.append(nx)
xs = new_xs
diff --git a/tensorflow/contrib/keras/python/keras/datasets/reuters.py b/tensorflow/contrib/keras/python/keras/datasets/reuters.py
index b1c22fee63..2904eb5bf6 100644
--- a/tensorflow/contrib/keras/python/keras/datasets/reuters.py
+++ b/tensorflow/contrib/keras/python/keras/datasets/reuters.py
@@ -104,7 +104,7 @@ def load_data(path='reuters.npz',
for x in xs:
nx = []
for w in x:
- if w >= num_words or w < skip_top:
+ if skip_top <= w < num_words:
nx.append(w)
new_xs.append(nx)
xs = new_xs
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py
index 8bc0c412b5..7561ef78f3 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology.py
@@ -23,7 +23,6 @@ import copy
import json
import os
import re
-import warnings
import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
@@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summar
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as tf_base_layers
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
@@ -50,43 +50,7 @@ except ImportError:
yaml = None
# pylint: enable=g-import-not-at-top
-
-class InputSpec(object):
- """Specifies the ndim, dtype and shape of every input to a layer.
-
- Every layer should expose (if appropriate) an `input_spec` attribute:
- a list of instances of InputSpec (one per input tensor).
-
- A None entry in a shape is compatible with any dimension,
- a None shape is compatible with any shape.
-
- Arguments:
- dtype: Expected datatype of the input.
- shape: Shape tuple, expected shape of the input
- (may include None for unchecked axes).
- ndim: Integer, expected rank of the input.
- max_ndim: Integer, maximum rank of the input.
- min_ndim: Integer, minimum rank of the input.
- axes: Dictionary mapping integer axes to
- a specific dimension value.
- """
-
- def __init__(self,
- dtype=None,
- shape=None,
- ndim=None,
- max_ndim=None,
- min_ndim=None,
- axes=None):
- self.dtype = dtype
- self.shape = shape
- if shape is not None:
- self.ndim = len(shape)
- else:
- self.ndim = ndim
- self.max_ndim = max_ndim
- self.min_ndim = min_ndim
- self.axes = axes or {}
+InputSpec = tf_base_layers.InputSpec # pylint: disable=invalid-name
class Node(object):
@@ -435,6 +399,20 @@ class Layer(tf_base_layers.Layer):
with K.name_scope(self.name):
output_mask = self.compute_mask(inputs, previous_mask)
+ # If the layer returns tensors from its inputs, unmodified,
+ # we copy them to avoid loss of tensor metadata.
+ output_ls = _to_list(output)
+ inputs_ls = _to_list(inputs)
+ output_ls_copy = []
+ for x in output_ls:
+ if x in inputs_ls:
+ x = K.identity(x)
+ output_ls_copy.append(x)
+ if len(output_ls_copy) == 1:
+ output = output_ls_copy[0]
+ else:
+ output = output_ls_copy
+
# Add an inbound node to the layer, so that it keeps track
# of the call and of all new variables created during the call.
# This also updates the layer history of the output tensor(s).
@@ -1085,7 +1063,7 @@ def Input( # pylint: disable=invalid-name
attributes that allow us to build a Keras model
just by knowing the inputs and outputs of the model.
- For instance, if a, b and c and Keras tensors,
+ For instance, if a, b and c are Keras tensors,
it becomes possible to do:
`model = Model(input=[a, b], output=c)`
@@ -1259,16 +1237,16 @@ class Container(Layer):
if len(layer.inbound_nodes) > 1 or (
layer.inbound_nodes and layer.inbound_nodes[0].inbound_layers):
cls_name = self.__class__.__name__
- warnings.warn(cls_name + ' inputs must come from '
- 'a Keras Input layer, '
- 'they cannot be the output of '
- 'a previous non-Input layer. '
- 'Here, a tensor specified as '
- 'input to "' + self.name + '" was not an Input tensor, '
- 'it was generated by layer ' + layer.name + '.\n'
- 'Note that input tensors are '
- 'instantiated via `tensor = Input(shape)`.\n'
- 'The tensor that caused the issue was: ' + str(x.name))
+ logging.warning(cls_name + ' inputs must come from '
+ 'a Keras Input layer, '
+ 'they cannot be the output of '
+ 'a previous non-Input layer. '
+ 'Here, a tensor specified as '
+ 'input to "' + self.name + '" was not an Input tensor, '
+ 'it was generated by layer ' + layer.name + '.\n'
+ 'Note that input tensors are '
+ 'instantiated via `tensor = Input(shape)`.\n'
+ 'The tensor that caused the issue was: ' + str(x.name))
for x in self.outputs:
if not hasattr(x, '_keras_history'):
cls_name = self.__class__.__name__
@@ -1338,76 +1316,96 @@ class Container(Layer):
nodes_depths = {} # dict {node: depth value}
layers_depths = {} # dict {layer: depth value}
layer_indices = {} # dict {layer: index in traversal}
-
- def make_node_marker(node, depth):
- return str(id(node)) + '-' + str(depth)
+ nodes_in_decreasing_depth = []
def build_map_of_graph(tensor,
- seen_nodes=None,
- depth=0,
+ finished_nodes,
+ nodes_in_progress,
layer=None,
node_index=None,
tensor_index=None):
"""Builds a map of the graph of layers.
- This recursively updates the maps `nodes_depths`,
- `layers_depths` and the set `container_nodes`.
-
- Does not try to detect cycles in the graph.
+ This recursively updates the map `layer_indices`,
+ the list `nodes_in_decreasing_depth` and the set `container_nodes`.
Arguments:
tensor: Some tensor in a graph.
- seen_nodes: Set of node ids ("{layer.name}_ib-{node_index}")
- of nodes seen so far. Useful to prevent infinite loops.
- depth: Current depth in the graph (0 = last output).
+ finished_nodes: Set of nodes whose subgraphs have been traversed
+ completely. Useful to prevent duplicated work.
+ nodes_in_progress: Set of nodes that are currently active on the
+ recursion stack. Useful to detect cycles.
layer: Layer from which `tensor` comes from. If not provided,
will be obtained from `tensor._keras_history`.
node_index: Node index from which `tensor` comes from.
tensor_index: Tensor_index from which `tensor` comes from.
+
+ Raises:
+ RuntimeError: if a cycle is detected.
"""
- seen_nodes = seen_nodes or set()
if not layer or node_index is None or tensor_index is None:
layer, node_index, tensor_index = tensor._keras_history
node = layer.inbound_nodes[node_index]
# Prevent cycles.
- seen_nodes.add(make_node_marker(node, depth))
+ if node in nodes_in_progress:
+ raise RuntimeError('The tensor ' + str(tensor) + ' at layer "' +
+ layer.name + '" is part of a cycle.')
+
+ # Don't repeat work for shared subgraphs
+ if node in finished_nodes:
+ return
node_key = layer.name + '_ib-' + str(node_index)
# Update container_nodes.
container_nodes.add(node_key)
- # Update nodes_depths.
- node_depth = nodes_depths.get(node)
- if node_depth is None:
- nodes_depths[node] = depth
- else:
- nodes_depths[node] = max(depth, node_depth)
- # Update layers_depths.
- previously_seen_depth = layers_depths.get(layer)
- if previously_seen_depth is None:
- current_depth = depth
- else:
- current_depth = max(depth, previously_seen_depth)
- layers_depths[layer] = current_depth
+
+ # Store the traversal order for layer sorting.
if layer not in layer_indices:
layer_indices[layer] = len(layer_indices)
+ nodes_in_progress.add(node)
+
# Propagate to all previous tensors connected to this node.
for i in range(len(node.inbound_layers)):
x = node.input_tensors[i]
layer = node.inbound_layers[i]
node_index = node.node_indices[i]
tensor_index = node.tensor_indices[i]
- next_node = layer.inbound_nodes[node_index]
- # use node_marker to prevent cycles
- node_marker = make_node_marker(next_node, current_depth + 1)
- if node_marker not in seen_nodes:
- build_map_of_graph(x, seen_nodes, current_depth + 1, layer,
- node_index, tensor_index)
+ build_map_of_graph(x, finished_nodes, nodes_in_progress, layer,
+ node_index, tensor_index)
+
+ finished_nodes.add(node)
+ nodes_in_progress.remove(node)
+
+ nodes_in_decreasing_depth.append(node)
+ finished_nodes = set()
+ nodes_in_progress = set()
for x in self.outputs:
- seen_nodes = set()
- build_map_of_graph(x, seen_nodes, depth=0)
+ build_map_of_graph(x, finished_nodes, nodes_in_progress)
+
+ for node in reversed(nodes_in_decreasing_depth):
+ # If the depth is not set, the node has no outbound nodes (depth 0).
+ depth = nodes_depths.setdefault(node, 0)
+
+ # Update the depth of the corresponding layer
+ previous_depth = layers_depths.get(node.outbound_layer, 0)
+ # If we've seen this layer before at a higher depth,
+ # we should use that depth instead of the node depth.
+ # This is necessary for shared layers that have inputs at different
+ # depth levels in the graph.
+ depth = max(depth, previous_depth)
+ layers_depths[node.outbound_layer] = depth
+ nodes_depths[node] = depth
+
+ # Update the depth of inbound nodes.
+ for i in range(len(node.inbound_layers)):
+ inbound_layer = node.inbound_layers[i]
+ node_index = node.node_indices[i]
+ inbound_node = inbound_layer.inbound_nodes[node_index]
+ previous_depth = nodes_depths.get(inbound_node, 0)
+ nodes_depths[inbound_node] = max(depth + 1, previous_depth)
# Build a dict {depth: list of nodes with this depth}
nodes_by_depth = {}
@@ -2043,11 +2041,12 @@ class Container(Layer):
json.dumps(node.arguments)
kwargs = node.arguments
except TypeError:
- warnings.warn('Layer ' + layer.name +
- ' was passed non-serializable keyword arguments: ' +
- str(node.arguments) + '. They will not be included '
- 'in the serialized model (and thus will be missing '
- 'at deserialization time).')
+ logging.warning(
+ 'Layer ' + layer.name +
+ ' was passed non-serializable keyword arguments: ' +
+ str(node.arguments) + '. They will not be included '
+ 'in the serialized model (and thus will be missing '
+ 'at deserialization time).')
kwargs = {}
else:
kwargs = {}
@@ -2527,6 +2526,21 @@ def preprocess_weights_for_loading(layer,
A list of weights values (Numpy arrays).
"""
if original_keras_version == '1':
+ if layer.__class__.__name__ == 'Bidirectional':
+ num_weights_per_layer = len(weights) // 2
+
+ forward_weights = preprocess_weights_for_loading(
+ layer.forward_layer, weights[:num_weights_per_layer],
+ original_keras_version, original_backend)
+ backward_weights = preprocess_weights_for_loading(
+ layer.backward_layer, weights[num_weights_per_layer:],
+ original_keras_version, original_backend)
+ weights = forward_weights + backward_weights
+
+ if layer.__class__.__name__ == 'TimeDistributed':
+ weights = preprocess_weights_for_loading(
+ layer.layer, weights, original_keras_version, original_backend)
+
if layer.__class__.__name__ == 'Conv1D':
shape = weights[0].shape
# Handle Keras 1.1 format
@@ -2595,13 +2609,16 @@ def preprocess_weights_for_loading(layer,
recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
weights = [kernel, recurrent_kernel, bias]
- if original_backend and K.backend() != original_backend:
- conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose']
- if layer.__class__.__name__ in conv_layers:
- weights[0] = conv_utils.convert_kernel(weights[0])
- if layer.__class__.__name__ == 'ConvLSTM2D':
+ conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
+ if layer.__class__.__name__ in conv_layers:
+ if original_backend and K.backend() != original_backend:
weights[0] = conv_utils.convert_kernel(weights[0])
- weights[1] = conv_utils.convert_kernel(weights[1])
+ if layer.__class__.__name__ == 'ConvLSTM2D':
+ weights[1] = conv_utils.convert_kernel(weights[1])
+ if K.int_shape(layer.weights[0]) != weights[0].shape:
+ weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
+ if layer.__class__.__name__ == 'ConvLSTM2D':
+ weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
return weights
diff --git a/tensorflow/contrib/keras/python/keras/engine/training.py b/tensorflow/contrib/keras/python/keras/engine/training.py
index 96d1c2f262..09459fd713 100644
--- a/tensorflow/contrib/keras/python/keras/engine/training.py
+++ b/tensorflow/contrib/keras/python/keras/engine/training.py
@@ -23,7 +23,6 @@ import copy
import multiprocessing
import threading
import time
-import warnings
import numpy as np
import six
@@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras import metrics as metrics_module
from tensorflow.contrib.keras.python.keras import optimizers
from tensorflow.contrib.keras.python.keras.engine.topology import Container
from tensorflow.contrib.keras.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.platform import tf_logging as logging
# pylint: disable=g-import-not-at-top
@@ -72,6 +72,8 @@ def _standardize_input_data(data,
Raises:
ValueError: in case of improperly formatted user-provided data.
"""
+ if not names:
+ return []
if data is None:
return [None for _ in range(len(names))]
if isinstance(data, dict):
@@ -85,7 +87,7 @@ def _standardize_input_data(data,
if len(data) != len(names):
if data and hasattr(data[0], 'shape'):
raise ValueError(
- 'Error when checking ' + exception_prefix +
+ 'Error when checking model ' + exception_prefix +
': the list of Numpy arrays '
'that you are passing to your model '
'is not the size the model expected. '
@@ -96,7 +98,7 @@ def _standardize_input_data(data,
if len(names) == 1:
data = [np.asarray(data)]
else:
- raise ValueError('Error when checking ' + exception_prefix +
+ raise ValueError('Error when checking model ' + exception_prefix +
': you are passing a list as '
'input to your model, '
'but the model expects '
@@ -106,15 +108,15 @@ def _standardize_input_data(data,
arrays = data
else:
if not hasattr(data, 'shape'):
- raise TypeError('Error when checking ' + exception_prefix +
+ raise TypeError('Error when checking model ' + exception_prefix +
': data should be a Numpy array, '
'or list/dict of Numpy arrays. '
'Found: ' + str(data)[:200] + '...')
- if len(names) != 1:
+ if len(names) > 1:
# Case: model expects multiple inputs but only received
# a single Numpy array.
- raise ValueError('The model expects ' + str(len(names)) +
- ' input arrays, but only received one array. '
+ raise ValueError('The model expects ' + str(len(names)) + exception_prefix
+ + ' arrays, but only received one array. '
'Found: array with shape ' + str(data.shape))
arrays = [data]
@@ -682,7 +684,8 @@ class Model(Container):
loss,
metrics=None,
loss_weights=None,
- sample_weight_mode=None):
+ sample_weight_mode=None,
+ **kwargs):
"""Configures the model for training.
Arguments:
@@ -692,6 +695,8 @@ class Model(Container):
See [losses](/losses).
If the model has multiple outputs, you can use a different loss
on each output by passing a dictionary or a list of losses.
+ The loss value that will be minimized by the model
+ will then be the sum of all individual losses.
metrics: list of metrics to be evaluated by the model
during training and testing.
Typically you will use `metrics=['accuracy']`.
@@ -701,6 +706,9 @@ class Model(Container):
loss_weights: Optional list or dictionary specifying scalar
coefficients (Python floats) to weight the loss contributions
of different model outputs.
+ The loss value that will be minimized by the model
+ will then be the *weighted sum* of all individual losses,
+ weighted by the `loss_weights` coefficients.
If a list, it is expected to have a 1:1 mapping
to the model's outputs. If a tensor, it is expected to map
output names (strings) to scalar coefficients.
@@ -710,6 +718,7 @@ class Model(Container):
If the model has multiple outputs, you can use a different
`sample_weight_mode` on each output by passing a
dictionary or a list of modes.
+ **kwargs: Additional arguments passed to `tf.Session.run`.
Raises:
ValueError: In case of invalid arguments for
@@ -733,7 +742,7 @@ class Model(Container):
loss_functions = []
for name in self.output_names:
if name not in loss:
- warnings.warn(
+ logging.warning(
'Output "' + name + '" missing from loss dictionary. '
'We assume this was done on purpose, '
'and we will not be expecting '
@@ -975,6 +984,7 @@ class Model(Container):
self.train_function = None
self.test_function = None
self.predict_function = None
+ self._function_kwargs = kwargs
# Collected trainable weights and sort them deterministically.
trainable_weights = self.trainable_weights
@@ -997,7 +1007,10 @@ class Model(Container):
updates = self.updates + training_updates
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(
- inputs, [self.total_loss] + self.metrics_tensors, updates=updates)
+ inputs, [self.total_loss] + self.metrics_tensors,
+ updates=updates,
+ name='train_function',
+ **self._function_kwargs)
def _make_test_function(self):
if not hasattr(self, 'test_function'):
@@ -1011,11 +1024,14 @@ class Model(Container):
# Does update the network states.
self.test_function = K.function(
inputs, [self.total_loss] + self.metrics_tensors,
- updates=self.state_updates)
+ updates=self.state_updates,
+ name='test_function',
+ **self._function_kwargs)
def _make_predict_function(self):
if not hasattr(self, 'predict_function'):
self.predict_function = None
+ self._function_kwargs = {}
if self.predict_function is None:
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
inputs = self._feed_inputs + [K.learning_phase()]
@@ -1024,7 +1040,11 @@ class Model(Container):
# Gets network outputs. Does not update weights.
# Does update the network states.
self.predict_function = K.function(
- inputs, self.outputs, updates=self.state_updates)
+ inputs,
+ self.outputs,
+ updates=self.state_updates,
+ name='predict_function',
+ **self._function_kwargs)
def _fit_loop(self,
f,
@@ -1124,7 +1144,7 @@ class Model(Container):
batch_ids = index_array[batch_start:batch_end]
try:
if isinstance(ins[-1], float):
- # do not slice the training phase flag
+ # Do not slice the training phase flag.
ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = _slice_arrays(ins, batch_ids)
@@ -1143,16 +1163,16 @@ class Model(Container):
batch_logs[l] = o
callbacks.on_batch_end(batch_index, batch_logs)
+ if callback_model.stop_training:
+ break
- if batch_index == len(batches) - 1: # last batch
- # validation
+ if batch_index == len(batches) - 1: # Last batch.
if do_validation:
- # replace with self._evaluate
val_outs = self._test_loop(
val_f, val_ins, batch_size=batch_size, verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
- # same labels assumed
+ # Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
callbacks.on_epoch_end(epoch, epoch_logs)
@@ -1192,7 +1212,7 @@ class Model(Container):
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
if ins and isinstance(ins[-1], float):
- # do not slice the training phase flag
+ # Do not slice the training phase flag.
ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = _slice_arrays(ins, batch_ids)
@@ -1246,7 +1266,7 @@ class Model(Container):
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
if isinstance(ins[-1], float):
- # do not slice the training phase flag
+ # Do not slice the training phase flag.
ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = _slice_arrays(ins, batch_ids)
@@ -1297,13 +1317,13 @@ class Model(Container):
self._feed_input_names,
self._feed_input_shapes,
check_batch_axis=False,
- exception_prefix='model input')
+ exception_prefix='input')
y = _standardize_input_data(
y,
self._feed_output_names,
output_shapes,
check_batch_axis=False,
- exception_prefix='model target')
+ exception_prefix='target')
sample_weights = _standardize_sample_weights(sample_weight,
self._feed_output_names)
class_weights = _standardize_class_weights(class_weight,
@@ -1325,6 +1345,20 @@ class Model(Container):
str(x[0].shape[0]) + ' samples')
return x, y, sample_weights
+ def _get_deduped_metrics_names(self):
+ out_labels = self.metrics_names
+
+ # Rename duplicated metrics name
+ # (can happen with an output layer shared among multiple dataflows).
+ deduped_out_labels = []
+ for i, label in enumerate(out_labels):
+ new_label = label
+ if out_labels.count(label) > 1:
+ dup_idx = out_labels[:i].count(label)
+ new_label += '_' + str(dup_idx + 1)
+ deduped_out_labels.append(new_label)
+ return deduped_out_labels
+
def fit(self,
x=None,
y=None,
@@ -1354,7 +1388,7 @@ class Model(Container):
batch_size: integer. Number of samples per gradient update.
epochs: integer, the number of times to iterate
over the training data arrays.
- verbose: 0, 1, or 2. Verbosity mode.
+ verbose: 0, 1, or 2. Verbosity mode.
0 = silent, 1 = verbose, 2 = one log line per epoch.
callbacks: list of callbacks to be called during training.
See [callbacks](/callbacks).
@@ -1396,7 +1430,7 @@ class Model(Container):
ValueError: In case of mismatch between the provided input data
and what the model expects.
"""
- # validate user data
+ # Validate user data.
x, y, sample_weights = self._standardize_user_data(
x,
y,
@@ -1404,7 +1438,7 @@ class Model(Container):
class_weight=class_weight,
check_batch_axis=False,
batch_size=batch_size)
- # prepare validation data
+ # Prepare validation data.
if validation_data:
do_validation = True
if len(validation_data) == 2:
@@ -1450,7 +1484,7 @@ class Model(Container):
val_f = None
val_ins = None
- # prepare input arrays and training function
+ # Prepare input arrays and training function.
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = x + y + sample_weights + [1.]
else:
@@ -1458,19 +1492,8 @@ class Model(Container):
self._make_train_function()
f = self.train_function
- # prepare display labels
- out_labels = self.metrics_names
-
- # rename duplicated metrics name
- # (can happen with an output layer shared among multiple dataflows)
- deduped_out_labels = []
- for i, label in enumerate(out_labels):
- new_label = label
- if out_labels.count(label) > 1:
- dup_idx = out_labels[:i].count(label)
- new_label += '_' + str(dup_idx + 1)
- deduped_out_labels.append(new_label)
- out_labels = deduped_out_labels
+ # Prepare display labels.
+ out_labels = self._get_deduped_metrics_names()
if do_validation:
callback_metrics = copy.copy(out_labels) + [
@@ -1479,7 +1502,7 @@ class Model(Container):
else:
callback_metrics = copy.copy(out_labels)
- # delegate logic to _fit_loop
+ # Delegate logic to `_fit_loop`.
return self._fit_loop(
f,
ins,
@@ -1521,14 +1544,14 @@ class Model(Container):
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the scalar outputs.
"""
- # validate user data
+ # Validate user data.
x, y, sample_weights = self._standardize_user_data(
x,
y,
sample_weight=sample_weight,
check_batch_axis=False,
batch_size=batch_size)
- # prepare inputs, delegate logic to _test_loop
+ # Prepare inputs, delegate logic to `_test_loop`.
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = x + y + sample_weights + [0.]
else:
@@ -1557,7 +1580,7 @@ class Model(Container):
or in case a stateful model receives a number of samples
that is not a multiple of the batch size.
"""
- # validate user data
+ # Validate user data.
x = _standardize_input_data(
x,
self._feed_input_names,
@@ -1572,7 +1595,7 @@ class Model(Container):
str(x[0].shape[0]) + ' samples. '
'Batch size: ' + str(batch_size) + '.')
- # prepare inputs, delegate logic to _predict_loop
+ # Prepare inputs, delegate logic to `_predict_loop`.
if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = x + [0.]
else:
@@ -1720,7 +1743,7 @@ class Model(Container):
All arrays should contain the same number of samples.
The generator is expected to loop over its data
indefinitely. An epoch finishes when `steps_per_epoch`
- samples have been seen by the model.
+ batches have been seen by the model.
steps_per_epoch: Total number of steps (batches of samples)
to yield from `generator` before declaring one epoch
finished and starting the next epoch. It should typically
@@ -1792,7 +1815,8 @@ class Model(Container):
'you must specify a value for '
'`validation_steps`.')
- out_labels = self.metrics_names
+ # Prepare display labels.
+ out_labels = self._get_deduped_metrics_names()
callback_metrics = out_labels + ['val_' + n for n in out_labels]
# prepare callbacks
@@ -1829,8 +1853,11 @@ class Model(Container):
'or `(val_x, val_y)`. Found: ' + str(validation_data))
val_x, val_y, val_sample_weights = self._standardize_user_data(
val_x, val_y, val_sample_weight)
+ val_data = val_x + val_y + val_sample_weights
+ if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ val_data += [0.]
for cbk in callbacks:
- cbk.validation_data = val_x + [val_y, val_sample_weights]
+ cbk.validation_data = val_data
enqueuer = None
try:
diff --git a/tensorflow/contrib/keras/python/keras/integration_test.py b/tensorflow/contrib/keras/python/keras/integration_test.py
index f25e8eeaac..bcd844201c 100644
--- a/tensorflow/contrib/keras/python/keras/integration_test.py
+++ b/tensorflow/contrib/keras/python/keras/integration_test.py
@@ -80,7 +80,7 @@ class KerasIntegrationTest(test.TestCase):
def test_temporal_classification_declarative(self):
with self.test_session():
- np.random.seed(1337)
+ np.random.seed(1336)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=200,
test_samples=100,
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
index 30325b7148..9ab2e72bf1 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
@@ -124,9 +124,12 @@ class ConvRecurrent2D(Recurrent):
self.return_sequences = return_sequences
self.go_backwards = go_backwards
self.stateful = stateful
- self.input_spec = InputSpec(ndim=5)
+ self.input_spec = [InputSpec(ndim=5)]
+ self.state_spec = None
def _compute_output_shape(self, input_shape):
+ if isinstance(input_shape, list):
+ input_shape = input_shape[0]
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
rows = input_shape[3]
@@ -344,11 +347,14 @@ class ConvLSTM2D(ConvRecurrent2D):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
+ self.state_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- # TODO(fchollet): better handling of input spec
- self.input_spec = InputSpec(shape=input_shape)
+ if isinstance(input_shape, list):
+ input_shape = input_shape[0]
+ input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
+ batch_size = input_shape[0] if self.stateful else None
+ self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:])
if self.stateful:
self.reset_states()
@@ -364,6 +370,13 @@ class ConvLSTM2D(ConvRecurrent2D):
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
+ state_shape = [None] * 4
+ state_shape[channel_axis] = input_dim
+ state_shape = tuple(state_shape)
+ self.state_spec = [
+ InputSpec(shape=state_shape),
+ InputSpec(shape=state_shape)
+ ]
kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
self.kernel_shape = kernel_shape
recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
@@ -417,7 +430,7 @@ class ConvLSTM2D(ConvRecurrent2D):
self.bias_o = None
self.built = True
- def get_initial_states(self, inputs):
+ def get_initial_state(self, inputs):
# (samples, timesteps, rows, cols, filters)
initial_state = K.zeros_like(inputs)
# (samples, rows, cols, filters)
@@ -433,8 +446,9 @@ class ConvLSTM2D(ConvRecurrent2D):
def reset_states(self):
if not self.stateful:
raise RuntimeError('Layer must be stateful.')
- input_shape = self.input_spec.shape
+ input_shape = self.input_spec[0].shape
output_shape = self._compute_output_shape(input_shape)
+
if not input_shape[0]:
raise ValueError('If a RNN is stateful, a complete '
'input_shape must be provided '
@@ -453,8 +467,8 @@ class ConvLSTM2D(ConvRecurrent2D):
np.zeros((input_shape[0], out_row, out_col, out_filter)))
else:
self.states = [
- K.zeros((input_shape[0], out_row, out_col, out_filter)), K.zeros(
- (input_shape[0], out_row, out_col, out_filter))
+ K.zeros((input_shape[0], out_row, out_col, out_filter)),
+ K.zeros((input_shape[0], out_row, out_col, out_filter))
]
def get_constants(self, inputs, training=None):
diff --git a/tensorflow/contrib/keras/python/keras/layers/lstm_test.py b/tensorflow/contrib/keras/python/keras/layers/lstm_test.py
index 0e1d148bd8..90bf95a781 100644
--- a/tensorflow/contrib/keras/python/keras/layers/lstm_test.py
+++ b/tensorflow/contrib/keras/python/keras/layers/lstm_test.py
@@ -189,6 +189,109 @@ class LSTMLayerTest(test.TestCase):
l2 = layer_class.from_config(l1.get_config())
assert l1.get_config() == l2.get_config()
+ def test_specify_initial_state_keras_tensor(self):
+ num_states = 2
+ timesteps = 3
+ embedding_dim = 4
+ units = 3
+ num_samples = 2
+
+ with self.test_session():
+ # Test with Keras tensor
+ inputs = keras.Input((timesteps, embedding_dim))
+ initial_state = [keras.Input((units,)) for _ in range(num_states)]
+ layer = keras.layers.LSTM(units)
+ if len(initial_state) == 1:
+ output = layer(inputs, initial_state=initial_state[0])
+ else:
+ output = layer(inputs, initial_state=initial_state)
+ assert initial_state[0] in layer.inbound_nodes[0].input_tensors
+
+ model = keras.models.Model([inputs] + initial_state, output)
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
+
+ inputs = np.random.random((num_samples, timesteps, embedding_dim))
+ initial_state = [np.random.random((num_samples, units))
+ for _ in range(num_states)]
+ targets = np.random.random((num_samples, units))
+ model.train_on_batch([inputs] + initial_state, targets)
+
+ def test_specify_initial_state_non_keras_tensor(self):
+ num_states = 2
+ timesteps = 3
+ embedding_dim = 4
+ units = 3
+ num_samples = 2
+
+ with self.test_session():
+ # Test with non-Keras tensor
+ inputs = keras.Input((timesteps, embedding_dim))
+ initial_state = [keras.backend.random_normal_variable(
+ (num_samples, units), 0, 1)
+ for _ in range(num_states)]
+ layer = keras.layers.LSTM(units)
+ output = layer(inputs, initial_state=initial_state)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
+
+ inputs = np.random.random((num_samples, timesteps, embedding_dim))
+ targets = np.random.random((num_samples, units))
+ model.train_on_batch(inputs, targets)
+
+ def test_reset_states_with_values(self):
+ num_states = 2
+ timesteps = 3
+ embedding_dim = 4
+ units = 3
+ num_samples = 2
+
+ with self.test_session():
+ layer = keras.layers.LSTM(units, stateful=True)
+ layer.build((num_samples, timesteps, embedding_dim))
+ layer.reset_states()
+ assert len(layer.states) == num_states
+ assert layer.states[0] is not None
+ np.testing.assert_allclose(
+ keras.backend.eval(layer.states[0]),
+ np.zeros(keras.backend.int_shape(layer.states[0])),
+ atol=1e-4)
+ state_shapes = [keras.backend.int_shape(state) for state in layer.states]
+ values = [np.ones(shape) for shape in state_shapes]
+ if len(values) == 1:
+ values = values[0]
+ layer.reset_states(values)
+ np.testing.assert_allclose(
+ keras.backend.eval(layer.states[0]),
+ np.ones(keras.backend.int_shape(layer.states[0])),
+ atol=1e-4)
+
+ # Test with invalid data
+ with self.assertRaises(ValueError):
+ layer.reset_states([1] * (len(layer.states) + 1))
+
+ def test_specify_state_with_masking(self):
+ num_states = 2
+ timesteps = 3
+ embedding_dim = 4
+ units = 3
+ num_samples = 2
+
+ with self.test_session():
+ inputs = keras.Input((timesteps, embedding_dim))
+ _ = keras.layers.Masking()(inputs)
+ initial_state = [keras.Input((units,)) for _ in range(num_states)]
+ output = keras.layers.LSTM(units)(inputs, initial_state=initial_state)
+
+ model = keras.models.Model([inputs] + initial_state, output)
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
+
+ inputs = np.random.random((num_samples, timesteps, embedding_dim))
+ initial_state = [np.random.random((num_samples, units))
+ for _ in range(num_states)]
+ targets = np.random.random((num_samples, units))
+ model.train_on_batch([inputs] + initial_state, targets)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/contrib/keras/python/keras/layers/merge.py
index b4bb9935fd..84c03fdebd 100644
--- a/tensorflow/contrib/keras/python/keras/layers/merge.py
+++ b/tensorflow/contrib/keras/python/keras/layers/merge.py
@@ -139,7 +139,8 @@ class _Merge(Layer):
batch_size = x_shape[0]
new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
x_transposed = K.reshape(x,
- K.stack([batch_size, K.prod(x_shape[1:])]))
+ K.stack([batch_size,
+ K.prod(x_shape[1:])]))
x_transposed = K.permute_dimensions(x_transposed, (1, 0))
x_transposed = K.reshape(x_transposed, new_shape)
reshaped_inputs.append(x_transposed)
diff --git a/tensorflow/contrib/keras/python/keras/layers/recurrent.py b/tensorflow/contrib/keras/python/keras/layers/recurrent.py
index e608921add..5e8c23ed3e 100644
--- a/tensorflow/contrib/keras/python/keras/layers/recurrent.py
+++ b/tensorflow/contrib/keras/python/keras/layers/recurrent.py
@@ -197,11 +197,16 @@ class Recurrent(Layer):
To reset the states of your model, call `.reset_states()` on either
a specific layer, or on your entire model.
- # Note on specifying initial states in RNNs
- You can specify the initial state of RNN layers by calling them with
- the keyword argument `initial_state`. The value of `initial_state`
- should be a tensor or list of tensors representing the initial state
- of the RNN layer.
+ # Note on specifying the initial state of RNNs
+ You can specify the initial state of RNN layers symbolically by
+ calling them with the keyword argument `initial_state`. The value of
+ `initial_state` should be a tensor or list of tensors representing
+ the initial state of the RNN layer.
+
+ You can specify the initial state of RNN layers numerically by
+ calling `reset_states` with the keyword argument `states`. The value of
+ `states` should be a numpy array or list of numpy arrays representing
+ the initial state of the RNN layer.
"""
def __init__(self,
@@ -218,7 +223,7 @@ class Recurrent(Layer):
self.unroll = unroll
self.implementation = implementation
self.supports_masking = True
- self.input_spec = InputSpec(ndim=3)
+ self.input_spec = [InputSpec(ndim=3)]
self.state_spec = None
self.dropout = 0
self.recurrent_dropout = 0
@@ -235,6 +240,8 @@ class Recurrent(Layer):
def compute_mask(self, inputs, mask):
if self.return_sequences:
+ if isinstance(mask, list):
+ return mask[0]
return mask
else:
return None
@@ -245,15 +252,15 @@ class Recurrent(Layer):
def get_constants(self, inputs, training=None):
return []
- def get_initial_states(self, inputs):
+ def get_initial_state(self, inputs):
# build an all-zero tensor of shape (samples, output_dim)
initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim)
initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,)
initial_state = K.expand_dims(initial_state) # (samples, 1)
initial_state = K.tile(initial_state, [1,
self.units]) # (samples, output_dim)
- initial_states = [initial_state for _ in range(len(self.states))]
- return initial_states
+ initial_state = [initial_state for _ in range(len(self.states))]
+ return initial_state
def preprocess_input(self, inputs, training=None):
return inputs
@@ -263,50 +270,62 @@ class Recurrent(Layer):
# and if it a Keras tensor,
# then add it to the inputs and temporarily
# modify the input spec to include the state.
- if initial_state is not None:
- if hasattr(initial_state, '_keras_history'):
- # Compute the full input spec, including state
- input_spec = self.input_spec
- state_spec = self.state_spec
- if not isinstance(state_spec, list):
- state_spec = [state_spec]
- self.input_spec = [input_spec] + state_spec
-
- # Compute the full inputs, including state
- if not isinstance(initial_state, (list, tuple)):
- initial_state = [initial_state]
- inputs = [inputs] + list(initial_state)
-
- # Perform the call
- output = super(Recurrent, self).__call__(inputs, **kwargs)
-
- # Restore original input spec
- self.input_spec = input_spec
- return output
- else:
- kwargs['initial_state'] = initial_state
- return super(Recurrent, self).__call__(inputs, **kwargs)
+ if initial_state is None:
+ return super(Recurrent, self).__call__(inputs, **kwargs)
+
+ if not isinstance(initial_state, (list, tuple)):
+ initial_state = [initial_state]
+
+ is_keras_tensor = hasattr(initial_state[0], '_keras_history')
+ for tensor in initial_state:
+ if hasattr(tensor, '_keras_history') != is_keras_tensor:
+ raise ValueError('The initial state of an RNN layer cannot be'
+ ' specified with a mix of Keras tensors and'
+ ' non-Keras tensors')
+
+ if is_keras_tensor:
+ # Compute the full input spec, including state
+ input_spec = self.input_spec
+ state_spec = self.state_spec
+ if not isinstance(input_spec, list):
+ input_spec = [input_spec]
+ if not isinstance(state_spec, list):
+ state_spec = [state_spec]
+ self.input_spec = input_spec + state_spec
+
+ # Compute the full inputs, including state
+ inputs = [inputs] + list(initial_state)
+
+ # Perform the call
+ output = super(Recurrent, self).__call__(inputs, **kwargs)
+
+ # Restore original input spec
+ self.input_spec = input_spec
+ return output
+ else:
+ kwargs['initial_state'] = initial_state
+ return super(Recurrent, self).__call__(inputs, **kwargs)
- def call(self, inputs, mask=None, initial_state=None, training=None):
+ def call(self, inputs, mask=None, training=None, initial_state=None):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
- if initial_state is not None:
- if not isinstance(initial_state, (list, tuple)):
- initial_states = [initial_state]
- else:
- initial_states = list(initial_state)
if isinstance(inputs, list):
- initial_states = inputs[1:]
+ initial_state = inputs[1:]
inputs = inputs[0]
+ elif initial_state is not None:
+ pass
elif self.stateful:
- initial_states = self.states
+ initial_state = self.states
else:
- initial_states = self.get_initial_states(inputs)
+ initial_state = self.get_initial_state(inputs)
+
+ if isinstance(mask, list):
+ mask = mask[0]
- if len(initial_states) != len(self.states):
+ if len(initial_state) != len(self.states):
raise ValueError('Layer has ' + str(len(self.states)) +
- ' states but was passed ' + str(len(initial_states)) +
+ ' states but was passed ' + str(len(initial_state)) +
' initial states.')
input_shape = K.int_shape(inputs)
if self.unroll and input_shape[1] is None:
@@ -326,7 +345,7 @@ class Recurrent(Layer):
last_output, outputs, states = K.rnn(
self.step,
preprocessed_input,
- initial_states,
+ initial_state,
go_backwards=self.go_backwards,
mask=mask,
constants=constants,
@@ -347,13 +366,10 @@ class Recurrent(Layer):
else:
return last_output
- def reset_states(self, states_value=None):
+ def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
- if not self.input_spec:
- raise RuntimeError('Layer has never been called '
- 'and thus has no states.')
- batch_size = self.input_spec.shape[0]
+ batch_size = self.input_spec[0].shape[0]
if not batch_size:
raise ValueError('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '
@@ -365,28 +381,27 @@ class Recurrent(Layer):
'- If using the functional API, specify '
'the time dimension by passing a '
'`batch_shape` argument to your Input layer.')
- if states_value is not None:
- if not isinstance(states_value, (list, tuple)):
- states_value = [states_value]
- if len(states_value) != len(self.states):
- raise ValueError('The layer has ' + str(len(self.states)) +
- ' states, but the `states_value` '
- 'argument passed '
- 'only has ' + str(len(states_value)) + ' entries')
+ # initialize state if None
if self.states[0] is None:
self.states = [K.zeros((batch_size, self.units)) for _ in self.states]
- if not states_value:
- return
- for i, state in enumerate(self.states):
- if states_value:
- value = states_value[i]
+ elif states is None:
+ for state in self.states:
+ K.set_value(state, np.zeros((batch_size, self.units)))
+ else:
+ if not isinstance(states, (list, tuple)):
+ states = [states]
+ if len(states) != len(self.states):
+ raise ValueError('Layer ' + self.name + ' expects ' +
+ str(len(self.states)) + ' states, '
+ 'but it received ' + str(len(states)) +
+ ' state values. Input received: ' + str(states))
+ for index, (value, state) in enumerate(zip(states, self.states)):
if value.shape != (batch_size, self.units):
- raise ValueError('Expected state #' + str(i) + ' to have shape ' +
- str((batch_size, self.units)) +
- ' but got array with shape ' + str(value.shape))
- else:
- value = np.zeros((batch_size, self.units))
- K.set_value(state, value)
+ raise ValueError('State ' + str(index) +
+ ' is incompatible with layer ' + self.name +
+ ': expected shape=' + str((batch_size, self.units)) +
+ ', found shape=' + str(value.shape))
+ K.set_value(state, value)
def get_config(self):
config = {
@@ -477,6 +492,7 @@ class SimpleRNN(Recurrent):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
+ self.state_spec = InputSpec(shape=(None, self.units))
def build(self, input_shape):
if isinstance(input_shape, list):
@@ -485,8 +501,7 @@ class SimpleRNN(Recurrent):
batch_size = input_shape[0] if self.stateful else None
self.input_dim = input_shape[2]
- self.input_spec = InputSpec(shape=(batch_size, None, self.input_dim))
- self.state_spec = InputSpec(shape=(batch_size, self.units))
+ self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
self.states = [None]
if self.stateful:
@@ -707,16 +722,15 @@ class GRU(Recurrent):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
+ self.state_spec = InputSpec(shape=(None, self.units))
def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- self.input_spec = InputSpec(shape=input_shape)
batch_size = input_shape[0] if self.stateful else None
self.input_dim = input_shape[2]
- self.input_spec = InputSpec(shape=(batch_size, None, self.input_dim))
- self.state_spec = InputSpec(shape=(batch_size, self.units))
+ self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
self.states = [None]
if self.stateful:
@@ -1020,19 +1034,18 @@ class LSTM(Recurrent):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
+ self.state_spec = [
+ InputSpec(shape=(None, self.units)),
+ InputSpec(shape=(None, self.units))
+ ]
def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- self.input_spec = InputSpec(shape=input_shape)
batch_size = input_shape[0] if self.stateful else None
self.input_dim = input_shape[2]
- self.input_spec = InputSpec(shape=(batch_size, None, self.input_dim))
- self.state_spec = [
- InputSpec(shape=(batch_size, self.units)), InputSpec(
- shape=(batch_size, self.units))
- ]
+ self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
self.states = [None, None]
if self.stateful:
@@ -1052,16 +1065,22 @@ class LSTM(Recurrent):
constraint=self.recurrent_constraint)
if self.use_bias:
+ if self.unit_forget_bias:
+
+ def bias_initializer(_, *args, **kwargs):
+ return K.concatenate([
+ self.bias_initializer((self.units,), *args, **kwargs),
+ initializers.Ones()((self.units,), *args, **kwargs),
+ self.bias_initializer((self.units * 2,), *args, **kwargs),
+ ])
+ else:
+ bias_initializer = self.bias_initializer
self.bias = self.add_weight(
shape=(self.units * 4,),
name='bias',
- initializer=self.bias_initializer,
+ initializer=bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
- if self.unit_forget_bias:
- bias_value = np.zeros((self.units * 4,))
- bias_value[self.units:self.units * 2] = 1.
- K.set_value(self.bias, bias_value)
else:
self.bias = None
diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
index 092501cb11..dbc79fb193 100644
--- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py
+++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
@@ -132,13 +132,18 @@ class TimeDistributed(Wrapper):
model = Sequential()
model.add(TimeDistributed(Dense(8), input_shape=(10, 16)))
# now model.output_shape == (None, 10, 8)
+ ```
+
+ The output will then have shape `(32, 10, 8)`.
- # subsequent layers: no need for input_shape
+ In subsequent layers, there is no need for the `input_shape`:
+
+ ```python
model.add(TimeDistributed(Dense(32)))
# now model.output_shape == (None, 10, 32)
```
- The output will then have shape `(32, 10, 8)`.
+ The output will then have shape `(32, 10, 32)`.
`TimeDistributed` can be used with arbitrary layers, not just `Dense`,
for instance with a `Conv2D` layer:
@@ -186,12 +191,7 @@ class TimeDistributed(Wrapper):
output = self.layer.call(x)
return output, []
- _, outputs, _ = K.rnn(
- step,
- inputs,
- initial_states=[],
- input_length=input_shape[1],
- unroll=False)
+ _, outputs, _ = K.rnn(step, inputs, initial_states=[], unroll=False)
y = outputs
else:
# No batch size specified, therefore the layer will be able
diff --git a/tensorflow/contrib/keras/python/keras/losses.py b/tensorflow/contrib/keras/python/keras/losses.py
index 54b8fa429d..777ec440ac 100644
--- a/tensorflow/contrib/keras/python/keras/losses.py
+++ b/tensorflow/contrib/keras/python/keras/losses.py
@@ -52,6 +52,20 @@ def hinge(y_true, y_pred):
return K.mean(K.maximum(1. - y_true * y_pred, 0.), axis=-1)
+def categorical_hinge(y_true, y_pred):
+ pos = K.sum(y_true * y_pred, axis=-1)
+ neg = K.max((1. - y_true) * y_pred, axis=-1)
+ return K.maximum(neg - pos + 1., 0.)
+
+
+def logcosh(y_true, y_pred):
+
+ def cosh(x):
+ return (K.exp(x) + K.exp(-x)) / 2
+
+ return K.mean(K.log(cosh(y_pred - y_true)), axis=-1)
+
+
def categorical_crossentropy(y_true, y_pred):
return K.categorical_crossentropy(y_pred, y_true)
diff --git a/tensorflow/contrib/keras/python/keras/losses_test.py b/tensorflow/contrib/keras/python/keras/losses_test.py
index fd4458cce2..6bdcc0b5ff 100644
--- a/tensorflow/contrib/keras/python/keras/losses_test.py
+++ b/tensorflow/contrib/keras/python/keras/losses_test.py
@@ -34,7 +34,9 @@ ALL_LOSSES = [keras.losses.mean_squared_error,
keras.losses.binary_crossentropy,
keras.losses.kullback_leibler_divergence,
keras.losses.poisson,
- keras.losses.cosine_proximity]
+ keras.losses.cosine_proximity,
+ keras.losses.logcosh,
+ keras.losses.categorical_hinge]
class KerasLossesTest(test.TestCase):
@@ -73,6 +75,14 @@ class KerasLossesTest(test.TestCase):
new_fn = keras.losses.deserialize(config)
self.assertEqual(fn, new_fn)
+ def test_categorical_hinge(self):
+ y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1],
+ [0.1, 0.2, 0.7]]))
+ y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]]))
+ expected_loss = ((0.3 - 0.2 + 1) + (0.7 - 0.1 + 1)) / 2.0
+ loss = keras.backend.eval(keras.losses.categorical_hinge(y_true, y_pred))
+ self.assertAllClose(expected_loss, np.mean(loss))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/metrics.py b/tensorflow/contrib/keras/python/keras/metrics.py
index 59d380f73b..93c8684f91 100644
--- a/tensorflow/contrib/keras/python/keras/metrics.py
+++ b/tensorflow/contrib/keras/python/keras/metrics.py
@@ -27,6 +27,7 @@ from tensorflow.contrib.keras.python.keras.losses import categorical_crossentrop
from tensorflow.contrib.keras.python.keras.losses import cosine_proximity
from tensorflow.contrib.keras.python.keras.losses import hinge
from tensorflow.contrib.keras.python.keras.losses import kullback_leibler_divergence
+from tensorflow.contrib.keras.python.keras.losses import logcosh
from tensorflow.contrib.keras.python.keras.losses import mean_absolute_error
from tensorflow.contrib.keras.python.keras.losses import mean_absolute_percentage_error
from tensorflow.contrib.keras.python.keras.losses import mean_squared_error
diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py
index 1c041091fc..0ae373da3c 100644
--- a/tensorflow/contrib/keras/python/keras/models.py
+++ b/tensorflow/contrib/keras/python/keras/models.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import copy
import json
import os
-import warnings
import numpy as np
@@ -36,6 +35,7 @@ from tensorflow.contrib.keras.python.keras.engine.topology import Layer
from tensorflow.contrib.keras.python.keras.engine.training import Model
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
# pylint: disable=g-import-not-at-top
@@ -133,7 +133,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
if include_optimizer and hasattr(model, 'optimizer'):
if isinstance(model.optimizer, optimizers.TFOptimizer):
- warnings.warn(
+ logging.warning(
'TensorFlow optimizers do not '
'make it possible to access '
'optimizer attributes or optimizer state '
@@ -189,7 +189,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
f.close()
-def load_model(filepath, custom_objects=None):
+def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin
"""Loads a model saved via `save_model`.
Arguments:
@@ -197,12 +197,16 @@ def load_model(filepath, custom_objects=None):
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
considered during deserialization.
+ compile: Boolean, whether to compile the model
+ after loading.
Returns:
A Keras model instance. If an optimizer was found
as part of the saved model, the model is already
compiled. Otherwise, the model is uncompiled and
- a warning will be displayed.
+ a warning will be displayed. When `compile` is set
+ to False, the compilation is omitted without any
+ warning.
Raises:
ImportError: if h5py is not available.
@@ -264,11 +268,16 @@ def load_model(filepath, custom_objects=None):
# set weights
topology.load_weights_from_hdf5_group(f['model_weights'], model.layers)
+ # Early return if compilation is not required.
+ if not compile:
+ f.close()
+ return model
+
# instantiate optimizer
training_config = f.attrs.get('training_config')
if training_config is None:
- warnings.warn('No training configuration found in save file: '
- 'the model was *not* compiled. Compile it manually.')
+ logging.warning('No training configuration found in save file: '
+ 'the model was *not* compiled. Compile it manually.')
f.close()
return model
training_config = json.loads(training_config.decode('utf-8'))
@@ -320,9 +329,12 @@ def model_from_config(config, custom_objects=None):
Returns:
A Keras model instance (uncompiled).
+
+ Raises:
+ TypeError if `config` is not a dictionary
"""
if isinstance(config, list):
- raise TypeError('`model_fom_config` expects a dictionary, not a list. '
+ raise TypeError('`model_from_config` expects a dictionary, not a list. '
'Maybe you meant to use '
'`Sequential.from_config(config)`?')
return layer_module.deserialize(config, custom_objects=custom_objects)
@@ -730,7 +742,7 @@ class Sequential(Model):
optimizer: str (name of optimizer) or optimizer object.
See [optimizers](/optimizers).
loss: str (name of objective function) or objective function.
- See [objectives](/objectives).
+ See [losses](/losses).
metrics: list of metrics to be evaluated by the model
during training and testing.
Typically you will use `metrics=['accuracy']`.
@@ -739,7 +751,8 @@ class Sequential(Model):
sample weighting (2D weights), set this to "temporal".
"None" defaults to sample-wise weights (1D).
**kwargs: for Theano backend, these are passed into K.function.
- Ignored for Tensorflow backend.
+ When using the Tensorflow backend, these are passed into
+ `tf.Session.run`.
Example:
```python
@@ -762,11 +775,14 @@ class Sequential(Model):
**kwargs)
self.optimizer = self.model.optimizer
self.loss = self.model.loss
+ self.total_loss = self.model.total_loss
self.loss_weights = self.model.loss_weights
self.metrics = self.model.metrics
self.metrics_tensors = self.model.metrics_tensors
self.metrics_names = self.model.metrics_names
self.sample_weight_mode = self.model.sample_weight_mode
+ self.sample_weights = self.model.sample_weights
+ self.targets = self.model.targets
def fit(self,
x,
@@ -966,10 +982,10 @@ class Sequential(Model):
"""
preds = self.predict(x, batch_size, verbose)
if preds.min() < 0. or preds.max() > 1.:
- warnings.warn('Network returning invalid probability values. '
- 'The last layer might not normalize predictions '
- 'into probabilities '
- '(like softmax or sigmoid would).')
+ logging.warning('Network returning invalid probability values. '
+ 'The last layer might not normalize predictions '
+ 'into probabilities '
+ '(like softmax or sigmoid would).')
return preds
def predict_classes(self, x, batch_size=32, verbose=1):
@@ -1018,8 +1034,8 @@ class Sequential(Model):
- a tuple (inputs, targets, sample_weights).
All arrays should contain the same number of samples.
The generator is expected to loop over its data
- indefinitely. An epoch finishes when `samples_per_epoch`
- samples have been seen by the model.
+ indefinitely. An epoch finishes when `steps_per_epoch`
+ batches have been seen by the model.
steps_per_epoch: Total number of steps (batches of samples)
to yield from `generator` before declaring one epoch
finished and starting the next epoch. It should typically
@@ -1072,7 +1088,7 @@ class Sequential(Model):
f.close()
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
- samples_per_epoch=10000, epochs=10)
+ steps_per_epoch=1000, epochs=10)
```
"""
if self.model is None:
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/image.py b/tensorflow/contrib/keras/python/keras/preprocessing/image.py
index 8cceb441df..0d69396e8b 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/image.py
+++ b/tensorflow/contrib/keras/python/keras/preprocessing/image.py
@@ -24,12 +24,12 @@ from __future__ import print_function
import os
import re
import threading
-import warnings
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from tensorflow.contrib.keras.python.keras import backend as K
+from tensorflow.python.platform import tf_logging as logging
# pylint: disable=g-import-not-at-top
@@ -368,9 +368,9 @@ def load_img(path, grayscale=False, target_size=None):
if img.mode != 'RGB':
img = img.convert('RGB')
if target_size:
- wh_tuple = (target_size[1], target_size[0])
- if img.size != wh_tuple:
- img = img.resize(wh_tuple)
+ hw_tuple = (target_size[1], target_size[0])
+ if img.size != hw_tuple:
+ img = img.resize(hw_tuple)
return img
@@ -391,6 +391,7 @@ class ImageDataGenerator(object):
featurewise_std_normalization: divide inputs by std of the dataset.
samplewise_std_normalization: divide each input by its std.
zca_whitening: apply ZCA whitening.
+ zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
rotation_range: degrees (0 to 180).
width_shift_range: fraction of total width.
height_shift_range: fraction of total height.
@@ -428,6 +429,7 @@ class ImageDataGenerator(object):
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
+ zca_epsilon=1e-6,
rotation_range=0.,
width_shift_range=0.,
height_shift_range=0.,
@@ -448,6 +450,7 @@ class ImageDataGenerator(object):
self.featurewise_std_normalization = featurewise_std_normalization
self.samplewise_std_normalization = samplewise_std_normalization
self.zca_whitening = zca_whitening
+ self.zca_epsilon = zca_epsilon
self.rotation_range = rotation_range
self.width_shift_range = width_shift_range
self.height_shift_range = height_shift_range
@@ -497,7 +500,7 @@ class ImageDataGenerator(object):
seed=None,
save_to_dir=None,
save_prefix='',
- save_format='jpeg'):
+ save_format='png'):
return NumpyArrayIterator(
x,
y,
@@ -521,7 +524,7 @@ class ImageDataGenerator(object):
seed=None,
save_to_dir=None,
save_prefix='',
- save_format='jpeg',
+ save_format='png',
follow_links=False):
return DirectoryIterator(
directory,
@@ -563,28 +566,28 @@ class ImageDataGenerator(object):
if self.mean is not None:
x -= self.mean
else:
- warnings.warn('This ImageDataGenerator specifies '
- '`featurewise_center`, but it hasn\'t'
- 'been fit on any training data. Fit it '
- 'first by calling `.fit(numpy_data)`.')
+ logging.warning('This ImageDataGenerator specifies '
+ '`featurewise_center`, but it hasn\'t'
+ 'been fit on any training data. Fit it '
+ 'first by calling `.fit(numpy_data)`.')
if self.featurewise_std_normalization:
if self.std is not None:
x /= (self.std + 1e-7)
else:
- warnings.warn('This ImageDataGenerator specifies '
- '`featurewise_std_normalization`, but it hasn\'t'
- 'been fit on any training data. Fit it '
- 'first by calling `.fit(numpy_data)`.')
+ logging.warning('This ImageDataGenerator specifies '
+ '`featurewise_std_normalization`, but it hasn\'t'
+ 'been fit on any training data. Fit it '
+ 'first by calling `.fit(numpy_data)`.')
if self.zca_whitening:
if self.principal_components is not None:
flatx = np.reshape(x, (x.size))
whitex = np.dot(flatx, self.principal_components)
x = np.reshape(whitex, (x.shape[0], x.shape[1], x.shape[2]))
else:
- warnings.warn('This ImageDataGenerator specifies '
- '`zca_whitening`, but it hasn\'t'
- 'been fit on any training data. Fit it '
- 'first by calling `.fit(numpy_data)`.')
+ logging.warning('This ImageDataGenerator specifies '
+ '`zca_whitening`, but it hasn\'t'
+ 'been fit on any training data. Fit it '
+ 'first by calling `.fit(numpy_data)`.')
return x
def random_transform(self, x):
@@ -640,7 +643,8 @@ class ImageDataGenerator(object):
transform_matrix = None
if theta != 0:
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
- [np.sin(theta), np.cos(theta), 0], [0, 0, 1]])
+ [np.sin(theta),
+ np.cos(theta), 0], [0, 0, 1]])
transform_matrix = rotation_matrix
if tx != 0 or ty != 0:
@@ -748,7 +752,7 @@ class ImageDataGenerator(object):
sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
u, s, _ = linalg.svd(sigma)
self.principal_components = np.dot(
- np.dot(u, np.diag(1. / np.sqrt(s + 10e-7))), u.T)
+ np.dot(u, np.diag(1. / np.sqrt(s + self.zca_epsilon))), u.T)
class Iterator(object):
@@ -836,7 +840,7 @@ class NumpyArrayIterator(Iterator):
data_format=None,
save_to_dir=None,
save_prefix='',
- save_format='jpeg'):
+ save_format='png'):
if y is not None and len(x) != len(y):
raise ValueError('X (images tensor) and y (labels) '
'should have the same length. '
@@ -927,6 +931,8 @@ class DirectoryIterator(Iterator):
`"binary"`: binary targets (if there are only two classes),
`"categorical"`: categorical targets,
`"sparse"`: integer targets,
+ `"input"`: targets are images identical to input images (mainly
+ used to work with autoencoders),
`None`: no targets get yielded (only input images are yielded).
batch_size: Integer, size of a batch.
shuffle: Boolean, whether to shuffle the data between epochs.
@@ -955,7 +961,7 @@ class DirectoryIterator(Iterator):
data_format=None,
save_to_dir=None,
save_prefix='',
- save_format='jpeg',
+ save_format='png',
follow_links=False):
if data_format is None:
data_format = K.image_data_format()
@@ -978,10 +984,11 @@ class DirectoryIterator(Iterator):
else:
self.image_shape = (1,) + self.target_size
self.classes = classes
- if class_mode not in {'categorical', 'binary', 'sparse', None}:
+ if class_mode not in {'categorical', 'binary', 'sparse', 'input', None}:
raise ValueError('Invalid class_mode:', class_mode,
'; expected one of "categorical", '
- '"binary", "sparse", or None.')
+ '"binary", "sparse", "input"'
+ ' or None.')
self.class_mode = class_mode
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
@@ -1076,7 +1083,9 @@ class DirectoryIterator(Iterator):
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
# build batch of labels
- if self.class_mode == 'sparse':
+ if self.class_mode == 'input':
+ batch_y = batch_x.copy()
+ elif self.class_mode == 'sparse':
batch_y = self.classes[index_array]
elif self.class_mode == 'binary':
batch_y = self.classes[index_array].astype(K.floatx())
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/sequence.py b/tensorflow/contrib/keras/python/keras/preprocessing/sequence.py
index 692a359ead..382aa386d4 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/sequence.py
+++ b/tensorflow/contrib/keras/python/keras/preprocessing/sequence.py
@@ -205,7 +205,8 @@ def skipgrams(sequence,
words = [c[0] for c in couples]
random.shuffle(words)
- couples += [[words[i % len(words)], random.randint(1, vocabulary_size - 1)]
+ couples += [[words[i % len(words)],
+ random.randint(1, vocabulary_size - 1)]
for i in range(num_negative_samples)]
if categorical:
labels += [[1, 0]] * num_negative_samples
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/text.py b/tensorflow/contrib/keras/python/keras/preprocessing/text.py
index 5b89c8035c..93e629af17 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/text.py
+++ b/tensorflow/contrib/keras/python/keras/preprocessing/text.py
@@ -20,15 +20,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from collections import OrderedDict
import string
import sys
-import warnings
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
-
if sys.version_info < (3,):
maketrans = string.maketrans
else:
@@ -39,7 +38,7 @@ def text_to_word_sequence(text,
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
lower=True,
split=' '):
- """Converts a text to a sequence of word indices.
+ """Converts a text to a sequence of words (or tokens).
Arguments:
text: Input text (string).
@@ -48,7 +47,7 @@ def text_to_word_sequence(text,
split: Sentence split marker (string).
Returns:
- A list of integer word indices.
+ A list of words (or tokens).
"""
if lower:
text = text.lower()
@@ -83,7 +82,7 @@ class Tokenizer(object):
tabs and line breaks, minus the `'` character.
lower: boolean. Whether to convert the texts to lowercase.
split: character or string to use for token splitting.
- char_level: if True, every character will be treated as a word.
+ char_level: if True, every character will be treated as a token.
By default, all punctuation is removed, turning the texts into
space-separated sequences of words
@@ -98,17 +97,8 @@ class Tokenizer(object):
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
lower=True,
split=' ',
- char_level=False,
- **kwargs):
- # Legacy support
- if 'nb_words' in kwargs:
- warnings.warn('The `nb_words` argument in `Tokenizer` '
- 'has been renamed `num_words`.')
- num_words = kwargs.pop('nb_words')
- if kwargs:
- raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
-
- self.word_counts = {}
+ char_level=False):
+ self.word_counts = OrderedDict()
self.word_docs = {}
self.filters = filters
self.split = split
diff --git a/tensorflow/contrib/keras/python/keras/utils/conv_utils.py b/tensorflow/contrib/keras/python/keras/utils/conv_utils.py
index 7d4fdda296..570a63b606 100644
--- a/tensorflow/contrib/keras/python/keras/utils/conv_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/conv_utils.py
@@ -89,7 +89,7 @@ def convert_kernel(kernel):
Also works reciprocally, since the transformation is its own inverse.
Arguments:
- kernel: Numpy array (4D or 5D).
+ kernel: Numpy array (3D, 4D or 5D).
Returns:
The converted kernel.
@@ -97,7 +97,8 @@ def convert_kernel(kernel):
Raises:
ValueError: in case of invalid kernel shape or invalid data_format.
"""
- if not 4 <= kernel.ndim <= 5:
+ kernel = np.asarray(kernel)
+ if not 3 <= kernel.ndim <= 5:
raise ValueError('Invalid kernel shape:', kernel.shape)
slices = [slice(None, None, -1) for _ in range(kernel.ndim)]
no_flip = (slice(None, None), slice(None, None))
diff --git a/tensorflow/contrib/keras/python/keras/utils/data_utils.py b/tensorflow/contrib/keras/python/keras/utils/data_utils.py
index 5a42444308..61a11b95e8 100644
--- a/tensorflow/contrib/keras/python/keras/utils/data_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/data_utils.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
import hashlib
import os
import shutil
@@ -54,8 +53,10 @@ if sys.version_info[0] == 2:
"""
def chunk_read(response, chunk_size=8192, reporthook=None):
- total_size = response.info().get('Content-Length').strip()
- total_size = int(total_size)
+ content_type = response.info().get('Content-Length')
+ total_size = -1
+ if content_type is not None:
+ total_size = int(content_type.strip())
count = 0
while 1:
chunk = response.read(chunk_size)
@@ -204,19 +205,24 @@ def get_file(fname,
if download:
print('Downloading data from', origin)
- progbar = None
- def dl_progress(count, block_size, total_size, progbar=None):
- if progbar is None:
- progbar = Progbar(total_size)
+ class ProgressTracker(object):
+ # Maintain progbar for the lifetime of download.
+ # This design was chosen for Python 2.7 compatibility.
+ progbar = None
+
+ def dl_progress(count, block_size, total_size):
+ if ProgressTracker.progbar is None:
+ if total_size is -1:
+ total_size = None
+ ProgressTracker.progbar = Progbar(total_size)
else:
- progbar.update(count * block_size)
+ ProgressTracker.progbar.update(count * block_size)
error_msg = 'URL fetch failure on {}: {} -- {}'
try:
try:
- urlretrieve(origin, fpath,
- functools.partial(dl_progress, progbar=progbar))
+ urlretrieve(origin, fpath, dl_progress)
except URLError as e:
raise Exception(error_msg.format(origin, e.errno, e.reason))
except HTTPError as e:
@@ -225,7 +231,7 @@ def get_file(fname,
if os.path.exists(fpath):
os.remove(fpath)
raise
- progbar = None
+ ProgressTracker.progbar = None
if untar:
if not os.path.exists(untar_fpath):
diff --git a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
index 27cc23f232..5cae694d54 100644
--- a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
@@ -45,8 +45,8 @@ class CustomObjectScope(object):
Consider a custom object `MyObject`
```python
- with CustomObjectScope({"MyObject":MyObject}):
- layer = Dense(..., W_regularizer="MyObject")
+ with CustomObjectScope({'MyObject':MyObject}):
+ layer = Dense(..., kernel_regularizer='MyObject')
# save, load, etc. will recognize custom object by name
```
"""
@@ -81,8 +81,8 @@ def custom_object_scope(*args):
Consider a custom object `MyObject`
```python
- with custom_object_scope({"MyObject":MyObject}):
- layer = Dense(..., W_regularizer="MyObject")
+ with custom_object_scope({'MyObject':MyObject}):
+ layer = Dense(..., kernel_regularizer='MyObject')
# save, load, etc. will recognize custom object by name
```
@@ -107,7 +107,7 @@ def get_custom_objects():
```python
get_custom_objects().clear()
- get_custom_objects()["MyObject"] = MyObject
+ get_custom_objects()['MyObject'] = MyObject
```
Returns:
@@ -152,19 +152,23 @@ def deserialize_keras_object(identifier,
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
if hasattr(cls, 'from_config'):
arg_spec = tf_inspect.getargspec(cls.from_config)
+ custom_objects = custom_objects or {}
+
if 'custom_objects' in arg_spec.args:
- custom_objects = custom_objects or {}
return cls.from_config(
config['config'],
custom_objects=dict(
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
list(custom_objects.items())))
- return cls.from_config(config['config'])
+ with CustomObjectScope(custom_objects):
+ return cls.from_config(config['config'])
else:
# Then `cls` may be a function returning a class.
# in this case by convention `config` holds
# the kwargs of the function.
- return cls(**config['config'])
+ custom_objects = custom_objects or {}
+ with CustomObjectScope(custom_objects):
+ return cls(**config['config'])
elif isinstance(identifier, six.string_types):
function_name = identifier
if custom_objects and function_name in custom_objects:
@@ -174,18 +178,14 @@ def deserialize_keras_object(identifier,
else:
fn = module_objects.get(function_name)
if fn is None:
- raise ValueError('Unknown ' + printable_module_name,
- ':' + function_name)
+ raise ValueError('Unknown ' + printable_module_name + ':' +
+ function_name)
return fn
else:
raise ValueError('Could not interpret serialized ' + printable_module_name +
': ' + identifier)
-def make_tuple(*args):
- return args
-
-
def func_dump(func):
"""Serializes a user defined function.
@@ -231,12 +231,14 @@ class Progbar(object):
"""Displays a progress bar.
Arguments:
- target: Total number of steps expected.
+ target: Total number of steps expected, None if unknown.
interval: Minimum visual progress update interval (in seconds).
"""
def __init__(self, target, width=30, verbose=1, interval=0.05):
self.width = width
+ if target is None:
+ target = -1
self.target = target
self.sum_values = {}
self.unique_values = []
@@ -277,21 +279,22 @@ class Progbar(object):
sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r')
- numdigits = int(np.floor(np.log10(self.target))) + 1
- barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
- bar = barstr % (current, self.target)
- prog = float(current) / self.target
- prog_width = int(self.width * prog)
- if prog_width > 0:
- bar += ('=' * (prog_width - 1))
- if current < self.target:
- bar += '>'
- else:
- bar += '='
- bar += ('.' * (self.width - prog_width))
- bar += ']'
- sys.stdout.write(bar)
- self.total_width = len(bar)
+ if self.target is not -1:
+ numdigits = int(np.floor(np.log10(self.target))) + 1
+ barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
+ bar = barstr % (current, self.target)
+ prog = float(current) / self.target
+ prog_width = int(self.width * prog)
+ if prog_width > 0:
+ bar += ('=' * (prog_width - 1))
+ if current < self.target:
+ bar += '>'
+ else:
+ bar += '='
+ bar += ('.' * (self.width - prog_width))
+ bar += ']'
+ sys.stdout.write(bar)
+ self.total_width = len(bar)
if current:
time_per_unit = (now - self.start) / current
@@ -299,7 +302,7 @@ class Progbar(object):
time_per_unit = 0
eta = time_per_unit * (self.target - current)
info = ''
- if current < self.target:
+ if current < self.target and self.target is not -1:
info += ' - ETA: %ds' % eta
else:
info += ' - %ds' % (now - self.start)
diff --git a/tensorflow/contrib/keras/python/keras/utils/io_utils.py b/tensorflow/contrib/keras/python/keras/utils/io_utils.py
index 7cef39b03f..55c135b5eb 100644
--- a/tensorflow/contrib/keras/python/keras/utils/io_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/io_utils.py
@@ -80,8 +80,13 @@ class HDF5Matrix(object):
def __getitem__(self, key):
if isinstance(key, slice):
- if key.stop + self.start <= self.end:
- idx = slice(key.start + self.start, key.stop + self.start)
+ start, stop = key.start, key.stop
+ if start is None:
+ start = 0
+ if stop is None:
+ stop = self.data.shape[0]
+ if stop + self.start <= self.end:
+ idx = slice(start + self.start, stop + self.start)
else:
raise IndexError
elif isinstance(key, int):
diff --git a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
index 26878fdd57..154070fb93 100644
--- a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
@@ -38,8 +38,11 @@ def print_summary(model, line_length=None, positions=None):
else:
sequential_like = True
for v in model.nodes_by_depth.values():
- if len(v) > 1:
+ if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1):
+ # If the model has multiple nodes or if the nodes have
+ # multiple inbound_layers, the model is no longer sequential.
sequential_like = False
+ break
if sequential_like:
line_length = line_length or 65
@@ -94,12 +97,10 @@ def print_summary(model, line_length=None, positions=None):
except AttributeError:
output_shape = 'multiple'
connections = []
- for node_index, node in enumerate(layer.inbound_nodes):
- if relevant_nodes:
- node_key = layer.name + '_ib-' + str(node_index)
- if node_key not in relevant_nodes:
- # node is node part of the current network
- continue
+ for node in layer.inbound_nodes:
+ if relevant_nodes and node not in relevant_nodes:
+ # node is not part of the current network
+ continue
for i in range(len(node.inbound_layers)):
inbound_layer = node.inbound_layers[i].name
inbound_node_index = node.node_indices[i]
@@ -114,8 +115,8 @@ def print_summary(model, line_length=None, positions=None):
else:
first_connection = connections[0]
fields = [
- name + ' (' + cls_name + ')', output_shape, layer.count_params(),
- first_connection
+ name + ' (' + cls_name + ')', output_shape,
+ layer.count_params(), first_connection
]
print_row(fields, positions)
if len(connections) > 1:
@@ -134,8 +135,10 @@ def print_summary(model, line_length=None, positions=None):
else:
print('_' * line_length)
- trainable_count, non_trainable_count = count_total_params(
- layers, layer_set=None)
+ trainable_count = int(
+ np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
+ non_trainable_count = int(
+ np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
@@ -143,37 +146,6 @@ def print_summary(model, line_length=None, positions=None):
print('_' * line_length)
-def count_total_params(layers, layer_set=None):
- """Counts the number of parameters in a list of layers.
-
- Arguments:
- layers: list of layers.
- layer_set: set of layers already seen
- (so that we don't count their weights twice).
-
- Returns:
- A tuple (count of trainable weights, count of non-trainable weights.)
- """
- if layer_set is None:
- layer_set = set()
- trainable_count = 0
- non_trainable_count = 0
- for layer in layers:
- if layer in layer_set:
- continue
- layer_set.add(layer)
- if hasattr(layer, 'layers'):
- t, nt = count_total_params(layer.layers, layer_set)
- trainable_count += t
- non_trainable_count += nt
- else:
- trainable_count += np.sum(
- [K.count_params(p) for p in layer.trainable_weights])
- non_trainable_count += np.sum(
- [K.count_params(p) for p in layer.non_trainable_weights])
- return int(trainable_count), int(non_trainable_count)
-
-
def convert_all_kernels_in_model(model):
"""Converts all convolution kernels in a model from Theano to TensorFlow.
@@ -218,7 +190,7 @@ def convert_dense_weights_data_format(dense,
came before the target `Dense` layer.
target_data_format: One of "channels_last", "channels_first".
Set it "channels_last"
- if converting a "chnnels_first" model to "channels_last",
+ if converting a "channels_first" model to "channels_last",
or reciprocally.
"""
assert target_data_format in {'channels_last', 'channels_first'}
diff --git a/tensorflow/contrib/keras/python/keras/utils/vis_utils.py b/tensorflow/contrib/keras/python/keras/utils/vis_utils.py
index 9e2ee86424..949767299b 100644
--- a/tensorflow/contrib/keras/python/keras/utils/vis_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/vis_utils.py
@@ -39,18 +39,28 @@ except ImportError:
def _check_pydot():
- if not (pydot and pydot.find_graphviz()):
+ try:
+ # Attempt to create an image of a blank graph
+ # to check the pydot/graphviz installation.
+ pydot.Dot.create(pydot.Dot())
+ except Exception:
+ # pydot raises a generic Exception here,
+ # so no specific class can be caught.
raise ImportError('Failed to import pydot. You must install pydot'
' and graphviz for `pydotprint` to work.')
-def model_to_dot(model, show_shapes=False, show_layer_names=True):
- """Converts a Keras model to dot format.
+def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
+ """Convert a Keras model to dot format.
Arguments:
model: A Keras model instance.
show_shapes: whether to display shape information.
show_layer_names: whether to display layer names.
+ rankdir: `rankdir` argument passed to PyDot,
+ a string specifying the format of the plot:
+ 'TB' creates a vertical plot;
+ 'LR' creates a horizontal plot.
Returns:
A `pydot.Dot` instance representing the Keras model.
@@ -60,7 +70,7 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True):
_check_pydot()
dot = pydot.Dot()
- dot.set('rankdir', 'TB')
+ dot.set('rankdir', rankdir)
dot.set('concentrate', True)
dot.set_node_defaults(shape='record')
@@ -102,7 +112,6 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True):
inputlabels = 'multiple'
label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label, inputlabels,
outputlabels)
-
node = pydot.Node(layer_id, label=label)
dot.add_node(node)
@@ -122,8 +131,21 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True):
def plot_model(model,
to_file='model.png',
show_shapes=False,
- show_layer_names=True):
- dot = model_to_dot(model, show_shapes, show_layer_names)
+ show_layer_names=True,
+ rankdir='TB'):
+ """Converts a Keras model to dot format and save to a file.
+
+ Arguments:
+ model: A Keras model instance
+ to_file: File name of the plot image.
+ show_shapes: whether to display shape information.
+ show_layer_names: whether to display layer names.
+ rankdir: `rankdir` argument passed to PyDot,
+ a string specifying the format of the plot:
+ 'TB' creates a vertical plot;
+ 'LR' creates a horizontal plot.
+ """
+ dot = model_to_dot(model, show_shapes, show_layer_names, rankdir)
_, extension = os.path.splitext(to_file)
if not extension:
extension = 'png'