aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r--tensorflow/python/keras/engine/training.py124
1 files changed, 116 insertions, 8 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index fce6cbdb7a..4df739254b 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -27,6 +27,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
@@ -42,6 +43,8 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -115,6 +118,7 @@ class Model(Network):
# Create a cache for dataset - uninitialized iterators
self._dataset_iterator_cache = weakref.WeakKeyDictionary()
+ @checkpointable.no_automatic_dependency_tracking
def compile(self,
optimizer,
loss=None,
@@ -178,6 +182,11 @@ class Model(Network):
raise ValueError('Only TF native optimizers are supported in Eager mode.')
self.optimizer = optimizers.get(optimizer)
+ # We've disabled automatic dependency tracking for this method, but do want
+ # to add a checkpoint dependency on the optimizer if it's checkpointable.
+ if isinstance(self.optimizer, checkpointable.CheckpointableBase):
+ self._track_checkpointable(
+ self.optimizer, name='optimizer', overwrite=True)
self.loss = loss
self.metrics = metrics or []
self.loss_weights = loss_weights
@@ -210,10 +219,9 @@ class Model(Network):
for name in self.output_names:
if name not in loss:
logging.warning(
- 'Output "' + name + '" missing from loss dictionary. '
- 'We assume this was done on purpose, '
- 'and we will not be expecting '
- 'any data to be passed to "' + name + '" during training.')
+ 'Output "' + name + '" missing from loss dictionary. We assume '
+ 'this was done on purpose. The fit and evaluate APIs will not be '
+ 'expecting any data to be passed to "' + name + '".')
loss_functions.append(losses.get(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
@@ -554,6 +562,95 @@ class Model(Network):
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights
+ def build(self, input_shape):
+ """Build the model based on input shapes received.
+
+ This is to be used for subclassed models, which do not know at instantiation
+ time what their inputs look like.
+
+ Args:
+ input_shape: Single tuple, TensorShape, or list of shapes, where shapes
+ are tuples, integers, or TensorShapes.
+
+ Raises:
+ ValueError:
+ 1. In case of invalid user-provided data (not of type tuple,
+ list, or TensorShape).
+ 2. If the model requires call arguments that are agnostic
+ to the input shapes (positional or kwarg in call signature).
+ 3. If not all layers were properly built.
+ 4. If float type inputs are not supported within the layers.
+
+ In each of these cases, the user should build their model by calling it
+ on real tensor data.
+ """
+ if self._is_graph_network:
+ self.built = True
+ return
+
+ # If subclass network
+ if input_shape is None:
+ raise ValueError('Input shape must be defined when calling build on a '
+ 'model subclass network.')
+ valid_types = (tuple, list, tensor_shape.TensorShape)
+ if not isinstance(input_shape, valid_types):
+ raise ValueError('Specified input shape is not one of the valid types. '
+ 'Please specify a batch input shape of type tuple or '
+ 'list of input shapes. User provided '
+ 'input type: {}'.format(type(input_shape)))
+
+ def _generate_dummy_data_from_shape(shape):
+ if isinstance(shape, tensor_shape.TensorShape):
+ shape = shape.as_list()
+
+ # Replace Nones in input shape with dummy `1` value
+ shape = [x.value if isinstance(x, tensor_shape.Dimension) else x
+ for x in shape]
+ shape = [1 if x is None else x for x in shape]
+ return array_ops.ones(shape, dtype=K.floatx())
+
+ if input_shape and not self.inputs:
+ if isinstance(input_shape, list):
+ # List of input shapes
+ x = [_generate_dummy_data_from_shape(shape) for shape in input_shape]
+ else:
+ x = _generate_dummy_data_from_shape(input_shape)
+
+ kwargs = {}
+ num_call_args = len(tf_inspect.getargspec(self.call).args)
+ if self._expects_training_arg and num_call_args == 3:
+ # Has call signature of call(self, input, training)
+ kwargs['training'] = False
+ elif num_call_args > 2:
+ # Has invalid call signature of call(self, input, *args, **kwargs)
+ raise ValueError('Currently, you cannot build your model if it has '
+ 'positional or keyword arguments that are not '
+ 'inputs to the model, but are required for its '
+ '`call` method. Instead, in order to instantiate '
+ 'and build your model, `call` your model on real '
+ 'tensor data with all expected call arguments.')
+
+ try:
+ self.call(x, **kwargs)
+ except (errors.InvalidArgumentError, TypeError):
+ raise ValueError('You cannot build your model by calling `build` '
+ 'if your layers do not support float type inputs. '
+ 'Instead, in order to instantiate and build your '
+ 'model, `call` your model on real tensor data (of '
+ 'the correct dtype).')
+
+ if self._layers:
+ self._track_layers(self._layers)
+ if self.layers:
+ for layer in self.layers:
+ if not layer.built:
+ raise ValueError('Layer: {} was not built in your model. Calling '
+ '`build` manually on a subclassed model is only '
+ 'allowed for models with a static topology. '
+ 'In this case, you can build your model by '
+ 'calling it on real tensor data.'.format(layer))
+ self.built = True
+
def _check_trainable_weights_consistency(self):
"""Check trainable weights count consistency.
@@ -592,7 +689,7 @@ class Model(Network):
# Unconditional updates
updates += self.get_updates_for(None)
# Conditional updates relevant to this model
- updates += self.get_updates_for(self._feed_inputs)
+ updates += self.get_updates_for(self.inputs)
# Stateful metrics updates
updates += self.metrics_updates
# Gets loss and metrics. Updates weights at each call.
@@ -890,7 +987,11 @@ class Model(Network):
for output_shape, loss_fn in zip(self._feed_output_shapes,
self._feed_loss_fns):
if loss_fn is losses.sparse_categorical_crossentropy:
- feed_output_shapes.append(output_shape[:-1] + (1,))
+ if K.image_data_format() == 'channels_first':
+ feed_output_shapes.append(
+ (output_shape[0], 1) + output_shape[2:])
+ else:
+ feed_output_shapes.append(output_shape[:-1] + (1,))
elif (not hasattr(loss_fn, '__name__') or
getattr(losses, loss_fn.__name__, None) is None):
# If `loss_fn` is not a function (e.g. callable class)
@@ -941,6 +1042,7 @@ class Model(Network):
str(x[0].shape[0]) + ' samples')
return x, y, sample_weights
+ @checkpointable.no_automatic_dependency_tracking
def _set_inputs(self, inputs, training=None):
"""Set model's input and output specs based on the input data received.
@@ -980,15 +1082,20 @@ class Model(Network):
inputs = inputs[0]
if tensor_util.is_tensor(inputs):
- input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
+ if context.executing_eagerly():
+ input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
+ self.build(input_shape=input_shape)
+ else:
+ self.symbolic_set_inputs(inputs)
else:
input_shape = (None,) + inputs.shape[1:]
- self.build(input_shape=input_shape)
+ self.build(input_shape=input_shape)
elif context.executing_eagerly():
self._eager_set_inputs(inputs)
else:
self._symbolic_set_inputs(inputs, training=training)
+ @checkpointable.no_automatic_dependency_tracking
def _eager_set_inputs(self, inputs):
"""Set model's input and output specs based on the input data received.
@@ -1041,6 +1148,7 @@ class Model(Network):
'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
self.built = True
+ @checkpointable.no_automatic_dependency_tracking
def _symbolic_set_inputs(self, inputs, outputs=None, training=None):
"""Set model's inputs and output specs based.