diff options
author | 2017-08-14 11:29:54 -0700 | |
---|---|---|
committer | 2017-08-14 11:33:14 -0700 | |
commit | b0641138b866a5ffdc511f4ab055735513c57c92 (patch) | |
tree | d112665643d74d7915ac054396f7dcc2ddbf6b14 /tensorflow/python | |
parent | 538111038e9106a90d371f5db37cb5e9633de386 (diff) |
convert_to_tensor calls eager_convert_to_tensor in eager mode
Temporary hack to make most composite ops work.
PiperOrigin-RevId: 165205218
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 7 |
3 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 8f7ea9e27a..75c6baa2f8 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -967,6 +967,7 @@ py_test( ":variable_scope", ":variables", "//tensorflow/core:protos_all_py", + "//tensorflow/python/eager:context", ], ) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 857e6e65d4..032767c908 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -961,6 +961,8 @@ def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None): RuntimeError: If a registered conversion function returns an invalid value. """ + if context.in_eager_mode(): + return convert_to_eager_tensor(value, dtype=dtype) return internal_convert_to_tensor( value=value, dtype=dtype, @@ -1005,6 +1007,8 @@ def internal_convert_to_tensor(value, RuntimeError: If a registered conversion function returns an invalid value. """ + if context.in_eager_mode(): + return convert_to_eager_tensor(value, dtype=dtype) error_prefix = "" if name is None else "%s: " % name if dtype is not None: dtype = dtypes.as_dtype(dtype) diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 4a8e30f4cb..acb5fa53bf 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -25,6 +25,7 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev @@ -290,6 +291,12 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertAllEqual((4, 1), tensor.get_shape().as_list()) self.assertAllEqual(values, tensor.eval()) + def testConvertToTensorEager(self): + with context.eager_mode(): + t = ops.EagerTensor(1) + converted = ops.convert_to_tensor(t) + self.assertTrue(isinstance(converted, ops.EagerTensor)) + def testConvertToTensorNestedTuple(self): with self.test_session(): values = ((2,), (3,), (5,), (7,)) |