aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r--tensorflow/python/framework/ops.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 659bc394b9..b197e96886 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -49,6 +49,7 @@ from tensorflow.python.framework import versions
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
# Temporary global switch determining if we should enable the work-in-progress
@@ -1036,12 +1037,19 @@ def internal_convert_to_tensor(value,
# tracing gradients, to ensure the same behavior happens with and without
# tracing.
unwrapped = ag_core.getval(value)
- # Fast path for EagerTensors that don't need any conversion.
- if isinstance(unwrapped, EagerTensor) and context.in_eager_mode():
- # Note that we don't check that value's dtype matches the dtype
- # argument. We exepct that the C runtime will do that checking
- # when we execute the kernel.
- return value
+
+ if context.in_eager_mode():
+ # Fast path for EagerTensors that don't need any conversion.
+ if isinstance(unwrapped, EagerTensor):
+ # Note that we don't check that value's dtype matches the dtype
+ # argument. We exepct that the C runtime will do that checking
+ # when we execute the kernel.
+ return value
+ values = nest.flatten(value)
+ if (len(values) > 1 and
+ any(isinstance(ag_core.getval(v), EagerTensor) for v in values)):
+ raise TypeError("Cannot convert to a eager tensor.")
+
if dtype is not None:
dtype = dtypes.as_dtype(dtype)
unwrapped_type = type(unwrapped)