aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r--tensorflow/python/keras/engine/training.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 2ebb4cf99f..ff2ae54ad4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -563,9 +563,11 @@ class Model(Network):
for name in self.output_names:
tmp_target_tensors.append(target_tensors.get(name, None))
target_tensors = tmp_target_tensors
+ elif tensor_util.is_tensor(target_tensors):
+ target_tensors = [target_tensors]
else:
- raise TypeError('Expected `target_tensors` to be '
- 'a list or dict, but got:', target_tensors)
+ raise TypeError('Expected `target_tensors` to be a list or tuple or '
+ 'dict or a single tensor, but got:', target_tensors)
for i in range(len(self.outputs)):
if i in skip_target_indices: