diff options
author | 2018-07-20 09:28:37 -0700 | |
---|---|---|
committer | 2018-07-20 09:31:51 -0700 | |
commit | 2ff8c85dfca8afb2a4129e8fa86802bd5f25a1c6 (patch) | |
tree | e4c9254ae90c7d2ae5f4ba53322542a3c5a1d1ab /tensorflow | |
parent | 8d14663dbe9446ba50a36f64aaecfb5c06ea26d3 (diff) |
[eager]: Correctly handle operation arguments of mixed types in the slow path.
Consider the following:
import tensorflow as tf
tf.enable_eager_execution()
x = tf.Variable(1.0)
tf.Print(x, ["foo", x])
Prior to this commit, this snippet would fail with an error:
ValueError: exceptions.TypeError: object of type 'ResourceVariable' has no len()
raised from the call to ops.EagerTensor in convert_to_mixed_eager_tensors.
With this commit, the tf.Print call works correctly.
Note that convert_to_mixed_eager_tensors is only called in the slow path of operation execution (i.e., when TFE_Py_FastPathExecute fails). Which happens rarely (e.g., when mixing primitive string and EagerTensor/ResourceVariable arguments).
PiperOrigin-RevId: 205408407
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/eager/core_test.py | 8 | ||||
-rw-r--r-- | tensorflow/python/eager/execute.py | 6 |
2 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 3fabe7060e..cc765725a4 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -610,6 +610,14 @@ class TFETest(test_util.TensorFlowTestCase): self.assertEquals(typ, dtypes.float32) self.assertIsInstance(t, ops.EagerTensor) + def testConvertMixedEagerTensorsWithVariables(self): + var = resource_variable_ops.ResourceVariable(1.0) + types, tensors = execute_lib.convert_to_mixed_eager_tensors( + ['foo', var], context.context()) + self.assertAllEqual([dtypes.string, dtypes.float32], types) + for t in tensors: + self.assertIsInstance(t, ops.EagerTensor) + class SendRecvTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 2ff5b8d8f4..f9b8d2cb5d 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -198,11 +198,7 @@ def args_to_matching_eager(l, ctx, default_dtype=None): def convert_to_mixed_eager_tensors(values, ctx): - v = [ - t if isinstance(t, ops.EagerTensor) else ops.EagerTensor( - t, context=ctx._handle, device=ctx.device_name) # pylint: disable=protected-access - for t in values - ] + v = [ops.internal_convert_to_tensor(t, ctx=ctx) for t in values] types = [t._datatype_enum() for t in v] # pylint: disable=protected-access return types, v |