aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r--tensorflow/python/keras/engine/training.py148
1 files changed, 67 insertions, 81 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 966b446f22..d224dfffdd 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -45,6 +45,7 @@ from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -862,7 +863,8 @@ class Model(Network):
Fraction of the training data to be used as validation data.
Returns:
- A tuple of 3 lists: input arrays, target arrays, sample-weight arrays.
+ A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict
+ or not), target arrays, sample-weight arrays.
If the model's input and targets are symbolic, these lists are empty
(since the model takes no user-provided data, instead the data comes
from the symbolic inputs/targets).
@@ -928,11 +930,16 @@ class Model(Network):
'Make sure that your dataset can generate '
'required number of samples.')
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide model inputs as a list or tuple of 2 '
- 'elements: input and target pair. '
- 'Received %s' % next_element)
- x, y = next_element
+ if (not isinstance(next_element, (list, tuple)) or
+ len(next_element) not in [2, 3]):
+ raise ValueError(
+ 'Please provide model inputs as a list or tuple of 2 or 3'
+ 'elements: (input, target) or (input, target, sample_weights)'
+ 'Received %s' % next_element)
+ if len(next_element) == 2:
+ x, y = next_element
+ else:
+ x, y, sample_weight = next_element
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
class_weight, batch_size)
return x, y, sample_weights
@@ -948,6 +955,7 @@ class Model(Network):
all_inputs = []
is_build_called = False
is_compile_called = False
+ dict_inputs = False
if not self.inputs:
# We need to use `x` to set the model inputs.
# We type-check that `x` and `y` are either single arrays
@@ -959,7 +967,9 @@ class Model(Network):
'array or a list of arrays. You passed: x=' + str(x))
all_inputs += list(x)
elif isinstance(x, dict):
- raise ValueError('Please do not pass a dictionary as model inputs.')
+ dict_inputs = True
+ keys = sorted(x.keys())
+ all_inputs = [x[k] for k in keys]
else:
if not isinstance(x, np.ndarray) and not tensor_util.is_tensor(x):
raise ValueError('Please provide as model inputs either a single '
@@ -972,6 +982,8 @@ class Model(Network):
if not self.inputs:
is_build_called = True
self._set_inputs(x)
+ else:
+ dict_inputs = isinstance(self.inputs, dict)
if y is not None:
if not self.optimizer:
@@ -1124,6 +1136,10 @@ class Model(Network):
'a number of samples that can be '
'divided by the batch size. Found: ' +
str(x[0].shape[0]) + ' samples')
+
+ # If dictionary inputs were provided, we return a dictionary as well.
+ if dict_inputs:
+ x = dict(zip(feed_input_names, x))
return x, y, sample_weights
@checkpointable.no_automatic_dependency_tracking
@@ -1146,6 +1162,9 @@ class Model(Network):
training: Boolean or None. Only relevant in symbolic mode. Specifies
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
+ Raises:
+ ValueError: If dict inputs are passed to a Sequential Model where the
+ first layer isn't FeatureLayer.
"""
call_convention = getattr(
self,
@@ -1162,6 +1181,14 @@ class Model(Network):
if tensor_util.is_tensor(inputs):
input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
self.build(input_shape=input_shape)
+ elif isinstance(inputs, dict):
+ # We assert that the first layer is a FeatureLayer.
+ if not training_utils.is_feature_layer(self.layers[0]):
+ raise ValueError('Passing a dictionary input to a Sequential Model '
+ 'which doesnt have FeatureLayer as the first layer '
+ 'is an error')
+ input_shape = (None,)
+ self.build(input_shape=input_shape)
else:
input_shape = (None,) + inputs.shape[1:]
self.build(input_shape=input_shape)
@@ -1189,36 +1216,22 @@ class Model(Network):
assert context.executing_eagerly()
if self.inputs:
raise ValueError('Model inputs are already set.')
+
# On-the-fly setting of model inputs/outputs as DeferredTensors,
# to keep track of number of inputs and outputs and their ndim.
- if isinstance(inputs, (list, tuple)):
- if tensor_util.is_tensor(inputs[0]):
- dummy_output_values = self.call(
- training_utils.cast_if_floating_dtype(inputs))
- else:
- dummy_output_values = self.call(
- [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs])
- dummy_input_values = list(inputs)
- else:
- if tensor_util.is_tensor(inputs):
- dummy_output_values = self.call(
- training_utils.cast_if_floating_dtype(inputs))
- else:
- dummy_output_values = self.call(
- ops.convert_to_tensor(inputs, dtype=K.floatx()))
- dummy_input_values = [inputs]
- if isinstance(dummy_output_values, (list, tuple)):
- dummy_output_values = list(dummy_output_values)
- else:
- dummy_output_values = [dummy_output_values]
+ model_inputs = training_utils.ModelInputs(inputs)
+ dummy_input_values = model_inputs.get_input_values()
+ dummy_output_values = self.call(dummy_input_values)
+
+ self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.input_names = model_inputs.get_input_names()
+
+ dummy_output_values = nest.flatten(dummy_output_values)
self.outputs = [
- base_layer.DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_output_values]
- self.inputs = [
- base_layer.DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_input_values]
- self.input_names = [
- 'input_%d' % (i + 1) for i in range(len(dummy_input_values))]
+ base_layer.DeferredTensor(shape=(None
+ for _ in v.shape), dtype=v.dtype)
+ for v in dummy_output_values
+ ]
self.output_names = [
'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
self.built = True
@@ -1248,58 +1261,29 @@ class Model(Network):
# On-the-fly setting of symbolic model inputs (either by using the tensor
# provided, or by creating a placeholder if Numpy data was provided).
- self.inputs = []
- self.input_names = []
+ model_inputs = training_utils.ModelInputs(inputs)
+ dummy_input_values = model_inputs.get_symbolic_inputs()
+ self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.input_names = model_inputs.get_input_names()
+
self._feed_inputs = []
self._feed_input_names = []
self._feed_input_shapes = []
- if isinstance(inputs, (list, tuple)):
- inputs = list(inputs)
- else:
- inputs = [inputs]
-
- for i, v in enumerate(inputs):
- name = 'input_%d' % (i + 1)
- self.input_names.append(name)
- if isinstance(v, list):
- v = np.asarray(v)
- if v.ndim == 1:
- v = np.expand_dims(v, 1)
- if isinstance(v, (np.ndarray)):
- # We fix the placeholder shape except the batch size.
- # This is suboptimal, but it is the best we can do with the info
- # we have. The user should call `model._set_inputs(placeholders)`
- # to specify custom placeholders if the need arises.
- shape = (None,) + v.shape[1:]
- placeholder = K.placeholder(shape=shape, name=name)
- self.inputs.append(placeholder)
- self._feed_inputs.append(placeholder)
- self._feed_input_names.append(name)
- self._feed_input_shapes.append(shape)
- else:
- # Assumed tensor - TODO(fchollet) additional type check?
- self.inputs.append(v)
- if K.is_placeholder(v):
- self._feed_inputs.append(v)
- self._feed_input_names.append(name)
- self._feed_input_shapes.append(K.int_shape(v))
+
+ for k, v in model_inputs.as_dict():
+ if K.is_placeholder(v):
+ self._feed_inputs.append(v)
+ self._feed_input_names.append(k)
+ self._feed_input_shapes.append(K.int_shape(v))
if outputs is None:
# Obtain symbolic outputs by calling the model.
- if len(self.inputs) == 1:
- if self._expects_training_arg:
- outputs = self.call(self.inputs[0], training=training)
- else:
- outputs = self.call(self.inputs[0])
+ if self._expects_training_arg:
+ outputs = self.call(dummy_input_values, training=training)
else:
- if self._expects_training_arg:
- outputs = self.call(self.inputs, training=training)
- else:
- outputs = self.call(self.inputs)
- if isinstance(outputs, (list, tuple)):
- outputs = list(outputs)
- else:
- outputs = [outputs]
+ outputs = self.call(dummy_input_values)
+
+ outputs = nest.flatten(outputs)
self.outputs = outputs
self.output_names = [
'output_%d' % (i + 1) for i in range(len(self.outputs))]
@@ -1331,7 +1315,8 @@ class Model(Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- - A `tf.data` dataset or a dataset iterator.
+ - A `tf.data` dataset or a dataset iterator. Should return a tuple
+ of either (inputs, targets) or (inputs, targets, sample_weights).
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
@@ -1396,7 +1381,8 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator, instead
+ provide the sample_weights as the third element of `x`.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).