diff options
Diffstat (limited to 'tensorflow/python/keras/engine/training.py')
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 2ebb4cf99f..ff2ae54ad4 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -563,9 +563,11 @@ class Model(Network): for name in self.output_names: tmp_target_tensors.append(target_tensors.get(name, None)) target_tensors = tmp_target_tensors + elif tensor_util.is_tensor(target_tensors): + target_tensors = [target_tensors] else: - raise TypeError('Expected `target_tensors` to be ' - 'a list or dict, but got:', target_tensors) + raise TypeError('Expected `target_tensors` to be a list or tuple or ' + 'dict or a single tensor, but got:', target_tensors) for i in range(len(self.outputs)): if i in skip_target_indices: |