aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2018-10-08 10:52:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 11:09:18 -0700
commit3f0155133d668cf6cee1f1fb362d2a75c04836e3 (patch)
tree8333824da1fa9e56367612c78b4f14aeeb89932a /tensorflow/python
parent0691d49fb6e15740b8ddf8019fea4edb91bca914 (diff)
Fix support for a single tensor to be passed to target_tensors
PiperOrigin-RevId: 216212953
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/keras/engine/training.py6
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py4
-rw-r--r--tensorflow/python/keras/engine/training_test.py4
3 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 2ebb4cf99f..ff2ae54ad4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -563,9 +563,11 @@ class Model(Network):
for name in self.output_names:
tmp_target_tensors.append(target_tensors.get(name, None))
target_tensors = tmp_target_tensors
+ elif tensor_util.is_tensor(target_tensors):
+ target_tensors = [target_tensors]
else:
- raise TypeError('Expected `target_tensors` to be '
- 'a list or dict, but got:', target_tensors)
+ raise TypeError('Expected `target_tensors` to be a list or tuple or '
+ 'dict or a single tensor, but got:', target_tensors)
for i in range(len(self.outputs)):
if i in skip_target_indices:
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 04e8d079c0..ac759ef3aa 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -820,10 +820,6 @@ def _clone_and_build_model(model, inputs=None, targets=None):
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
- # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
- # single tensor should be OK but it throws an error in that case.
- if targets is not None and not isinstance(targets, (list, dict, tuple)):
- targets = [targets]
if isinstance(targets, tuple):
targets = nest.flatten(targets)
cloned_model.compile(
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 54ad74c08b..868fd1dc69 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -1865,6 +1865,10 @@ class TestTrainingWithDataTensors(test.TestCase):
model.compile(optimizer='rmsprop', loss='mse', target_tensors=[target])
model.train_on_batch(input_val, None)
+ # single-output, as single tensor
+ model.compile(optimizer='rmsprop', loss='mse', target_tensors=target)
+ model.train_on_batch(input_val, None)
+
# single-output, as dict
model.compile(optimizer='rmsprop', loss='mse',
target_tensors={'dense': target})