diff options
-rw-r--r-- | tensorflow/python/eager/ops_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/python_eager_op_gen.cc | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/tensor.py | 2 | ||||
-rw-r--r-- | tensorflow/python/framework/constant_op.py | 23 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 49 |
5 files changed, 34 insertions, 48 deletions
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index dee339f7f1..78ff2f6777 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -272,9 +272,7 @@ class TargetTest(test_util.TensorFlowTestCase): def testInvalidInputDataType(self): # Fill requires the first input to be an int32 tensor. - with self.assertRaisesRegexp( - TypeError, - 'Expected tensor with type tf.int32 not tf.int64'): + with self.assertRaisesRegexp(ValueError, 'int64'): array_ops.fill(tensor.Tensor([2], dtype=dtypes.int64), tensor.Tensor(1)) def testOutputOnHostMemory(self): diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index 511ce82eeb..c46a3d8db3 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -624,8 +624,8 @@ void GenEagerPythonOp::AddEagerInputCasts() { const string fn = arg.number_attr().empty() ? "" : "n_"; const string dtype = python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); - strings::StrAppend(&result_, " ", param, " = _tensor.convert_", fn, - "to_eager_tensor(", param, ", ", dtype, ")\n"); + strings::StrAppend(&result_, " ", param, " = _ops.convert_", fn, + "to_tensor(", param, ", ", dtype, ")\n"); } } diff --git a/tensorflow/python/eager/tensor.py b/tensorflow/python/eager/tensor.py index 1c2f4d74c7..69269d1975 100644 --- a/tensorflow/python/eager/tensor.py +++ b/tensorflow/python/eager/tensor.py @@ -24,8 +24,6 @@ import numpy as np # ops.py. # pylint: disable=unused-import from tensorflow.python.framework.ops import _tensor_from_handle -from tensorflow.python.framework.ops import convert_n_to_eager_tensor -from tensorflow.python.framework.ops import convert_to_eager_tensor from tensorflow.python.framework.ops import EagerTensor as Tensor # pylint: enable=unused-import diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index af3be7230c..9de63607e1 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -41,6 +41,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from autograd import core as ag_core import numpy as np from tensorflow.core.framework import attr_value_pb2 @@ -66,13 +67,29 @@ def _eager_reshape(tensor, shape): def _eager_fill(dims, value): """Eager-only version of Fill op; requires value is an eager Tensor.""" attr_t = value.dtype.as_datatype_enum - dims = ops.convert_to_eager_tensor(dims, dtypes.int32) + dims = convert_to_eager_tensor(dims, dtypes.int32) inputs_flat = [dims, value] attrs = ("T", attr_t) result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs) return result +def convert_to_eager_tensor(t, dtype=None): + """Converts the given `value` to an `EagerTensor`.""" + if isinstance(ag_core.getval(t), ops.EagerTensor): + if dtype is not None and t.dtype != dtype: + raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) + return t + # Handle converting ResourceVariable to Tensor. + # TODO(josh11b): get rid of this explicit ugly conversion once we have a more + # general scheme in place. + try: + return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access + except AttributeError: + pass + return ops.EagerTensor(t, dtype=dtype) + + def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """Creates a constant tensor. @@ -123,8 +140,8 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """ if not context.in_graph_mode(): if shape is None: - return ops.convert_to_eager_tensor(value, dtype) - t = ops.convert_to_eager_tensor(value, dtype) + return convert_to_eager_tensor(value, dtype) + t = convert_to_eager_tensor(value, dtype) shape = tensor_shape.as_shape(shape) if shape == t.shape: return t diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 862dd706f4..6f1954537e 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -876,29 +876,6 @@ class EagerTensor(Tensor): raise NotImplementedError("eval not supported for Eager Tensors.") -# TODO(josh11b): Support other cases like converting TensorShape, lists/tuples and -# other custom conversion functions. -def convert_to_eager_tensor(t, dtype=None): - """Converts the given `value` to an `EagerTensor`.""" - if isinstance(ag_core.getval(t), EagerTensor): - if dtype is not None and t.dtype != dtype: - raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) - return t - # Handle converting ResourceVariable to Tensor. - # TODO(josh11b): get rid of this explicit ugly conversion once we have a more - # general scheme in place. - try: - return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access - except AttributeError: - pass - return EagerTensor(t, dtype=dtype) - - -def convert_n_to_eager_tensor(values, dtype): - """Converts the given `values` to a list of `EagerTensor`.""" - return [convert_to_eager_tensor(t, dtype) for t in values] - - def _tensor_from_handle(handle): """'Private' constructor for the Tensor object. @@ -1112,21 +1089,17 @@ def internal_convert_n_to_tensor(values, """ if not isinstance(values, collections.Sequence): raise TypeError("values must be a list.") - if context.in_graph_mode(): - ret = [] - for i, value in enumerate(values): - n = None if name is None else "%s_%d" % (name, i) - ret.append( - internal_convert_to_tensor( - value, - dtype=dtype, - name=n, - as_ref=as_ref, - preferred_dtype=preferred_dtype)) - return ret - else: - # TODO(josh11b): handle preferred_dtype, as_ref - return convert_n_to_eager_tensor(values, dtype=dtype) + ret = [] + for i, value in enumerate(values): + n = None if name is None else "%s_%d" % (name, i) + ret.append( + internal_convert_to_tensor( + value, + dtype=dtype, + name=n, + as_ref=as_ref, + preferred_dtype=preferred_dtype)) + return ret def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): |