diff options
-rw-r--r-- | tensorflow/python/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/context.py | 3 | ||||
-rw-r--r-- | tensorflow/python/framework/constant_op.py | 65 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 116 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 3 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 20 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/array_ops_test.py | 91 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/constant_op_eager_test.py | 426 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/resource_variable_ops_test.py | 159 | ||||
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 218 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 299 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 23 | ||||
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 1 |
13 files changed, 1110 insertions, 318 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 75c6baa2f8..6af0d2bbe0 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -483,6 +483,7 @@ py_library( ":framework_ops", ":tensor_shape", "//tensorflow/core:protos_all_py", + "//tensorflow/python/eager:execute", ], ) @@ -1723,6 +1724,8 @@ py_library( ":util", ":variables", "//tensorflow/core:protos_all_py", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tensor", ], ) @@ -2203,6 +2206,7 @@ py_library( ":tensor_shape", ":util", "//tensorflow/core:protos_all_py", + "//tensorflow/python/eager:context", ], ) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index bfc9b47302..5094a1def7 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -155,7 +155,7 @@ class Context(object): def device_name(self): """Returns the device name for the current thread.""" index = self._device_index - return None if index < 0 else self._devices[index] + return "CPU:0" if index < 0 else self._devices[index] def devices(self): """List of the names of devices available to execute operations.""" @@ -271,6 +271,7 @@ def eager_mode(): return get_default_context()._mode(EAGER_MODE) # pylint: disable=protected-access +# TODO(agarwal): get rid of this and use ops.name_scope instead. @contextlib.contextmanager def namescope(name): """ContextManager for creating hierarchical name scopes.""" diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index dd0fdaf712..af3be7230c 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Operations that generate constants. See the @{$python/constant_op$constants guide}. @@ -45,12 +44,35 @@ from __future__ import print_function import numpy as np from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.eager import context +from tensorflow.python.eager import execute from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +def _eager_reshape(tensor, shape): + """Eager-only version of Reshape op; requires tensor is an eager Tensor.""" + attr_t = tensor.dtype.as_datatype_enum + attr_tshape, (shape,) = execute.args_to_matching_eager([shape], dtypes.int32) + attr_tshape = attr_tshape.as_datatype_enum + inputs_flat = [tensor, shape] + attrs = ("T", attr_t, "Tshape", attr_tshape) + result, = execute.execute("Reshape", 1, inputs=inputs_flat, attrs=attrs) + return result + + +def _eager_fill(dims, value): + """Eager-only version of Fill op; requires value is an eager Tensor.""" + attr_t = value.dtype.as_datatype_enum + dims = ops.convert_to_eager_tensor(dims, dtypes.int32) + inputs_flat = [dims, value] + attrs = ("T", attr_t) + result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs) + return result + + def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): """Creates a constant tensor. @@ -95,15 +117,40 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False): Returns: A Constant Tensor. + + Raises: + TypeError if shape is incorrectly specified or unsupported. """ + if not context.in_graph_mode(): + if shape is None: + return ops.convert_to_eager_tensor(value, dtype) + t = ops.convert_to_eager_tensor(value, dtype) + shape = tensor_shape.as_shape(shape) + if shape == t.shape: + return t + if verify_shape: + raise TypeError("Expected Tensor's shape: %s, got %s." % (tuple(shape), + tuple(t.shape))) + num_t = t.shape.num_elements() + # TODO(josh11b): Implement shape -> eager tensor conversion. + if num_t == shape.num_elements(): + return _eager_reshape(t, shape.as_list()) + if num_t == 1: + return _eager_fill(shape.as_list(), t) + raise TypeError("Eager execution of tf.constant with unsupported shape " + "(value has %d elements, shape is %s with %d elements)." % + (num_t, shape, shape.num_elements())) g = ops.get_default_graph() tensor_value = attr_value_pb2.AttrValue() tensor_value.tensor.CopyFrom( - tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape)) + tensor_util.make_tensor_proto( + value, dtype=dtype, shape=shape, verify_shape=verify_shape)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) const_tensor = g.create_op( "Const", [], [dtype_value.type], - attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0] + attrs={"value": tensor_value, + "dtype": dtype_value}, + name=name).outputs[0] return const_tensor @@ -131,8 +178,11 @@ ops.register_tensor_conversion_function( object, _constant_tensor_conversion_function, 200) -def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None, +def _tensor_shape_tensor_conversion_function(s, + dtype=None, + name=None, as_ref=False): + """Function to convert TensorShape to Tensor.""" _ = as_ref if not s.is_fully_defined(): raise ValueError( @@ -156,12 +206,16 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None, name = "shape_as_tensor" return constant(s_list, dtype=dtype, name=name) + ops.register_tensor_conversion_function( tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100) -def _dimension_tensor_conversion_function(d, dtype=None, name=None, +def _dimension_tensor_conversion_function(d, + dtype=None, + name=None, as_ref=False): + """Function to convert Dimension to Tensor.""" _ = as_ref if d.value is None: raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d) @@ -174,5 +228,6 @@ def _dimension_tensor_conversion_function(d, dtype=None, name=None, name = "shape_as_tensor" return constant(d.value, dtype=dtype, name=name) + ops.register_tensor_conversion_function( tensor_shape.Dimension, _dimension_tensor_conversion_function, 100) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e01f6f3ce4..5948e59824 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -80,6 +80,11 @@ def _in_gpu_device(): return context.get_default_context()._device_index > 0 # pylint: disable=protected-access +@tf_contextlib.contextmanager +def _null_contextmanager(): + yield + + def _override_helper(clazz_object, operator, func): """Overrides (string) operator on Tensors to call func. @@ -698,8 +703,9 @@ class EagerTensor(Tensor): return numpy_text def __str__(self): - return "tfe.Tensor(shape=%s, dtype=%s, numpy=%s)" % ( - self.shape, self.dtype.name, self._numpy_text()) + return "tfe.Tensor(shape=%s, dtype=%s, numpy=%s)" % (self.shape, + self.dtype.name, + self._numpy_text()) def __repr__(self): return "<tfe.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s)>" % ( @@ -728,10 +734,14 @@ class EagerTensor(Tensor): with errors.raise_exception_on_not_ok_status() as status: return c_api.TFE_Py_TensorHandleToNumpy(cpu._handle, status) # pylint: disable=protected-access - def _copy(self, ctx, device_name): + def _copy(self, ctx=None, device_name=None): """Copies tensor to dest device.""" # pylint: disable=protected-access # Creates a new tensor on the dest device. + if ctx is None: + ctx = context.get_default_context() + if device_name is None: + device_name = ctx.device_name with errors.raise_exception_on_not_ok_status() as status: h = c_api.TFE_TensorHandleCopyToDevice(self._handle, ctx._handle, device_name, status) @@ -863,15 +873,26 @@ class EagerTensor(Tensor): raise NotImplementedError("eval not supported for Eager Tensors.") +# TODO(josh11b): Support other cases like converting TensorShape, lists/tuples and +# other custom conversion functions. def convert_to_eager_tensor(t, dtype=None): + """Converts the given `value` to an `EagerTensor`.""" if isinstance(ag_core.getval(t), EagerTensor): if dtype is not None and t.dtype != dtype: raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) return t + # Handle converting ResourceVariable to Tensor. + # TODO(josh11b): get rid of this explicit ugly conversion once we have a more + # general scheme in place. + try: + return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access + except AttributeError: + pass return EagerTensor(t, dtype=dtype) def convert_n_to_eager_tensor(values, dtype): + """Converts the given `values` to a list of `EagerTensor`.""" return [convert_to_eager_tensor(t, dtype) for t in values] @@ -1092,17 +1113,21 @@ def internal_convert_n_to_tensor(values, """ if not isinstance(values, collections.Sequence): raise TypeError("values must be a list.") - ret = [] - for i, value in enumerate(values): - n = None if name is None else "%s_%d" % (name, i) - ret.append( - internal_convert_to_tensor( - value, - dtype=dtype, - name=n, - as_ref=as_ref, - preferred_dtype=preferred_dtype)) - return ret + if context.in_graph_mode(): + ret = [] + for i, value in enumerate(values): + n = None if name is None else "%s_%d" % (name, i) + ret.append( + internal_convert_to_tensor( + value, + dtype=dtype, + name=n, + as_ref=as_ref, + preferred_dtype=preferred_dtype)) + return ret + else: + # TODO(josh11b): handle preferred_dtype, as_ref + return convert_n_to_eager_tensor(values, dtype=dtype) def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): @@ -2463,7 +2488,7 @@ def _name_from_scope_name(name): Returns: the name of the op (equal to scope name minus any trailing slash). """ - return name[:-1] if name[-1] == "/" else name + return name[:-1] if (name and name[-1] == "/") else name class Graph(object): @@ -4282,6 +4307,8 @@ class Graph(object): return tensor_or_op not in self._unfetchable_ops +# TODO(agarwal): currently device directives in an outer eager scope will not +# apply to inner graph mode code. Fix that. def device(device_name_or_function): """Wrapper for `Graph.device()` using the default graph. @@ -4297,7 +4324,11 @@ def device(device_name_or_function): A context manager that specifies the default device to use for newly created ops. """ - return get_default_graph().device(device_name_or_function) + if context.in_graph_mode(): + return get_default_graph().device(device_name_or_function) + else: + # TODO(agarwal): support device functions in EAGER mode. + return context.device(device_name_or_function) def container(container_name): @@ -4314,7 +4345,15 @@ def container(container_name): def colocate_with(op, ignore_existing=False): - return get_default_graph().colocate_with(op, ignore_existing) + if context.in_graph_mode(): + return get_default_graph().colocate_with(op, ignore_existing) + else: + if not ignore_existing: + raise ValueError("ignore_existing must currently be True in EAGER mode.") + if op: + return device(op.device) + else: + return _null_contextmanager() def control_dependencies(control_inputs): @@ -4333,7 +4372,10 @@ def control_dependencies(control_inputs): A context manager that specifies control dependencies for all operations constructed within the context. """ - return get_default_graph().control_dependencies(control_inputs) + if context.in_graph_mode(): + return get_default_graph().control_dependencies(control_inputs) + else: + return _null_contextmanager() _default_session_stack = context._DefaultStack() # pylint: disable=protected-access @@ -4871,18 +4913,32 @@ def name_scope(name, default_name=None, values=None): but `values` are. """ n = default_name if name is None else name - if n is None and values is not None: - # We only raise an error if values is not None (provided) because currently - # tf.name_scope(None) (values=None then) is sometimes used as an idiom - # to reset to top scope. - raise ValueError( - "At least one of name (%s) and default_name (%s) must be provided." % - (name, default_name)) - if values is None: - values = [] - g = _get_graph_from_inputs(values) - with g.as_default(), g.name_scope(n) as scope: - yield scope + if context.in_eager_mode(): + if n is None: + raise ValueError( + "At least one of name (%s) and default_name (%s) should be provided" % + (name, default_name)) + ctx = context.get_default_context() + old_name = ctx.scope_name + scope_name = "%s/%s" % (old_name, name) if old_name else name + ctx.scope_name = scope_name + try: + yield scope_name + finally: + ctx.scope_name = old_name + else: + if n is None and values is not None: + # We only raise an error if values is not None (provided) because + # currently tf.name_scope(None) (values=None then) is sometimes used as an + # idiom to reset to top scope. + raise ValueError( + "At least one of name (%s) and default_name (%s) must be provided." % + (name, default_name)) + if values is None: + values = [] + g = _get_graph_from_inputs(values) + with g.as_default(), g.name_scope(n) as scope: + yield scope # pylint: enable=g-doc-return-or-yield diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 8c9406593c..4d47b2bc56 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -23,6 +23,7 @@ import six from tensorflow.core.framework import tensor_pb2 from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.util import compat @@ -711,6 +712,8 @@ def constant_value(tensor, partial=False): # pylint: disable=invalid-name Raises: TypeError: if tensor is not an ops.Tensor. """ + if isinstance(tensor, ops.EagerTensor): + return tensor.numpy() ret = _ConstantValue(tensor, partial) if ret is not None: # The caller may now depend on the constant value of `tensor`, so we diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index f2be3f134d..22306e08d7 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1049,6 +1049,8 @@ cuda_py_test( "//tensorflow/python:state_ops", "//tensorflow/python:test_ops", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tensor", ], shard_count = 10, ) @@ -1151,6 +1153,24 @@ cuda_py_test( ) cuda_py_test( + name = "constant_op_eager_test", + size = "small", + srcs = ["constant_op_eager_test.py"], + additional_deps = [ + "//tensorflow/python/eager:core", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + ], +) + +cuda_py_test( name = "control_flow_ops_py_test", size = "small", srcs = ["control_flow_ops_py_test.py"], diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 7b8cd25664..7dc28b6aa9 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -22,10 +22,12 @@ import time import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_ops @@ -168,9 +170,10 @@ class BooleanMaskTest(test_util.TensorFlowTestCase): arr = np.array([[1, 2], [3, 4]]) mask = np.array([False, True]) - masked_tensor = sess.run(array_ops.boolean_mask(ph_tensor, ph_mask), - feed_dict={ph_tensor: arr, - ph_mask: mask}) + masked_tensor = sess.run( + array_ops.boolean_mask(ph_tensor, ph_mask), + feed_dict={ph_tensor: arr, + ph_mask: mask}) np.testing.assert_allclose(masked_tensor, arr[mask]) def testMaskDimensionsSetToNoneRaises(self): @@ -283,14 +286,16 @@ class ReverseV2Test(test_util.TensorFlowTestCase): def testReverse1DimAuto(self): for dtype in [ np.uint8, np.int8, np.int32, np.int64, np.bool, np.float16, np.float32, - np.float64, np.complex64, np.complex128, np.array(b"").dtype.type + np.float64, np.complex64, np.complex128, + np.array(b"").dtype.type ]: self._reverse1DimAuto(dtype) def testReverse2DimAuto(self): for dtype in [ np.uint8, np.int8, np.int32, np.int64, np.bool, np.float16, np.float32, - np.float64, np.complex64, np.complex128, np.array(b"").dtype.type + np.float64, np.complex64, np.complex128, + np.array(b"").dtype.type ]: self._reverse2DimAuto(dtype) @@ -318,8 +323,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase): for outer_size in (1, 2): for middle_size in list(range(50)) + [100000]: x_np = np.reshape( - np.arange( - outer_size * middle_size * 3, dtype=np.float32), + np.arange(outer_size * middle_size * 3, dtype=np.float32), newshape=(outer_size, middle_size, 3)) x_tf = reverse_f(x_np, [1]).eval() np_answer = x_np[:, ::-1, :] @@ -331,8 +335,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase): for outer_size in (1, 2): for middle_size in list(range(50)) + [100000]: x_np = np.reshape( - np.arange( - outer_size * middle_size * 4, dtype=np.float32), + np.arange(outer_size * middle_size * 4, dtype=np.float32), newshape=(outer_size, middle_size, 4)) x_tf = reverse_f(x_np, [1]).eval() np_answer = x_np[:, ::-1, :] @@ -344,8 +347,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase): for outer_size in list(range(50)) + [100000]: for middle_size in (1, 2): x_np = np.reshape( - np.arange( - outer_size * middle_size * 3, dtype=np.float32), + np.arange(outer_size * middle_size * 3, dtype=np.float32), newshape=(outer_size, middle_size, 3)) x_tf = reverse_f(x_np, [0]).eval() np_answer = x_np[::-1, :, :] @@ -437,9 +439,10 @@ class StridedSliceChecker(object): return tensor -STRIDED_SLICE_TYPES = [dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8, - dtypes.float32, dtypes.float64, dtypes.complex64, - dtypes.complex128] +STRIDED_SLICE_TYPES = [ + dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8, dtypes.float32, + dtypes.float64, dtypes.complex64, dtypes.complex128 +] class StridedSliceTest(test_util.TensorFlowTestCase): @@ -600,8 +603,7 @@ class StridedSliceShapeTest(test_util.TensorFlowTestCase): a = StridedSliceShapeChecker(uncertain_tensor) self.tensorShapeEqual(a[3:5], tensor_shape.TensorShape([2, None, 7])) self.tensorShapeEqual(a[3:5, :, 4], tensor_shape.TensorShape([2, None])) - self.tensorShapeEqual(a[3:5, 3:4, 4], - tensor_shape.TensorShape([2, None])) + self.tensorShapeEqual(a[3:5, 3:4, 4], tensor_shape.TensorShape([2, None])) self.tensorShapeEqual(a[3:5, :, 5:10], tensor_shape.TensorShape([2, None, 2])) self.tensorShapeEqual(a[3:5, :, 50:3], @@ -651,15 +653,14 @@ class GradSliceChecker(object): analytic_grad2 = 2 * slice_val dy = variables.Variable( - array_ops.ones( - shape=slice_var.get_shape(), dtype=dtypes.int32)) + array_ops.ones(shape=slice_var.get_shape(), dtype=dtypes.int32)) assign = dy.assign(slice_var) slice_val_grad, = gradients_impl.gradients(slice_val, self.var, grad_ys=dy) slice_val_grad2, = gradients_impl.gradients( slice_val_grad, dy, grad_ys=self.var) self.sess.run(assign) - slice_val_grad_evaled, slice_val_grad2_evaled = ( - self.sess.run([slice_val_grad, slice_val_grad2])) + slice_val_grad_evaled, slice_val_grad2_evaled = (self.sess.run( + [slice_val_grad, slice_val_grad2])) analytic_grad2_evaled = analytic_grad2.eval() self.test.assertAllEqual(slice_val_grad2_evaled, analytic_grad2_evaled) @@ -677,8 +678,7 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase): def testGradient(self): with self.test_session(use_gpu=True) as sess: var = variables.Variable( - array_ops.reshape( - math_ops.range(1, 97, 1), shape=(6, 4, 4))) + array_ops.reshape(math_ops.range(1, 97, 1), shape=(6, 4, 4))) init = variables.global_variables_initializer() sess.run(init) @@ -853,9 +853,10 @@ class SliceAssignTest(test_util.TensorFlowTestCase): def doTestSliceAssign(self, use_resource): for dtype in STRIDED_SLICE_TYPES: - checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]], - use_resource=use_resource, - tensor_type=dtype) + checker = StridedSliceAssignChecker( + self, [[1, 2, 3], [4, 5, 6]], + use_resource=use_resource, + tensor_type=dtype) # Check if equal checker[:] = [[10, 20, 30], [40, 50, 60]] # Check trivial (1,1) shape tensor @@ -943,18 +944,16 @@ class SequenceMaskTest(test_util.TensorFlowTestCase): res = array_ops.sequence_mask( constant_op.constant([0, 1, 4]), dtype=dtypes.float32) self.assertAllEqual(res.get_shape().as_list(), [3, None]) - self.assertAllEqual(res.eval(), [[0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0], + self.assertAllEqual(res.eval(), [[0.0, 0.0, 0.0, + 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]]) def testDtypes(self): def check_dtypes(lengths_dtype, maxlen_dtype): res = array_ops.sequence_mask( - constant_op.constant( - [1, 3, 2], dtype=lengths_dtype), - constant_op.constant( - 5, dtype=maxlen_dtype)) + constant_op.constant([1, 3, 2], dtype=lengths_dtype), + constant_op.constant(5, dtype=maxlen_dtype)) self.assertAllEqual(res.get_shape(), [3, 5]) self.assertAllEqual(res.eval(), [[True, False, False, False, False], [True, True, True, False, False], @@ -979,5 +978,35 @@ class ConcatSliceResourceTest(test_util.TensorFlowTestCase): with self.assertRaises(errors.AlreadyExistsError): test_ops.resource_create_op(r2).run() + +class IdentityTest(test_util.TensorFlowTestCase): + + def testEagerIdentity(self): + with context.eager_mode(): + ctx = context.get_default_context() + if not ctx.num_gpus(): + self.skipTest("No GPUs found") + + def _test(x, y, device): + self.assertIsNot(x, y) + self.assertAllEqual(x.numpy(), y.numpy()) + self.assertTrue(device in y.device.lower()) + + with ops.device("gpu:0"): + a = constant_op.constant([[2], [3]], dtype=dtypes.float32) + with ops.device("gpu:0"): + b = array_ops.identity(a) + _test(a, b, "gpu") + with ops.device("cpu:0"): + c = array_ops.identity(b) + _test(b, c, "cpu") + with ops.device("cpu:0"): + d = array_ops.identity(c) + _test(c, d, "cpu") + with ops.device("gpu:0"): + e = array_ops.identity(d) + _test(d, e, "gpu") + + if __name__ == "__main__": test_lib.main() diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py new file mode 100644 index 0000000000..0e98afbe6e --- /dev/null +++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py @@ -0,0 +1,426 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ConstantOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops + + +# TODO(josh11b): add tests with string types, lists/tuples, Shape. +class ConstantTest(test.TestCase): + + def _testCpu(self, x): + np_ans = np.array(x) + tf_ans = ops.convert_to_tensor(x).numpy() + if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]: + self.assertAllClose(np_ans, tf_ans) + else: + self.assertAllEqual(np_ans, tf_ans) + + def _testGpu(self, x): + np_ans = np.array(x) + tf_ans = ops.convert_to_tensor(x).numpy() + if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]: + self.assertAllClose(np_ans, tf_ans) + else: + self.assertAllEqual(np_ans, tf_ans) + + def _testAll(self, x): + self._testCpu(x) + self._testGpu(x) + + def testFloat(self): + self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32)) + self._testAll( + np.random.normal(size=30).reshape([2, 3, 5]).astype(np.float32)) + self._testAll(np.empty((2, 0, 5)).astype(np.float32)) + + def testDouble(self): + self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float64)) + self._testAll( + np.random.normal(size=30).reshape([2, 3, 5]).astype(np.float64)) + self._testAll(np.empty((2, 0, 5)).astype(np.float64)) + + def testInt32(self): + self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.int32)) + self._testAll( + (100 * np.random.normal(size=30)).reshape([2, 3, 5]).astype(np.int32)) + self._testAll(np.empty((2, 0, 5)).astype(np.int32)) + + def testInt64(self): + self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.int64)) + self._testAll( + (100 * np.random.normal(size=30)).reshape([2, 3, 5]).astype(np.int64)) + self._testAll(np.empty((2, 0, 5)).astype(np.int64)) + + def testComplex64(self): + self._testAll( + np.complex(1, 2) * np.arange(-15, 15).reshape([2, 3, 5 + ]).astype(np.complex64)) + self._testAll( + np.complex(1, 2) * np.random.normal(size=30).reshape( + [2, 3, 5]).astype(np.complex64)) + self._testAll(np.empty((2, 0, 5)).astype(np.complex64)) + + def testComplex128(self): + self._testAll( + np.complex(1, 2) * np.arange(-15, 15).reshape([2, 3, 5 + ]).astype(np.complex128)) + self._testAll( + np.complex(1, 2) * np.random.normal(size=30).reshape( + [2, 3, 5]).astype(np.complex128)) + self._testAll(np.empty((2, 0, 5)).astype(np.complex128)) + + def testExplicitShapeNumPy(self): + c = constant_op.constant( + np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32), + shape=[2, 3, 5]) + self.assertEqual(c.get_shape(), [2, 3, 5]) + + def testImplicitShapeNumPy(self): + c = constant_op.constant( + np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32)) + self.assertEqual(c.get_shape(), [2, 3, 5]) + + def testExplicitShapeList(self): + c = constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[7]) + self.assertEqual(c.get_shape(), [7]) + + def testExplicitShapeFill(self): + c = constant_op.constant(12, shape=[7]) + self.assertEqual(c.get_shape(), [7]) + self.assertAllEqual([12, 12, 12, 12, 12, 12, 12], c.numpy()) + + def testExplicitShapeReshape(self): + c = constant_op.constant( + np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32), + shape=[5, 2, 3]) + self.assertEqual(c.get_shape(), [5, 2, 3]) + + def testImplicitShapeList(self): + c = constant_op.constant([1, 2, 3, 4, 5, 6, 7]) + self.assertEqual(c.get_shape(), [7]) + + def testExplicitShapeNumber(self): + c = constant_op.constant(1, shape=[1]) + self.assertEqual(c.get_shape(), [1]) + + def testImplicitShapeNumber(self): + c = constant_op.constant(1) + self.assertEqual(c.get_shape(), []) + + def testShapeTooBig(self): + with self.assertRaises(TypeError): + constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[10]) + + def testShapeTooSmall(self): + with self.assertRaises(TypeError): + constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[5]) + + def testShapeWrong(self): + with self.assertRaisesRegexp(TypeError, None): + constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[5]) + + def testSparseValuesRaiseErrors(self): + with self.assertRaisesRegexp(ValueError, + "setting an array element with a sequence"): + constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, None): + constant_op.constant([[1, 2], [3]]) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, None): + constant_op.constant([[1, 2], [3], [4, 5]]) + + +class AsTensorTest(test.TestCase): + + def testAsTensorForTensorInput(self): + t = constant_op.constant(10.0) + x = ops.convert_to_tensor(t) + self.assertIs(t, x) + + def testAsTensorForNonTensorInput(self): + x = ops.convert_to_tensor(10.0) + self.assertTrue(isinstance(x, ops.EagerTensor)) + + +class ZerosTest(test.TestCase): + + def _Zeros(self, shape): + ret = array_ops.zeros(shape) + self.assertEqual(shape, ret.get_shape()) + return ret.numpy() + + def testConst(self): + self.assertTrue( + np.array_equal(self._Zeros([2, 3]), np.array([[0] * 3] * 2))) + + def testScalar(self): + self.assertEqual(0, self._Zeros([])) + self.assertEqual(0, self._Zeros(())) + scalar = array_ops.zeros(constant_op.constant([], dtype=dtypes_lib.int32)) + self.assertEqual(0, scalar.numpy()) + + def testDynamicSizes(self): + np_ans = np.array([[0] * 3] * 2) + # Creates a tensor of 2 x 3. + d = array_ops.fill([2, 3], 12., name="fill") + # Constructs a tensor of zeros of the same dimensions as "d". + z = array_ops.zeros(array_ops.shape(d)) + out = z.numpy() + self.assertAllEqual(np_ans, out) + self.assertShapeEqual(np_ans, d) + self.assertShapeEqual(np_ans, z) + + def testDtype(self): + d = array_ops.fill([2, 3], 12., name="fill") + self.assertEqual(d.get_shape(), [2, 3]) + # Test default type for both constant size and dynamic size + z = array_ops.zeros([2, 3]) + self.assertEqual(z.dtype, dtypes_lib.float32) + self.assertEqual([2, 3], z.get_shape()) + self.assertAllEqual(z.numpy(), np.zeros([2, 3])) + z = array_ops.zeros(array_ops.shape(d)) + self.assertEqual(z.dtype, dtypes_lib.float32) + self.assertEqual([2, 3], z.get_shape()) + self.assertAllEqual(z.numpy(), np.zeros([2, 3])) + # Test explicit type control + for dtype in [ + dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, + dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, + dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64, + dtypes_lib.bool, + # TODO(josh11b): Support string type here. + # dtypes_lib.string + ]: + z = array_ops.zeros([2, 3], dtype=dtype) + self.assertEqual(z.dtype, dtype) + self.assertEqual([2, 3], z.get_shape()) + z_value = z.numpy() + self.assertFalse(np.any(z_value)) + self.assertEqual((2, 3), z_value.shape) + z = array_ops.zeros(array_ops.shape(d), dtype=dtype) + self.assertEqual(z.dtype, dtype) + self.assertEqual([2, 3], z.get_shape()) + z_value = z.numpy() + self.assertFalse(np.any(z_value)) + self.assertEqual((2, 3), z_value.shape) + + +class ZerosLikeTest(test.TestCase): + + def _compareZeros(self, dtype, use_gpu): + # Creates a tensor of non-zero values with shape 2 x 3. + # NOTE(kearnes): The default numpy dtype associated with tf.string is + # np.object (and can't be changed without breaking a lot things), which + # causes a TypeError in constant_op.constant below. Here we catch the + # special case of tf.string and set the numpy dtype appropriately. + if dtype == dtypes_lib.string: + numpy_dtype = np.string_ + else: + numpy_dtype = dtype.as_numpy_dtype + d = constant_op.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype) + # Constructs a tensor of zeros of the same dimensions and type as "d". + z_var = array_ops.zeros_like(d) + # Test that the type is correct + self.assertEqual(z_var.dtype, dtype) + # Test that the shape is correct + self.assertEqual([2, 3], z_var.get_shape()) + + # Test that the value is correct + z_value = z_var.numpy() + self.assertFalse(np.any(z_value)) + self.assertEqual((2, 3), z_value.shape) + + def testZerosLikeCPU(self): + for dtype in [ + dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, + dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, + dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64, + # TODO(josh11b): Support string type here. + # dtypes_lib.string + ]: + self._compareZeros(dtype, use_gpu=False) + + def testZerosLikeGPU(self): + for dtype in [ + dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, + dtypes_lib.bool, dtypes_lib.int64, + # TODO(josh11b): Support string type here. + # dtypes_lib.string + ]: + self._compareZeros(dtype, use_gpu=True) + + def testZerosLikeDtype(self): + # Make sure zeros_like works even for dtypes that cannot be cast between + shape = (3, 5) + dtypes = np.float32, np.complex64 + for in_type in dtypes: + x = np.arange(15).astype(in_type).reshape(*shape) + for out_type in dtypes: + y = array_ops.zeros_like(x, dtype=out_type).numpy() + self.assertEqual(y.dtype, out_type) + self.assertEqual(y.shape, shape) + self.assertAllEqual(y, np.zeros(shape, dtype=out_type)) + + +class OnesTest(test.TestCase): + + def _Ones(self, shape): + ret = array_ops.ones(shape) + self.assertEqual(shape, ret.get_shape()) + return ret.numpy() + + def testConst(self): + self.assertTrue(np.array_equal(self._Ones([2, 3]), np.array([[1] * 3] * 2))) + + def testScalar(self): + self.assertEqual(1, self._Ones([])) + self.assertEqual(1, self._Ones(())) + scalar = array_ops.ones(constant_op.constant([], dtype=dtypes_lib.int32)) + self.assertEqual(1, scalar.numpy()) + + def testDynamicSizes(self): + np_ans = np.array([[1] * 3] * 2) + # Creates a tensor of 2 x 3. + d = array_ops.fill([2, 3], 12., name="fill") + # Constructs a tensor of ones of the same dimensions as "d". + z = array_ops.ones(array_ops.shape(d)) + out = z.numpy() + self.assertAllEqual(np_ans, out) + self.assertShapeEqual(np_ans, d) + self.assertShapeEqual(np_ans, z) + + def testDtype(self): + d = array_ops.fill([2, 3], 12., name="fill") + self.assertEqual(d.get_shape(), [2, 3]) + # Test default type for both constant size and dynamic size + z = array_ops.ones([2, 3]) + self.assertEqual(z.dtype, dtypes_lib.float32) + self.assertEqual([2, 3], z.get_shape()) + self.assertAllEqual(z.numpy(), np.ones([2, 3])) + z = array_ops.ones(array_ops.shape(d)) + self.assertEqual(z.dtype, dtypes_lib.float32) + self.assertEqual([2, 3], z.get_shape()) + self.assertAllEqual(z.numpy(), np.ones([2, 3])) + # Test explicit type control + for dtype in (dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, + dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, + dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64, + dtypes_lib.bool): + z = array_ops.ones([2, 3], dtype=dtype) + self.assertEqual(z.dtype, dtype) + self.assertEqual([2, 3], z.get_shape()) + self.assertAllEqual(z.numpy(), np.ones([2, 3])) + z = array_ops.ones(array_ops.shape(d), dtype=dtype) + self.assertEqual(z.dtype, dtype) + self.assertEqual([2, 3], z.get_shape()) + self.assertAllEqual(z.numpy(), np.ones([2, 3])) + + +class OnesLikeTest(test.TestCase): + + def testOnesLike(self): + for dtype in [ + dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, + dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, + dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64 + ]: + numpy_dtype = dtype.as_numpy_dtype + # Creates a tensor of non-zero values with shape 2 x 3. + d = constant_op.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype) + # Constructs a tensor of zeros of the same dimensions and type as "d". + z_var = array_ops.ones_like(d) + # Test that the type is correct + self.assertEqual(z_var.dtype, dtype) + z_value = z_var.numpy() + + # Test that the value is correct + self.assertTrue(np.array_equal(z_value, np.array([[1] * 3] * 2))) + self.assertEqual([2, 3], z_var.get_shape()) + + +class FillTest(test.TestCase): + + def _compare(self, dims, val, np_ans, use_gpu): + ctx = context.get_default_context() + device = "GPU:0" if (use_gpu and ctx.num_gpus()) else "CPU:0" + with ops.device(device): + tf_ans = array_ops.fill(dims, val, name="fill") + out = tf_ans.numpy() + self.assertAllClose(np_ans, out) + + def _compareAll(self, dims, val, np_ans): + self._compare(dims, val, np_ans, False) + self._compare(dims, val, np_ans, True) + + def testFillFloat(self): + np_ans = np.array([[3.1415] * 3] * 2).astype(np.float32) + self._compareAll([2, 3], np_ans[0][0], np_ans) + + def testFillDouble(self): + np_ans = np.array([[3.1415] * 3] * 2).astype(np.float64) + self._compareAll([2, 3], np_ans[0][0], np_ans) + + def testFillInt32(self): + np_ans = np.array([[42] * 3] * 2).astype(np.int32) + self._compareAll([2, 3], np_ans[0][0], np_ans) + + def testFillInt64(self): + np_ans = np.array([[-42] * 3] * 2).astype(np.int64) + self._compareAll([2, 3], np_ans[0][0], np_ans) + + def testFillComplex64(self): + np_ans = np.array([[0.15] * 3] * 2).astype(np.complex64) + self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False) + + def testFillComplex128(self): + np_ans = np.array([[0.15] * 3] * 2).astype(np.complex128) + self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False) + + def testFillString(self): + np_ans = np.array([[b"yolo"] * 3] * 2) + tf_ans = array_ops.fill([2, 3], np_ans[0][0], name="fill").numpy() + self.assertAllEqual(np_ans, tf_ans) + + def testFillNegative(self): + for shape in (-1,), (2, -1), (-1, 2), (-2), (-3): + with self.assertRaises(errors_impl.InvalidArgumentError): + array_ops.fill(shape, 7) + + def testShapeFunctionEdgeCases(self): + # Non-vector dimensions. + with self.assertRaises(errors_impl.InvalidArgumentError): + array_ops.fill([[0, 1], [2, 3]], 1.0) + + # Non-scalar value. + with self.assertRaises(errors_impl.InvalidArgumentError): + array_ops.fill([3, 2], [1.0, 2.0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 20c21eca4c..e31bd8d0ee 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -43,23 +44,31 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.assign_variable_op( handle, constant_op.constant(0.0, dtype=dtypes.float32)).run() with self.assertRaises(ValueError): - resource_variable_ops.assign_variable_op( - handle, constant_op.constant([0], dtype=dtypes.int32)).run() - resource_variable_ops.assign_variable_op( - handle, constant_op.constant(0, dtype=dtypes.int32)).run() + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [0], + dtype=dtypes.int32)).run() + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + 0, + dtype=dtypes.int32)).run() def testDtypeSurvivesIdentity(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) - resource_variable_ops.assign_variable_op( - id_handle, constant_op.constant(0, dtype=dtypes.int32)).run() + resource_variable_ops.assign_variable_op(id_handle, + constant_op.constant( + 0, + dtype=dtypes.int32)).run() def testCreateRead(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) - resource_variable_ops.assign_variable_op( - handle, constant_op.constant(1, dtype=dtypes.int32)).run() + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + 1, + dtype=dtypes.int32)).run() value = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32).eval() self.assertAllEqual(1, value) @@ -67,8 +76,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): def testManyAssigns(self): with self.test_session() as session: handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) - create = resource_variable_ops.assign_variable_op( - handle, constant_op.constant(1, dtype=dtypes.int32)) + create = resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + 1, + dtype=dtypes.int32)) with ops.control_dependencies([create]): first_read = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32) @@ -85,8 +96,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): def testAssignAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) - resource_variable_ops.assign_variable_op( - handle, constant_op.constant(1, dtype=dtypes.int32)).run() + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + 1, + dtype=dtypes.int32)).run() resource_variable_ops.assign_add_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) @@ -96,10 +109,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) - resource_variable_ops.assign_variable_op( - handle, constant_op.constant([[1]], dtype=dtypes.int32)).run() - resource_variable_ops.resource_scatter_add( - handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)).run() + resource_variable_ops.assign_variable_op(handle, + constant_op.constant( + [[1]], + dtype=dtypes.int32)).run() + resource_variable_ops.resource_scatter_add(handle, [0], + constant_op.constant( + [[2]], + dtype=dtypes.int32)).run() read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(read.eval(), [[3]]) @@ -113,32 +130,31 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): sess.run(variables.global_variables_initializer()) self.assertEqual( - resource_variable_ops.var_is_initialized_op(abc.handle).eval(), - True) + resource_variable_ops.var_is_initialized_op(abc.handle).eval(), True) print(sess.run(abc)) def testConstraintArg(self): constraint = lambda x: x - v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, - constraint=constraint) + v = resource_variable_ops.ResourceVariable( + initial_value=lambda: 1, constraint=constraint) self.assertEqual(v.constraint, constraint) constraint = 0 with self.assertRaises(ValueError): - v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, - constraint=constraint) + v = resource_variable_ops.ResourceVariable( + initial_value=lambda: 1, constraint=constraint) def testInitFn(self): with self.test_session(): - v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, - dtype=dtypes.float32) + v = resource_variable_ops.ResourceVariable( + initial_value=lambda: 1, dtype=dtypes.float32) self.assertEqual(v.handle.op.colocation_groups(), v.initializer.inputs[1].op.colocation_groups()) def testInitFnDtype(self): with self.test_session(): - v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1, - dtype=dtypes.float32) + v = resource_variable_ops.ResourceVariable( + initial_value=lambda: 1, dtype=dtypes.float32) self.assertEqual(dtypes.float32, v.value().dtype) def testInitFnNoDtype(self): @@ -158,7 +174,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.test_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() - self.assertEqual(2.0, (v+v).eval()) + self.assertEqual(2.0, (v + v).eval()) def testAssignMethod(self): with self.test_session(): @@ -217,8 +233,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): # Handle to a resource not actually created. handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) # Should raise no exception - sess.run(resource_variable_ops.destroy_resource_op( - handle, ignore_lookup_error=True)) + sess.run( + resource_variable_ops.destroy_resource_op( + handle, ignore_lookup_error=True)) def testAssignDifferentShapes(self): with self.test_session() as sess, variable_scope.variable_scope( @@ -226,9 +243,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): var = variable_scope.get_variable("x", shape=[1, 1], dtype=dtypes.float32) placeholder = array_ops.placeholder(dtypes.float32) assign = var.assign(placeholder) - sess.run([assign], - feed_dict={placeholder: np.zeros(shape=[2, 2], - dtype=np.float32)}) + sess.run( + [assign], + feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)}) def testDtypeAfterFromProto(self): v = resource_variable_ops.ResourceVariable(2.0) @@ -256,19 +273,28 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable(300.0, name="var1") v.initializer.run() - w = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype, - shape=v.get_shape(), - shared_name="var1") + w = resource_variable_ops.var_handle_op( + dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1") w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) self.assertEqual(300.0, w_read.eval()) - x = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype, - shape=v.get_shape(), - shared_name="var1/") + x = resource_variable_ops.var_handle_op( + dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1/") x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype) with self.assertRaisesOpError("Resource .*/var1//.* does not exist"): _ = x_read.eval() + def testSharedNameWithNamescope(self): + with self.test_session(): + with ops.name_scope("foo"): + v = resource_variable_ops.ResourceVariable(300.0, name="var1") + v.initializer.run() + + w = resource_variable_ops.var_handle_op( + dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1") + w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype) + self.assertEqual(300.0, w_read.eval()) + def testShape(self): with self.test_session(): v = resource_variable_ops.ResourceVariable( @@ -291,6 +317,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): def testControlFlowInitialization(self): """Expects an error if an initializer is in a control-flow scope.""" + def cond(i, _): return i < 10 @@ -303,5 +330,61 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): control_flow_ops.while_loop(cond, body, [0, 0]) +# TODO(agarwal,apassos): Add more comprehensive tests and/or translate the above +# tests to work in both GRAPH and EAGER modes. +# TODO(agarwal): Add tests for sparse_read, scatter_sub +class ResourceVariableOpsEagerTest(test_util.TensorFlowTestCase): + + def testVariable(self): + with context.eager_mode(): + init = array_ops.ones(shape=[10, 20, 35], dtype=dtypes.int32) + constraint = lambda x: x + with ops.name_scope("foo"): + v = resource_variable_ops.ResourceVariable( + name="var1", + initial_value=init, + caching_device="cpu:0", + constraint=constraint) + # Test properties + self.assertEqual(dtypes.int32, v.dtype) + self.assertEqual("foo/var1:0", v.name) + self.assertAllEqual([10, 20, 35], v.shape.as_list()) + self.assertAllEqual(init.device, v.device) + self.assertTrue(isinstance(v.handle, ops.EagerTensor)) + self.assertEqual(constraint, v.constraint) + self.assertAllEqual(init.numpy(), v.read_value().numpy()) + self.assertAllEqual(init.numpy(), v.value().numpy()) + + # Callable init. + callable_init = lambda: init * 2 + v2 = resource_variable_ops.ResourceVariable( + initial_value=callable_init, name="v2") + self.assertEqual("v2:0", v2.name) + self.assertAllEqual(2 * init.numpy(), v2.read_value().numpy()) + + # Test assign_add. + new_v2_val = v2.assign_add(v.read_value()) + self.assertAllEqual(v.read_value().numpy() * 3, new_v2_val.numpy()) + + # Test assign_sub. + new_v2_val = v2.assign_sub(v.read_value()) + self.assertAllEqual(v.read_value().numpy() * 2, new_v2_val.numpy()) + + # Test assign. + v2.assign(v.read_value()) + self.assertAllEqual(v.read_value().numpy(), v2.read_value().numpy()) + + # Test load + v2.load(2 * v.read_value()) + self.assertAllEqual(2 * v.read_value().numpy(), v2.read_value().numpy()) + + # Test convert_to_tensor + t = ops.convert_to_tensor(v) + self.assertAllEqual(t.numpy(), v.read_value().numpy()) + + # Test operations + self.assertAllEqual((v * 2).numpy(), (v + v).numpy()) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 7fbc12d1e0..9c9b852b76 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Support for manipulating tensors. See the @{$python/array_ops} guide. @@ -85,6 +84,7 @@ from __future__ import print_function import sys import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -103,7 +103,6 @@ from tensorflow.python.util import deprecation from tensorflow.python.util.deprecation import deprecated # pylint: enable=wildcard-import - # Used for slicing to specify a new 1 size dimension newaxis = None @@ -112,6 +111,23 @@ newaxis = None _baseslice = slice +def identity(input, name=None): # pylint: disable=redefined-builtin + r"""Return a tensor with the same shape and contents as input. + + Args: + input: A `Tensor`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + if context.in_graph_mode(): + return gen_array_ops.identity(input, name=name) + else: + # TODO(apassos): make sure this works ok with gradients. + return input._copy() # pylint: disable=protected-access + + # pylint: disable=redefined-builtin,protected-access def expand_dims(input, axis=None, name=None, dim=None): """Inserts a dimension of 1 into a tensor's shape. @@ -168,25 +184,31 @@ def expand_dims(input, axis=None, name=None, dim=None): raise ValueError("can't specify both 'dim' and 'axis'") axis = dim return gen_array_ops._expand_dims(input, axis, name) + + # pylint: enable=redefined-builtin,protected-access # Aliases for some automatically-generated names. # pylint: disable=protected-access -@deprecated( - "2016-11-30", - "This op will be removed after the deprecation date. " - "Please switch to tf.setdiff1d().") +@deprecated("2016-11-30", "This op will be removed after the deprecation date. " + "Please switch to tf.setdiff1d().") def listdiff(x, y, out_idx=None, name=None): return gen_array_ops._list_diff(x, y, out_idx, name) + + listdiff.__doc__ = gen_array_ops._list_diff.__doc__ + "\n" + listdiff.__doc__ + # pylint: enable=protected-access # pylint: disable=undefined-variable,protected-access def setdiff1d(x, y, index_dtype=dtypes.int32, name=None): return gen_array_ops._list_diff(x, y, index_dtype, name) + + setdiff1d.__doc__ = gen_array_ops._list_diff.__doc__ + # pylint: enable=protected-access @@ -262,8 +284,8 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32): """ with ops.name_scope(name, "Shape", [input]) as name: - if isinstance( - input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): + if isinstance(input, (sparse_tensor.SparseTensor, + sparse_tensor.SparseTensorValue)): return gen_math_ops.cast(input.dense_shape, out_type) else: input_tensor = ops.convert_to_tensor(input) @@ -314,8 +336,8 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32): A `Tensor` of type `out_type`. """ with ops.name_scope(name, "Size", [input]) as name: - if isinstance( - input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): + if isinstance(input, (sparse_tensor.SparseTensor, + sparse_tensor.SparseTensorValue)): return gen_math_ops._prod( gen_math_ops.cast(input.dense_shape, out_type), 0, name=name) else: @@ -371,8 +393,8 @@ def rank_internal(input, name=None, optimize=True): A `Tensor` of type `int32`. """ with ops.name_scope(name, "Rank", [input]) as name: - if isinstance( - input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): + if isinstance(input, (sparse_tensor.SparseTensor, + sparse_tensor.SparseTensorValue)): return gen_array_ops.size(input.dense_shape, name=name) else: input_tensor = ops.convert_to_tensor(input) @@ -412,7 +434,8 @@ def _SliceHelper(tensor, slice_spec, var=None): foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) print(foo[tf.newaxis, :, :].eval()) # => [[[1,2,3], [4,5,6], [7,8,9]]] print(foo[:, tf.newaxis, :].eval()) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]] - print(foo[:, :, tf.newaxis].eval()) # => [[[1],[2],[3]], [[4],[5],[6]], [[7],[8],[9]]] + print(foo[:, :, tf.newaxis].eval()) # => [[[1],[2],[3]], [[4],[5],[6]], + [[7],[8],[9]]] # Ellipses (3 equivalent operations) foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) @@ -491,8 +514,8 @@ def _SliceHelper(tensor, slice_spec, var=None): with ops.name_scope(None, "strided_slice", [tensor] + begin + end + strides) as name: if begin: - packed_begin, packed_end, packed_strides = ( - stack(begin), stack(end), stack(strides)) + packed_begin, packed_end, packed_strides = (stack(begin), stack(end), + stack(strides)) else: var_empty = constant([], dtype=dtypes.int32) packed_begin = packed_end = packed_strides = var_empty @@ -748,6 +771,7 @@ def _SliceHelperVar(var, slice_spec): return _SliceHelper(var._AsTensor(), slice_spec, var) + ops.Tensor._override_operator("__getitem__", _SliceHelper) @@ -851,8 +875,8 @@ def stack(values, axis=0, name="stack"): if value_shape.ndims is not None: expanded_num_dims = value_shape.ndims + 1 if axis < -expanded_num_dims or axis >= expanded_num_dims: - raise ValueError("axis = %d not in [%d, %d)" % - (axis, -expanded_num_dims, expanded_num_dims)) + raise ValueError("axis = %d not in [%d, %d)" % (axis, -expanded_num_dims, + expanded_num_dims)) return gen_array_ops._pack(values, axis=axis, name=name) @@ -875,9 +899,9 @@ def _autopacking_helper(list_or_tuple, dtype, name): for i, elem in enumerate(list_or_tuple): if ops.is_dense_tensor_like(elem): if dtype is not None and elem.dtype.base_dtype != dtype: - raise TypeError( - "Cannot convert a list containing a tensor of dtype " - "%s to %s (Tensor is: %r)" % (elem.dtype, dtype, elem)) + raise TypeError("Cannot convert a list containing a tensor of dtype " + "%s to %s (Tensor is: %r)" % (elem.dtype, dtype, + elem)) converted_elems.append(elem) must_pack = True elif isinstance(elem, (list, tuple)): @@ -936,13 +960,14 @@ def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False): if dtype is not None and dtype != inferred_dtype: return NotImplemented return _autopacking_helper(v, inferred_dtype, name or "packed") -# pylint: enable=invalid-name +# pylint: enable=invalid-name + # NOTE: Register this conversion function to run *before* one that # assumes every element is a value. -ops.register_tensor_conversion_function( - (list, tuple), _autopacking_conversion_function, 99) +ops.register_tensor_conversion_function((list, tuple), + _autopacking_conversion_function, 99) def unstack(value, num=None, axis=0, name="unstack"): @@ -1058,14 +1083,12 @@ def concat(values, axis, name="concat"): # the returned tensor is a scalar. # TODO(keveman): Implement a standalone type and shape checker. with ops.name_scope(name) as scope: - ops.convert_to_tensor(axis, - name="concat_dim", - dtype=dtypes.int32).get_shape( - ).assert_is_compatible_with(tensor_shape.scalar()) + ops.convert_to_tensor( + axis, name="concat_dim", + dtype=dtypes.int32).get_shape().assert_is_compatible_with( + tensor_shape.scalar()) return identity(values[0], name=scope) - return gen_array_ops._concat_v2(values=values, - axis=axis, - name=name) + return gen_array_ops._concat_v2(values=values, axis=axis, name=name) def boolean_mask(tensor, mask, name="boolean_mask"): @@ -1104,6 +1127,7 @@ def boolean_mask(tensor, mask, name="boolean_mask"): boolean_mask(tensor, mask) ==> [[1, 2], [5, 6]] ``` """ + def _apply_mask_1d(reshaped_tensor, mask): """Mask tensor along dimension 0 with a 1-D mask.""" indices = squeeze(where(mask), squeeze_dims=[1]) @@ -1125,9 +1149,9 @@ def boolean_mask(tensor, mask, name="boolean_mask"): shape_tensor[:ndims_mask].assert_is_compatible_with(shape_mask) leading_size = gen_math_ops._prod(shape(tensor)[:ndims_mask], [0]) - tensor = reshape( - tensor, - concat([[leading_size], shape(tensor)[ndims_mask:]], 0)) + tensor = reshape(tensor, + concat([[leading_size], + shape(tensor)[ndims_mask:]], 0)) first_dim = shape_tensor[:ndims_mask].num_elements() tensor.set_shape( tensor_shape.as_shape([first_dim]) @@ -1363,10 +1387,12 @@ def matrix_transpose(a, name="matrix_transpose"): perm = list(range(ndims - 2)) + [ndims - 1] + [ndims - 2] else: a_rank = rank(a) - perm = concat( - (gen_math_ops._range(0, a_rank - 2, 1), [a_rank - 1, a_rank - 2]), 0) + perm = concat((gen_math_ops._range(0, a_rank - 2, 1), + [a_rank - 1, a_rank - 2]), 0) return transpose(a, perm=perm) + + # pylint: enable=invalid-name @@ -1383,7 +1409,8 @@ def zeros(shape, dtype=dtypes.float32, name=None): ``` Args: - shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type `int32`. + shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type + `int32`. dtype: The type of an element in the resulting `Tensor`. name: A name for the operation (optional). @@ -1442,8 +1469,8 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): return zeros(tensor.shape, dtype=dtype or tensor.dtype, name=name) if dtype is not None and dtype != tensor.dtype: - return zeros(shape_internal(tensor, optimize=optimize), dtype=dtype, - name=name) + return zeros( + shape_internal(tensor, optimize=optimize), dtype=dtype, name=name) else: return gen_array_ops._zeros_like(tensor, name=name) @@ -1480,7 +1507,8 @@ def ones_like(tensor, dtype=None, name=None, optimize=True): if dtype is None: dtype = tensor.dtype ret = ones(ones_shape, dtype=dtype, name=name) - ret.set_shape(tensor.get_shape()) + if context.in_graph_mode(): + ret.set_shape(tensor.get_shape()) return ret @@ -1497,7 +1525,8 @@ def ones(shape, dtype=dtypes.float32, name=None): ``` Args: - shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type `int32`. + shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type + `int32`. dtype: The type of an element in the resulting `Tensor`. name: A name for the operation (optional). @@ -1553,7 +1582,8 @@ def placeholder(dtype, shape=None, name=None): # pylint: disable=redefined-outer-name def _normalize_sparse_shape(shape, name): """Returns a tuple of (Tensor or None, rank or None).""" - if shape is None: return (None, None) + if shape is None: + return (None, None) rank = shape.get_shape()[0] if isinstance(shape, ops.Tensor) else len(shape) if not isinstance(shape, ops.Tensor) and None in shape: return (None, rank) @@ -1605,12 +1635,16 @@ def sparse_placeholder(dtype, shape=None, name=None): shape = placeholder(dtypes.int64, shape=[rank], name=shape_name) return sparse_tensor.SparseTensor( values=placeholder( - dtype, shape=[None], + dtype, + shape=[None], name=(name + "/values") if name is not None else None), indices=placeholder( - dtypes.int64, shape=[None, None], + dtypes.int64, + shape=[None, None], name=(name + "/indices") if name is not None else None), dense_shape=shape) + + # pylint: enable=redefined-outer-name @@ -1764,15 +1798,15 @@ def meshgrid(*args, **kwargs): # Prepare reshape by inserting dimensions with size 1 where needed output = [] for i, x in enumerate(args): - output.append(reshape(stack(x), (s0[:i] + (-1,) + s0[i + 1::])) ) + output.append(reshape(stack(x), (s0[:i] + (-1,) + s0[i + 1::]))) # Create parameters for broadcasting each tensor to the full size shapes = [size(x) for x in args] output_dtype = ops.convert_to_tensor(args[0]).dtype.base_dtype if indexing == "xy" and ndim > 1: - output[0] = reshape(output[0], (1, -1) + (1,)*(ndim - 2)) - output[1] = reshape(output[1], (-1, 1) + (1,)*(ndim - 2)) + output[0] = reshape(output[0], (1, -1) + (1,) * (ndim - 2)) + output[1] = reshape(output[1], (-1, 1) + (1,) * (ndim - 2)) shapes[0], shapes[1] = shapes[1], shapes[0] # TODO: improve performance with a broadcast @@ -1903,23 +1937,22 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"): Raises: TypeError: If either `hypothesis` or `truth` are not a `SparseTensor`. """ - if not isinstance( - hypothesis, (sparse_tensor.SparseTensor, - sparse_tensor.SparseTensorValue)): + if not isinstance(hypothesis, (sparse_tensor.SparseTensor, + sparse_tensor.SparseTensorValue)): raise TypeError("Hypothesis must be a SparseTensor.") - if not isinstance( - truth, (sparse_tensor.SparseTensor, - sparse_tensor.SparseTensorValue)): + if not isinstance(truth, (sparse_tensor.SparseTensor, + sparse_tensor.SparseTensorValue)): raise TypeError("Truth must be a SparseTensor.") - return gen_array_ops._edit_distance(hypothesis.indices, - hypothesis.values, - hypothesis.dense_shape, - truth.indices, - truth.values, - truth.dense_shape, - normalize=normalize, - name=name) + return gen_array_ops._edit_distance( + hypothesis.indices, + hypothesis.values, + hypothesis.dense_shape, + truth.indices, + truth.values, + truth.dense_shape, + normalize=normalize, + name=name) @ops.RegisterGradient("FakeQuantWithMinMaxArgs") @@ -1992,12 +2025,10 @@ def required_space_to_batch_paddings(input_shape, """ with ops.name_scope(name, "required_space_to_batch_paddings", [input_shape, block_shape]): - input_shape = ops.convert_to_tensor(input_shape, - dtype=dtypes.int32, - name="input_shape") - block_shape = ops.convert_to_tensor(block_shape, - dtype=dtypes.int32, - name="block_shape") + input_shape = ops.convert_to_tensor( + input_shape, dtype=dtypes.int32, name="input_shape") + block_shape = ops.convert_to_tensor( + block_shape, dtype=dtypes.int32, name="block_shape") block_shape.get_shape().assert_is_fully_defined() block_shape.get_shape().assert_has_rank(1) @@ -2008,9 +2039,8 @@ def required_space_to_batch_paddings(input_shape, input_shape.get_shape().assert_is_compatible_with([num_block_dims]) if base_paddings is not None: - base_paddings = ops.convert_to_tensor(base_paddings, - dtype=dtypes.int32, - name="base_paddings") + base_paddings = ops.convert_to_tensor( + base_paddings, dtype=dtypes.int32, name="base_paddings") base_paddings.get_shape().assert_is_compatible_with([num_block_dims, 2]) else: base_paddings = zeros([num_block_dims, 2], dtypes.int32) @@ -2040,11 +2070,11 @@ def required_space_to_batch_paddings(input_shape, def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=redefined-builtin - result = space_to_batch_nd(input, - paddings=paddings, - block_shape=np.array([block_size, block_size], - dtype=np.int64), - name=name) + result = space_to_batch_nd( + input, + paddings=paddings, + block_shape=np.array([block_size, block_size], dtype=np.int64), + name=name) result.set_shape(result.get_shape().with_rank(4)) return result @@ -2053,11 +2083,11 @@ space_to_batch.__doc__ = gen_array_ops._space_to_batch.__doc__ def batch_to_space(input, crops, block_size, name=None): # pylint: disable=redefined-builtin - result = batch_to_space_nd(input, - crops=crops, - block_shape=np.array([block_size, block_size], - dtype=np.int64), - name=name) + result = batch_to_space_nd( + input, + crops=crops, + block_shape=np.array([block_size, block_size], dtype=np.int64), + name=name) result.set_shape(result.get_shape().with_rank(4)) return result @@ -2065,8 +2095,13 @@ def batch_to_space(input, crops, block_size, name=None): # pylint: disable=rede batch_to_space.__doc__ = gen_array_ops._batch_to_space.__doc__ -def one_hot(indices, depth, on_value=None, off_value=None, - axis=None, dtype=None, name=None): +def one_hot(indices, + depth, + on_value=None, + off_value=None, + axis=None, + dtype=None, + name=None): """Returns a one-hot tensor. The locations represented by indices in `indices` take value `on_value`, @@ -2190,8 +2225,9 @@ def one_hot(indices, depth, on_value=None, off_value=None, TypeError: If dtype of either `on_value` or `off_value` don't match `dtype` TypeError: If dtype of `on_value` and `off_value` don't match one another """ - with ops.name_scope(name, "one_hot", [indices, depth, on_value, off_value, - axis, dtype]) as name: + with ops.name_scope(name, "one_hot", + [indices, depth, on_value, off_value, axis, + dtype]) as name: on_exists = on_value is not None off_exists = off_value is not None @@ -2274,9 +2310,8 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None): # to length as a matrix with 1 column: [[1], [3], [2]]. # Because of broadcasting on both arguments this comparison results # in a matrix of size (len(lengths), maxlen) - row_vector = gen_math_ops._range(constant(0, maxlen.dtype), - maxlen, - constant(1, maxlen.dtype)) + row_vector = gen_math_ops._range( + constant(0, maxlen.dtype), maxlen, constant(1, maxlen.dtype)) # Since maxlen >= max(lengths), it is safe to use maxlen as a cast # authoritative type. Whenever maxlen fits into tf.int32, so do the lengths. matrix = gen_math_ops.cast(expand_dims(lengths, 1), maxlen.dtype) @@ -2388,6 +2423,8 @@ def where(condition, x=None, y=None, name=None): def reverse(tensor, axis, name=None): return gen_array_ops.reverse_v2(tensor, axis, name) + + reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ @@ -2409,9 +2446,10 @@ def reverse_sequence(input, seq_dim=seq_axis, batch_dim=batch_axis, name=name) -# pylint: enable=redefined-builtin +# pylint: enable=redefined-builtin + reverse_sequence.__doc__ = deprecation.rewrite_argument_docstring( deprecation.rewrite_argument_docstring( gen_array_ops.reverse_sequence.__doc__, "batch_dim", "batch_axis"), @@ -2422,8 +2460,10 @@ def gather(params, indices, validate_indices=None, name=None, axis=0): # TODO(rjryan): Remove "Gather" creation in favor of GatherV2 once the forward # compatibility 3 week period has passed. if axis == 0: - return gen_array_ops.gather(params, indices, - validate_indices=validate_indices, name=name) + return gen_array_ops.gather( + params, indices, validate_indices=validate_indices, name=name) else: return gen_array_ops.gather_v2(params, indices, axis, name=name) + + gather.__doc__ = gen_array_ops.gather_v2.__doc__ diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index e3a0a45f29..c31121b3eb 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -101,16 +102,20 @@ class ResourceVariable(variables.Variable): if initial_value: raise ValueError("variable_def and initial_value are mutually " "exclusive.") + if not context.in_graph_mode(): + raise ValueError("Creating ResourceVariable from variable_def" + " only supported in GRAPH mode.") self._init_from_proto(variable_def, import_scope=import_scope) else: - self._init_from_args(initial_value=initial_value, - trainable=trainable, - collections=collections, - validate_shape=validate_shape, - caching_device=caching_device, - name=name, - dtype=dtype, - constraint=constraint) + self._init_from_args( + initial_value=initial_value, + trainable=trainable, + collections=collections, + validate_shape=validate_shape, + caching_device=caching_device, + name=name, + dtype=dtype, + constraint=constraint) # pylint: disable=unused-argument def _init_from_args(self, @@ -122,7 +127,6 @@ class ResourceVariable(variables.Variable): name=None, dtype=None, constraint=None): - """Creates a variable. Args: @@ -177,34 +181,48 @@ class ResourceVariable(variables.Variable): if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] self._save_slice_info = None + in_graph_mode = context.in_graph_mode() with ops.control_dependencies(None): - with ops.name_scope(name, "Variable", [] if init_from_fn else - [initial_value]) as name: + with ops.name_scope(name, "Variable", [] + if init_from_fn else [initial_value]) as name: # pylint: disable=protected-access - true_name = ops._name_from_scope_name(name) + handle_name = ops._name_from_scope_name(name) if init_from_fn: # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. - attr = attr_value_pb2.AttrValue( - list=attr_value_pb2.AttrValue.ListValue( - s=[compat.as_bytes("loc:@%s" % true_name)])) - with ops.get_default_graph()._attr_scope({"_class": attr}): - with ops.name_scope("Initializer"), ops.device(None): - self._initial_value = ops.convert_to_tensor( - initial_value(), name="initial_value", dtype=dtype) + if in_graph_mode: + attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + s=[compat.as_bytes("loc:@%s" % handle_name)])) + with ops.get_default_graph()._attr_scope({"_class": attr}): + with ops.name_scope("Initializer"), ops.device(None): + initial_value = ops.convert_to_tensor( + initial_value(), name="initial_value", dtype=dtype) + self._handle = gen_resource_variable_ops.var_handle_op( + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=handle_name, + name=name) + else: + initial_value = initial_value() self._handle = gen_resource_variable_ops.var_handle_op( - shape=self._initial_value.get_shape(), - dtype=self._initial_value.dtype.base_dtype, - shared_name=true_name, name=name) + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=handle_name, + name=name, + container="") # pylint: enable=protected-access # Or get the initial value from a Tensor or Python object. else: - self._initial_value = ops.convert_to_tensor( - initial_value, name="initial_value", dtype=dtype) + with ops.name_scope("Initializer"): + initial_value = ops.convert_to_tensor( + initial_value, name="initial_value", dtype=dtype) # pylint: disable=protected-access - if self._initial_value.op._get_control_flow_context() is not None: + if (in_graph_mode and + initial_value is not None and + initial_value.op._get_control_flow_context() is not None): raise ValueError( "Initializer for variable %s is from inside a control-flow " "construct, such as a loop or conditional. When creating a " @@ -212,45 +230,65 @@ class ResourceVariable(variables.Variable): "initializer." % name) # pylint: enable=protected-access self._handle = gen_resource_variable_ops.var_handle_op( - shape=self._initial_value.get_shape(), - dtype=self._initial_value.dtype.base_dtype, - shared_name=true_name, name=name) - - self._dtype = self._initial_value.dtype.base_dtype + shape=initial_value.get_shape(), + dtype=initial_value.dtype.base_dtype, + shared_name=handle_name, + name=name, + container="") + + self._initial_value = initial_value if in_graph_mode else None + self._handle_name = handle_name + ":0" + self._dtype = initial_value.dtype.base_dtype self._constraint = constraint - with ops.name_scope("IsInitialized"): - self._is_initialized_op = ( - gen_resource_variable_ops.var_is_initialized_op(self._handle)) - if initial_value is not None: - with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): - self._initializer_op = gen_resource_variable_ops.assign_variable_op( - self._handle, - self._build_initializer_expr(self._initial_value), - name=n) - with ops.name_scope("Read"), ops.colocate_with(self._handle): - # Manually assign reads to the handle's device to avoid log messages. - with ops.device(self._handle.device): - value = gen_resource_variable_ops.read_variable_op( - self._handle, dtype=self._dtype) - self._graph_element = value - if caching_device is not None: - # Variables may be created in a tf.device() or ops.colocate_with() - # context. At the same time, users would expect caching device to be - # independent of this context, and/or would not expect the current - # device context to be merged with the caching device spec. - # Therefore we reset the colocation stack before creating the cached - # value. Note that resetting the colocation stack will also reset - # the device stack. - with ops.colocate_with(None, ignore_existing=True): - with ops.device(caching_device): - self._cached_value = array_ops.identity(value) + if in_graph_mode: + with ops.name_scope("IsInitialized"): + self._is_initialized_op = ( + gen_resource_variable_ops.var_is_initialized_op(self._handle)) + if initial_value is not None: + with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): + self._initializer_op = ( + gen_resource_variable_ops.assign_variable_op( + self._handle, + self._build_initializer_expr(initial_value), + name=n)) + with ops.name_scope("Read"), ops.colocate_with(self._handle): + # Manually assign reads to the handle's device to avoid log + # messages. + with ops.device(self._handle.device): + value = gen_resource_variable_ops.read_variable_op( + self._handle, dtype=self._dtype) + self._graph_element = value + if caching_device is not None: + # Variables may be created in a tf.device() or ops.colocate_with() + # context. At the same time, users would expect caching device to + # be independent of this context, and/or would not expect the + # current device context to be merged with the caching device + # spec. Therefore we reset the colocation stack before creating + # the cached value. Note that resetting the colocation stack will + # also reset the device stack. + with ops.colocate_with(None, ignore_existing=True): + with ops.device(caching_device): + self._cached_value = array_ops.identity(value) + else: + self._cached_value = None + else: + gen_resource_variable_ops.assign_variable_op(self._handle, + initial_value) + self._is_initialized_op = None + self._initializer_op = None + self._graph_element = None + if caching_device: + with ops.device(caching_device): + self._cached_value = gen_resource_variable_ops.read_variable_op( + self._handle, dtype=self._dtype) else: self._cached_value = None - ops.add_to_collections(collections, self) + ops.add_to_collections(collections, self) def _init_from_proto(self, variable_def, import_scope=None): """Initializes from `VariableDef` proto.""" + assert context.in_graph_mode() assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: raise ValueError("Trying to restore Variable as ResourceVariable.") @@ -258,15 +296,16 @@ class ResourceVariable(variables.Variable): # Create from variable_def. g = ops.get_default_graph() self._handle = g.as_graph_element( - ops.prepend_name_scope(variable_def.variable_name, - import_scope=import_scope)) + ops.prepend_name_scope( + variable_def.variable_name, import_scope=import_scope)) + self._handle_name = self._handle.name self._initializer_op = g.as_graph_element( - ops.prepend_name_scope(variable_def.initializer_name, - import_scope=import_scope)) + ops.prepend_name_scope( + variable_def.initializer_name, import_scope=import_scope)) if variable_def.snapshot_name: self._cached_value = g.as_graph_element( - ops.prepend_name_scope(variable_def.snapshot_name, - import_scope=import_scope)) + ops.prepend_name_scope( + variable_def.snapshot_name, import_scope=import_scope)) else: self._cached_value = None if variable_def.HasField("save_slice_info_def"): @@ -297,16 +336,22 @@ class ResourceVariable(variables.Variable): @property def name(self): """The name of the handle for this variable.""" - return self._handle.name + return self._handle_name @property def shape(self): """The shape of this variable.""" - return tensor_shape.TensorShape(self._handle.op.get_attr("shape")) + if context.in_graph_mode(): + return tensor_shape.TensorShape(self._handle.op.get_attr("shape")) + else: + # TODO(agarwal): avoid a call to value here. + return self.value().shape @property def create(self): """The op responsible for initializing this variable.""" + if not context.in_graph_mode(): + raise RuntimeError("Calling create in EAGER mode not supported.") return self._initializer_op @property @@ -335,6 +380,8 @@ class ResourceVariable(variables.Variable): @property def initial_value(self): """Returns the Tensor used as the initial value for the variable.""" + if context.in_eager_mode(): + raise RuntimeError("initial_value not supported in EAGER mode.""") return self._initial_value @property @@ -354,6 +401,8 @@ class ResourceVariable(variables.Variable): def eval(self, session=None): """Evaluates and returns the value of this variable.""" + if context.in_eager_mode(): + raise RuntimeError("Trying to eval in EAGER mode") return self._graph_element.eval(session=session) def _set_save_slice_info(self, save_slice_info): @@ -397,32 +446,38 @@ class ResourceVariable(variables.Variable): Args: export_scope: Optional `string`. Name scope to remove. + Raises: + RuntimeError: If run in EAGER mode. + Returns: A `VariableDef` protocol buffer, or `None` if the `Variable` is not in the specified name scope. """ - if (export_scope is None or - self.handle.name.startswith(export_scope)): + if context.in_eager_mode(): + raise RuntimeError("to_proto not supported in EAGER mode.") + if export_scope is None or self.handle.name.startswith(export_scope): var_def = variable_pb2.VariableDef() - var_def.variable_name = ops.strip_name_scope( - self.handle.name, export_scope) - var_def.initializer_name = ops.strip_name_scope( - self.initializer.name, export_scope) + var_def.variable_name = ops.strip_name_scope(self.handle.name, + export_scope) + var_def.initializer_name = ops.strip_name_scope(self.initializer.name, + export_scope) if self._cached_value is not None: - var_def.snapshot_name = ops.strip_name_scope( - self._cached_value.name, export_scope) + var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name, + export_scope) var_def.is_resource = True if self._save_slice_info: - var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto( - export_scope=export_scope)) + var_def.save_slice_info_def.MergeFrom( + self._save_slice_info.to_proto(export_scope=export_scope)) return var_def else: return None @staticmethod def from_proto(variable_def, import_scope=None): - return ResourceVariable(variable_def=variable_def, - import_scope=import_scope) + if context.in_eager_mode(): + raise RuntimeError("from_proto not supported in EAGER mode.") + return ResourceVariable( + variable_def=variable_def, import_scope=import_scope) @staticmethod def _OverloadAllOperators(): # pylint: disable=invalid-name @@ -454,6 +509,7 @@ class ResourceVariable(variables.Variable): def _run_op(a, *args): # pylint: disable=protected-access return getattr(ops.Tensor, operator)(a._AsTensor(), *args) + # Propagate __doc__ to wrapper try: _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__ @@ -468,61 +524,66 @@ class ResourceVariable(variables.Variable): # TODO(apassos): this here and below is not atomic. Consider making it # atomic if there's a way to do so without a performance cost for those who # don't need it. - with ops.control_dependencies( - [gen_resource_variable_ops.assign_sub_variable_op( + with ops.control_dependencies([ + gen_resource_variable_ops.assign_sub_variable_op( self.handle, - ops.convert_to_tensor(delta, dtype=self.dtype), name=name)]): + ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + ]): return self.read_value() def assign_add(self, delta, use_locking=None, name=None): - with ops.control_dependencies( - [gen_resource_variable_ops.assign_add_variable_op( + with ops.control_dependencies([ + gen_resource_variable_ops.assign_add_variable_op( self.handle, - ops.convert_to_tensor(delta, dtype=self.dtype), name=name)]): + ops.convert_to_tensor(delta, dtype=self.dtype), + name=name) + ]): return self.read_value() def assign(self, value, use_locking=None, name=None): - with ops.control_dependencies( - [gen_resource_variable_ops.assign_variable_op( + with ops.control_dependencies([ + gen_resource_variable_ops.assign_variable_op( self.handle, - ops.convert_to_tensor(value, dtype=self.dtype), name=name)]): + ops.convert_to_tensor(value, dtype=self.dtype), + name=name) + ]): return self.read_value() - def _strided_slice_assign(self, - begin, - end, - strides, - value, - name, - begin_mask, - end_mask, - ellipsis_mask, - new_axis_mask, + def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, + end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): - with ops.control_dependencies([gen_array_ops.resource_strided_slice_assign( - ref=self.handle, - begin=begin, - end=end, - strides=strides, - value=value, - name=name, - begin_mask=begin_mask, - end_mask=end_mask, - ellipsis_mask=ellipsis_mask, - new_axis_mask=new_axis_mask, - shrink_axis_mask=shrink_axis_mask)]): + with ops.control_dependencies([ + gen_array_ops.resource_strided_slice_assign( + ref=self.handle, + begin=begin, + end=end, + strides=strides, + value=value, + name=name, + begin_mask=begin_mask, + end_mask=end_mask, + ellipsis_mask=ellipsis_mask, + new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask) + ]): + return self.value() + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + del name + if dtype is not None and dtype != self.value().dtype: + print("trying to switch the dtype to ", dtype, " from ", + self.value().dtype) + return NotImplemented + if as_ref: + return self.read_value().op.inputs[0] + else: return self.value() -# pylint: disable=unused-argument,protected-access def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): - if dtype is not None and dtype != var.value().dtype: - print("trying to switch the dtype to ", dtype, " from ", var.value().dtype) - return NotImplemented - if as_ref: - return var.read_value().op.inputs[0] - return var.value() -# pylint: enable=unused-argument,protected-access + return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access + # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. @@ -531,8 +592,7 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): # otherwise lead to the wrong behavior. ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor) ops.register_tensor_conversion_function( - variables.Variable, - variables.Variable._TensorConversionFunction) # pylint: disable=protected-access + variables.Variable, variables.Variable._TensorConversionFunction) # pylint: disable=protected-access # pylint: disable=protected-access ResourceVariable._OverloadAllOperators() @@ -551,6 +611,9 @@ def _GatherGrad(op, grad): # Build appropriately shaped IndexedSlices # Walk graph back until the original handle is found. # TODO(apassos): more robust way of getting the shape. + # TODO(apassos): implement this for EAGER mode. + if context.in_eager_mode(): + raise NotImplementedError("_GatherGrad not implemented for EAGER mode") handle = op.inputs[0] while handle.op.type != "VarHandleOp": handle = handle.op.inputs[0] diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 7c12020263..a706193a40 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -187,7 +188,11 @@ class Variable(object): ValueError: If both `variable_def` and initial_value are specified. ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. + RuntimeError: If created in EAGER mode. """ + if not context.in_graph_mode(): + raise RuntimeError("Variable not supported in Eager mode. " + "Please use ResourceVariable instead") if variable_def: # If variable_def is provided, recreates the variable from its fields. if initial_value: @@ -692,12 +697,15 @@ class Variable(object): Raises: ValueError: Session is not passed and no default session """ - session = session or ops.get_default_session() - if session is None: - raise ValueError( - "Either session argument should be provided or default session " - "should be established") - session.run(self._initializer_op, {self._initializer_op.inputs[1]: value}) + if context.in_graph_mode(): + session = session or ops.get_default_session() + if session is None: + raise ValueError( + "Either session argument should be provided or default session " + "should be established") + session.run(self._initializer_op, {self._initializer_op.inputs[1]: value}) + else: + self.assign(value) # Conversion to tensor. @staticmethod @@ -1057,7 +1065,10 @@ class PartitionedVariable(object): `partitions` is not a list. ValueError: If `variable_list` is empty, or the `Variable` shape information does not match `shape`, or `partitions` has invalid values. + RuntimeError: If created in EAGER mode. """ + if not context.in_graph_mode(): + raise RuntimeError("PartitionedVariable not supported in Eager mode.") if not isinstance(variable_list, (list, tuple)): raise TypeError( "variable_list is not a list or tuple: %s" % variable_list) diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index e93c3c9c38..1f91259e7c 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -25,6 +25,7 @@ limitations under the License. %rename("%s") TFE_Py_Execute; %rename("%s") TFE_ContextAddFunctionDef; %rename("%s") TFE_TensorHandleDim; +%rename("%s") TFE_TensorHandleDeviceName; %rename("%s") TFE_TensorHandleCopyToDevice; %rename("%s") TFE_NewOp; %rename("%s") TFE_Py_TensorHandleToNumpy; |