aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-08-18 07:39:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-18 07:43:28 -0700
commita6729325a3534ef4aeb2065be82bb2963b9b03de (patch)
treea7b22490cb6e9af963e94c652f2e2cf197ad7d23
parent573b303ac8204d626bee266798e1eb3df0fed491 (diff)
Deletes convert_n_to_eager_tensor. Moves convert_to_eager_tensor to constant_op.
PiperOrigin-RevId: 165704074
-rw-r--r--tensorflow/python/eager/ops_test.py4
-rw-r--r--tensorflow/python/eager/python_eager_op_gen.cc4
-rw-r--r--tensorflow/python/eager/tensor.py2
-rw-r--r--tensorflow/python/framework/constant_op.py23
-rw-r--r--tensorflow/python/framework/ops.py49
5 files changed, 34 insertions, 48 deletions
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py
index dee339f7f1..78ff2f6777 100644
--- a/tensorflow/python/eager/ops_test.py
+++ b/tensorflow/python/eager/ops_test.py
@@ -272,9 +272,7 @@ class TargetTest(test_util.TensorFlowTestCase):
def testInvalidInputDataType(self):
# Fill requires the first input to be an int32 tensor.
- with self.assertRaisesRegexp(
- TypeError,
- 'Expected tensor with type tf.int32 not tf.int64'):
+ with self.assertRaisesRegexp(ValueError, 'int64'):
array_ops.fill(tensor.Tensor([2], dtype=dtypes.int64), tensor.Tensor(1))
def testOutputOnHostMemory(self):
diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc
index 511ce82eeb..c46a3d8db3 100644
--- a/tensorflow/python/eager/python_eager_op_gen.cc
+++ b/tensorflow/python/eager/python_eager_op_gen.cc
@@ -624,8 +624,8 @@ void GenEagerPythonOp::AddEagerInputCasts() {
const string fn = arg.number_attr().empty() ? "" : "n_";
const string dtype =
python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
- strings::StrAppend(&result_, " ", param, " = _tensor.convert_", fn,
- "to_eager_tensor(", param, ", ", dtype, ")\n");
+ strings::StrAppend(&result_, " ", param, " = _ops.convert_", fn,
+ "to_tensor(", param, ", ", dtype, ")\n");
}
}
diff --git a/tensorflow/python/eager/tensor.py b/tensorflow/python/eager/tensor.py
index 1c2f4d74c7..69269d1975 100644
--- a/tensorflow/python/eager/tensor.py
+++ b/tensorflow/python/eager/tensor.py
@@ -24,8 +24,6 @@ import numpy as np
# ops.py.
# pylint: disable=unused-import
from tensorflow.python.framework.ops import _tensor_from_handle
-from tensorflow.python.framework.ops import convert_n_to_eager_tensor
-from tensorflow.python.framework.ops import convert_to_eager_tensor
from tensorflow.python.framework.ops import EagerTensor as Tensor
# pylint: enable=unused-import
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index af3be7230c..9de63607e1 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -41,6 +41,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from autograd import core as ag_core
import numpy as np
from tensorflow.core.framework import attr_value_pb2
@@ -66,13 +67,29 @@ def _eager_reshape(tensor, shape):
def _eager_fill(dims, value):
"""Eager-only version of Fill op; requires value is an eager Tensor."""
attr_t = value.dtype.as_datatype_enum
- dims = ops.convert_to_eager_tensor(dims, dtypes.int32)
+ dims = convert_to_eager_tensor(dims, dtypes.int32)
inputs_flat = [dims, value]
attrs = ("T", attr_t)
result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs)
return result
+def convert_to_eager_tensor(t, dtype=None):
+ """Converts the given `value` to an `EagerTensor`."""
+ if isinstance(ag_core.getval(t), ops.EagerTensor):
+ if dtype is not None and t.dtype != dtype:
+ raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
+ return t
+ # Handle converting ResourceVariable to Tensor.
+ # TODO(josh11b): get rid of this explicit ugly conversion once we have a more
+ # general scheme in place.
+ try:
+ return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
+ except AttributeError:
+ pass
+ return ops.EagerTensor(t, dtype=dtype)
+
+
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
"""Creates a constant tensor.
@@ -123,8 +140,8 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
"""
if not context.in_graph_mode():
if shape is None:
- return ops.convert_to_eager_tensor(value, dtype)
- t = ops.convert_to_eager_tensor(value, dtype)
+ return convert_to_eager_tensor(value, dtype)
+ t = convert_to_eager_tensor(value, dtype)
shape = tensor_shape.as_shape(shape)
if shape == t.shape:
return t
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 862dd706f4..6f1954537e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -876,29 +876,6 @@ class EagerTensor(Tensor):
raise NotImplementedError("eval not supported for Eager Tensors.")
-# TODO(josh11b): Support other cases like converting TensorShape, lists/tuples and
-# other custom conversion functions.
-def convert_to_eager_tensor(t, dtype=None):
- """Converts the given `value` to an `EagerTensor`."""
- if isinstance(ag_core.getval(t), EagerTensor):
- if dtype is not None and t.dtype != dtype:
- raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
- return t
- # Handle converting ResourceVariable to Tensor.
- # TODO(josh11b): get rid of this explicit ugly conversion once we have a more
- # general scheme in place.
- try:
- return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
- except AttributeError:
- pass
- return EagerTensor(t, dtype=dtype)
-
-
-def convert_n_to_eager_tensor(values, dtype):
- """Converts the given `values` to a list of `EagerTensor`."""
- return [convert_to_eager_tensor(t, dtype) for t in values]
-
-
def _tensor_from_handle(handle):
"""'Private' constructor for the Tensor object.
@@ -1112,21 +1089,17 @@ def internal_convert_n_to_tensor(values,
"""
if not isinstance(values, collections.Sequence):
raise TypeError("values must be a list.")
- if context.in_graph_mode():
- ret = []
- for i, value in enumerate(values):
- n = None if name is None else "%s_%d" % (name, i)
- ret.append(
- internal_convert_to_tensor(
- value,
- dtype=dtype,
- name=n,
- as_ref=as_ref,
- preferred_dtype=preferred_dtype))
- return ret
- else:
- # TODO(josh11b): handle preferred_dtype, as_ref
- return convert_n_to_eager_tensor(values, dtype=dtype)
+ ret = []
+ for i, value in enumerate(values):
+ n = None if name is None else "%s_%d" % (name, i)
+ ret.append(
+ internal_convert_to_tensor(
+ value,
+ dtype=dtype,
+ name=n,
+ as_ref=as_ref,
+ preferred_dtype=preferred_dtype))
+ return ret
def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):