aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 10:57:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 11:06:53 -0700
commit025277a1598fa227b53ddc4e316a7a953b2006c8 (patch)
tree67ead746e128ddf16712b30e19c617e871d70183
parent6d893ecfb9ba2dfc3948215557d4f8ddaf7cf51b (diff)
Small improvements to handling of Datasets in Keras.
* Allow sparse labels to work with Datasets. * Allow sample_weights to be passed as the third output of a Dataset (like how generator input is treated). PiperOrigin-RevId: 211834259
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py3
-rw-r--r--tensorflow/python/keras/engine/training.py21
-rw-r--r--tensorflow/python/keras/engine/training_eager.py9
-rw-r--r--tensorflow/python/keras/engine/training_test.py43
-rw-r--r--tensorflow/python/keras/engine/training_utils.py18
5 files changed, 72 insertions, 22 deletions
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index d39fd57294..3cee3e37a7 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -446,8 +446,7 @@ class TestWithDistributionStrategy(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- with self.assertRaisesRegexp(ValueError,
- 'expected input to have 2 dimensions'):
+ with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
# Wrong input shape
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 966b446f22..46149bed09 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -928,11 +928,16 @@ class Model(Network):
'Make sure that your dataset can generate '
'required number of samples.')
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide model inputs as a list or tuple of 2 '
- 'elements: input and target pair. '
- 'Received %s' % next_element)
- x, y = next_element
+ if (not isinstance(next_element, (list, tuple)) or
+ len(next_element) not in [2, 3]):
+ raise ValueError(
+ 'Please provide model inputs as a list or tuple of 2 or 3'
+ 'elements: (input, target) or (input, target, sample_weights)'
+ 'Received %s' % next_element)
+ if len(next_element) == 2:
+ x, y = next_element
+ else:
+ x, y, sample_weight = next_element
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
class_weight, batch_size)
return x, y, sample_weights
@@ -1331,7 +1336,8 @@ class Model(Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- - A `tf.data` dataset or a dataset iterator.
+ - A `tf.data` dataset or a dataset iterator. Should return a tuple
+ of either (inputs, targets) or (inputs, targets, sample_weights).
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
@@ -1396,7 +1402,8 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator, instead
+ provide the sample_weights as the third element of `x`.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 1e377149b6..f5bf2429d0 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -417,11 +417,12 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
"""
assert isinstance(inputs, iterator_ops.EagerIterator)
if not isinstance(inputs.output_shapes,
- (list, tuple)) or len(inputs.output_shapes) > 2:
+ (list, tuple)) or len(inputs.output_shapes) > 3:
raise ValueError(
- 'Please provide data as a list or tuple of 1 or 2 elements '
- ' - input or input and target pair. Received %s. We do not use the '
- '`target` value here.' % inputs.output_shapes)
+ 'Please provide data as a list or tuple of 1, 2, or 3 elements '
+ ' - `(input)`, or `(input, target)`, or `(input, target,'
+ 'sample_weights)`. Received %s. We do not use the `target` or'
+ '`sample_weights` value here.' % inputs.output_shapes)
outs = []
if verbose == 1:
progbar = generic_utils.Progbar(target=steps)
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index bf5c7fd7f8..d5c9a2ed1a 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -2097,6 +2097,43 @@ class TestTrainingWithDataset(test.TestCase):
'you should specify the `steps` argument'):
model.predict(dataset, verbose=0)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_dataset_with_sample_weights(self):
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae', metrics_module.CategoricalAccuracy()]
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3), np.float32)
+ targets = np.zeros((10, 4), np.float32)
+ sample_weights = np.ones((10), np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
+ sample_weights))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
+ model.train_on_batch(dataset)
+ model.predict_on_batch(dataset)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_dataset_with_sparse_labels(self):
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'sparse_categorical_crossentropy'
+ model.compile(optimizer, loss)
+
+ inputs = np.zeros((10, 3))
+ targets = np.random.randint(0, 4, size=10, dtype=np.int32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+
def test_dataset_input_shape_validation(self):
with self.test_session():
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
@@ -2108,8 +2145,10 @@ class TestTrainingWithDataset(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- with self.assertRaisesRegexp(ValueError,
- r'expected (.*?) to have 2 dimensions'):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)'
+ ):
model.train_on_batch(dataset)
# Wrong input shape
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index f94697c913..ae5741d9f7 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -210,10 +210,11 @@ def check_num_samples(ins,
def standardize_single_array(x):
if x is None:
return None
- elif tensor_util.is_tensor(x):
- return x
- elif x.ndim == 1:
- x = np.expand_dims(x, 1)
+ if x.shape is not None and len(x.shape) == 1:
+ if tensor_util.is_tensor(x):
+ return array_ops.expand_dims(x, axis=1)
+ else:
+ return np.expand_dims(x, 1)
return x
@@ -341,7 +342,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
Raises:
ValueError: In case of invalid user-provided argument.
"""
- if x_weight is None or len(x_weight) == 0: # pylint: disable=g-explicit-length-test
+ if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test
return [None for _ in output_names]
if len(output_names) == 1:
if isinstance(x_weight, list) and len(x_weight) == 1:
@@ -675,7 +676,8 @@ def standardize_weights(y,
'Expected sample_weight with rank '
'less than or equal to ' + str(len(y.shape)))
- if y.shape[:sample_weight.ndim] != sample_weight.shape:
+ if (not tensor_util.is_tensor(sample_weight) and
+ y.shape[:sample_weight.ndim] != sample_weight.shape):
raise ValueError(
'Found a sample_weight array with shape ' + str(sample_weight.shape) +
' for an input with shape ' + str(y.shape) + '. '
@@ -777,7 +779,9 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
'Received: %s' % (x, y))
if sample_weight is not None:
raise ValueError('`sample_weight` argument is not supported when input '
- '`x` is a dataset or a dataset iterator. '
+ '`x` is a dataset or a dataset iterator. Instead, you'
+ 'can provide sample_weight as the third element of your'
+ 'dataset, i.e. (inputs, targets, sample_weight). '
'Received: x=%s, sample_weight=%s' % (x, sample_weight))
if validation_split is not None and validation_split != 0.0:
raise ValueError(