diff options
author | Sourabh Bajaj <sourabhbajaj@google.com> | 2018-10-08 10:52:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 11:09:18 -0700 |
commit | 3f0155133d668cf6cee1f1fb362d2a75c04836e3 (patch) | |
tree | 8333824da1fa9e56367612c78b4f14aeeb89932a /tensorflow/python | |
parent | 0691d49fb6e15740b8ddf8019fea4edb91bca914 (diff) |
Fix support for a single tensor to be passed to target_tensors
PiperOrigin-RevId: 216212953
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 6 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_distributed.py | 4 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training_test.py | 4 |
3 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 2ebb4cf99f..ff2ae54ad4 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -563,9 +563,11 @@ class Model(Network): for name in self.output_names: tmp_target_tensors.append(target_tensors.get(name, None)) target_tensors = tmp_target_tensors + elif tensor_util.is_tensor(target_tensors): + target_tensors = [target_tensors] else: - raise TypeError('Expected `target_tensors` to be ' - 'a list or dict, but got:', target_tensors) + raise TypeError('Expected `target_tensors` to be a list or tuple or ' + 'dict or a single tensor, but got:', target_tensors) for i in range(len(self.outputs)): if i in skip_target_indices: diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 04e8d079c0..ac759ef3aa 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -820,10 +820,6 @@ def _clone_and_build_model(model, inputs=None, targets=None): optimizer_config = model.optimizer.get_config() optimizer = model.optimizer.__class__.from_config(optimizer_config) - # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a - # single tensor should be OK but it throws an error in that case. - if targets is not None and not isinstance(targets, (list, dict, tuple)): - targets = [targets] if isinstance(targets, tuple): targets = nest.flatten(targets) cloned_model.compile( diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 54ad74c08b..868fd1dc69 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -1865,6 +1865,10 @@ class TestTrainingWithDataTensors(test.TestCase): model.compile(optimizer='rmsprop', loss='mse', target_tensors=[target]) model.train_on_batch(input_val, None) + # single-output, as single tensor + model.compile(optimizer='rmsprop', loss='mse', target_tensors=target) + model.train_on_batch(input_val, None) + # single-output, as dict model.compile(optimizer='rmsprop', loss='mse', target_tensors={'dense': target}) |