aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/constant_op.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-08-18 07:39:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-18 07:43:28 -0700
commita6729325a3534ef4aeb2065be82bb2963b9b03de (patch)
treea7b22490cb6e9af963e94c652f2e2cf197ad7d23 /tensorflow/python/framework/constant_op.py
parent573b303ac8204d626bee266798e1eb3df0fed491 (diff)
Deletes convert_n_to_eager_tensor. Moves convert_to_eager_tensor to constant_op.
PiperOrigin-RevId: 165704074
Diffstat (limited to 'tensorflow/python/framework/constant_op.py')
-rw-r--r--tensorflow/python/framework/constant_op.py23
1 files changed, 20 insertions, 3 deletions
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index af3be7230c..9de63607e1 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -41,6 +41,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from autograd import core as ag_core
import numpy as np
from tensorflow.core.framework import attr_value_pb2
@@ -66,13 +67,29 @@ def _eager_reshape(tensor, shape):
def _eager_fill(dims, value):
"""Eager-only version of Fill op; requires value is an eager Tensor."""
attr_t = value.dtype.as_datatype_enum
- dims = ops.convert_to_eager_tensor(dims, dtypes.int32)
+ dims = convert_to_eager_tensor(dims, dtypes.int32)
inputs_flat = [dims, value]
attrs = ("T", attr_t)
result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs)
return result
+def convert_to_eager_tensor(t, dtype=None):
+ """Converts the given `value` to an `EagerTensor`."""
+ if isinstance(ag_core.getval(t), ops.EagerTensor):
+ if dtype is not None and t.dtype != dtype:
+ raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
+ return t
+ # Handle converting ResourceVariable to Tensor.
+ # TODO(josh11b): get rid of this explicit ugly conversion once we have a more
+ # general scheme in place.
+ try:
+ return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
+ except AttributeError:
+ pass
+ return ops.EagerTensor(t, dtype=dtype)
+
+
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
"""Creates a constant tensor.
@@ -123,8 +140,8 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
"""
if not context.in_graph_mode():
if shape is None:
- return ops.convert_to_eager_tensor(value, dtype)
- t = ops.convert_to_eager_tensor(value, dtype)
+ return convert_to_eager_tensor(value, dtype)
+ t = convert_to_eager_tensor(value, dtype)
shape = tensor_shape.as_shape(shape)
if shape == t.shape:
return t