diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 112 |
1 files changed, 105 insertions, 7 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index bd03f4871f..4df739254b 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -27,6 +27,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K from tensorflow.python.keras import losses @@ -43,6 +44,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -217,10 +219,9 @@ class Model(Network): for name in self.output_names: if name not in loss: logging.warning( - 'Output "' + name + '" missing from loss dictionary. ' - 'We assume this was done on purpose, ' - 'and we will not be expecting ' - 'any data to be passed to "' + name + '" during training.') + 'Output "' + name + '" missing from loss dictionary. We assume ' + 'this was done on purpose. The fit and evaluate APIs will not be ' + 'expecting any data to be passed to "' + name + '".') loss_functions.append(losses.get(loss.get(name))) elif isinstance(loss, list): if len(loss) != len(self.outputs): @@ -561,6 +562,95 @@ class Model(Network): trainable_weights = self.trainable_weights self._collected_trainable_weights = trainable_weights + def build(self, input_shape): + """Build the model based on input shapes received. + + This is to be used for subclassed models, which do not know at instantiation + time what their inputs look like. + + Args: + input_shape: Single tuple, TensorShape, or list of shapes, where shapes + are tuples, integers, or TensorShapes. + + Raises: + ValueError: + 1. In case of invalid user-provided data (not of type tuple, + list, or TensorShape). + 2. If the model requires call arguments that are agnostic + to the input shapes (positional or kwarg in call signature). + 3. If not all layers were properly built. + 4. If float type inputs are not supported within the layers. + + In each of these cases, the user should build their model by calling it + on real tensor data. + """ + if self._is_graph_network: + self.built = True + return + + # If subclass network + if input_shape is None: + raise ValueError('Input shape must be defined when calling build on a ' + 'model subclass network.') + valid_types = (tuple, list, tensor_shape.TensorShape) + if not isinstance(input_shape, valid_types): + raise ValueError('Specified input shape is not one of the valid types. ' + 'Please specify a batch input shape of type tuple or ' + 'list of input shapes. User provided ' + 'input type: {}'.format(type(input_shape))) + + def _generate_dummy_data_from_shape(shape): + if isinstance(shape, tensor_shape.TensorShape): + shape = shape.as_list() + + # Replace Nones in input shape with dummy `1` value + shape = [x.value if isinstance(x, tensor_shape.Dimension) else x + for x in shape] + shape = [1 if x is None else x for x in shape] + return array_ops.ones(shape, dtype=K.floatx()) + + if input_shape and not self.inputs: + if isinstance(input_shape, list): + # List of input shapes + x = [_generate_dummy_data_from_shape(shape) for shape in input_shape] + else: + x = _generate_dummy_data_from_shape(input_shape) + + kwargs = {} + num_call_args = len(tf_inspect.getargspec(self.call).args) + if self._expects_training_arg and num_call_args == 3: + # Has call signature of call(self, input, training) + kwargs['training'] = False + elif num_call_args > 2: + # Has invalid call signature of call(self, input, *args, **kwargs) + raise ValueError('Currently, you cannot build your model if it has ' + 'positional or keyword arguments that are not ' + 'inputs to the model, but are required for its ' + '`call` method. Instead, in order to instantiate ' + 'and build your model, `call` your model on real ' + 'tensor data with all expected call arguments.') + + try: + self.call(x, **kwargs) + except (errors.InvalidArgumentError, TypeError): + raise ValueError('You cannot build your model by calling `build` ' + 'if your layers do not support float type inputs. ' + 'Instead, in order to instantiate and build your ' + 'model, `call` your model on real tensor data (of ' + 'the correct dtype).') + + if self._layers: + self._track_layers(self._layers) + if self.layers: + for layer in self.layers: + if not layer.built: + raise ValueError('Layer: {} was not built in your model. Calling ' + '`build` manually on a subclassed model is only ' + 'allowed for models with a static topology. ' + 'In this case, you can build your model by ' + 'calling it on real tensor data.'.format(layer)) + self.built = True + def _check_trainable_weights_consistency(self): """Check trainable weights count consistency. @@ -897,7 +987,11 @@ class Model(Network): for output_shape, loss_fn in zip(self._feed_output_shapes, self._feed_loss_fns): if loss_fn is losses.sparse_categorical_crossentropy: - feed_output_shapes.append(output_shape[:-1] + (1,)) + if K.image_data_format() == 'channels_first': + feed_output_shapes.append( + (output_shape[0], 1) + output_shape[2:]) + else: + feed_output_shapes.append(output_shape[:-1] + (1,)) elif (not hasattr(loss_fn, '__name__') or getattr(losses, loss_fn.__name__, None) is None): # If `loss_fn` is not a function (e.g. callable class) @@ -988,10 +1082,14 @@ class Model(Network): inputs = inputs[0] if tensor_util.is_tensor(inputs): - input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:]) + if context.executing_eagerly(): + input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:]) + self.build(input_shape=input_shape) + else: + self.symbolic_set_inputs(inputs) else: input_shape = (None,) + inputs.shape[1:] - self.build(input_shape=input_shape) + self.build(input_shape=input_shape) elif context.executing_eagerly(): self._eager_set_inputs(inputs) else: |