aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras')
-rwxr-xr-xtensorflow/python/keras/BUILD25
-rw-r--r--tensorflow/python/keras/activations.py22
-rw-r--r--tensorflow/python/keras/applications/mobilenet.py22
-rw-r--r--tensorflow/python/keras/backend.py63
-rw-r--r--tensorflow/python/keras/backend_test.py61
-rw-r--r--tensorflow/python/keras/callbacks.py209
-rw-r--r--tensorflow/python/keras/callbacks_test.py144
-rw-r--r--tensorflow/python/keras/engine/base_layer.py10
-rw-r--r--tensorflow/python/keras/engine/network.py4
-rw-r--r--tensorflow/python/keras/engine/sequential.py28
-rw-r--r--tensorflow/python/keras/engine/sequential_test.py46
-rw-r--r--tensorflow/python/keras/engine/training.py112
-rw-r--r--tensorflow/python/keras/engine/training_eager.py473
-rw-r--r--tensorflow/python/keras/engine/training_gpu_test.py125
-rw-r--r--tensorflow/python/keras/engine/training_test.py39
-rw-r--r--tensorflow/python/keras/engine/training_utils.py138
-rw-r--r--tensorflow/python/keras/engine/training_utils_test.py150
-rw-r--r--tensorflow/python/keras/initializers.py100
-rw-r--r--tensorflow/python/keras/initializers_test.py10
-rw-r--r--tensorflow/python/keras/layers/advanced_activations.py37
-rw-r--r--tensorflow/python/keras/layers/advanced_activations_test.py8
-rw-r--r--tensorflow/python/keras/layers/convolutional_recurrent.py2
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent_test.py4
-rw-r--r--tensorflow/python/keras/layers/normalization.py6
-rw-r--r--tensorflow/python/keras/layers/normalization_test.py18
-rw-r--r--tensorflow/python/keras/layers/recurrent.py3
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py18
-rw-r--r--tensorflow/python/keras/layers/wrappers.py7
-rw-r--r--tensorflow/python/keras/layers/wrappers_test.py8
-rw-r--r--tensorflow/python/keras/metrics.py470
-rw-r--r--tensorflow/python/keras/metrics_test.py285
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py270
-rw-r--r--tensorflow/python/keras/testing_utils.py73
-rw-r--r--tensorflow/python/keras/utils/np_utils.py3
34 files changed, 2195 insertions, 798 deletions
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 4056818a95..df409d2aa5 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -704,6 +704,17 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "training_gpu_test",
+ size = "small",
+ srcs = ["engine/training_gpu_test.py"],
+ additional_deps = [
+ ":keras",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_test(
name = "imagenet_utils_test",
size = "small",
@@ -721,7 +732,6 @@ py_test(
size = "medium",
srcs = ["preprocessing/image_test.py"],
srcs_version = "PY2AND3",
- tags = ["nomsan"], # TODO(b/110990716) reenable
deps = [
":keras",
"//tensorflow/python:client_testlib",
@@ -793,6 +803,19 @@ py_test(
)
py_test(
+ name = "training_utils_test",
+ size = "medium",
+ srcs = ["engine/training_utils_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "model_subclassing_test",
size = "medium",
srcs = ["model_subclassing_test.py"],
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index f608dea430..99645de736 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -128,20 +128,26 @@ def softsign(x):
@tf_export('keras.activations.relu')
-def relu(x, alpha=0., max_value=None):
+def relu(x, alpha=0., max_value=None, threshold=0):
"""Rectified Linear Unit.
+ With default values, it returns element-wise `max(x, 0)`.
+
+ Otherwise, it follows:
+ `f(x) = max_value` for `x >= max_value`,
+ `f(x) = x` for `threshold <= x < max_value`,
+ `f(x) = alpha * (x - threshold)` otherwise.
+
Arguments:
- x: Input tensor.
- alpha: Slope of the negative part. Defaults to zero.
- max_value: Maximum value for the output.
+ x: A tensor or variable.
+ alpha: A scalar, slope of negative section (default=`0.`).
+ max_value: float. Saturation threshold.
+ threshold: float. Threshold value for thresholded activation.
Returns:
- The (leaky) rectified linear unit activation: `x` if `x > 0`,
- `alpha * x` if `x < 0`. If `max_value` is defined, the result
- is truncated to this value.
+ A tensor.
"""
- return K.relu(x, alpha=alpha, max_value=max_value)
+ return K.relu(x, alpha=alpha, max_value=max_value, threshold=threshold)
@tf_export('keras.activations.tanh')
diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py
index e56c695a28..7285e03963 100644
--- a/tensorflow/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/applications/mobilenet.py
@@ -72,13 +72,9 @@ from __future__ import print_function
import os
from tensorflow.python.keras import backend as K
-from tensorflow.python.keras import constraints
-from tensorflow.python.keras import initializers
-from tensorflow.python.keras import regularizers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import BatchNormalization
from tensorflow.python.keras.layers import Conv2D
@@ -87,10 +83,10 @@ from tensorflow.python.keras.layers import Dropout
from tensorflow.python.keras.layers import GlobalAveragePooling2D
from tensorflow.python.keras.layers import GlobalMaxPooling2D
from tensorflow.python.keras.layers import Input
+from tensorflow.python.keras.layers import ReLU
from tensorflow.python.keras.layers import Reshape
from tensorflow.python.keras.layers import ZeroPadding2D
from tensorflow.python.keras.models import Model
-from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
@@ -100,10 +96,6 @@ from tensorflow.python.util.tf_export import tf_export
BASE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/'
-def relu6(x):
- return K.relu(x, max_value=6)
-
-
@tf_export('keras.applications.mobilenet.preprocess_input')
def preprocess_input(x):
"""Preprocesses a numpy array encoding a batch of images.
@@ -130,12 +122,6 @@ def MobileNet(input_shape=None,
classes=1000):
"""Instantiates the MobileNet architecture.
- To load a MobileNet model via `load_model`, import the custom
- objects `relu6` and pass them to the `custom_objects` parameter.
- E.g.
- model = load_model('mobilenet.h5', custom_objects={
- 'relu6': mobilenet.relu6})
-
Arguments:
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
@@ -412,7 +398,7 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
strides=strides,
name='conv1')(x)
x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
- return Activation(relu6, name='conv1_relu')(x)
+ return ReLU(6, name='conv1_relu')(x)
def _depthwise_conv_block(inputs,
@@ -479,7 +465,7 @@ def _depthwise_conv_block(inputs,
use_bias=False,
name='conv_dw_%d' % block_id)(x)
x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
- x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
+ x = ReLU(6, name='conv_dw_%d_relu' % block_id)(x)
x = Conv2D(
pointwise_conv_filters, (1, 1),
@@ -489,4 +475,4 @@ def _depthwise_conv_block(inputs,
name='conv_pw_%d' % block_id)(
x)
x = BatchNormalization(axis=channel_axis, name='conv_pw_%d_bn' % block_id)(x)
- return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
+ return ReLU(6, name='conv_pw_%d_relu' % block_id)(x)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index cb3423598b..38794f1612 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -3372,26 +3372,48 @@ def in_test_phase(x, alt, training=None):
@tf_export('keras.backend.relu')
-def relu(x, alpha=0., max_value=None):
+def relu(x, alpha=0., max_value=None, threshold=0):
"""Rectified linear unit.
With default values, it returns element-wise `max(x, 0)`.
+ Otherwise, it follows:
+ `f(x) = max_value` for `x >= max_value`,
+ `f(x) = x` for `threshold <= x < max_value`,
+ `f(x) = alpha * (x - threshold)` otherwise.
+
Arguments:
x: A tensor or variable.
alpha: A scalar, slope of negative section (default=`0.`).
- max_value: Saturation threshold.
+ max_value: float. Saturation threshold.
+ threshold: float. Threshold value for thresholded activation.
Returns:
A tensor.
"""
+ clip_max = max_value is not None
+
if alpha != 0.:
- negative_part = nn.relu(-x)
- x = nn.relu(x)
- if max_value is not None:
+ if threshold != 0:
+ negative_part = nn.relu(-x + threshold)
+ else:
+ negative_part = nn.relu(-x)
+
+ if threshold != 0:
+ # computes x for x > threshold else 0
+ x = x * math_ops.cast(math_ops.greater(x, threshold), floatx())
+ elif max_value == 6:
+ # if no threshold, then can use nn.relu6 native TF op for performance
+ x = nn.relu6(x)
+ clip_max = False
+ else:
+ x = nn.relu(x)
+
+ if clip_max:
max_value = _to_tensor(max_value, x.dtype.base_dtype)
zero = _to_tensor(0., x.dtype.base_dtype)
x = clip_ops.clip_by_value(x, zero, max_value)
+
if alpha != 0.:
alpha = _to_tensor(alpha, x.dtype.base_dtype)
x -= alpha * negative_part
@@ -3458,7 +3480,7 @@ def softsign(x):
@tf_export('keras.backend.categorical_crossentropy')
-def categorical_crossentropy(target, output, from_logits=False):
+def categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy between an output tensor and a target tensor.
Arguments:
@@ -3468,28 +3490,33 @@ def categorical_crossentropy(target, output, from_logits=False):
case `output` is expected to be the logits).
from_logits: Boolean, whether `output` is the
result of a softmax, or is a tensor of logits.
+ axis: Int specifying the channels axis. `axis=-1` corresponds to data
+ format `channels_last', and `axis=1` corresponds to data format
+ `channels_first`.
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if `axis` is neither -1 nor one of the axes of `output`.
"""
+ rank = len(output.get_shape())
+ axis = axis % rank
# Note: nn.softmax_cross_entropy_with_logits_v2
# expects logits, Keras expects probabilities.
if not from_logits:
# scale preds so that the class probas of each sample sum to 1
- output = output / math_ops.reduce_sum( # pylint: disable=g-no-augmented-assignment
- output, len(output.get_shape()) - 1, True)
+ output = output / math_ops.reduce_sum(output, axis, True)
# manual computation of crossentropy
epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
- return -math_ops.reduce_sum(
- target * math_ops.log(output),
- axis=len(output.get_shape()) - 1)
+ return -math_ops.reduce_sum(target * math_ops.log(output), axis)
else:
return nn.softmax_cross_entropy_with_logits_v2(labels=target, logits=output)
@tf_export('keras.backend.sparse_categorical_crossentropy')
-def sparse_categorical_crossentropy(target, output, from_logits=False):
+def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy with integer targets.
Arguments:
@@ -3499,10 +3526,22 @@ def sparse_categorical_crossentropy(target, output, from_logits=False):
case `output` is expected to be the logits).
from_logits: Boolean, whether `output` is the
result of a softmax, or is a tensor of logits.
+ axis: Int specifying the channels axis. `axis=-1` corresponds to data
+ format `channels_last', and `axis=1` corresponds to data format
+ `channels_first`.
Returns:
Output tensor.
+
+ Raises:
+ ValueError: if `axis` is neither -1 nor one of the axes of `output`.
"""
+ rank = len(output.get_shape())
+ axis = axis % rank
+ if axis != rank - 1:
+ permutation = list(range(axis)) + list(range(axis + 1, rank)) + [axis]
+ output = array_ops.transpose(output, perm=permutation)
+
# Note: nn.sparse_softmax_cross_entropy_with_logits
# expects logits, Keras expects probabilities.
if not from_logits:
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 36478ea089..40e7910061 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -23,6 +23,7 @@ import scipy.sparse
from tensorflow.python import keras
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -490,6 +491,66 @@ class BackendLinearAlgebraTest(test.TestCase):
input_shape_a=(4, 7),
input_shape_b=(4, 7))
+ def test_relu(self):
+ x = ops.convert_to_tensor([[-4, 0], [2, 7]], 'float32')
+ with self.test_session():
+ # standard relu
+ relu_op = keras.backend.relu(x)
+ self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
+
+ # alpha
+ relu_op = keras.backend.relu(x, alpha=0.5)
+ self.assertAllClose(keras.backend.eval(relu_op), [[-2, 0], [2, 7]])
+
+ # max_value < some elements
+ relu_op = keras.backend.relu(x, max_value=5)
+ self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 5]])
+
+ # nn.relu6 used
+ relu_op = keras.backend.relu(x, max_value=6)
+ self.assertTrue('Relu6' in relu_op.name) # uses tf.nn.relu6
+ self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 6]])
+
+ # max value > 6
+ relu_op = keras.backend.relu(x, max_value=10)
+ self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
+
+ # max value is float
+ relu_op = keras.backend.relu(x, max_value=4.3)
+ self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 4.3]])
+
+ # max value == 0
+ relu_op = keras.backend.relu(x, max_value=0)
+ self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [0, 0]])
+
+ # alpha and max_value
+ relu_op = keras.backend.relu(x, alpha=0.25, max_value=3)
+ self.assertAllClose(keras.backend.eval(relu_op), [[-1, 0], [2, 3]])
+
+ # threshold
+ relu_op = keras.backend.relu(x, threshold=3)
+ self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [0, 7]])
+
+ # threshold is float
+ relu_op = keras.backend.relu(x, threshold=1.5)
+ self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
+
+ # threshold is negative
+ relu_op = keras.backend.relu(x, threshold=-5)
+ self.assertAllClose(keras.backend.eval(relu_op), [[-4, 0], [2, 7]])
+
+ # threshold and max_value
+ relu_op = keras.backend.relu(x, threshold=3, max_value=5)
+ self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [0, 5]])
+
+ # threshold and alpha
+ relu_op = keras.backend.relu(x, alpha=0.25, threshold=4)
+ self.assertAllClose(keras.backend.eval(relu_op), [[-2, -1], [-0.5, 7]])
+
+ # threshold, alpha, and max_value
+ relu_op = keras.backend.relu(x, alpha=0.25, threshold=4, max_value=5)
+ self.assertAllClose(keras.backend.eval(relu_op), [[-2, -1], [-0.5, 5]])
+
class BackendShapeOpsTest(test.TestCase):
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 5d66db232a..d1b9dc27bd 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -31,13 +31,16 @@ import time
import numpy as np
import six
+from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
-from tensorflow.python.keras import optimizers
+from tensorflow.python.keras.engine.training_utils import standardize_input_data
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
+from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import tf_export
@@ -644,35 +647,17 @@ class LearningRateScheduler(Callback):
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
- # TODO(yashkatariya): Change the property checking when the learning
- # rate attribute is unified across all TF Optimizers.
- if isinstance(self.model.optimizer, optimizers.TFOptimizer):
- if not hasattr(self.model.optimizer.optimizer, '_lr') and not hasattr(
- self.model.optimizer.optimizer, '_learning_rate'):
- raise ValueError(
- 'TF Optimizer must have a "_lr" or "_learning_rate" attribute.')
- else:
- opt = self.model.optimizer.optimizer
- if hasattr(opt, '_lr'):
- opt_lr = Variable(opt._lr) # pylint: disable=protected-access
- elif hasattr(opt, '_learning_rate'):
- opt_lr = Variable(opt._learning_rate) # pylint: disable=protected-access
- else:
- if not hasattr(self.model.optimizer, 'lr'):
- raise ValueError('Optimizer must have a "lr" attribute.')
- else:
- opt = self.model.optimizer
- opt_lr = opt.lr
-
+ if not hasattr(self.model.optimizer, 'lr'):
+ raise ValueError('Optimizer must have a "lr" attribute.')
try: # new API
- lr = float(K.get_value(opt_lr))
+ lr = float(K.get_value(self.model.optimizer.lr))
lr = self.schedule(epoch, lr)
except TypeError: # Support for old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
- K.set_value(opt_lr, lr)
+ K.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))
@@ -717,7 +702,9 @@ class TensorBoard(Callback):
write_images: whether to write model weights to visualize as
image in TensorBoard.
embeddings_freq: frequency (in epochs) at which selected embedding
- layers will be saved.
+ layers will be saved. If set to 0, embeddings won't be computed.
+ Data to be visualized in TensorBoard's Embedding tab must be passed
+ as `embeddings_data`.
embeddings_layer_names: a list of names of layers to keep eye on. If
None or empty list all the embedding layer will be watched.
embeddings_metadata: a dictionary which maps layer name to a file name
@@ -725,6 +712,10 @@ class TensorBoard(Callback):
[details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
about metadata files format. In case if the same metadata file is
used for all embedding layers, string can be passed.
+ embeddings_data: data to be embedded at layers specified in
+ `embeddings_layer_names`. Numpy array (if the model has a single
+ input) or list of Numpy arrays (if the model has multiple inputs).
+ Learn [more about embeddings](https://www.tensorflow.org/programmers_guide/embedding)
"""
# pylint: enable=line-too-long
@@ -735,7 +726,11 @@ class TensorBoard(Callback):
batch_size=32,
write_graph=True,
write_grads=False,
- write_images=False):
+ write_images=False,
+ embeddings_freq=0,
+ embeddings_layer_names=None,
+ embeddings_metadata=None,
+ embeddings_data=None):
super(TensorBoard, self).__init__()
self.log_dir = log_dir
self.histogram_freq = histogram_freq
@@ -745,8 +740,13 @@ class TensorBoard(Callback):
self.write_images = write_images
self.batch_size = batch_size
self._current_batch = 0
+ self._total_batches_seen = 0
# abstracted writer class to be able to stub for testing
self._writer_class = tf_summary.FileWriter
+ self.embeddings_freq = embeddings_freq
+ self.embeddings_layer_names = embeddings_layer_names
+ self.embeddings_metadata = embeddings_metadata
+ self.embeddings_data = embeddings_data
def set_model(self, model):
"""Sets Keras model and creates summary ops."""
@@ -798,7 +798,11 @@ class TensorBoard(Callback):
tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
if hasattr(layer, 'output'):
- tf_summary.histogram('{}_out'.format(layer.name), layer.output)
+ if isinstance(layer.output, list):
+ for i, output in enumerate(layer.output):
+ tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
+ else:
+ tf_summary.histogram('{}_out'.format(layer.name), layer.output)
self.merged = tf_summary.merge_all()
if self.write_graph:
@@ -806,12 +810,98 @@ class TensorBoard(Callback):
else:
self.writer = self._writer_class(self.log_dir)
+ # If both embedding_freq and embeddings_data are available, we will
+ # visualize embeddings.
+ if self.embeddings_freq and self.embeddings_data is not None:
+ self.embeddings_data = standardize_input_data(self.embeddings_data,
+ model.input_names)
+
+ # If embedding_layer_names are not provided, get all of the embedding
+ # layers from the model.
+ embeddings_layer_names = self.embeddings_layer_names
+ if not embeddings_layer_names:
+ embeddings_layer_names = [
+ layer.name
+ for layer in self.model.layers
+ if type(layer).__name__ == 'Embedding'
+ ]
+
+ self.assign_embeddings = []
+ embeddings_vars = {}
+
+ self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
+ self.step = step = array_ops.placeholder(dtypes.int32)
+
+ for layer in self.model.layers:
+ if layer.name in embeddings_layer_names:
+ embedding_input = self.model.get_layer(layer.name).output
+ embedding_size = np.prod(embedding_input.shape[1:])
+ embedding_input = array_ops.reshape(embedding_input,
+ (step, int(embedding_size)))
+ shape = (self.embeddings_data[0].shape[0], int(embedding_size))
+ embedding = variables.Variable(
+ array_ops.zeros(shape), name=layer.name + '_embedding')
+ embeddings_vars[layer.name] = embedding
+ batch = state_ops.assign(embedding[batch_id:batch_id + step],
+ embedding_input)
+ self.assign_embeddings.append(batch)
+
+ self.saver = saver.Saver(list(embeddings_vars.values()))
+
+ # Create embeddings_metadata dictionary
+ if isinstance(self.embeddings_metadata, str):
+ embeddings_metadata = {
+ layer_name: self.embeddings_metadata
+ for layer_name in embeddings_vars.keys()
+ }
+ else:
+ # If embedding_metadata is already a dictionary
+ embeddings_metadata = self.embeddings_metadata
+
+ try:
+ from tensorboard.plugins import projector
+ except ImportError:
+ raise ImportError('Failed to import TensorBoard. Please make sure that '
+ 'TensorBoard integration is complete."')
+
+ # TODO(psv): Add integration tests to test embedding visualization
+ # with TensorBoard callback. We are unable to write a unit test for this
+ # because TensorBoard dependency assumes TensorFlow package is installed.
+ config = projector.ProjectorConfig()
+ for layer_name, tensor in embeddings_vars.items():
+ embedding = config.embeddings.add()
+ embedding.tensor_name = tensor.name
+
+ if (embeddings_metadata is not None and
+ layer_name in embeddings_metadata):
+ embedding.metadata_path = embeddings_metadata[layer_name]
+
+ projector.visualize_embeddings(self.writer, config)
+
def _fetch_callback(self, summary):
self.writer.add_summary(
summary,
self._epoch + self._current_val_batch / self._validation_batches)
self._current_val_batch += 1
+ def _write_custom_summaries(self, step, logs=None):
+ """Writes metrics out as custom scalar summaries.
+
+ Arguments:
+ step: the global step to use for Tensorboard.
+ logs: dict. Keys are scalar summary names, values are
+ NumPy scalars.
+
+ """
+ logs = logs or {}
+ for name, value in logs.items():
+ summary = tf_summary.Summary()
+ summary_value = summary.value.add()
+ summary_value.simple_value = value.item()
+ summary_value.tag = name
+ self.writer.add_summary(summary, step)
+ self.writer.flush()
+
def on_train_begin(self, logs=None):
"""Checks if histogram summaries can be run."""
@@ -828,6 +918,16 @@ class TensorBoard(Callback):
raise ValueError(
'If printing histograms, validation data must have length > 0.')
+ def on_batch_end(self, batch, logs=None):
+ """Writes scalar summaries for metrics on every training batch."""
+ # Don't output batch_size and batch number as Tensorboard summaries
+ logs = logs or {}
+ batch_logs = {('batch_' + k): v
+ for k, v in logs.items()
+ if k not in ['batch', 'size']}
+ self._write_custom_summaries(self._total_batches_seen, batch_logs)
+ self._total_batches_seen += 1
+
def on_epoch_begin(self, epoch, logs=None):
"""Add histogram op to Model test_function callbacks, reset batch count."""
@@ -844,7 +944,12 @@ class TensorBoard(Callback):
def on_epoch_end(self, epoch, logs=None):
"""Checks if summary ops should run next epoch, logs scalar summaries."""
- logs = logs or {}
+ # don't output batch_size and
+ # batch number as Tensorboard summaries
+ logs = {('epoch_' + k): v
+ for k, v in logs.items()
+ if k not in ['batch', 'size']}
+ self._write_custom_summaries(epoch, logs)
# pop the histogram summary op after each epoch
if self.histogram_freq:
@@ -853,15 +958,45 @@ class TensorBoard(Callback):
if self.merged in self.model.test_function.fetch_callbacks:
self.model.test_function.fetch_callbacks.pop(self.merged)
- for name, value in logs.items():
- if name in ['batch', 'size']:
- continue
- summary = tf_summary.Summary()
- summary_value = summary.value.add()
- summary_value.simple_value = value.item()
- summary_value.tag = name
- self.writer.add_summary(summary, epoch)
- self.writer.flush()
+ if self.embeddings_data is None and self.embeddings_freq:
+ raise ValueError('To visualize embeddings, embeddings_data must '
+ 'be provided.')
+
+ if self.embeddings_freq and self.embeddings_data is not None:
+ if epoch % self.embeddings_freq == 0:
+ # We need a second forward-pass here because we're passing
+ # the `embeddings_data` explicitly. This design allows to pass
+ # arbitrary data as `embeddings_data` and results from the fact
+ # that we need to know the size of the `tf.Variable`s which
+ # hold the embeddings in `set_model`. At this point, however,
+ # the `validation_data` is not yet set.
+
+ embeddings_data = self.embeddings_data
+ n_samples = embeddings_data[0].shape[0]
+ i = 0
+ while i < n_samples:
+ step = min(self.batch_size, n_samples - i)
+ batch = slice(i, i + step)
+
+ if isinstance(self.model.input, list):
+ feed_dict = {
+ model_input: embeddings_data[idx][batch]
+ for idx, model_input in enumerate(self.model.input)
+ }
+ else:
+ feed_dict = {self.model.input: embeddings_data[0][batch]}
+
+ feed_dict.update({self.batch_id: i, self.step: step})
+
+ if self.model.uses_learning_phase:
+ feed_dict[K.learning_phase()] = False
+
+ self.sess.run(self.assign_embeddings, feed_dict=feed_dict)
+ self.saver.save(self.sess,
+ os.path.join(self.log_dir, 'keras_embedding.ckpt'),
+ epoch)
+
+ i += self.batch_size
def on_train_end(self, logs=None):
self.writer.close()
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 244d48591c..7d830078ce 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -29,16 +29,10 @@ import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
-from tensorflow.python.eager import context
-from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
-from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training.adam import AdamOptimizer
-from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
-
try:
import h5py # pylint:disable=g-import-not-at-top
@@ -376,76 +370,6 @@ class KerasCallbacksTest(test.TestCase):
float(keras.backend.get_value(
model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon()
- @test_util.run_in_graph_and_eager_modes
- def test_TF_LearningRateScheduler_Adam(self):
- with self.test_session():
- with context.eager_mode():
- np.random.seed(1337)
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=TRAIN_SAMPLES,
- test_samples=TEST_SAMPLES,
- input_shape=(INPUT_DIM,),
- num_classes=NUM_CLASSES)
- y_test = keras.utils.to_categorical(y_test)
- y_train = keras.utils.to_categorical(y_train)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
- model.compile(
- loss='categorical_crossentropy',
- optimizer=AdamOptimizer(),
- metrics=['accuracy'])
- cbks = [keras.callbacks.LearningRateScheduler(lambda x: 1. / (1. + x))]
- model.fit(
- x_train,
- y_train,
- batch_size=BATCH_SIZE,
- validation_data=(x_test, y_test),
- callbacks=cbks,
- epochs=5,
- verbose=0)
- opt_lr = model.optimizer.optimizer._lr
- self.assertLess(
- float(keras.backend.get_value(
- Variable(opt_lr))) - 0.2, keras.backend.epsilon())
-
- @test_util.run_in_graph_and_eager_modes
- def test_TF_LearningRateScheduler_GradientDescent(self):
- with self.test_session():
- with context.eager_mode():
- np.random.seed(1337)
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=TRAIN_SAMPLES,
- test_samples=TEST_SAMPLES,
- input_shape=(INPUT_DIM,),
- num_classes=NUM_CLASSES)
- y_test = keras.utils.to_categorical(y_test)
- y_train = keras.utils.to_categorical(y_train)
- model = keras.models.Sequential()
- model.add(
- keras.layers.Dense(
- NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
- model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
- model.compile(
- loss='categorical_crossentropy',
- optimizer=GradientDescentOptimizer(1e-3),
- metrics=['accuracy'])
- cbks = [keras.callbacks.LearningRateScheduler(lambda x: 1. / (1. + x))]
- model.fit(
- x_train,
- y_train,
- batch_size=BATCH_SIZE,
- validation_data=(x_test, y_test),
- callbacks=cbks,
- epochs=5,
- verbose=0)
- opt_lr = model.optimizer.optimizer._learning_rate
- self.assertLess(
- float(keras.backend.get_value(
- Variable(opt_lr))) - 0.2, keras.backend.epsilon())
-
def test_ReduceLROnPlateau(self):
with self.test_session():
np.random.seed(1337)
@@ -1172,6 +1096,74 @@ class KerasCallbacksTest(test.TestCase):
assert os.path.exists(temp_dir)
+ def test_Tensorboard_batch_logging(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+ self.batches_logged = []
+ self.summary_values = []
+ self.summary_tags = []
+
+ def add_summary(self, summary, step):
+ self.summary_values.append(summary.value[0].simple_value)
+ self.summary_tags.append(summary.value[0].tag)
+ self.batches_logged.append(step)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ logdir = 'fake_dir'
+
+ # log every batch
+ tb_cbk = keras.callbacks.TensorBoard(logdir)
+ tb_cbk.writer = FileWriterStub(logdir)
+
+ for batch in range(5):
+ tb_cbk.on_batch_end(batch, {'acc': np.float32(batch)})
+ self.assertEqual(tb_cbk.writer.batches_logged, [0, 1, 2, 3, 4])
+ self.assertEqual(tb_cbk.writer.summary_values, [0., 1., 2., 3., 4.])
+ self.assertEqual(tb_cbk.writer.summary_tags, ['batch_acc'] * 5)
+
+ def test_Tensorboard_epoch_and_batch_logging(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+
+ def add_summary(self, summary, step):
+ if 'batch_' in summary.value[0].tag:
+ self.batch_summary = (step, summary)
+ elif 'epoch_' in summary.value[0].tag:
+ self.epoch_summary = (step, summary)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ logdir = 'fake_dir'
+
+ tb_cbk = keras.callbacks.TensorBoard(logdir)
+ tb_cbk.writer = FileWriterStub(logdir)
+
+ tb_cbk.on_batch_end(0, {'acc': np.float32(5.0)})
+ tb_cbk.on_epoch_end(0, {'acc': np.float32(10.0)})
+ batch_step, batch_summary = tb_cbk.writer.batch_summary
+ self.assertEqual(batch_step, 0)
+ self.assertEqual(batch_summary.value[0].simple_value, 5.0)
+ epoch_step, epoch_summary = tb_cbk.writer.epoch_summary
+ self.assertEqual(epoch_step, 0)
+ self.assertEqual(epoch_summary.value[0].simple_value, 10.0)
+
def test_RemoteMonitorWithJsonPayload(self):
if requests is None:
self.skipTest('`requests` required to run this test')
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index e02792208b..b41f6ee03b 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -723,9 +723,17 @@ class Layer(checkpointable.CheckpointableBase):
self._dtype = input_list[0].dtype.base_dtype.name
except AttributeError:
pass
+
if all(hasattr(x, 'shape') for x in input_list):
input_shapes = nest.map_structure(lambda x: x.shape, inputs)
- self.build(input_shapes)
+
+ if (not hasattr(self, '_is_graph_network') or
+ self.__class__.__name__ == 'Sequential'):
+ # Only if self is a layer or an instance of a sequential model do we
+ # need to build it.
+ self.build(input_shapes)
+ # We must set self.built since user defined build functions are not
+ # constrained to set self.built.
self.built = True
# Check input assumptions set after layer building, e.g. input shape.
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index a4d96de74f..752e9963ca 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -318,8 +318,8 @@ class Network(base_layer.Layer):
else:
self._expects_training_arg = False
self._call_convention = self._determine_call_convention(call_argspec)
- self.outputs = None
- self.inputs = None
+ self.outputs = []
+ self.inputs = []
self.built = False
def _determine_call_convention(self, call_argspec):
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 371504a503..41cdfda660 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -213,13 +213,31 @@ class Sequential(Model):
self.outputs = [self.layers[-1].output]
self.build()
- @checkpointable.no_automatic_dependency_tracking
def build(self, input_shape=None):
- if input_shape and not self.inputs:
- batch_shape = tuple(input_shape)
+ self._set_inputs_and_outputs(input_shape=input_shape)
+
+ def symbolic_set_inputs(self, inputs):
+ self._set_inputs_and_outputs(tensor=inputs)
+
+ @checkpointable.no_automatic_dependency_tracking
+ def _set_inputs_and_outputs(self, input_shape=None, tensor=None):
+ """Set model's input and output specs based on the input received.
+
+ If `tensor` is provided, `input_shape` is not required.
+
+ Args:
+ input_shape: Optional shape of input.
+ tensor: Optional existing tensor to wrap into the `Input` layer.
+ """
+ if not self.inputs:
dtype = K.floatx()
- x = Input(
- batch_shape=batch_shape, dtype=dtype, name=self.name + '_input')
+ if tensor is not None:
+ batch_shape = (None,) + tuple(tensor.get_shape().as_list()[1:])
+ x = Input(dtype=dtype, name=self.name + '_input', tensor=tensor)
+ elif input_shape is not None:
+ batch_shape = tuple(input_shape)
+ x = Input(
+ batch_shape=batch_shape, dtype=dtype, name=self.name + '_input')
self.inputs = [x]
for layer in self._layers:
x = layer(x)
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index 0f54e29cee..4f4adca333 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -22,7 +22,6 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -104,9 +103,6 @@ class TestSequential(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_sequential_deferred_build_with_dataset_iterators(self):
- if not context.executing_eagerly():
- # TODO(psv/fchollet): Add support for this use case in graph mode.
- return
num_hidden = 5
input_dim = 3
num_classes = 2
@@ -136,6 +132,48 @@ class TestSequential(test.TestCase):
[None, num_classes])
self.assertEqual(len(model.weights), 2 * 2)
+ def test_training_and_eval_methods_on_symbolic_tensors(self):
+ with self.test_session():
+
+ def create_model():
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(10, activation='relu'))
+ model.add(keras.layers.Dense(4, activation='softmax'))
+
+ model.compile(
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+ return model
+
+ inputs = keras.backend.zeros(shape=(10, 3))
+ targets = keras.backend.zeros(shape=(10, 4))
+
+ model = create_model()
+ model.fit(inputs, targets, epochs=10, steps_per_epoch=30)
+
+ model = create_model()
+ model.evaluate(inputs, targets, steps=2, verbose=0)
+
+ model = create_model()
+ model.predict(inputs, steps=2)
+
+ model = create_model()
+ model.train_on_batch(inputs, targets)
+
+ model = create_model()
+ model.test_on_batch(inputs, targets)
+
+ model = create_model()
+ model.fit(
+ inputs,
+ targets,
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0,
+ validation_data=(inputs, targets),
+ validation_steps=2)
+
@tf_test_util.run_in_graph_and_eager_modes
def test_invalid_use_cases(self):
# Added objects must be layer instances
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index bd03f4871f..4df739254b 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -27,6 +27,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
@@ -43,6 +44,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -217,10 +219,9 @@ class Model(Network):
for name in self.output_names:
if name not in loss:
logging.warning(
- 'Output "' + name + '" missing from loss dictionary. '
- 'We assume this was done on purpose, '
- 'and we will not be expecting '
- 'any data to be passed to "' + name + '" during training.')
+ 'Output "' + name + '" missing from loss dictionary. We assume '
+ 'this was done on purpose. The fit and evaluate APIs will not be '
+ 'expecting any data to be passed to "' + name + '".')
loss_functions.append(losses.get(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
@@ -561,6 +562,95 @@ class Model(Network):
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights
+ def build(self, input_shape):
+ """Build the model based on input shapes received.
+
+ This is to be used for subclassed models, which do not know at instantiation
+ time what their inputs look like.
+
+ Args:
+ input_shape: Single tuple, TensorShape, or list of shapes, where shapes
+ are tuples, integers, or TensorShapes.
+
+ Raises:
+ ValueError:
+ 1. In case of invalid user-provided data (not of type tuple,
+ list, or TensorShape).
+ 2. If the model requires call arguments that are agnostic
+ to the input shapes (positional or kwarg in call signature).
+ 3. If not all layers were properly built.
+ 4. If float type inputs are not supported within the layers.
+
+ In each of these cases, the user should build their model by calling it
+ on real tensor data.
+ """
+ if self._is_graph_network:
+ self.built = True
+ return
+
+ # If subclass network
+ if input_shape is None:
+ raise ValueError('Input shape must be defined when calling build on a '
+ 'model subclass network.')
+ valid_types = (tuple, list, tensor_shape.TensorShape)
+ if not isinstance(input_shape, valid_types):
+ raise ValueError('Specified input shape is not one of the valid types. '
+ 'Please specify a batch input shape of type tuple or '
+ 'list of input shapes. User provided '
+ 'input type: {}'.format(type(input_shape)))
+
+ def _generate_dummy_data_from_shape(shape):
+ if isinstance(shape, tensor_shape.TensorShape):
+ shape = shape.as_list()
+
+ # Replace Nones in input shape with dummy `1` value
+ shape = [x.value if isinstance(x, tensor_shape.Dimension) else x
+ for x in shape]
+ shape = [1 if x is None else x for x in shape]
+ return array_ops.ones(shape, dtype=K.floatx())
+
+ if input_shape and not self.inputs:
+ if isinstance(input_shape, list):
+ # List of input shapes
+ x = [_generate_dummy_data_from_shape(shape) for shape in input_shape]
+ else:
+ x = _generate_dummy_data_from_shape(input_shape)
+
+ kwargs = {}
+ num_call_args = len(tf_inspect.getargspec(self.call).args)
+ if self._expects_training_arg and num_call_args == 3:
+ # Has call signature of call(self, input, training)
+ kwargs['training'] = False
+ elif num_call_args > 2:
+ # Has invalid call signature of call(self, input, *args, **kwargs)
+ raise ValueError('Currently, you cannot build your model if it has '
+ 'positional or keyword arguments that are not '
+ 'inputs to the model, but are required for its '
+ '`call` method. Instead, in order to instantiate '
+ 'and build your model, `call` your model on real '
+ 'tensor data with all expected call arguments.')
+
+ try:
+ self.call(x, **kwargs)
+ except (errors.InvalidArgumentError, TypeError):
+ raise ValueError('You cannot build your model by calling `build` '
+ 'if your layers do not support float type inputs. '
+ 'Instead, in order to instantiate and build your '
+ 'model, `call` your model on real tensor data (of '
+ 'the correct dtype).')
+
+ if self._layers:
+ self._track_layers(self._layers)
+ if self.layers:
+ for layer in self.layers:
+ if not layer.built:
+ raise ValueError('Layer: {} was not built in your model. Calling '
+ '`build` manually on a subclassed model is only '
+ 'allowed for models with a static topology. '
+ 'In this case, you can build your model by '
+ 'calling it on real tensor data.'.format(layer))
+ self.built = True
+
def _check_trainable_weights_consistency(self):
"""Check trainable weights count consistency.
@@ -897,7 +987,11 @@ class Model(Network):
for output_shape, loss_fn in zip(self._feed_output_shapes,
self._feed_loss_fns):
if loss_fn is losses.sparse_categorical_crossentropy:
- feed_output_shapes.append(output_shape[:-1] + (1,))
+ if K.image_data_format() == 'channels_first':
+ feed_output_shapes.append(
+ (output_shape[0], 1) + output_shape[2:])
+ else:
+ feed_output_shapes.append(output_shape[:-1] + (1,))
elif (not hasattr(loss_fn, '__name__') or
getattr(losses, loss_fn.__name__, None) is None):
# If `loss_fn` is not a function (e.g. callable class)
@@ -988,10 +1082,14 @@ class Model(Network):
inputs = inputs[0]
if tensor_util.is_tensor(inputs):
- input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
+ if context.executing_eagerly():
+ input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
+ self.build(input_shape=input_shape)
+ else:
+ self.symbolic_set_inputs(inputs)
else:
input_shape = (None,) + inputs.shape[1:]
- self.build(input_shape=input_shape)
+ self.build(input_shape=input_shape)
elif context.executing_eagerly():
self._eager_set_inputs(inputs)
else:
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index c78684c9f4..397de42985 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -34,7 +34,6 @@ from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import generic_utils
-from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
@@ -194,7 +193,8 @@ def iterator_fit_loop(model,
callbacks=None,
callback_metrics=None,
validation_steps=None,
- do_validation=False):
+ do_validation=False,
+ batch_size=None):
"""Fit function for eager execution when input is given as dataset iterator.
Updates the given epoch logs.
@@ -224,16 +224,23 @@ def iterator_fit_loop(model,
validation_steps: Number of steps to run validation for (only if doing
validation from data tensors). Ignored with default value of `None`.
do_validation: Boolean value indicating whether we should do validation.
+ batch_size: int, val_inputs and val_targets will be evaled batch by
+ batch with size batch_size if they are array.
Raises:
ValueError: In case of mismatch between given number of inputs and
expectations of the model.
"""
assert isinstance(inputs, iterator_ops.EagerIterator)
+
+ # make sure either x,y or x,y,sample_weights is provided
+ if (not isinstance(inputs.output_shapes, (list, tuple)) or
+ len(inputs.output_shapes) not in (2, 3)):
+ raise ValueError('Please provide either inputs and targets'
+ 'or inputs, targets, and sample_weights')
+
for step_index in range(steps_per_epoch):
- batch_logs = {}
- batch_logs['batch'] = step_index
- batch_logs['size'] = 1
+ batch_logs = {'batch': step_index, 'size': 1}
callbacks.on_batch_begin(step_index, batch_logs)
# Get data from the iterator.
@@ -247,19 +254,21 @@ def iterator_fit_loop(model,
'batches (in this case, %d batches).' % steps_per_epoch * epochs)
break
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide data as a list or tuple of 2 elements '
- ' - input and target pair. Received %s' % next_element)
- x, y = next_element
+ if len(inputs.output_shapes) == 2:
+ x, y = next_element
+ sample_weights = None
+ else:
+ x, y, sample_weights = next_element
# Validate and standardize data.
x, y, sample_weights = model._standardize_user_data(
- x, y, class_weight=class_weight)
+ x, y, sample_weight=sample_weights, class_weight=class_weight)
x = training_utils.cast_if_floating_dtype(x)
y = training_utils.cast_if_floating_dtype(y)
if sample_weights:
sample_weights = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
+ training_utils.cast_if_floating_dtype(
+ ops.convert_to_tensor(val, dtype=backend.floatx()))
if val is not None else None for val in sample_weights
]
@@ -307,122 +316,8 @@ def iterator_fit_loop(model,
val_targets,
sample_weights=val_sample_weights,
steps=validation_steps,
- verbose=0)
- if not isinstance(val_outs, list):
- val_outs = [val_outs]
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
-
-
-def batch_fit_loop(model,
- inputs,
- targets,
- epoch_logs,
- index_array,
- out_labels,
- callback_model,
- batch_size,
- sample_weights=None,
- val_inputs=None,
- val_targets=None,
- val_sample_weights=None,
- callbacks=None,
- shuffle=True,
- num_train_samples=None,
- do_validation=False):
- """Fit function for eager execution when input is given as arrays or tensors.
-
- Updates the given epoch logs.
-
- Arguments:
- model: Instance of the `Model`.
- inputs: List of input arrays.
- targets: List of target arrays.
- epoch_logs: Dictionary of logs from every epoch.
- index_array: Index array generated from number of training samples.
- out_labels: Output labels generated from model metric names.
- callback_model: Instance of `Model` to callback.
- batch_size: Integer batch size or None if unknown.
- sample_weights: Optional list of sample weight arrays.
- val_inputs: Input data for validation.
- val_targets: Target data for validation.
- val_sample_weights: Sample weight data for validation.
- callbacks: List of callbacks to be called during training.
- shuffle: Whether to shuffle the data at the beginning of each epoch.
- num_train_samples: Integer number of training samples.
- do_validation: Boolean value indicating whether we should do validation.
- """
- # TODO(psv): Create a dataset iterator instead of manually creating batches
- # here and in batch_test_loop, batch_predict_loop.
- if shuffle == 'batch':
- index_array = model._batch_shuffle(index_array, batch_size)
- elif shuffle:
- np.random.shuffle(index_array)
-
- batches = generic_utils.make_batches(num_train_samples, batch_size)
-
- for batch_index, (batch_start, batch_end) in enumerate(batches):
- batch_ids = index_array[batch_start:batch_end]
- inputs_batch = slice_arrays(inputs, batch_ids, contiguous=not shuffle)
- targets_batch = slice_arrays(targets, batch_ids, contiguous=not shuffle)
- if sample_weights:
- sample_weights_batch = slice_arrays(
- sample_weights, batch_ids, contiguous=not shuffle)
- else:
- sample_weights_batch = None
- batch_logs = {}
- batch_logs['batch'] = batch_index
- batch_logs['size'] = len(batch_ids)
-
- callbacks.on_batch_begin(batch_index, batch_logs)
-
- inputs_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in inputs_batch
- ]
- targets_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in targets_batch
- ]
- if sample_weights:
- sample_weights_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- if val is not None else None for val in sample_weights_batch
- ]
-
- outs, loss, loss_metrics = _process_single_batch(
- model,
- inputs_batch,
- targets_batch,
- sample_weights=sample_weights_batch,
- training=True)
-
- if not isinstance(outs, list):
- outs = [outs]
-
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- # Required for eager execution
- metrics_results = _eager_metrics_fn(model, outs, targets_batch)
- batch_logs['loss'] = tensor_util.constant_value(backend.mean(loss))
-
- for k, v in zip(model.metrics_names,
- [backend.mean(loss)] + loss_metrics + metrics_results):
- batch_logs[k] = tensor_util.constant_value(v)
- callbacks.on_batch_end(batch_index, batch_logs)
- if callback_model.stop_training:
- break
-
- if batch_index == len(batches) - 1: # Last batch.
- if do_validation:
- val_outs = test_loop(
- model,
- val_inputs,
- val_targets,
- sample_weights=val_sample_weights,
- batch_size=batch_size,
- verbose=0)
+ verbose=0,
+ batch_size=batch_size)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
@@ -451,6 +346,11 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
expectations of the model.
"""
assert isinstance(inputs, iterator_ops.EagerIterator)
+ # make sure either x,y or x,y,sample_weights is provided
+ if (not isinstance(inputs.output_shapes, (list, tuple)) or
+ len(inputs.output_shapes) < 2 or len(inputs.output_shapes) > 3):
+ raise ValueError('Please provide either inputs and targets'
+ 'or inputs, targets, and sample_weights')
outs = []
num_samples = 0
if verbose == 1:
@@ -466,10 +366,11 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
'(in this case, %d batches).', steps)
break
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide data as a list or tuple of 2 elements '
- ' - input and target pair. Received %s' % next_element)
- x, y = next_element
+ if len(inputs.output_shapes) == 2:
+ x, y = next_element
+ sample_weights = None
+ else:
+ x, y, sample_weights = next_element
# Validate and standardize data.
x, y, sample_weights = model._standardize_user_data(x, y)
@@ -512,94 +413,6 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
return outs
-def batch_test_loop(model,
- inputs,
- targets,
- batch_size,
- sample_weights=None,
- verbose=0):
- """Test function for eager execution when input is given as arrays or tensors.
-
- Arguments:
- model: Model instance that is being evaluated in Eager mode.
- inputs: List of input arrays.
- targets: List of target arrays.
- batch_size: Integer batch size.
- sample_weights: Optional list of sample weight arrays.
- verbose: Verbosity mode.
-
- Returns:
- Scalar loss (if the model has a single output and no metrics)
- or list of scalars (if the model has multiple outputs
- and/or metrics). The attribute `model.metrics_names` will give you
- the display labels for the scalar outputs.
- """
- outs = []
- feed_data = inputs + targets
- if sample_weights:
- feed_data += sample_weights
- num_samples = training_utils.check_num_samples(
- feed_data, batch_size=batch_size)
- if verbose == 1:
- progbar = generic_utils.Progbar(target=num_samples)
- batches = generic_utils.make_batches(num_samples, batch_size)
- index_array = np.arange(num_samples)
- for batch_index, (batch_start, batch_end) in enumerate(batches):
- batch_ids = index_array[batch_start:batch_end]
- inputs_batch = slice_arrays(inputs, batch_ids)
- targets_batch = slice_arrays(targets, batch_ids)
- if sample_weights:
- sample_weights_batch = slice_arrays(sample_weights, batch_ids)
- else:
- sample_weights_batch = None
-
- inputs_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in inputs_batch
- ]
- targets_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in targets_batch
- ]
- if sample_weights:
- sample_weights_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- if val is not None else None for val in sample_weights_batch
- ]
-
- loss_outs, loss, loss_metrics = _model_loss(
- model,
- inputs_batch,
- targets_batch,
- sample_weights=sample_weights_batch,
- training=False)
- metrics_results = _eager_metrics_fn(model, loss_outs, targets_batch)
- batch_outs = []
- for _, v in zip(model.metrics_names,
- [backend.mean(loss)] + loss_metrics + metrics_results):
- batch_outs.append(tensor_util.constant_value(v))
-
- if isinstance(batch_outs, list):
- if batch_index == 0:
- for _ in enumerate(batch_outs):
- outs.append(0.)
- for i, batch_out in enumerate(batch_outs):
- outs[i] += batch_out * len(batch_ids)
- else:
- if batch_index == 0:
- outs.append(0.)
- outs[0] += batch_outs * len(batch_ids)
-
- if verbose == 1:
- progbar.update(batch_end)
-
- for i in range(len(outs)):
- outs[i] /= num_samples
- if len(outs) == 1:
- return outs[0]
- return outs
-
-
def iterator_predict_loop(model, inputs, steps, verbose=0):
"""Predict function for eager execution when input is dataset iterator.
@@ -619,6 +432,12 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
expectations of the model.
"""
assert isinstance(inputs, iterator_ops.EagerIterator)
+ if not isinstance(inputs.output_shapes,
+ (list, tuple)) or len(inputs.output_shapes) > 2:
+ raise ValueError(
+ 'Please provide data as a list or tuple of 1 or 2 elements '
+ ' - input or input and target pair. Received %s. We do not use the '
+ '`target` value here.' % inputs.output_shapes)
outs = []
if verbose == 1:
progbar = generic_utils.Progbar(target=steps)
@@ -634,12 +453,8 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
'batches (in this case, %d batches).', steps)
break
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError(
- 'Please provide data as a list or tuple of 2 elements '
- ' - input and target pair. Received %s. We do not use the '
- '`target` value here.' % next_element)
- x, _ = next_element
+ # expects a tuple, where first element of tuple represents inputs
+ x = next_element[0]
# Validate and standardize data.
x, _, _ = model._standardize_user_data(x)
@@ -670,99 +485,6 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
return outs
-def batch_predict_loop(model, inputs, batch_size, verbose=0):
- """Predict function for eager execution when input is arrays or tensors.
-
- Arguments:
- model: Instance of `Model`.
- inputs: List of input arrays.
- batch_size: Integer batch size.
- verbose: Verbosity mode.
-
- Returns:
- Array of predictions (if the model has a single output)
- or list of arrays of predictions (if the model has multiple outputs).
- """
- outs = []
- num_samples = training_utils.check_num_samples(inputs, batch_size)
- if verbose == 1:
- progbar = generic_utils.Progbar(target=num_samples)
- batches = generic_utils.make_batches(num_samples, batch_size)
- index_array = np.arange(num_samples)
- for batch_index, (batch_start, batch_end) in enumerate(batches):
- batch_ids = index_array[batch_start:batch_end]
- inputs_batch = slice_arrays(inputs, batch_ids)
-
- inputs_batch = [
- ops.convert_to_tensor(val, dtype=backend.floatx())
- for val in inputs_batch
- ]
-
- if len(inputs_batch) == 1:
- if model._expects_training_arg:
- batch_outs = model.call(inputs_batch[0], training=False)
- else:
- batch_outs = model.call(inputs_batch[0])
- else:
- if model._expects_training_arg:
- batch_outs = model.call(inputs_batch, training=False)
- else:
- batch_outs = model.call(inputs_batch)
-
- if not isinstance(batch_outs, list):
- batch_outs = [batch_outs]
- if batch_index == 0:
- # Pre-allocate the results arrays.
- for batch_out in batch_outs:
- dims = batch_out.shape[1:].dims
- dims_list = [d.value for d in dims]
- shape = (num_samples,) + tuple(dims_list)
- outs.append(np.zeros(shape, dtype=batch_out.dtype.as_numpy_dtype))
- for i, batch_out in enumerate(batch_outs):
- outs[i][batch_start:batch_end] = batch_out
- if verbose == 1:
- progbar.update(batch_end)
-
- if len(outs) == 1:
- return outs[0]
- return outs
-
-
-def slice_arrays(arrays, indices, contiguous=True):
- """Slices batches out of provided arrays (workaround for eager tensors).
-
- Unfortunately eager tensors don't have the same slicing behavior as
- Numpy arrays (they follow the same slicing behavior as symbolic TF tensors),
- hence we cannot use `generic_utils.slice_arrays` directly
- and we have to implement this workaround based on `concat`. This has a
- performance cost.
-
- Arguments:
- arrays: Single array or list of arrays.
- indices: List of indices in the array that should be included in the output
- batch.
- contiguous: Boolean flag indicating whether the indices are contiguous.
-
- Returns:
- Slice of data (either single array or list of arrays).
- """
- if any(tensor_util.is_tensor(x) for x in arrays):
- converted_to_list = False
- if not isinstance(arrays, list):
- converted_to_list = True
- arrays = [arrays]
- if not contiguous:
- entries = [[x[i:i + 1] for i in indices] for x in arrays]
- slices = [array_ops.concat(x, axis=0) for x in entries]
- else:
- slices = [x[indices[0]:indices[-1] + 1] for x in arrays]
- if converted_to_list:
- slices = slices[0]
- return slices
- else:
- return generic_utils.slice_arrays(arrays, indices)
-
-
def _process_single_batch(model,
inputs,
targets,
@@ -935,19 +657,24 @@ def fit_loop(model,
Raises:
ValueError: In case of invalid argument values.
"""
+ # Convert training inputs to an EagerIterator
+ inputs, steps_per_epoch = training_utils.convert_to_iterator(
+ x=inputs,
+ y=targets,
+ sample_weights=sample_weights,
+ batch_size=batch_size,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ shuffle=shuffle)
# Required for eager execution
with backend.learning_phase_scope(1):
do_validation = False
if val_inputs:
do_validation = True
- if (steps_per_epoch is None and verbose and inputs and
- hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
- print('Train on %d samples, validate on %d samples' %
- (inputs[0].shape[0], val_inputs[0].shape[0]))
num_train_samples = None
out_labels = None
- if steps_per_epoch is None or model._is_compiled:
+ if model._is_compiled:
out_labels = model.metrics_names
if do_validation:
callback_metrics = copy.copy(out_labels) + [
@@ -956,28 +683,10 @@ def fit_loop(model,
else:
callback_metrics = copy.copy(out_labels)
- if steps_per_epoch is None:
- if sample_weights:
- feed_data = inputs + targets + sample_weights
- else:
- feed_data = inputs + targets
- num_train_samples = training_utils.check_num_samples(
- feed_data,
- batch_size=batch_size,
- steps=steps_per_epoch,
- steps_name='steps_per_epoch')
-
- if num_train_samples is not None:
- index_array = np.arange(num_train_samples)
-
model.history = cbks.History()
callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history]
if verbose:
- if steps_per_epoch is not None:
- count_mode = 'steps'
- else:
- count_mode = 'samples'
- callbacks += [cbks.ProgbarLogger(count_mode)]
+ callbacks += [cbks.ProgbarLogger('steps')]
callbacks = cbks.CallbackList(callbacks)
# it's possible to callback a different model than self
@@ -1019,43 +728,24 @@ def fit_loop(model,
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
-
- if steps_per_epoch is not None:
- iterator_fit_loop(
- model,
- inputs,
- class_weight,
- steps_per_epoch=steps_per_epoch,
- callback_model=callback_model,
- out_labels=out_labels,
- epoch_logs=epoch_logs,
- val_inputs=val_inputs,
- val_targets=val_targets,
- val_sample_weights=val_sample_weights,
- epochs=epochs,
- verbose=verbose,
- callbacks=callbacks,
- callback_metrics=callback_metrics,
- validation_steps=validation_steps,
- do_validation=do_validation)
- else:
- batch_fit_loop(
- model,
- inputs,
- targets,
- epoch_logs=epoch_logs,
- index_array=index_array,
- out_labels=out_labels,
- callback_model=callback_model,
- batch_size=batch_size,
- sample_weights=sample_weights,
- val_inputs=val_inputs,
- val_targets=val_targets,
- val_sample_weights=val_sample_weights,
- callbacks=callbacks,
- shuffle=shuffle,
- num_train_samples=num_train_samples,
- do_validation=do_validation)
+ iterator_fit_loop(
+ model,
+ inputs,
+ class_weight,
+ steps_per_epoch=steps_per_epoch,
+ callback_model=callback_model,
+ out_labels=out_labels,
+ epoch_logs=epoch_logs,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ callback_metrics=callback_metrics,
+ validation_steps=validation_steps,
+ do_validation=do_validation,
+ batch_size=batch_size)
callbacks.on_epoch_end(epoch, epoch_logs)
if callback_model.stop_training:
break
@@ -1087,17 +777,14 @@ def test_loop(model, inputs, targets,
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the scalar outputs.
"""
+ inputs, steps = training_utils.convert_to_iterator(
+ x=inputs,
+ y=targets,
+ sample_weights=sample_weights,
+ batch_size=batch_size,
+ steps_per_epoch=steps)
with backend.learning_phase_scope(0):
- if steps is not None:
- return iterator_test_loop(model, inputs, steps, verbose=verbose)
- else:
- return batch_test_loop(
- model,
- inputs,
- targets,
- batch_size=batch_size,
- sample_weights=sample_weights,
- verbose=verbose)
+ return iterator_test_loop(model, inputs, steps, verbose=verbose)
def predict_loop(model, inputs,
@@ -1121,8 +808,6 @@ def predict_loop(model, inputs,
(if the model has multiple outputs).
"""
with backend.learning_phase_scope(0):
- if steps is not None:
- return iterator_predict_loop(model, inputs, steps, verbose=verbose)
- else:
- return batch_predict_loop(
- model, inputs, batch_size=batch_size, verbose=verbose)
+ inputs, steps = training_utils.convert_to_iterator(
+ x=inputs, batch_size=batch_size, steps_per_epoch=steps)
+ return iterator_predict_loop(model, inputs, steps, verbose=verbose)
diff --git a/tensorflow/python/keras/engine/training_gpu_test.py b/tensorflow/python/keras/engine/training_gpu_test.py
new file mode 100644
index 0000000000..5825ce814f
--- /dev/null
+++ b/tensorflow/python/keras/engine/training_gpu_test.py
@@ -0,0 +1,125 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for training routines."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.layers.convolutional import Conv2D
+from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
+
+
+class TrainingGPUTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_model_with_crossentropy_losses_channels_first(self):
+ """Tests use of all crossentropy losses with `channels_first`.
+
+ Tests `sparse_categorical_crossentropy`, `categorical_crossentropy`,
+ and `binary_crossentropy`.
+ Verifies that evaluate gives the same result with either `channels_first`
+ or `channels_last` image_data_format.
+ """
+ def prepare_simple_model(input_tensor, loss_name, target):
+ axis = 1 if K.image_data_format() == 'channels_first' else -1
+ loss = None
+ num_channels = None
+ activation = None
+ if loss_name == 'sparse_categorical_crossentropy':
+ loss = lambda y_true, y_pred: K.sparse_categorical_crossentropy( # pylint: disable=g-long-lambda
+ y_true, y_pred, axis=axis)
+ num_channels = np.amax(target) + 1
+ activation = 'softmax'
+ elif loss_name == 'categorical_crossentropy':
+ loss = lambda y_true, y_pred: K.categorical_crossentropy( # pylint: disable=g-long-lambda
+ y_true, y_pred, axis=axis)
+ num_channels = target.shape[axis]
+ activation = 'softmax'
+ elif loss_name == 'binary_crossentropy':
+ loss = lambda y_true, y_pred: K.binary_crossentropy(y_true, y_pred) # pylint: disable=unnecessary-lambda
+ num_channels = target.shape[axis]
+ activation = 'sigmoid'
+ predictions = Conv2D(num_channels,
+ 1,
+ activation=activation,
+ kernel_initializer='ones',
+ bias_initializer='ones')(input_tensor)
+ simple_model = keras.models.Model(inputs=input_tensor,
+ outputs=predictions)
+ simple_model.compile(optimizer=rmsprop.RMSPropOptimizer(1e-3), loss=loss)
+ return simple_model
+
+ if test.is_gpu_available(cuda_only=True):
+ with self.test_session(use_gpu=True):
+ losses_to_test = ['sparse_categorical_crossentropy',
+ 'categorical_crossentropy', 'binary_crossentropy']
+
+ data_channels_first = np.array([[[[8., 7.1, 0.], [4.5, 2.6, 0.55],
+ [0.9, 4.2, 11.2]]]], dtype=np.float32)
+ # Labels for testing 4-class sparse_categorical_crossentropy, 4-class
+ # categorical_crossentropy, and 2-class binary_crossentropy:
+ labels_channels_first = [np.array([[[[0, 1, 3], [2, 1, 0], [2, 2, 1]]]], dtype=np.float32), # pylint: disable=line-too-long
+ np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 0]],
+ [[1, 0, 0], [0, 0, 1], [0, 1, 0]],
+ [[0, 0, 0], [1, 0, 0], [0, 0, 1]],
+ [[0, 0, 1], [0, 0, 0], [1, 0, 0]]]], dtype=np.float32), # pylint: disable=line-too-long
+ np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 1]],
+ [[1, 0, 1], [1, 0, 1], [1, 1, 0]]]], dtype=np.float32)] # pylint: disable=line-too-long
+ # Compute one loss for each loss function in the list `losses_to_test`:
+ loss_channels_last = [0., 0., 0.]
+ loss_channels_first = [0., 0., 0.]
+
+ old_data_format = K.image_data_format()
+
+ # Evaluate a simple network with channels last, with all three loss
+ # functions:
+ K.set_image_data_format('channels_last')
+ data = np.moveaxis(data_channels_first, 1, -1)
+ for index, loss_function in enumerate(losses_to_test):
+ labels = np.moveaxis(labels_channels_first[index], 1, -1)
+ inputs = keras.Input(shape=(3, 3, 1))
+ model = prepare_simple_model(inputs, loss_function, labels)
+ loss_channels_last[index] = model.evaluate(x=data, y=labels,
+ batch_size=1, verbose=0)
+
+ # Evaluate the same network with channels first, with all three loss
+ # functions:
+ K.set_image_data_format('channels_first')
+ data = data_channels_first
+ for index, loss_function in enumerate(losses_to_test):
+ labels = labels_channels_first[index]
+ inputs = keras.Input(shape=(1, 3, 3))
+ model = prepare_simple_model(inputs, loss_function, labels)
+ loss_channels_first[index] = model.evaluate(x=data, y=labels,
+ batch_size=1, verbose=0)
+
+ K.set_image_data_format(old_data_format)
+
+ np.testing.assert_allclose(loss_channels_first,
+ loss_channels_last,
+ err_msg='{}{}'.format(
+ 'Computed different losses for ',
+ 'channels_first and channels_last'))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index d9e548f01f..301a6ca866 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import logging
import os
import unittest
@@ -415,6 +416,28 @@ class TrainingTest(test.TestCase):
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)
+ def test_compile_warning_for_loss_missing_output(self):
+ with self.test_session():
+ inp = keras.layers.Input(shape=(16,), name='input_a')
+ out_1 = keras.layers.Dense(8, name='dense_1')(inp)
+ out_2 = keras.layers.Dense(3, activation='softmax', name='dense_2')(out_1)
+ model = keras.models.Model(inputs=[inp], outputs=[out_1, out_2])
+
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ model.compile(
+ loss={
+ 'dense_2': 'categorical_crossentropy',
+ },
+ optimizer='rmsprop',
+ metrics={
+ 'dense_2': 'categorical_accuracy',
+ 'dense_1': 'categorical_accuracy',
+ })
+ msg = ('Output "dense_1" missing from loss dictionary. We assume this '
+ 'was done on purpose. The fit and evaluate APIs will not be '
+ 'expecting any data to be passed to "dense_1".')
+ self.assertRegexpMatches(str(mock_log.call_args), msg)
+
class LossWeightingTest(test.TestCase):
@@ -744,6 +767,22 @@ class LossMaskingTest(test.TestCase):
keras.backend.variable(weights), keras.backend.variable(mask)))
+class LearningPhaseTest(test.TestCase):
+
+ def test_empty_model_no_learning_phase(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ self.assertFalse(model.uses_learning_phase)
+
+ def test_dropout_has_learning_phase(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_dim=3))
+ model.add(keras.layers.Dropout(0.5))
+ model.add(keras.layers.Dense(2))
+ self.assertTrue(model.uses_learning_phase)
+
+
class TestDynamicTrainability(test.TestCase):
def test_trainable_warning(self):
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 728a2b493b..dbbc87daf9 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -19,9 +19,11 @@ from __future__ import division
from __future__ import print_function
import copy
+import math
import numpy as np
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_util
@@ -31,6 +33,135 @@ from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.ops import math_ops
+def _map_nested(data, func):
+ """Maps each nested element using func."""
+ if isinstance(data, list):
+ return [_map_nested(nested_data, func) for nested_data in data]
+ elif isinstance(data, tuple):
+ return tuple(_map_nested(nested_data, func) for nested_data in data)
+ elif isinstance(data, dict):
+ return {
+ k: _map_nested(nested_data, func) for k, nested_data in data.items()
+ }
+ else:
+ return func(data)
+
+
+def _nested_all(data, cond_func):
+ """Checks if all elements in a nested structure satisfy cond_func."""
+ if isinstance(data, (tuple, list)):
+ return all([_nested_all(nested_data, cond_func) for nested_data in data])
+ elif isinstance(data, dict):
+ return all(
+ [_nested_all(nested_data, cond_func) for nested_data in data.values()])
+ else:
+ return cond_func(data)
+
+
+def _nested_any(data, cond_func):
+ """Checks if any nested_elements in a nested structure satisfy cond_func."""
+ if isinstance(data, (tuple, list)):
+ return any([_nested_any(nested_data, cond_func) for nested_data in data])
+ elif isinstance(data, dict):
+ return any(
+ [_nested_any(nested_data, cond_func) for nested_data in data.values()])
+ else:
+ return cond_func(data)
+
+
+def _convert_lists_to_tuples(data):
+ """Converts all lists to tuples, since Datasets expect tuples."""
+ if isinstance(data, (tuple, list)):
+ return tuple(_convert_lists_to_tuples(nested_data) for nested_data in data)
+ elif isinstance(data, dict):
+ return {
+ k: _convert_lists_to_tuples(nested_data)
+ for k, nested_data in data.items()
+ }
+ else:
+ return data
+
+
+def _get_batch_axis_size(data):
+ """Returns batch axis shape for nested data."""
+ if isinstance(data, (tuple, list)):
+ return _get_batch_axis_size(data[0])
+ elif isinstance(data, dict):
+ return _get_batch_axis_size(list(data.values()))
+ else:
+ return int(data.shape[0])
+
+
+def convert_to_iterator(x=None,
+ y=None,
+ sample_weights=None,
+ batch_size=None,
+ steps_per_epoch=None,
+ epochs=1,
+ shuffle=False):
+ """Converts NumPy arrays or EagerTensors to an EagerIterator.
+
+ Combines all provided data into a single EagerIterator.
+
+ Arguments:
+ x: NumPy array or EagerTensor, or list of Numpy arrays or EagerTensors
+ representing inputs to a model.
+ y: Optional. NumPy array or EagerTensor, or list of Numpy arrays or
+ EagerTensors representing targets of a model.
+ sample_weights: Optional NumPy array or EagerTensor representing sample
+ weights.
+ batch_size: Used to batch data and calculate how many steps EagerIterator
+ should take per epoch.
+ steps_per_epoch: If provided, how many steps EagerIterator should take per
+ epoch.
+ epochs: Epochs to repeat iterator for.
+ shuffle: Whether to shuffle data after each epoch.
+
+ Raises:
+ ValueError: if steps_per_epoch cannot be calculated from the data
+ provided.
+
+ Returns:
+ (Iterator, steps_per_epoch).
+
+ """
+ if isinstance(x, iterator_ops.EagerIterator):
+ return x, steps_per_epoch
+
+ if not _nested_any(sample_weights, lambda x: x is None):
+ data = (x, y, sample_weights)
+ elif not _nested_any(y, lambda x: x is None):
+ data = (x, y)
+ else:
+ # always wrap in a tuple, so we know y, sample_weights weren't set
+ # even when x has multiple elements
+ data = (x,)
+
+ data = _convert_lists_to_tuples(data)
+ if steps_per_epoch is None and batch_size is not None:
+ num_samples = _get_batch_axis_size(data)
+ steps_per_epoch = int(math.ceil(num_samples / batch_size))
+
+ if steps_per_epoch is None:
+ raise ValueError('Could not determine steps_per_epoch.'
+ 'Please provide either batch_size or'
+ 'steps_per_epoch.')
+
+ # TODO(omalleyt) for NumPy arrays in graph mode
+ # placeholder ops should be used
+ # this is only ideal for eager mode
+ dataset = dataset_ops.Dataset.from_tensor_slices(data)
+
+ if batch_size is not None:
+ dataset = dataset.batch(batch_size)
+ if shuffle:
+ dataset = dataset.shuffle(buffer_size=10000)
+ dataset = dataset.repeat(epochs)
+ iterator = dataset.make_one_shot_iterator()
+
+ return iterator, steps_per_epoch
+
+
def check_num_samples(ins,
batch_size=None,
steps=None,
@@ -128,8 +259,8 @@ def standardize_input_data(data,
except KeyError as e:
raise ValueError('No data provided for "' + e.args[0] + '". Need data '
'for each key in: ' + str(names))
- elif isinstance(data, list):
- if isinstance(data[0], list):
+ elif isinstance(data, (list, tuple)):
+ if isinstance(data[0], (list, tuple)):
data = [np.asarray(d) for d in data]
elif len(names) == 1 and isinstance(data[0], (float, int)):
data = [np.asarray(data)]
@@ -482,6 +613,9 @@ def standardize_weights(y,
Raises:
ValueError: In case of invalid user-provided arguments.
"""
+ # Iterator may return sample_weight as 1-tuple
+ if isinstance(sample_weight, tuple):
+ sample_weight = sample_weight[0]
if sample_weight_mode is not None:
if sample_weight_mode != 'temporal':
raise ValueError('"sample_weight_mode '
diff --git a/tensorflow/python/keras/engine/training_utils_test.py b/tensorflow/python/keras/engine/training_utils_test.py
new file mode 100644
index 0000000000..297a1ae494
--- /dev/null
+++ b/tensorflow/python/keras/engine/training_utils_test.py
@@ -0,0 +1,150 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for training utility functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training_utils
+from tensorflow.python.platform import test
+
+
+class TrainingUtilTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_convert_to_iterator_single_numpy(self):
+ batch_size = 2
+ a = np.ones([10, 10])
+ iterator, steps_per_epoch = training_utils.convert_to_iterator(
+ x=a, batch_size=batch_size)
+ self.assertEquals(steps_per_epoch, 5)
+
+ expected_batch = a[:batch_size, :]
+ actual_batch, = iterator.get_next()
+ self.assertAllEqual(expected_batch, actual_batch)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_convert_to_iterator_single_tensor(self):
+ batch_size = 2
+ a = ops.convert_to_tensor(np.ones([10, 10]))
+ iterator, steps_per_epoch = training_utils.convert_to_iterator(
+ x=a, batch_size=batch_size)
+ self.assertEquals(steps_per_epoch, 5)
+
+ expected_batch = a[:batch_size, :]
+ actual_batch, = iterator.get_next()
+ self.assertAllEqual(expected_batch, actual_batch)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_convert_to_iterator_y(self):
+ batch_size = 2
+ a = np.ones([10, 100])
+ b = np.ones([10, 10])
+ iterator, steps_per_epoch = training_utils.convert_to_iterator(
+ x=a, y=b, batch_size=batch_size)
+ self.assertEquals(steps_per_epoch, 5)
+
+ expected_x = a[:batch_size, :]
+ expected_y = b[:batch_size, :]
+ actual_x, actual_y = iterator.get_next()
+ self.assertAllEqual(expected_x, actual_x)
+ self.assertAllEqual(expected_y, actual_y)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_convert_to_iterator_sample_weights(self):
+ batch_size = 2
+ a = ops.convert_to_tensor(np.ones([10, 100]))
+ b = ops.convert_to_tensor(np.ones([10, 10]))
+ sw = ops.convert_to_tensor(np.ones([10]))
+ iterator, steps_per_epoch = training_utils.convert_to_iterator(
+ x=a, y=b, sample_weights=sw, batch_size=batch_size)
+ self.assertEquals(steps_per_epoch, 5)
+
+ expected_x = a[:batch_size, :]
+ expected_y = b[:batch_size, :]
+ expected_sw = sw[:batch_size]
+ actual_x, actual_y, actual_sw = iterator.get_next()
+ self.assertAllEqual(expected_x, actual_x)
+ self.assertAllEqual(expected_y, actual_y)
+ self.assertAllEqual(expected_sw, actual_sw)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_convert_to_iterator_nested(self):
+ batch_size = 2
+ x = {'1': np.ones([10, 100]), '2': [np.zeros([10, 10]), np.ones([10, 20])]}
+ iterator, steps_per_epoch = training_utils.convert_to_iterator(
+ x=x, batch_size=batch_size)
+ self.assertEquals(steps_per_epoch, 5)
+
+ expected_x1 = x['1'][:batch_size, :]
+ expected_x2_0 = x['2'][0][:batch_size, :]
+ expected_x2_1 = x['2'][1][:batch_size, :]
+
+ actual_x, = iterator.get_next()
+ actual_x1 = actual_x['1'][:batch_size, :]
+ actual_x2_0 = actual_x['2'][0][:batch_size, :]
+ actual_x2_1 = actual_x['2'][1][:batch_size, :]
+
+ self.assertAllEqual(expected_x1, actual_x1)
+ self.assertAllEqual(expected_x2_0, actual_x2_0)
+ self.assertAllEqual(expected_x2_1, actual_x2_1)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_convert_to_iterator_epochs(self):
+ batch_size = 2
+ a = np.ones([10, 10])
+ iterator, steps_per_epoch = training_utils.convert_to_iterator(
+ x=a, batch_size=batch_size, epochs=2)
+ self.assertEquals(steps_per_epoch, 5)
+
+ expected_batch = a[:batch_size, :]
+ # loop through one whole epoch
+ for _ in range(6):
+ actual_batch, = iterator.get_next()
+ self.assertAllEqual(expected_batch, actual_batch)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_convert_to_iterator_insufficient_info(self):
+ # with batch_size and steps_per_epoch not set
+ with self.assertRaises(ValueError):
+ a = np.ones([10, 10])
+ _ = training_utils.convert_to_iterator(x=a)
+
+ def test_nested_all(self):
+ nested_data = {'a': True, 'b': [True, True, (False, True)]}
+ all_true = training_utils._nested_all(nested_data, lambda x: x)
+ self.assertEquals(all_true, False)
+
+ nested_data = {'a': True, 'b': [True, True, (True, True)]}
+ all_true = training_utils._nested_all(nested_data, lambda x: x)
+ self.assertEquals(all_true, True)
+
+ def test_nested_any(self):
+ nested_data = [False, {'a': False, 'b': (False, True)}]
+ any_true = training_utils._nested_any(nested_data, lambda x: x)
+ self.assertEquals(any_true, True)
+
+ nested_data = [False, {'a': False, 'b': (False, False)}]
+ any_true = training_utils._nested_any(nested_data, lambda x: x)
+ self.assertEquals(any_true, False)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/initializers.py b/tensorflow/python/keras/initializers.py
index 28beb6760d..b9d856efa8 100644
--- a/tensorflow/python/keras/initializers.py
+++ b/tensorflow/python/keras/initializers.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras initializer classes (soon to be replaced with core TF initializers).
+"""Keras initializer serialization / deserialization.
"""
from __future__ import absolute_import
from __future__ import division
@@ -22,107 +22,27 @@ import six
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+
+# These imports are brought in so that keras.initializers.deserialize
+# has them available in module_objects.
from tensorflow.python.ops.init_ops import Constant
from tensorflow.python.ops.init_ops import glorot_normal_initializer
from tensorflow.python.ops.init_ops import glorot_uniform_initializer
-
+from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import
+from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Identity
from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import
+from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unused-import
+from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Ones
from tensorflow.python.ops.init_ops import Orthogonal
from tensorflow.python.ops.init_ops import RandomNormal
from tensorflow.python.ops.init_ops import RandomUniform
from tensorflow.python.ops.init_ops import TruncatedNormal
-from tensorflow.python.ops.init_ops import VarianceScaling
+from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Zeros
-from tensorflow.python.util.tf_export import tf_export
-
-
-@tf_export('keras.initializers.lecun_normal')
-def lecun_normal(seed=None):
- """LeCun normal initializer.
-
- It draws samples from a truncated normal distribution centered on 0
- with `stddev = sqrt(1 / fan_in)`
- where `fan_in` is the number of input units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
- - [Efficient
- Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
- """
- return VarianceScaling(
- scale=1., mode='fan_in', distribution='normal', seed=seed)
-
-
-@tf_export('keras.initializers.lecun_uniform')
-def lecun_uniform(seed=None):
- """LeCun uniform initializer.
-
- It draws samples from a uniform distribution within [-limit, limit]
- where `limit` is `sqrt(3 / fan_in)`
- where `fan_in` is the number of input units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
- References:
- LeCun 98, Efficient Backprop,
- http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
- """
- return VarianceScaling(
- scale=1., mode='fan_in', distribution='uniform', seed=seed)
-
-
-@tf_export('keras.initializers.he_normal')
-def he_normal(seed=None):
- """He normal initializer.
-
- It draws samples from a truncated normal distribution centered on 0
- with `stddev = sqrt(2 / fan_in)`
- where `fan_in` is the number of input units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- He et al., http://arxiv.org/abs/1502.01852
- """
- return VarianceScaling(
- scale=2., mode='fan_in', distribution='normal', seed=seed)
-
-
-@tf_export('keras.initializers.he_uniform')
-def he_uniform(seed=None):
- """He uniform variance scaling initializer.
-
- It draws samples from a uniform distribution within [-limit, limit]
- where `limit` is `sqrt(6 / fan_in)`
- where `fan_in` is the number of input units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- He et al., http://arxiv.org/abs/1502.01852
- """
- return VarianceScaling(
- scale=2., mode='fan_in', distribution='uniform', seed=seed)
+from tensorflow.python.util.tf_export import tf_export
# Compatibility aliases
diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py
index c519e194bd..51725e03f2 100644
--- a/tensorflow/python/keras/initializers_test.py
+++ b/tensorflow/python/keras/initializers_test.py
@@ -31,16 +31,6 @@ class KerasInitializersTest(test.TestCase):
target_max=None, target_min=None):
variable = keras.backend.variable(init(shape))
output = keras.backend.get_value(variable)
- lim = 3e-2
- if target_std is not None:
- self.assertGreater(lim, abs(output.std() - target_std))
- if target_mean is not None:
- self.assertGreater(lim, abs(output.mean() - target_mean))
- if target_max is not None:
- self.assertGreater(lim, abs(output.max() - target_max))
- if target_min is not None:
- self.assertGreater(lim, abs(output.min() - target_min))
-
# Test serialization (assumes deterministic behavior).
config = init.get_config()
reconstructed_init = init.__class__.from_config(config)
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py
index eba10da6f3..61ab69c16f 100644
--- a/tensorflow/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/layers/advanced_activations.py
@@ -284,6 +284,13 @@ class Softmax(Layer):
class ReLU(Layer):
"""Rectified Linear Unit activation function.
+ With default values, it returns element-wise `max(x, 0)`.
+
+ Otherwise, it follows:
+ `f(x) = max_value` for `x >= max_value`,
+ `f(x) = x` for `threshold <= x < max_value`,
+ `f(x) = negative_slope * (x - threshold)` otherwise.
+
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
@@ -294,21 +301,39 @@ class ReLU(Layer):
Arguments:
max_value: float >= 0. Maximum activation value.
+ negative_slope: float >= 0. Negative slope coefficient.
+ threshold: float. Threshold value for thresholded activation.
"""
- def __init__(self, max_value=None, **kwargs):
+ def __init__(self, max_value=None, negative_slope=0, threshold=0, **kwargs):
super(ReLU, self).__init__(**kwargs)
- self.support_masking = True
- self.max_value = K.cast_to_floatx(max_value)
- if self.max_value < 0.:
+ if max_value is not None and max_value < 0.:
raise ValueError('max_value of Relu layer '
'cannot be negative value: ' + str(max_value))
+ if negative_slope < 0.:
+ raise ValueError('negative_slope of Relu layer '
+ 'cannot be negative value: ' + str(negative_slope))
+
+ self.support_masking = True
+ self.max_value = K.cast_to_floatx(max_value)
+ self.negative_slope = K.cast_to_floatx(negative_slope)
+ self.threshold = K.cast_to_floatx(threshold)
def call(self, inputs):
- return activations.relu(inputs, max_value=self.max_value)
+ # alpha is used for leaky relu slope in activations instead of
+ # negative_slope.
+ return activations.relu(
+ inputs,
+ alpha=self.negative_slope,
+ max_value=self.max_value,
+ threshold=self.threshold)
def get_config(self):
- config = {'max_value': self.max_value}
+ config = {
+ 'max_value': self.max_value,
+ 'negative_slope': self.negative_slope,
+ 'threshold': self.threshold
+ }
base_config = super(ReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index 9e1f15b1bc..53c1baa2bb 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -75,6 +75,14 @@ class AdvancedActivationsTest(test.TestCase):
testing_utils.layer_test(keras.layers.ReLU,
kwargs={'max_value': -10},
input_shape=(2, 3, 4))
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'negative_slope of Relu layer cannot be negative value: -2'):
+ with self.test_session():
+ testing_utils.layer_test(
+ keras.layers.ReLU,
+ kwargs={'negative_slope': -2},
+ input_shape=(2, 3, 4))
if __name__ == '__main__':
diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py
index 84d794cada..e61dd3043d 100644
--- a/tensorflow/python/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/layers/convolutional_recurrent.py
@@ -788,7 +788,7 @@ class ConvLSTM2D(ConvRNN2D):
Arguments:
filters: Integer, the dimensionality of the output space
- (i.e. the number output of filters in the convolution).
+ (i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of n integers, specifying the
dimensions of the convolution window.
strides: An integer or tuple/list of n integers,
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
index 8fd970239f..2ed0aa8f26 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
@@ -220,7 +220,7 @@ class CuDNNTest(test.TestCase, parameterized.TestCase):
self.assertNotEqual(out4.max(), out5.max())
@parameterized.named_parameters(
- *testing_utils.generate_combinations_with_testcase_name(
+ *test_util.generate_combinations_with_testcase_name(
rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False],
bidirectional=[True, False], implementation=[1, 2],
model_nest_level=[1, 2], model_type=['seq', 'func']))
@@ -301,7 +301,7 @@ class CuDNNTest(test.TestCase, parameterized.TestCase):
os.remove(fname)
@parameterized.named_parameters(
- *testing_utils.generate_combinations_with_testcase_name(
+ *test_util.generate_combinations_with_testcase_name(
rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False]))
def test_load_weights_between_noncudnn_rnn_time_distributed(self, rnn_type,
to_cudnn):
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 58c8a8a66d..a7835bc0a2 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -370,7 +370,7 @@ class BatchNormalization(Layer):
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
if decay.dtype != variable.dtype.base_dtype:
decay = math_ops.cast(decay, variable.dtype.base_dtype)
- update_delta = (variable - value) * decay
+ update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training):
@@ -619,6 +619,10 @@ class BatchNormalization(Layer):
else:
mean, variance = self.moving_mean, self.moving_variance
+ mean = math_ops.cast(mean, inputs.dtype)
+ variance = math_ops.cast(variance, inputs.dtype)
+ if offset is not None:
+ offset = math_ops.cast(offset, inputs.dtype)
outputs = nn.batch_normalization(inputs,
_broadcast(mean),
_broadcast(variance),
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index b22f3bd152..a97b4cac46 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -95,6 +95,24 @@ class NormalizationLayersTest(test.TestCase):
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
+ def test_batchnorm_mixed_precision(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
+ model.add(norm)
+ model.compile(loss='mse', optimizer='sgd')
+
+ # centered on 5.0, variance 10.0
+ x = np.random.normal(
+ loc=5.0, scale=10.0, size=(1000, 10)).astype(np.float16)
+ model.fit(x, x, epochs=4, verbose=0)
+ out = model.predict(x)
+ out -= keras.backend.eval(norm.beta)
+ out /= keras.backend.eval(norm.gamma)
+
+ np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
+ np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
+
def test_batchnorm_convnet(self):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 61775da47b..534c0eca08 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -403,6 +404,8 @@ class RNN(Layer):
'one integer per RNN state).')
super(RNN, self).__init__(**kwargs)
self.cell = cell
+ if isinstance(cell, checkpointable.CheckpointableBase):
+ self._track_checkpointable(self.cell, name='cell')
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index 802374d2d2..fefb92826b 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import util as checkpointable_util
class RNNTest(test.TestCase):
@@ -556,5 +557,22 @@ class RNNTest(test.TestCase):
[tuple(o.as_list()) for o in output_shape],
expected_output_shape)
+ def test_checkpointable_dependencies(self):
+ rnn = keras.layers.SimpleRNN
+ with self.test_session():
+ x = np.random.random((2, 2, 2))
+ y = np.random.random((2, 2))
+ model = keras.models.Sequential()
+ model.add(rnn(2))
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.fit(x, y, epochs=1, batch_size=1)
+
+ # check whether the model variables are present in the
+ # checkpointable list of objects
+ checkpointed_objects = set(checkpointable_util.list_objects(model))
+ for v in model.variables:
+ self.assertIn(v, checkpointed_objects)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index f651e03874..f0c1e76156 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -47,7 +47,6 @@ class Wrapper(Layer):
def __init__(self, layer, **kwargs):
assert isinstance(layer, Layer)
self.layer = layer
- self._track_checkpointable(layer, name='layer')
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on its inputs (as opposed
# to the inputs to the Wrapper layer).
@@ -168,6 +167,7 @@ class TimeDistributed(Wrapper):
'`Layer` instance. You passed: {input}'.format(input=layer))
super(TimeDistributed, self).__init__(layer, **kwargs)
self.supports_masking = True
+ self._track_checkpointable(layer, name='layer')
def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
"""Finds non-specific dimensions in the static shapes.
@@ -417,6 +417,8 @@ class Bidirectional(Wrapper):
self._num_constants = None
super(Bidirectional, self).__init__(layer, **kwargs)
self.input_spec = layer.input_spec
+ self._track_checkpointable(self.forward_layer, name='forward_layer')
+ self._track_checkpointable(self.backward_layer, name='backward_layer')
@property
def trainable(self):
@@ -526,7 +528,8 @@ class Bidirectional(Wrapper):
else:
return super(Bidirectional, self).__call__(inputs, **kwargs)
- def call(self, inputs,
+ def call(self,
+ inputs,
training=None,
mask=None,
initial_state=None,
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index 3f268acf5c..0cd774ef0f 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -87,6 +87,8 @@ class TimeDistributedTest(test.TestCase):
# test config
model.get_config()
+ # check whether the model variables are present in the
+ # checkpointable list of objects
checkpointed_objects = set(checkpointable_util.list_objects(model))
for v in model.variables:
self.assertIn(v, checkpointed_objects)
@@ -278,6 +280,12 @@ class BidirectionalTest(test.TestCase):
model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
model.fit(x, y, epochs=1, batch_size=1)
+ # check whether the model variables are present in the
+ # checkpointable list of objects
+ checkpointed_objects = set(checkpointable_util.list_objects(model))
+ for v in model.variables:
+ self.assertIn(v, checkpointed_objects)
+
# test compute output shape
ref_shape = model.layers[-1].output.get_shape()
shape = model.layers[-1].compute_output_shape(
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index e03d7dfe93..7d8b1fec45 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -19,9 +19,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from abc import ABCMeta
+from abc import abstractmethod
+
+import types
import six
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.losses import binary_crossentropy
from tensorflow.python.keras.losses import categorical_crossentropy
from tensorflow.python.keras.losses import cosine_proximity
@@ -37,14 +46,471 @@ from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras.losses import squared_hinge
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import confusion_matrix
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
+def check_is_tensor_or_operation(x, name):
+ """Raises type error if the given input is not a tensor or operation."""
+ if not (isinstance(x, ops.Tensor) or isinstance(x, ops.Operation)):
+ raise TypeError('{0} must be a Tensor or Operation, given: {1}'.format(
+ name, x))
+
+
+def update_state_wrapper(update_state_fn):
+ """Decorator to wrap metric `update_state()` with `defun()`, `add_update()`.
+
+ Args:
+ update_state_fn: function that accumulates metric statistics.
+
+ Returns:
+ If eager execution is enabled, returns None.
+ If graph execution is enabled, returns an update op. This op should be
+ executed to update the metric state with the given inputs.
+ """
+
+ def decorated(metric_obj, *args, **kwargs):
+ """Decorated function with `defun()` and `add_update()`."""
+
+ # Converting update_state_fn() into a graph function, so that
+ # we can return a single op that performs all of the variable updates.
+ # Assigning to a different method name to avoid reference cycle.
+ defuned_update_state_fn = function.defun(update_state_fn)
+ update_op = defuned_update_state_fn(*args, **kwargs)
+ if update_op is not None: # update_op will be None in eager execution.
+ metric_obj.add_update(update_op, inputs=True)
+ check_is_tensor_or_operation(
+ update_op, 'Metric {0}\'s update'.format(metric_obj.name))
+ return update_op
+
+ return tf_decorator.make_decorator(update_state_fn, decorated)
+
+
+def result_wrapper(result_fn):
+ """Decorator to wrap metric `result()` function in `merge_call()`.
+
+ Result computation is an idempotent operation that simply calculates the
+ metric value using the state variables.
+
+ If metric state variables are distributed across towers/devices and
+ `result()` is requested from the context of one device - This function wraps
+ `result()` in a distribution strategy `merge_call()`. With this,
+ the metric state variables will be aggregated across devices.
+
+ Args:
+ result_fn: function that computes the metric result.
+
+ Returns:
+ The metric result tensor.
+ """
+
+ def decorated(metric_obj, *args):
+ """Decorated function with merge_call."""
+ tower_context = distribute_lib.get_tower_context()
+ if tower_context is None: # if in cross tower context already
+ result_t = result_fn(*args)
+ else:
+ # TODO(psv): Test distribution of metrics using different distribution
+ # strategies.
+
+ # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
+ # with distribution object as the first parameter. We create a wrapper
+ # here so that the result function need not have that parameter.
+ def merge_fn_wrapper(distribution, merge_fn, *args):
+ # We will get `PerDevice` merge function. Taking the first one as all
+ # are identical copies of the function that we had passed below.
+ return distribution.unwrap(merge_fn)[0](*args)
+
+ # Wrapping result in merge_call. merge_call is used when we want to leave
+ # tower mode and compute a value in cross tower mode.
+ result_t = tower_context.merge_call(merge_fn_wrapper, result_fn, *args)
+ check_is_tensor_or_operation(result_t,
+ 'Metric {0}\'s result'.format(metric_obj.name))
+ return result_t
+
+ return tf_decorator.make_decorator(result_fn, decorated)
+
+
+def _safe_div(numerator, denominator):
+ """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
+
+ Args:
+ numerator: A `Tensor`.
+ denominator: A `Tensor`, with dtype matching `numerator`.
+
+ Returns:
+ 0 if `denominator` <= 0, else `numerator` / `denominator`
+ """
+ t = math_ops.truediv(numerator, denominator)
+ zero = array_ops.zeros_like(t, dtype=denominator.dtype)
+ condition = math_ops.greater(denominator, zero)
+ zero = math_ops.cast(zero, t.dtype)
+ return array_ops.where(condition, t, zero)
+
+
+def _squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
+ """Squeeze or expand last dimension if needed.
+
+ 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
+ (using `confusion_matrix.remove_squeezable_dimensions`).
+ 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
+ from the new rank of `y_pred`.
+ If `sample_weight` is scalar, it is kept scalar.
+
+ This will use static shape if available. Otherwise, it will add graph
+ operations, which could result in a performance hit.
+
+ Args:
+ y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
+ y_true: Optional label `Tensor` whose dimensions match `y_pred`.
+ sample_weight: Optional weight scalar or `Tensor` whose dimensions match
+ `y_pred`.
+
+ Returns:
+ Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
+ the last dimension squeezed,
+ `sample_weight` could be extended by one dimension.
+ """
+ if y_true is not None:
+ # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
+ y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
+ y_true, y_pred)
+ y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
+
+ if sample_weight is None:
+ return y_pred, y_true, None
+
+ sample_weight = ops.convert_to_tensor(sample_weight)
+ weights_shape = sample_weight.get_shape()
+ weights_rank = weights_shape.ndims
+ if weights_rank == 0: # If weights is scalar, do nothing.
+ return y_pred, y_true, sample_weight
+
+ y_pred_shape = y_pred.get_shape()
+ y_pred_rank = y_pred_shape.ndims
+ if (y_pred_rank is not None) and (weights_rank is not None):
+ # Use static rank.
+ if weights_rank - y_pred_rank == 1:
+ sample_weight = array_ops.squeeze(sample_weight, [-1])
+ elif y_pred_rank - weights_rank == 1:
+ sample_weight = array_ops.expand_dims(sample_weight, [-1])
+ return y_pred, y_true, sample_weight
+
+ # Use dynamic rank.
+ weights_rank_tensor = array_ops.rank(sample_weight)
+ rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
+ maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
+
+ def _maybe_expand_weights():
+ return control_flow_ops.cond(
+ math_ops.equal(rank_diff,
+ -1), lambda: array_ops.expand_dims(sample_weight, [-1]),
+ lambda: sample_weight)
+
+ def _maybe_adjust_weights():
+ return control_flow_ops.cond(
+ math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
+ _maybe_expand_weights)
+
+ # squeeze or expand last dim of `sample_weight` if its rank differs by 1
+ # from the new rank of `y_pred`.
+ sample_weight = control_flow_ops.cond(
+ math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
+ _maybe_adjust_weights)
+ return y_pred, y_true, sample_weight
+
+
+class Metric(Layer):
+ """Encapsulates metric logic and state.
+
+ Usage with eager execution:
+
+ ```python
+ m = SomeMetric(...)
+ for input in ...:
+ m.update_state(input)
+ print('Final result: ', m.result().numpy())
+ ```
+
+ Usage with graph execution:
+
+ ```python
+ m = SomeMetric(...)
+ init_op = tf.global_variables_initializer() # Initialize variables
+ with tf.Session() as sess:
+ sess.run(init_op)
+ for input in ...:
+ update_op = m.update_state(input)
+ sess.run(update_op)
+ print('Final result: ', sess.run(m.result()))
+ ```
+
+ To be implemented by subclasses:
+ * `__init__()`: All state variables should be created in this method by
+ calling `self.add_weight()` like: `self.var = self.add_weight(...)`
+ * `update_state()`: Has all updates to the state variables like:
+ self.var.assign_add(...).
+ * `result()`: Computes and returns a value for the metric
+ from the state variables.
+
+ Example subclass implementation:
+
+ ```
+ class BinaryTruePositives(Metric):
+ def __init__(self, name='binary-true-positives', dtype=None):
+ super(BinaryTruePositives, self).__init__(name=name, dtype=dtype)
+ self.true_positives = self.add_weight(
+ 'true_positives', initializer=init_ops.zeros_initializer)
+
+ def update_state(self, y_true, y_pred, sample_weight=None):
+ y_true = math_ops.cast(y_true, dtypes.bool)
+ y_pred = math_ops.cast(y_pred, dtypes.bool)
+ y_pred, y_true, sample_weight = _squeeze_or_expand_dimensions(
+ y_pred, y_true, sample_weight)
+
+ values = math_ops.logical_and(
+ math_ops.equal(y_true, True), math_ops.equal(y_pred, True))
+ values = math_ops.cast(values, self._dtype)
+ if sample_weight is not None:
+ sample_weight = math_ops.cast(sample_weight, self._dtype)
+ values = math_ops.multiply(values, sample_weight)
+ state_ops.assign_add(self.true_positives, math_ops.reduce_sum(values))
+
+ def result(self):
+ return array_ops.identity(self.true_positives)
+ ```
+ """
+ __metaclass__ = ABCMeta
+
+ def __init__(self, name=None, dtype=None):
+ super(Metric, self).__init__(name=name, dtype=dtype)
+ self.stateful = True # All metric layers are stateful.
+ self.built = True
+ self._dtype = K.floatx() if dtype is None else dtypes.as_dtype(dtype).name
+
+ def __new__(cls, *args, **kwargs):
+ obj = super(Metric, cls).__new__(cls, *args, **kwargs)
+ obj.update_state = types.MethodType(
+ update_state_wrapper(obj.update_state), obj)
+ obj.result = types.MethodType(result_wrapper(obj.result), obj)
+ return obj
+
+ def __call__(self, *args, **kwargs):
+ """Accumulates statistics and then computes metric result value.
+
+ Args:
+ *args:
+ **kwargs: A mini-batch of inputs to the Metric,
+ passed on to `update_state()`.
+
+ Returns:
+ The metric value tensor.
+ """
+ update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable
+ with ops.control_dependencies([update_op]):
+ return self.result() # pylint: disable=not-callable
+
+ def reset_states(self):
+ """Resets all of the metric state variables.
+
+ This function is called between epochs/steps,
+ when a metric is evaluated during training.
+ """
+ for v in self.variables:
+ K.set_value(v, 0)
+
+ @abstractmethod
+ def update_state(self, *args, **kwargs):
+ """Accumulates statistics for the metric.
+
+ Note: This function is executed as a graph function in graph mode.
+ This means:
+ a) Operations on the same resource are executed in textual order.
+ This should make it easier to do things like add the updated
+ value of a variable to another, for example.
+ b) You don't need to worry about collecting the update ops to execute.
+ All update ops added to the graph by this function will be executed.
+ As a result, code should generally work the same way with graph or
+ eager execution.
+ and adds the update op to the metric layer.
+
+ Args:
+ *args:
+ **kwargs: A mini-batch of inputs to the Metric.
+ """
+ NotImplementedError('Must be implemented in subclasses.')
+
+ @abstractmethod
+ def result(self):
+ """Computes and returns the metric value tensor.
+
+ Result computation is an idempotent operation that simply calculates the
+ metric value using the state variables.
+ """
+ NotImplementedError('Must be implemented in subclasses.')
+
+ ### For use by subclasses ###
+ def add_weight(self,
+ name,
+ shape=(),
+ aggregation=vs.VariableAggregation.SUM,
+ synchronization=vs.VariableSynchronization.ON_READ,
+ initializer=None):
+ """Adds state variable. Only for use by subclasses."""
+ return super(Metric, self).add_weight(
+ name=name,
+ shape=shape,
+ dtype=self._dtype,
+ trainable=False,
+ initializer=initializer,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ ### End: For use by subclasses ###
+
+
+class Mean(Metric):
+ """Computes the (weighted) mean of the given values.
+
+ This metric creates two variables, `total` and `count` that are used to
+ compute the average of `values`. This average is ultimately returned as `mean`
+ which is an idempotent operation that simply divides `total` by `count`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+ """
+
+ def __init__(self, name='mean', dtype=None):
+ """Creates a `Mean` instance.
+
+ Args:
+ name: (Optional) string name of the metric instance.
+ dtype: (Optional) data type of the metric result.
+ """
+ super(Mean, self).__init__(name=name, dtype=dtype)
+ # Create new state variables
+ self.total = self.add_weight(
+ 'total', initializer=init_ops.zeros_initializer)
+ self.count = self.add_weight(
+ 'count', initializer=init_ops.zeros_initializer)
+
+ def update_state(self, values, sample_weight=None):
+ """Accumulates statistics for computing the mean.
+
+ For example, if `values` is [1, 3, 5, 7] then the mean is 4. If
+ the `sample_weight` is specified as [1, 1, 0, 0] then the mean would be 2.
+
+ Args:
+ values: Per-example value.
+ sample_weight: Optional weighting of each example. Defaults to 1.
+ """
+ values = math_ops.cast(values, self._dtype)
+ if sample_weight is None:
+ num_values = math_ops.cast(array_ops.size(values), self._dtype)
+ else:
+ sample_weight = math_ops.cast(sample_weight, self._dtype)
+
+ # Update dimensions of weights to match with values.
+ values, _, sample_weight = _squeeze_or_expand_dimensions(
+ values, None, sample_weight)
+ sample_weight = weights_broadcast_ops.broadcast_weights(
+ sample_weight, values)
+ num_values = math_ops.reduce_sum(sample_weight)
+ values = math_ops.multiply(values, sample_weight)
+ values = math_ops.reduce_sum(values)
+
+ # Update state variables
+ state_ops.assign_add(self.total, values)
+ state_ops.assign_add(self.count, num_values)
+
+ def result(self):
+ return _safe_div(self.total, self.count)
+
+
+class MeanMetricWrapper(Mean):
+ """Wraps a stateless metric function with the Mean metric."""
+
+ def __init__(self, fn, name=None, dtype=None, **kwargs):
+ """Creates a `MeanMetricWrapper` instance.
+
+ Args:
+ fn: The metric function to wrap, with signature
+ `fn(y_true, y_pred, **kwargs)`.
+ name: (Optional) string name of the metric instance.
+ dtype: (Optional) data type of the metric result.
+ **kwargs: The keyword arguments that are passed on to `fn`.
+ """
+ super(MeanMetricWrapper, self).__init__(name=name, dtype=dtype)
+ self._fn = fn
+ self._fn_kwargs = kwargs
+
+ def update_state(self, y_true, y_pred, sample_weight=None):
+ """Accumulates metric statistics.
+
+ `y_true` and `y_pred` should have the same shape.
+
+ Args:
+ y_true: The ground truth values.
+ y_pred: The predicted values.
+ sample_weight: Optional weighting of each example. Defaults to 1. Can be
+ a `Tensor` whose rank is either 0, or the same rank as `y_true`,
+ and must be broadcastable to `y_true`.
+ """
+ y_true = math_ops.cast(y_true, self._dtype)
+ y_pred = math_ops.cast(y_pred, self._dtype)
+ y_pred, y_true, sample_weight = _squeeze_or_expand_dimensions(
+ y_pred, y_true, sample_weight)
+
+ matches = self._fn(y_true, y_pred, **self._fn_kwargs)
+ super(MeanMetricWrapper, self).update_state(
+ matches, sample_weight=sample_weight)
+
+ def get_config(self):
+ config = self._fn_kwargs
+ base_config = super(MeanMetricWrapper, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
+class BinaryAccuracy(MeanMetricWrapper):
+ """Calculates how often predictions matches labels.
+
+ This metric creates two local variables, `total` and `count` that are used to
+ compute the frequency with which `y_pred` matches `y_true`. This frequency is
+ ultimately returned as `binary accuracy`: an idempotent operation that simply
+ divides `total` by `count`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+ """
+
+ def __init__(self, name='binary-accuracy', dtype=None, threshold=0.5):
+ """Creates a `BinaryAccuracy` instance.
+
+ Args:
+ name: (Optional) string name of the metric instance.
+ dtype: (Optional) data type of the metric result.
+ threshold: (Optional) Float representing the threshold for deciding
+ whether prediction values are 1 or 0.
+ """
+ super(BinaryAccuracy, self).__init__(
+ binary_accuracy, name, dtype=dtype, threshold=threshold)
+
+
@tf_export('keras.metrics.binary_accuracy')
-def binary_accuracy(y_true, y_pred):
- return K.mean(math_ops.equal(y_true, math_ops.round(y_pred)), axis=-1)
+def binary_accuracy(y_true, y_pred, threshold=0.5):
+ threshold = math_ops.cast(threshold, y_pred.dtype)
+ y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype)
+ return K.mean(math_ops.equal(y_true, y_pred), axis=-1)
@tf_export('keras.metrics.categorical_accuracy')
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 15e793f5fc..d583379708 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -18,67 +18,72 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import numpy as np
-from tensorflow.python import keras
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import metrics
+from tensorflow.python.keras.engine.training import Model
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
class KerasMetricsTest(test.TestCase):
def test_metrics(self):
with self.test_session():
- y_a = keras.backend.variable(np.random.random((6, 7)))
- y_b = keras.backend.variable(np.random.random((6, 7)))
- for metric in [keras.metrics.binary_accuracy,
- keras.metrics.categorical_accuracy]:
+ y_a = K.variable(np.random.random((6, 7)))
+ y_b = K.variable(np.random.random((6, 7)))
+ for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]:
output = metric(y_a, y_b)
- self.assertEqual(keras.backend.eval(output).shape, (6,))
+ self.assertEqual(K.eval(output).shape, (6,))
def test_sparse_categorical_accuracy(self):
with self.test_session():
- metric = keras.metrics.sparse_categorical_accuracy
- y_a = keras.backend.variable(np.random.randint(0, 7, (6,)))
- y_b = keras.backend.variable(np.random.random((6, 7)))
- self.assertEqual(keras.backend.eval(metric(y_a, y_b)).shape, (6,))
+ metric = metrics.sparse_categorical_accuracy
+ y_a = K.variable(np.random.randint(0, 7, (6,)))
+ y_b = K.variable(np.random.random((6, 7)))
+ self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,))
def test_sparse_top_k_categorical_accuracy(self):
with self.test_session():
- y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1],
- [0.1, 0.2, 0.7]]))
- y_true = keras.backend.variable(np.array([[1], [0]]))
- result = keras.backend.eval(
- keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([[1], [0]]))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(result, 1)
- result = keras.backend.eval(
- keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(result, 0.5)
- result = keras.backend.eval(
- keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
def test_top_k_categorical_accuracy(self):
with self.test_session():
- y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1],
- [0.1, 0.2, 0.7]]))
- y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]]))
- result = keras.backend.eval(
- keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]]))
+ result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
self.assertEqual(result, 1)
- result = keras.backend.eval(
- keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=2))
+ result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2))
self.assertEqual(result, 0.5)
- result = keras.backend.eval(
- keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1))
+ result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
def test_stateful_metrics(self):
with self.test_session():
np.random.seed(1334)
- class BinaryTruePositives(keras.layers.Layer):
+ class BinaryTruePositives(layers.Layer):
"""Stateful Metric to count the total true positives over all batches.
Assumes predictions and targets of shape `(samples, 1)`.
@@ -91,11 +96,11 @@ class KerasMetricsTest(test.TestCase):
def __init__(self, name='true_positives', **kwargs):
super(BinaryTruePositives, self).__init__(name=name, **kwargs)
- self.true_positives = keras.backend.variable(value=0, dtype='int32')
+ self.true_positives = K.variable(value=0, dtype='int32')
self.stateful = True
def reset_states(self):
- keras.backend.set_value(self.true_positives, 0)
+ K.set_value(self.true_positives, 0)
def __call__(self, y_true, y_pred):
"""Computes the number of true positives in a batch.
@@ -120,14 +125,14 @@ class KerasMetricsTest(test.TestCase):
return current_true_pos + true_pos
metric_fn = BinaryTruePositives()
- config = keras.metrics.serialize(metric_fn)
- metric_fn = keras.metrics.deserialize(
+ config = metrics.serialize(metric_fn)
+ metric_fn = metrics.deserialize(
config, custom_objects={'BinaryTruePositives': BinaryTruePositives})
# Test on simple model
- inputs = keras.Input(shape=(2,))
- outputs = keras.layers.Dense(1, activation='sigmoid')(inputs)
- model = keras.Model(inputs, outputs)
+ inputs = layers.Input(shape=(2,))
+ outputs = layers.Dense(1, activation='sigmoid')(inputs)
+ model = Model(inputs, outputs)
model.compile(optimizer='sgd',
loss='binary_crossentropy',
metrics=['acc', metric_fn])
@@ -184,6 +189,214 @@ class KerasMetricsTest(test.TestCase):
self.assertAllClose(
val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
+ @test_util.run_in_graph_and_eager_modes
+ def test_mean(self):
+ m = metrics.Mean(name='my_mean')
+
+ # check config
+ self.assertEqual(m.name, 'my_mean')
+ self.assertTrue(m.stateful)
+ self.assertEqual(m.dtype, dtypes.float32)
+ self.assertEqual(len(m.variables), 2)
+ self.evaluate(variables.global_variables_initializer())
+
+ # check initial state
+ self.assertEqual(self.evaluate(m.total), 0)
+ self.assertEqual(self.evaluate(m.count), 0)
+
+ # check __call__()
+ self.assertEqual(self.evaluate(m(100)), 100)
+ self.assertEqual(self.evaluate(m.total), 100)
+ self.assertEqual(self.evaluate(m.count), 1)
+
+ # check update_state() and result() + state accumulation + tensor input
+ update_op = m.update_state(ops.convert_n_to_tensor([1, 5]))
+ self.evaluate(update_op)
+ self.assertAlmostEqual(self.evaluate(m.result()), 106 / 3, 2)
+ self.assertEqual(self.evaluate(m.total), 106) # 100 + 1 + 5
+ self.assertEqual(self.evaluate(m.count), 3)
+
+ # check reset_states()
+ m.reset_states()
+ self.assertEqual(self.evaluate(m.total), 0)
+ self.assertEqual(self.evaluate(m.count), 0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_mean_with_sample_weight(self):
+ m = metrics.Mean(dtype=dtypes.float64)
+ self.assertEqual(m.dtype, dtypes.float64)
+ self.evaluate(variables.global_variables_initializer())
+
+ # check scalar weight
+ result_t = m(100, sample_weight=0.5)
+ self.assertEqual(self.evaluate(result_t), 50 / 0.5)
+ self.assertEqual(self.evaluate(m.total), 50)
+ self.assertEqual(self.evaluate(m.count), 0.5)
+
+ # check weights not scalar and weights rank matches values rank
+ result_t = m([1, 5], sample_weight=[1, 0.2])
+ result = self.evaluate(result_t)
+ self.assertAlmostEqual(result, 52 / 1.7, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2
+ self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2
+
+ # check weights broadcast
+ result_t = m([1, 2], sample_weight=0.5)
+ self.assertAlmostEqual(self.evaluate(result_t), 53.5 / 2.7, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 53.5, 2) # 52 + 0.5 + 1
+ self.assertAlmostEqual(self.evaluate(m.count), 2.7, 2) # 1.7 + 0.5 + 0.5
+
+ # check weights squeeze
+ result_t = m([1, 5], sample_weight=[[1], [0.2]])
+ self.assertAlmostEqual(self.evaluate(result_t), 55.5 / 3.9, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 55.5, 2) # 53.5 + 1 + 1
+ self.assertAlmostEqual(self.evaluate(m.count), 3.9, 2) # 2.7 + 1.2
+
+ # check weights expand
+ result_t = m([[1], [5]], sample_weight=[1, 0.2])
+ self.assertAlmostEqual(self.evaluate(result_t), 57.5 / 5.1, 2)
+ self.assertAlmostEqual(self.evaluate(m.total), 57.5, 2) # 55.5 + 1 + 1
+ self.assertAlmostEqual(self.evaluate(m.count), 5.1, 2) # 3.9 + 1.2
+
+ def test_mean_graph_with_placeholder(self):
+ with context.graph_mode(), self.test_session() as sess:
+ m = metrics.Mean()
+ v = array_ops.placeholder(dtypes.float32)
+ w = array_ops.placeholder(dtypes.float32)
+ sess.run(variables.global_variables_initializer())
+
+ # check __call__()
+ result_t = m(v, sample_weight=w)
+ result = sess.run(result_t, feed_dict=({v: 100, w: 0.5}))
+ self.assertEqual(sess.run(m.total), 50)
+ self.assertEqual(sess.run(m.count), 0.5)
+ self.assertEqual(result, 50 / 0.5)
+
+ # check update_state() and result()
+ result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]}))
+ self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2
+ self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2
+ self.assertAlmostEqual(result, 52 / 1.7, 2)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_save_restore(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ m = metrics.Mean()
+ checkpoint = checkpointable_utils.Checkpoint(mean=m)
+ self.evaluate(variables.global_variables_initializer())
+
+ # update state
+ self.evaluate(m(100.))
+ self.evaluate(m(200.))
+
+ # save checkpoint and then add an update
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.evaluate(m(1000.))
+
+ # restore to the same checkpoint mean object
+ checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+ self.evaluate(m(300.))
+ self.assertEqual(200., self.evaluate(m.result()))
+
+ # restore to a different checkpoint mean object
+ restore_mean = metrics.Mean()
+ restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean)
+ status = restore_checkpoint.restore(save_path)
+ restore_update = restore_mean(300.)
+ status.assert_consumed().run_restore_ops()
+ self.evaluate(restore_update)
+ self.assertEqual(200., self.evaluate(restore_mean.result()))
+ self.assertEqual(3, self.evaluate(restore_mean.count))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_binary_accuracy(self):
+ acc_obj = metrics.BinaryAccuracy(name='my acc')
+
+ # check config
+ self.assertEqual(acc_obj.name, 'my acc')
+ self.assertTrue(acc_obj.stateful)
+ self.assertEqual(len(acc_obj.variables), 2)
+ self.assertEqual(acc_obj.dtype, dtypes.float32)
+ self.evaluate(variables.global_variables_initializer())
+
+ # verify that correct value is returned
+ update_op = acc_obj.update_state([[1], [0]], [[1], [0]])
+ self.evaluate(update_op)
+ result = self.evaluate(acc_obj.result())
+ self.assertEqual(result, 1) # 2/2
+
+ # check y_pred squeeze
+ update_op = acc_obj.update_state([[1], [1]], [[[1]], [[0]]])
+ self.evaluate(update_op)
+ result = self.evaluate(acc_obj.result())
+ self.assertAlmostEqual(result, 0.75, 2) # 3/4
+
+ # check y_true squeeze
+ result_t = acc_obj([[[1]], [[1]]], [[1], [0]])
+ result = self.evaluate(result_t)
+ self.assertAlmostEqual(result, 0.67, 2) # 4/6
+
+ # check with sample_weight
+ result_t = acc_obj([[1], [1]], [[1], [0]], [[0.5], [0.2]])
+ result = self.evaluate(result_t)
+ self.assertAlmostEqual(result, 0.67, 2) # 4.5/6.7
+
+ # check incompatible shapes
+ with self.assertRaisesRegexp(ValueError,
+ r'Shapes \(1,\) and \(2,\) are incompatible'):
+ acc_obj.update_state([1, 1], [1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_binary_accuracy_threshold(self):
+ acc_obj = metrics.BinaryAccuracy(threshold=0.7)
+ self.evaluate(variables.global_variables_initializer())
+ result_t = acc_obj([[1], [1], [0], [0]], [[0.9], [0.6], [0.4], [0.8]])
+ result = self.evaluate(result_t)
+ self.assertAlmostEqual(result, 0.5, 2)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_invalid_result(self):
+
+ class InvalidResult(metrics.Metric):
+
+ def __init__(self, name='invalid-result', dtype=dtypes.float64):
+ super(InvalidResult, self).__init__(name=name, dtype=dtype)
+
+ def update_state(self, *args, **kwargs):
+ pass
+
+ def result(self):
+ return 1
+
+ invalid_result_obj = InvalidResult()
+ with self.assertRaisesRegexp(
+ TypeError,
+ 'Metric invalid-result\'s result must be a Tensor or Operation, given:'
+ ):
+ invalid_result_obj.result()
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_invalid_update(self):
+
+ class InvalidUpdate(metrics.Metric):
+
+ def __init__(self, name='invalid-update', dtype=dtypes.float64):
+ super(InvalidUpdate, self).__init__(name=name, dtype=dtype)
+
+ def update_state(self, *args, **kwargs):
+ return [1]
+
+ def result(self):
+ pass
+
+ invalid_update_obj = InvalidUpdate()
+ with self.assertRaisesRegexp(
+ TypeError,
+ 'Metric invalid-update\'s update must be a Tensor or Operation, given:'
+ ):
+ invalid_update_obj.update_state()
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 3ac4852eff..5fbc191e78 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -29,6 +29,8 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import data_structures
@@ -65,6 +67,22 @@ class SimpleTestModel(keras.Model):
return self.dense2(x)
+class SimpleConvTestModel(keras.Model):
+
+ def __init__(self, num_classes=10):
+ super(SimpleConvTestModel, self).__init__(name='test_model')
+ self.num_classes = num_classes
+
+ self.conv1 = keras.layers.Conv2D(32, (3, 3), activation='relu')
+ self.flatten = keras.layers.Flatten()
+ self.dense1 = keras.layers.Dense(num_classes, activation='softmax')
+
+ def call(self, x):
+ x = self.conv1(x)
+ x = self.flatten(x)
+ return self.dense1(x)
+
+
class MultiIOTestModel(keras.Model):
def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)):
@@ -174,6 +192,213 @@ def get_nested_model_3(input_dim, num_classes):
class ModelSubclassingTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
+ def test_invalid_input_shape_build(self):
+ num_classes = 2
+ input_dim = 50
+
+ model = SimpleTestModel(num_classes=num_classes,
+ use_dp=True,
+ use_bn=True)
+
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ with self.assertRaisesRegexp(
+ ValueError, 'input shape is not one of the valid types'):
+ model.build(input_shape=tensor_shape.Dimension(input_dim))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_embed_dtype_with_subclass_build(self):
+ class Embedding(keras.layers.Layer):
+ """An Embedding layer."""
+
+ def __init__(self, vocab_size, embedding_dim, **kwargs):
+ super(Embedding, self).__init__(**kwargs)
+ self.vocab_size = vocab_size
+ self.embedding_dim = embedding_dim
+
+ def build(self, _):
+ self.embedding = self.add_variable(
+ 'embedding_kernel',
+ shape=[self.vocab_size, self.embedding_dim],
+ dtype=np.float32,
+ initializer=init_ops.random_uniform_initializer(-0.1, 0.1),
+ trainable=True)
+
+ def call(self, x):
+ return embedding_ops.embedding_lookup(self.embedding, x)
+
+ class EmbedModel(keras.Model):
+
+ def __init__(self, vocab_size, embed_size):
+ super(EmbedModel, self).__init__()
+ self.embed1 = Embedding(vocab_size, embed_size)
+
+ def call(self, inputs):
+ return self.embed1(inputs)
+
+ model = EmbedModel(100, 20)
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ with self.assertRaisesRegexp(
+ ValueError, 'if your layers do not support float type inputs'):
+ model.build(input_shape=(35, 20))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_single_time_step_rnn_build(self):
+ dim = 4
+ timesteps = 1
+ batch_input_shape = (None, timesteps, dim)
+ units = 3
+
+ class SimpleRNNModel(keras.Model):
+
+ def __init__(self):
+ super(SimpleRNNModel, self).__init__()
+ self.lstm = keras.layers.LSTM(units)
+
+ def call(self, inputs):
+ return self.lstm(inputs)
+
+ model = SimpleRNNModel()
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(batch_input_shape)
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ model(array_ops.ones((32, timesteps, dim)))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_single_io_subclass_build(self):
+ num_classes = 2
+ input_dim = 50
+ batch_size = None
+
+ model = SimpleTestModel(num_classes=num_classes,
+ use_dp=True,
+ use_bn=True)
+
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(input_shape=(batch_size, input_dim))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ model(array_ops.ones((32, input_dim)))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_single_io_dimension_subclass_build(self):
+ num_classes = 2
+ input_dim = tensor_shape.Dimension(50)
+ batch_size = tensor_shape.Dimension(None)
+
+ model = SimpleTestModel(num_classes=num_classes,
+ use_dp=True,
+ use_bn=True)
+
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(input_shape=(batch_size, input_dim))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ model(array_ops.ones((32, input_dim)))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_multidim_io_subclass_build(self):
+ num_classes = 10
+ # Input size, e.g. image
+ batch_size = 32
+ input_shape = (32, 32, 3)
+
+ model = SimpleConvTestModel(num_classes)
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ batch_input_shape = (batch_size,) + input_shape
+ model.build(input_shape=batch_input_shape)
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+
+ model(array_ops.ones(batch_input_shape))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_tensorshape_io_subclass_build(self):
+ num_classes = 10
+ # Input size, e.g. image
+ batch_size = None
+ input_shape = (32, 32, 3)
+
+ model = SimpleConvTestModel(num_classes)
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(
+ input_shape=tensor_shape.TensorShape((batch_size,) + input_shape))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+
+ model(array_ops.ones((32,) + input_shape))
+
+ def test_subclass_save_model(self):
+ num_classes = 10
+ # Input size, e.g. image
+ batch_size = None
+ input_shape = (32, 32, 3)
+
+ model = SimpleConvTestModel(num_classes)
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(
+ input_shape=tensor_shape.TensorShape((batch_size,) + input_shape))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ weights = model.get_weights()
+
+ tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt')
+ model.save_weights(tf_format_name)
+ if h5py is not None:
+ hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5')
+ model.save_weights(hdf5_format_name)
+
+ model = SimpleConvTestModel(num_classes)
+ model.build(
+ input_shape=tensor_shape.TensorShape((batch_size,) + input_shape))
+ if h5py is not None:
+ model.load_weights(hdf5_format_name)
+ self.assertAllClose(weights, model.get_weights())
+ model.load_weights(tf_format_name)
+ self.assertAllClose(weights, model.get_weights())
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_multi_io_subclass_build(self):
+ batch_size = None
+ num_samples = 1000
+ input_dim = 50
+ model = MultiIOTestModel()
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ batch_input_shape = tensor_shape.TensorShape((batch_size, input_dim))
+ model.build(
+ input_shape=[batch_input_shape, batch_input_shape])
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ x1 = array_ops.ones((num_samples, input_dim))
+ x2 = array_ops.ones((num_samples, input_dim))
+ model([x1, x2])
+
+ @test_util.run_in_graph_and_eager_modes
def test_single_io_workflow_with_np_arrays(self):
num_classes = 2
num_samples = 100
@@ -750,6 +975,16 @@ class CustomCallModel(keras.Model):
return combined
+class TrainingNoDefaultModel(keras.Model):
+
+ def __init__(self):
+ super(TrainingNoDefaultModel, self).__init__()
+ self.dense1 = keras.layers.Dense(1)
+
+ def call(self, x, training):
+ return self.dense1(x)
+
+
class CustomCallSignatureTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@@ -767,6 +1002,32 @@ class CustomCallSignatureTests(test.TestCase):
self.assertAllClose(expected_output, self.evaluate(output))
@test_util.run_in_graph_and_eager_modes
+ def test_training_args_call_build(self):
+ input_dim = 2
+
+ model = TrainingNoDefaultModel()
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build((None, input_dim))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_custom_call_kwargs_and_build(self):
+ first_input_shape = (2, 3)
+ second_input_shape = (2, 5)
+
+ model = CustomCallModel()
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ with self.assertRaisesRegexp(
+ ValueError, 'cannot build your model if it has positional'):
+ model.build(input_shape=[first_input_shape, second_input_shape])
+
+ @test_util.run_in_graph_and_eager_modes
def test_inputs_in_signature(self):
class HasInputsAndOtherPositional(keras.Model):
@@ -829,14 +1090,9 @@ class CustomCallSignatureTests(test.TestCase):
def test_training_no_default(self):
- class TrainingNoDefault(keras.Model):
-
- def call(self, x, training):
- return x
-
with context.graph_mode():
- model = TrainingNoDefault()
- arg = array_ops.ones([])
+ model = TrainingNoDefaultModel()
+ arg = array_ops.ones([1, 1])
model(arg, True)
six.assertCountEqual(self, [arg], model.inputs)
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 17aba7d86c..6e8ee06ff5 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from collections import OrderedDict
import numpy as np
from tensorflow.python import keras
@@ -185,75 +184,3 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
# for further checks in the caller function
return actual_output
-
-def _combine_named_parameters(**kwargs):
- """Generate combinations based on its keyword arguments.
-
- Two sets of returned combinations can be concatenated using +. Their product
- can be computed using `times()`.
-
- Args:
- **kwargs: keyword arguments of form `option=[possibilities, ...]`
- or `option=the_only_possibility`.
-
- Returns:
- a list of dictionaries for each combination. Keys in the dictionaries are
- the keyword argument names. Each key has one value - one of the
- corresponding keyword argument values.
- """
- if not kwargs:
- return [OrderedDict()]
-
- sort_by_key = lambda k: k[0][0]
- kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
- first = list(kwargs.items())[0]
-
- rest = dict(list(kwargs.items())[1:])
- rest_combined = _combine_named_parameters(**rest)
-
- key = first[0]
- values = first[1]
- if not isinstance(values, list):
- values = [values]
-
- combinations = [
- OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
- for v in values
- for combined in rest_combined
- ]
- return combinations
-
-
-def generate_combinations_with_testcase_name(**kwargs):
- """Generate combinations based on its keyword arguments using combine().
-
- This function calls combine() and appends a testcase name to the list of
- dictionaries returned. The 'testcase_name' key is a required for named
- parameterized tests.
-
- Args:
- **kwargs: keyword arguments of form `option=[possibilities, ...]`
- or `option=the_only_possibility`.
-
- Returns:
- a list of dictionaries for each combination. Keys in the dictionaries are
- the keyword argument names. Each key has one value - one of the
- corresponding keyword argument values.
- """
- combinations = _combine_named_parameters(**kwargs)
- named_combinations = []
- for combination in combinations:
- assert isinstance(combination, OrderedDict)
- name = ''.join([
- '_{}_{}'.format(
- ''.join(filter(str.isalnum, key)),
- ''.join(filter(str.isalnum, str(value))))
- for key, value in combination.items()
- ])
- named_combinations.append(
- OrderedDict(
- list(combination.items()) + [('testcase_name',
- '_test{}'.format(name))]))
-
- return named_combinations
-
diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py
index 9d9c72b162..c24e87308b 100644
--- a/tensorflow/python/keras/utils/np_utils.py
+++ b/tensorflow/python/keras/utils/np_utils.py
@@ -33,7 +33,8 @@ def to_categorical(y, num_classes=None):
num_classes: total number of classes.
Returns:
- A binary matrix representation of the input.
+ A binary matrix representation of the input. The classes axis is placed
+ last.
"""
y = np.array(y, dtype='int')
input_shape = y.shape