diff options
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r-- | tensorflow/python/framework/ops.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 659bc394b9..b197e96886 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -49,6 +49,7 @@ from tensorflow.python.framework import versions from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils +from tensorflow.python.util import nest from tensorflow.python.util import tf_contextlib # Temporary global switch determining if we should enable the work-in-progress @@ -1036,12 +1037,19 @@ def internal_convert_to_tensor(value, # tracing gradients, to ensure the same behavior happens with and without # tracing. unwrapped = ag_core.getval(value) - # Fast path for EagerTensors that don't need any conversion. - if isinstance(unwrapped, EagerTensor) and context.in_eager_mode(): - # Note that we don't check that value's dtype matches the dtype - # argument. We exepct that the C runtime will do that checking - # when we execute the kernel. - return value + + if context.in_eager_mode(): + # Fast path for EagerTensors that don't need any conversion. + if isinstance(unwrapped, EagerTensor): + # Note that we don't check that value's dtype matches the dtype + # argument. We exepct that the C runtime will do that checking + # when we execute the kernel. + return value + values = nest.flatten(value) + if (len(values) > 1 and + any(isinstance(ag_core.getval(v), EagerTensor) for v in values)): + raise TypeError("Cannot convert to a eager tensor.") + if dtype is not None: dtype = dtypes.as_dtype(dtype) unwrapped_type = type(unwrapped) |