aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-07-20 09:28:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 09:31:51 -0700
commit2ff8c85dfca8afb2a4129e8fa86802bd5f25a1c6 (patch)
treee4c9254ae90c7d2ae5f4ba53322542a3c5a1d1ab /tensorflow
parent8d14663dbe9446ba50a36f64aaecfb5c06ea26d3 (diff)
[eager]: Correctly handle operation arguments of mixed types in the slow path.
Consider the following: import tensorflow as tf tf.enable_eager_execution() x = tf.Variable(1.0) tf.Print(x, ["foo", x]) Prior to this commit, this snippet would fail with an error: ValueError: exceptions.TypeError: object of type 'ResourceVariable' has no len() raised from the call to ops.EagerTensor in convert_to_mixed_eager_tensors. With this commit, the tf.Print call works correctly. Note that convert_to_mixed_eager_tensors is only called in the slow path of operation execution (i.e., when TFE_Py_FastPathExecute fails). Which happens rarely (e.g., when mixing primitive string and EagerTensor/ResourceVariable arguments). PiperOrigin-RevId: 205408407
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/eager/core_test.py8
-rw-r--r--tensorflow/python/eager/execute.py6
2 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index 3fabe7060e..cc765725a4 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -610,6 +610,14 @@ class TFETest(test_util.TensorFlowTestCase):
self.assertEquals(typ, dtypes.float32)
self.assertIsInstance(t, ops.EagerTensor)
+ def testConvertMixedEagerTensorsWithVariables(self):
+ var = resource_variable_ops.ResourceVariable(1.0)
+ types, tensors = execute_lib.convert_to_mixed_eager_tensors(
+ ['foo', var], context.context())
+ self.assertAllEqual([dtypes.string, dtypes.float32], types)
+ for t in tensors:
+ self.assertIsInstance(t, ops.EagerTensor)
+
class SendRecvTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py
index 2ff5b8d8f4..f9b8d2cb5d 100644
--- a/tensorflow/python/eager/execute.py
+++ b/tensorflow/python/eager/execute.py
@@ -198,11 +198,7 @@ def args_to_matching_eager(l, ctx, default_dtype=None):
def convert_to_mixed_eager_tensors(values, ctx):
- v = [
- t if isinstance(t, ops.EagerTensor) else ops.EagerTensor(
- t, context=ctx._handle, device=ctx.device_name) # pylint: disable=protected-access
- for t in values
- ]
+ v = [ops.internal_convert_to_tensor(t, ctx=ctx) for t in values]
types = [t._datatype_enum() for t in v] # pylint: disable=protected-access
return types, v