diff options
author | 2017-08-18 07:39:41 -0700 | |
---|---|---|
committer | 2017-08-18 07:43:28 -0700 | |
commit | a6729325a3534ef4aeb2065be82bb2963b9b03de (patch) | |
tree | a7b22490cb6e9af963e94c652f2e2cf197ad7d23 /tensorflow/python/framework/constant_op.py | |
parent | 573b303ac8204d626bee266798e1eb3df0fed491 (diff) |
Deletes convert_n_to_eager_tensor. Moves convert_to_eager_tensor to constant_op.
PiperOrigin-RevId: 165704074
Diffstat (limited to 'tensorflow/python/framework/constant_op.py')
-rw-r--r-- | tensorflow/python/framework/constant_op.py | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index af3be7230c..9de63607e1 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -41,6 +41,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from autograd import core as ag_core import numpy as np from tensorflow.core.framework import attr_value_pb2 @@ -66,13 +67,29 @@ def _eager_reshape(tensor, shape): def _eager_fill(dims, value): """Eager-only version of Fill op; requires value is an eager Tensor.""" attr_t = value.dtype.as_datatype_enum - dims = ops.convert_to_eager_tensor(dims, dtypes.int32) + dims = convert_to_eager_tensor(dims, dtypes.int32) inputs_flat = [dims, value] attrs = ("T", attr_t) result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs) return result +def convert_to_eager_tensor(t, dtype=None): + """Converts the given `value` to an `EagerTensor`.""" + if isinstance(ag_core.getval(t), ops.EagerTensor): + if dtype is not None and t.dtype != dtype: + raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) + return t + # Handle converting ResourceVariable to Tensor. + # TODO(josh11b): get rid of this explicit ugly conversion once we have a more + # general scheme in place. + try: + return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access + except AttributeError: + pass + return ops.EagerTensor(t, dtype=dtype) + + def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """Creates a constant tensor. @@ -123,8 +140,8 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """ if not context.in_graph_mode(): if shape is None: - return ops.convert_to_eager_tensor(value, dtype) - t = ops.convert_to_eager_tensor(value, dtype) + return convert_to_eager_tensor(value, dtype) + t = convert_to_eager_tensor(value, dtype) shape = tensor_shape.as_shape(shape) if shape == t.shape: return t |