aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-08-14 11:29:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-14 11:33:14 -0700
commitb0641138b866a5ffdc511f4ab055735513c57c92 (patch)
treed112665643d74d7915ac054396f7dcc2ddbf6b14 /tensorflow/python
parent538111038e9106a90d371f5db37cb5e9633de386 (diff)
convert_to_tensor calls eager_convert_to_tensor in eager mode
Temporary hack to make most composite ops work. PiperOrigin-RevId: 165205218
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/framework/ops.py4
-rw-r--r--tensorflow/python/framework/ops_test.py7
3 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 8f7ea9e27a..75c6baa2f8 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -967,6 +967,7 @@ py_test(
":variable_scope",
":variables",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:context",
],
)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 857e6e65d4..032767c908 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -961,6 +961,8 @@ def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None):
RuntimeError: If a registered conversion function returns an invalid value.
"""
+ if context.in_eager_mode():
+ return convert_to_eager_tensor(value, dtype=dtype)
return internal_convert_to_tensor(
value=value,
dtype=dtype,
@@ -1005,6 +1007,8 @@ def internal_convert_to_tensor(value,
RuntimeError: If a registered conversion function returns an invalid value.
"""
+ if context.in_eager_mode():
+ return convert_to_eager_tensor(value, dtype=dtype)
error_prefix = "" if name is None else "%s: " % name
if dtype is not None:
dtype = dtypes.as_dtype(dtype)
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 4a8e30f4cb..acb5fa53bf 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -25,6 +25,7 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
@@ -290,6 +291,12 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, tensor.eval())
+ def testConvertToTensorEager(self):
+ with context.eager_mode():
+ t = ops.EagerTensor(1)
+ converted = ops.convert_to_tensor(t)
+ self.assertTrue(isinstance(converted, ops.EagerTensor))
+
def testConvertToTensorNestedTuple(self):
with self.test_session():
values = ((2,), (3,), (5,), (7,))