aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r--tensorflow/python/keras/engine/training.py112
1 files changed, 105 insertions, 7 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index bd03f4871f..4df739254b 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -27,6 +27,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
@@ -43,6 +44,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -217,10 +219,9 @@ class Model(Network):
for name in self.output_names:
if name not in loss:
logging.warning(
- 'Output "' + name + '" missing from loss dictionary. '
- 'We assume this was done on purpose, '
- 'and we will not be expecting '
- 'any data to be passed to "' + name + '" during training.')
+ 'Output "' + name + '" missing from loss dictionary. We assume '
+ 'this was done on purpose. The fit and evaluate APIs will not be '
+ 'expecting any data to be passed to "' + name + '".')
loss_functions.append(losses.get(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
@@ -561,6 +562,95 @@ class Model(Network):
trainable_weights = self.trainable_weights
self._collected_trainable_weights = trainable_weights
+ def build(self, input_shape):
+ """Build the model based on input shapes received.
+
+ This is to be used for subclassed models, which do not know at instantiation
+ time what their inputs look like.
+
+ Args:
+ input_shape: Single tuple, TensorShape, or list of shapes, where shapes
+ are tuples, integers, or TensorShapes.
+
+ Raises:
+ ValueError:
+ 1. In case of invalid user-provided data (not of type tuple,
+ list, or TensorShape).
+ 2. If the model requires call arguments that are agnostic
+ to the input shapes (positional or kwarg in call signature).
+ 3. If not all layers were properly built.
+ 4. If float type inputs are not supported within the layers.
+
+ In each of these cases, the user should build their model by calling it
+ on real tensor data.
+ """
+ if self._is_graph_network:
+ self.built = True
+ return
+
+ # If subclass network
+ if input_shape is None:
+ raise ValueError('Input shape must be defined when calling build on a '
+ 'model subclass network.')
+ valid_types = (tuple, list, tensor_shape.TensorShape)
+ if not isinstance(input_shape, valid_types):
+ raise ValueError('Specified input shape is not one of the valid types. '
+ 'Please specify a batch input shape of type tuple or '
+ 'list of input shapes. User provided '
+ 'input type: {}'.format(type(input_shape)))
+
+ def _generate_dummy_data_from_shape(shape):
+ if isinstance(shape, tensor_shape.TensorShape):
+ shape = shape.as_list()
+
+ # Replace Nones in input shape with dummy `1` value
+ shape = [x.value if isinstance(x, tensor_shape.Dimension) else x
+ for x in shape]
+ shape = [1 if x is None else x for x in shape]
+ return array_ops.ones(shape, dtype=K.floatx())
+
+ if input_shape and not self.inputs:
+ if isinstance(input_shape, list):
+ # List of input shapes
+ x = [_generate_dummy_data_from_shape(shape) for shape in input_shape]
+ else:
+ x = _generate_dummy_data_from_shape(input_shape)
+
+ kwargs = {}
+ num_call_args = len(tf_inspect.getargspec(self.call).args)
+ if self._expects_training_arg and num_call_args == 3:
+ # Has call signature of call(self, input, training)
+ kwargs['training'] = False
+ elif num_call_args > 2:
+ # Has invalid call signature of call(self, input, *args, **kwargs)
+ raise ValueError('Currently, you cannot build your model if it has '
+ 'positional or keyword arguments that are not '
+ 'inputs to the model, but are required for its '
+ '`call` method. Instead, in order to instantiate '
+ 'and build your model, `call` your model on real '
+ 'tensor data with all expected call arguments.')
+
+ try:
+ self.call(x, **kwargs)
+ except (errors.InvalidArgumentError, TypeError):
+ raise ValueError('You cannot build your model by calling `build` '
+ 'if your layers do not support float type inputs. '
+ 'Instead, in order to instantiate and build your '
+ 'model, `call` your model on real tensor data (of '
+ 'the correct dtype).')
+
+ if self._layers:
+ self._track_layers(self._layers)
+ if self.layers:
+ for layer in self.layers:
+ if not layer.built:
+ raise ValueError('Layer: {} was not built in your model. Calling '
+ '`build` manually on a subclassed model is only '
+ 'allowed for models with a static topology. '
+ 'In this case, you can build your model by '
+ 'calling it on real tensor data.'.format(layer))
+ self.built = True
+
def _check_trainable_weights_consistency(self):
"""Check trainable weights count consistency.
@@ -897,7 +987,11 @@ class Model(Network):
for output_shape, loss_fn in zip(self._feed_output_shapes,
self._feed_loss_fns):
if loss_fn is losses.sparse_categorical_crossentropy:
- feed_output_shapes.append(output_shape[:-1] + (1,))
+ if K.image_data_format() == 'channels_first':
+ feed_output_shapes.append(
+ (output_shape[0], 1) + output_shape[2:])
+ else:
+ feed_output_shapes.append(output_shape[:-1] + (1,))
elif (not hasattr(loss_fn, '__name__') or
getattr(losses, loss_fn.__name__, None) is None):
# If `loss_fn` is not a function (e.g. callable class)
@@ -988,10 +1082,14 @@ class Model(Network):
inputs = inputs[0]
if tensor_util.is_tensor(inputs):
- input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
+ if context.executing_eagerly():
+ input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
+ self.build(input_shape=input_shape)
+ else:
+ self.symbolic_set_inputs(inputs)
else:
input_shape = (None,) + inputs.shape[1:]
- self.build(input_shape=input_shape)
+ self.build(input_shape=input_shape)
elif context.executing_eagerly():
self._eager_set_inputs(inputs)
else: