From 3f0155133d668cf6cee1f1fb362d2a75c04836e3 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Mon, 8 Oct 2018 10:52:15 -0700 Subject: Fix support for a single tensor to be passed to target_tensors PiperOrigin-RevId: 216212953 --- tensorflow/python/keras/engine/training.py | 6 ++++-- tensorflow/python/keras/engine/training_distributed.py | 4 ---- tensorflow/python/keras/engine/training_test.py | 4 ++++ 3 files changed, 8 insertions(+), 6 deletions(-) (limited to 'tensorflow/python') diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 2ebb4cf99f..ff2ae54ad4 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -563,9 +563,11 @@ class Model(Network): for name in self.output_names: tmp_target_tensors.append(target_tensors.get(name, None)) target_tensors = tmp_target_tensors + elif tensor_util.is_tensor(target_tensors): + target_tensors = [target_tensors] else: - raise TypeError('Expected `target_tensors` to be ' - 'a list or dict, but got:', target_tensors) + raise TypeError('Expected `target_tensors` to be a list or tuple or ' + 'dict or a single tensor, but got:', target_tensors) for i in range(len(self.outputs)): if i in skip_target_indices: diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 04e8d079c0..ac759ef3aa 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -820,10 +820,6 @@ def _clone_and_build_model(model, inputs=None, targets=None): optimizer_config = model.optimizer.get_config() optimizer = model.optimizer.__class__.from_config(optimizer_config) - # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a - # single tensor should be OK but it throws an error in that case. - if targets is not None and not isinstance(targets, (list, dict, tuple)): - targets = [targets] if isinstance(targets, tuple): targets = nest.flatten(targets) cloned_model.compile( diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 54ad74c08b..868fd1dc69 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1865,6 +1865,10 @@ class TestTrainingWithDataTensors(test.TestCase): model.compile(optimizer='rmsprop', loss='mse', target_tensors=[target]) model.train_on_batch(input_val, None) + # single-output, as single tensor + model.compile(optimizer='rmsprop', loss='mse', target_tensors=target) + model.train_on_batch(input_val, None) + # single-output, as dict model.compile(optimizer='rmsprop', loss='mse', target_tensors={'dense': target}) -- cgit v1.2.3