aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/BUILD4
-rw-r--r--tensorflow/python/eager/context.py3
-rw-r--r--tensorflow/python/framework/constant_op.py65
-rw-r--r--tensorflow/python/framework/ops.py116
-rw-r--r--tensorflow/python/framework/tensor_util.py3
-rw-r--r--tensorflow/python/kernel_tests/BUILD20
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py91
-rw-r--r--tensorflow/python/kernel_tests/constant_op_eager_test.py426
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py159
-rw-r--r--tensorflow/python/ops/array_ops.py218
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py299
-rw-r--r--tensorflow/python/ops/variables.py23
-rw-r--r--tensorflow/python/pywrap_tfe.i1
13 files changed, 1110 insertions, 318 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 75c6baa2f8..6af0d2bbe0 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -483,6 +483,7 @@ py_library(
":framework_ops",
":tensor_shape",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:execute",
],
)
@@ -1723,6 +1724,8 @@ py_library(
":util",
":variables",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:tensor",
],
)
@@ -2203,6 +2206,7 @@ py_library(
":tensor_shape",
":util",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:context",
],
)
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index bfc9b47302..5094a1def7 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -155,7 +155,7 @@ class Context(object):
def device_name(self):
"""Returns the device name for the current thread."""
index = self._device_index
- return None if index < 0 else self._devices[index]
+ return "CPU:0" if index < 0 else self._devices[index]
def devices(self):
"""List of the names of devices available to execute operations."""
@@ -271,6 +271,7 @@ def eager_mode():
return get_default_context()._mode(EAGER_MODE) # pylint: disable=protected-access
+# TODO(agarwal): get rid of this and use ops.name_scope instead.
@contextlib.contextmanager
def namescope(name):
"""ContextManager for creating hierarchical name scopes."""
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index dd0fdaf712..af3be7230c 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Operations that generate constants.
See the @{$python/constant_op$constants guide}.
@@ -45,12 +44,35 @@ from __future__ import print_function
import numpy as np
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import context
+from tensorflow.python.eager import execute
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+def _eager_reshape(tensor, shape):
+ """Eager-only version of Reshape op; requires tensor is an eager Tensor."""
+ attr_t = tensor.dtype.as_datatype_enum
+ attr_tshape, (shape,) = execute.args_to_matching_eager([shape], dtypes.int32)
+ attr_tshape = attr_tshape.as_datatype_enum
+ inputs_flat = [tensor, shape]
+ attrs = ("T", attr_t, "Tshape", attr_tshape)
+ result, = execute.execute("Reshape", 1, inputs=inputs_flat, attrs=attrs)
+ return result
+
+
+def _eager_fill(dims, value):
+ """Eager-only version of Fill op; requires value is an eager Tensor."""
+ attr_t = value.dtype.as_datatype_enum
+ dims = ops.convert_to_eager_tensor(dims, dtypes.int32)
+ inputs_flat = [dims, value]
+ attrs = ("T", attr_t)
+ result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs)
+ return result
+
+
def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
"""Creates a constant tensor.
@@ -95,15 +117,40 @@ def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
Returns:
A Constant Tensor.
+
+ Raises:
+ TypeError if shape is incorrectly specified or unsupported.
"""
+ if not context.in_graph_mode():
+ if shape is None:
+ return ops.convert_to_eager_tensor(value, dtype)
+ t = ops.convert_to_eager_tensor(value, dtype)
+ shape = tensor_shape.as_shape(shape)
+ if shape == t.shape:
+ return t
+ if verify_shape:
+ raise TypeError("Expected Tensor's shape: %s, got %s." % (tuple(shape),
+ tuple(t.shape)))
+ num_t = t.shape.num_elements()
+ # TODO(josh11b): Implement shape -> eager tensor conversion.
+ if num_t == shape.num_elements():
+ return _eager_reshape(t, shape.as_list())
+ if num_t == 1:
+ return _eager_fill(shape.as_list(), t)
+ raise TypeError("Eager execution of tf.constant with unsupported shape "
+ "(value has %d elements, shape is %s with %d elements)." %
+ (num_t, shape, shape.num_elements()))
g = ops.get_default_graph()
tensor_value = attr_value_pb2.AttrValue()
tensor_value.tensor.CopyFrom(
- tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
+ tensor_util.make_tensor_proto(
+ value, dtype=dtype, shape=shape, verify_shape=verify_shape))
dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
const_tensor = g.create_op(
"Const", [], [dtype_value.type],
- attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
+ attrs={"value": tensor_value,
+ "dtype": dtype_value},
+ name=name).outputs[0]
return const_tensor
@@ -131,8 +178,11 @@ ops.register_tensor_conversion_function(
object, _constant_tensor_conversion_function, 200)
-def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None,
+def _tensor_shape_tensor_conversion_function(s,
+ dtype=None,
+ name=None,
as_ref=False):
+ """Function to convert TensorShape to Tensor."""
_ = as_ref
if not s.is_fully_defined():
raise ValueError(
@@ -156,12 +206,16 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None,
name = "shape_as_tensor"
return constant(s_list, dtype=dtype, name=name)
+
ops.register_tensor_conversion_function(
tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
-def _dimension_tensor_conversion_function(d, dtype=None, name=None,
+def _dimension_tensor_conversion_function(d,
+ dtype=None,
+ name=None,
as_ref=False):
+ """Function to convert Dimension to Tensor."""
_ = as_ref
if d.value is None:
raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d)
@@ -174,5 +228,6 @@ def _dimension_tensor_conversion_function(d, dtype=None, name=None,
name = "shape_as_tensor"
return constant(d.value, dtype=dtype, name=name)
+
ops.register_tensor_conversion_function(
tensor_shape.Dimension, _dimension_tensor_conversion_function, 100)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index e01f6f3ce4..5948e59824 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -80,6 +80,11 @@ def _in_gpu_device():
return context.get_default_context()._device_index > 0 # pylint: disable=protected-access
+@tf_contextlib.contextmanager
+def _null_contextmanager():
+ yield
+
+
def _override_helper(clazz_object, operator, func):
"""Overrides (string) operator on Tensors to call func.
@@ -698,8 +703,9 @@ class EagerTensor(Tensor):
return numpy_text
def __str__(self):
- return "tfe.Tensor(shape=%s, dtype=%s, numpy=%s)" % (
- self.shape, self.dtype.name, self._numpy_text())
+ return "tfe.Tensor(shape=%s, dtype=%s, numpy=%s)" % (self.shape,
+ self.dtype.name,
+ self._numpy_text())
def __repr__(self):
return "<tfe.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s)>" % (
@@ -728,10 +734,14 @@ class EagerTensor(Tensor):
with errors.raise_exception_on_not_ok_status() as status:
return c_api.TFE_Py_TensorHandleToNumpy(cpu._handle, status) # pylint: disable=protected-access
- def _copy(self, ctx, device_name):
+ def _copy(self, ctx=None, device_name=None):
"""Copies tensor to dest device."""
# pylint: disable=protected-access
# Creates a new tensor on the dest device.
+ if ctx is None:
+ ctx = context.get_default_context()
+ if device_name is None:
+ device_name = ctx.device_name
with errors.raise_exception_on_not_ok_status() as status:
h = c_api.TFE_TensorHandleCopyToDevice(self._handle, ctx._handle,
device_name, status)
@@ -863,15 +873,26 @@ class EagerTensor(Tensor):
raise NotImplementedError("eval not supported for Eager Tensors.")
+# TODO(josh11b): Support other cases like converting TensorShape, lists/tuples and
+# other custom conversion functions.
def convert_to_eager_tensor(t, dtype=None):
+ """Converts the given `value` to an `EagerTensor`."""
if isinstance(ag_core.getval(t), EagerTensor):
if dtype is not None and t.dtype != dtype:
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
return t
+ # Handle converting ResourceVariable to Tensor.
+ # TODO(josh11b): get rid of this explicit ugly conversion once we have a more
+ # general scheme in place.
+ try:
+ return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access
+ except AttributeError:
+ pass
return EagerTensor(t, dtype=dtype)
def convert_n_to_eager_tensor(values, dtype):
+ """Converts the given `values` to a list of `EagerTensor`."""
return [convert_to_eager_tensor(t, dtype) for t in values]
@@ -1092,17 +1113,21 @@ def internal_convert_n_to_tensor(values,
"""
if not isinstance(values, collections.Sequence):
raise TypeError("values must be a list.")
- ret = []
- for i, value in enumerate(values):
- n = None if name is None else "%s_%d" % (name, i)
- ret.append(
- internal_convert_to_tensor(
- value,
- dtype=dtype,
- name=n,
- as_ref=as_ref,
- preferred_dtype=preferred_dtype))
- return ret
+ if context.in_graph_mode():
+ ret = []
+ for i, value in enumerate(values):
+ n = None if name is None else "%s_%d" % (name, i)
+ ret.append(
+ internal_convert_to_tensor(
+ value,
+ dtype=dtype,
+ name=n,
+ as_ref=as_ref,
+ preferred_dtype=preferred_dtype))
+ return ret
+ else:
+ # TODO(josh11b): handle preferred_dtype, as_ref
+ return convert_n_to_eager_tensor(values, dtype=dtype)
def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
@@ -2463,7 +2488,7 @@ def _name_from_scope_name(name):
Returns:
the name of the op (equal to scope name minus any trailing slash).
"""
- return name[:-1] if name[-1] == "/" else name
+ return name[:-1] if (name and name[-1] == "/") else name
class Graph(object):
@@ -4282,6 +4307,8 @@ class Graph(object):
return tensor_or_op not in self._unfetchable_ops
+# TODO(agarwal): currently device directives in an outer eager scope will not
+# apply to inner graph mode code. Fix that.
def device(device_name_or_function):
"""Wrapper for `Graph.device()` using the default graph.
@@ -4297,7 +4324,11 @@ def device(device_name_or_function):
A context manager that specifies the default device to use for newly
created ops.
"""
- return get_default_graph().device(device_name_or_function)
+ if context.in_graph_mode():
+ return get_default_graph().device(device_name_or_function)
+ else:
+ # TODO(agarwal): support device functions in EAGER mode.
+ return context.device(device_name_or_function)
def container(container_name):
@@ -4314,7 +4345,15 @@ def container(container_name):
def colocate_with(op, ignore_existing=False):
- return get_default_graph().colocate_with(op, ignore_existing)
+ if context.in_graph_mode():
+ return get_default_graph().colocate_with(op, ignore_existing)
+ else:
+ if not ignore_existing:
+ raise ValueError("ignore_existing must currently be True in EAGER mode.")
+ if op:
+ return device(op.device)
+ else:
+ return _null_contextmanager()
def control_dependencies(control_inputs):
@@ -4333,7 +4372,10 @@ def control_dependencies(control_inputs):
A context manager that specifies control dependencies for all
operations constructed within the context.
"""
- return get_default_graph().control_dependencies(control_inputs)
+ if context.in_graph_mode():
+ return get_default_graph().control_dependencies(control_inputs)
+ else:
+ return _null_contextmanager()
_default_session_stack = context._DefaultStack() # pylint: disable=protected-access
@@ -4871,18 +4913,32 @@ def name_scope(name, default_name=None, values=None):
but `values` are.
"""
n = default_name if name is None else name
- if n is None and values is not None:
- # We only raise an error if values is not None (provided) because currently
- # tf.name_scope(None) (values=None then) is sometimes used as an idiom
- # to reset to top scope.
- raise ValueError(
- "At least one of name (%s) and default_name (%s) must be provided." %
- (name, default_name))
- if values is None:
- values = []
- g = _get_graph_from_inputs(values)
- with g.as_default(), g.name_scope(n) as scope:
- yield scope
+ if context.in_eager_mode():
+ if n is None:
+ raise ValueError(
+ "At least one of name (%s) and default_name (%s) should be provided" %
+ (name, default_name))
+ ctx = context.get_default_context()
+ old_name = ctx.scope_name
+ scope_name = "%s/%s" % (old_name, name) if old_name else name
+ ctx.scope_name = scope_name
+ try:
+ yield scope_name
+ finally:
+ ctx.scope_name = old_name
+ else:
+ if n is None and values is not None:
+ # We only raise an error if values is not None (provided) because
+ # currently tf.name_scope(None) (values=None then) is sometimes used as an
+ # idiom to reset to top scope.
+ raise ValueError(
+ "At least one of name (%s) and default_name (%s) must be provided." %
+ (name, default_name))
+ if values is None:
+ values = []
+ g = _get_graph_from_inputs(values)
+ with g.as_default(), g.name_scope(n) as scope:
+ yield scope
# pylint: enable=g-doc-return-or-yield
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 8c9406593c..4d47b2bc56 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -23,6 +23,7 @@ import six
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import compat
@@ -711,6 +712,8 @@ def constant_value(tensor, partial=False): # pylint: disable=invalid-name
Raises:
TypeError: if tensor is not an ops.Tensor.
"""
+ if isinstance(tensor, ops.EagerTensor):
+ return tensor.numpy()
ret = _ConstantValue(tensor, partial)
if ret is not None:
# The caller may now depend on the constant value of `tensor`, so we
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index f2be3f134d..22306e08d7 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1049,6 +1049,8 @@ cuda_py_test(
"//tensorflow/python:state_ops",
"//tensorflow/python:test_ops",
"//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:tensor",
],
shard_count = 10,
)
@@ -1151,6 +1153,24 @@ cuda_py_test(
)
cuda_py_test(
+ name = "constant_op_eager_test",
+ size = "small",
+ srcs = ["constant_op_eager_test.py"],
+ additional_deps = [
+ "//tensorflow/python/eager:core",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ ],
+)
+
+cuda_py_test(
name = "control_flow_ops_py_test",
size = "small",
srcs = ["control_flow_ops_py_test.py"],
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 7b8cd25664..7dc28b6aa9 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -22,10 +22,12 @@ import time
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_ops
@@ -168,9 +170,10 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
arr = np.array([[1, 2], [3, 4]])
mask = np.array([False, True])
- masked_tensor = sess.run(array_ops.boolean_mask(ph_tensor, ph_mask),
- feed_dict={ph_tensor: arr,
- ph_mask: mask})
+ masked_tensor = sess.run(
+ array_ops.boolean_mask(ph_tensor, ph_mask),
+ feed_dict={ph_tensor: arr,
+ ph_mask: mask})
np.testing.assert_allclose(masked_tensor, arr[mask])
def testMaskDimensionsSetToNoneRaises(self):
@@ -283,14 +286,16 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
def testReverse1DimAuto(self):
for dtype in [
np.uint8, np.int8, np.int32, np.int64, np.bool, np.float16, np.float32,
- np.float64, np.complex64, np.complex128, np.array(b"").dtype.type
+ np.float64, np.complex64, np.complex128,
+ np.array(b"").dtype.type
]:
self._reverse1DimAuto(dtype)
def testReverse2DimAuto(self):
for dtype in [
np.uint8, np.int8, np.int32, np.int64, np.bool, np.float16, np.float32,
- np.float64, np.complex64, np.complex128, np.array(b"").dtype.type
+ np.float64, np.complex64, np.complex128,
+ np.array(b"").dtype.type
]:
self._reverse2DimAuto(dtype)
@@ -318,8 +323,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
for outer_size in (1, 2):
for middle_size in list(range(50)) + [100000]:
x_np = np.reshape(
- np.arange(
- outer_size * middle_size * 3, dtype=np.float32),
+ np.arange(outer_size * middle_size * 3, dtype=np.float32),
newshape=(outer_size, middle_size, 3))
x_tf = reverse_f(x_np, [1]).eval()
np_answer = x_np[:, ::-1, :]
@@ -331,8 +335,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
for outer_size in (1, 2):
for middle_size in list(range(50)) + [100000]:
x_np = np.reshape(
- np.arange(
- outer_size * middle_size * 4, dtype=np.float32),
+ np.arange(outer_size * middle_size * 4, dtype=np.float32),
newshape=(outer_size, middle_size, 4))
x_tf = reverse_f(x_np, [1]).eval()
np_answer = x_np[:, ::-1, :]
@@ -344,8 +347,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
for outer_size in list(range(50)) + [100000]:
for middle_size in (1, 2):
x_np = np.reshape(
- np.arange(
- outer_size * middle_size * 3, dtype=np.float32),
+ np.arange(outer_size * middle_size * 3, dtype=np.float32),
newshape=(outer_size, middle_size, 3))
x_tf = reverse_f(x_np, [0]).eval()
np_answer = x_np[::-1, :, :]
@@ -437,9 +439,10 @@ class StridedSliceChecker(object):
return tensor
-STRIDED_SLICE_TYPES = [dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8,
- dtypes.float32, dtypes.float64, dtypes.complex64,
- dtypes.complex128]
+STRIDED_SLICE_TYPES = [
+ dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8, dtypes.float32,
+ dtypes.float64, dtypes.complex64, dtypes.complex128
+]
class StridedSliceTest(test_util.TensorFlowTestCase):
@@ -600,8 +603,7 @@ class StridedSliceShapeTest(test_util.TensorFlowTestCase):
a = StridedSliceShapeChecker(uncertain_tensor)
self.tensorShapeEqual(a[3:5], tensor_shape.TensorShape([2, None, 7]))
self.tensorShapeEqual(a[3:5, :, 4], tensor_shape.TensorShape([2, None]))
- self.tensorShapeEqual(a[3:5, 3:4, 4],
- tensor_shape.TensorShape([2, None]))
+ self.tensorShapeEqual(a[3:5, 3:4, 4], tensor_shape.TensorShape([2, None]))
self.tensorShapeEqual(a[3:5, :, 5:10],
tensor_shape.TensorShape([2, None, 2]))
self.tensorShapeEqual(a[3:5, :, 50:3],
@@ -651,15 +653,14 @@ class GradSliceChecker(object):
analytic_grad2 = 2 * slice_val
dy = variables.Variable(
- array_ops.ones(
- shape=slice_var.get_shape(), dtype=dtypes.int32))
+ array_ops.ones(shape=slice_var.get_shape(), dtype=dtypes.int32))
assign = dy.assign(slice_var)
slice_val_grad, = gradients_impl.gradients(slice_val, self.var, grad_ys=dy)
slice_val_grad2, = gradients_impl.gradients(
slice_val_grad, dy, grad_ys=self.var)
self.sess.run(assign)
- slice_val_grad_evaled, slice_val_grad2_evaled = (
- self.sess.run([slice_val_grad, slice_val_grad2]))
+ slice_val_grad_evaled, slice_val_grad2_evaled = (self.sess.run(
+ [slice_val_grad, slice_val_grad2]))
analytic_grad2_evaled = analytic_grad2.eval()
self.test.assertAllEqual(slice_val_grad2_evaled, analytic_grad2_evaled)
@@ -677,8 +678,7 @@ class StridedSliceGradTest(test_util.TensorFlowTestCase):
def testGradient(self):
with self.test_session(use_gpu=True) as sess:
var = variables.Variable(
- array_ops.reshape(
- math_ops.range(1, 97, 1), shape=(6, 4, 4)))
+ array_ops.reshape(math_ops.range(1, 97, 1), shape=(6, 4, 4)))
init = variables.global_variables_initializer()
sess.run(init)
@@ -853,9 +853,10 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
def doTestSliceAssign(self, use_resource):
for dtype in STRIDED_SLICE_TYPES:
- checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]],
- use_resource=use_resource,
- tensor_type=dtype)
+ checker = StridedSliceAssignChecker(
+ self, [[1, 2, 3], [4, 5, 6]],
+ use_resource=use_resource,
+ tensor_type=dtype)
# Check if equal
checker[:] = [[10, 20, 30], [40, 50, 60]]
# Check trivial (1,1) shape tensor
@@ -943,18 +944,16 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
res = array_ops.sequence_mask(
constant_op.constant([0, 1, 4]), dtype=dtypes.float32)
self.assertAllEqual(res.get_shape().as_list(), [3, None])
- self.assertAllEqual(res.eval(), [[0.0, 0.0, 0.0, 0.0],
- [1.0, 0.0, 0.0, 0.0],
+ self.assertAllEqual(res.eval(), [[0.0, 0.0, 0.0,
+ 0.0], [1.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0]])
def testDtypes(self):
def check_dtypes(lengths_dtype, maxlen_dtype):
res = array_ops.sequence_mask(
- constant_op.constant(
- [1, 3, 2], dtype=lengths_dtype),
- constant_op.constant(
- 5, dtype=maxlen_dtype))
+ constant_op.constant([1, 3, 2], dtype=lengths_dtype),
+ constant_op.constant(5, dtype=maxlen_dtype))
self.assertAllEqual(res.get_shape(), [3, 5])
self.assertAllEqual(res.eval(), [[True, False, False, False, False],
[True, True, True, False, False],
@@ -979,5 +978,35 @@ class ConcatSliceResourceTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.AlreadyExistsError):
test_ops.resource_create_op(r2).run()
+
+class IdentityTest(test_util.TensorFlowTestCase):
+
+ def testEagerIdentity(self):
+ with context.eager_mode():
+ ctx = context.get_default_context()
+ if not ctx.num_gpus():
+ self.skipTest("No GPUs found")
+
+ def _test(x, y, device):
+ self.assertIsNot(x, y)
+ self.assertAllEqual(x.numpy(), y.numpy())
+ self.assertTrue(device in y.device.lower())
+
+ with ops.device("gpu:0"):
+ a = constant_op.constant([[2], [3]], dtype=dtypes.float32)
+ with ops.device("gpu:0"):
+ b = array_ops.identity(a)
+ _test(a, b, "gpu")
+ with ops.device("cpu:0"):
+ c = array_ops.identity(b)
+ _test(b, c, "cpu")
+ with ops.device("cpu:0"):
+ d = array_ops.identity(c)
+ _test(c, d, "cpu")
+ with ops.device("gpu:0"):
+ e = array_ops.identity(d)
+ _test(d, e, "gpu")
+
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py
new file mode 100644
index 0000000000..0e98afbe6e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py
@@ -0,0 +1,426 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ConstantOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+
+
+# TODO(josh11b): add tests with string types, lists/tuples, Shape.
+class ConstantTest(test.TestCase):
+
+ def _testCpu(self, x):
+ np_ans = np.array(x)
+ tf_ans = ops.convert_to_tensor(x).numpy()
+ if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]:
+ self.assertAllClose(np_ans, tf_ans)
+ else:
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def _testGpu(self, x):
+ np_ans = np.array(x)
+ tf_ans = ops.convert_to_tensor(x).numpy()
+ if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]:
+ self.assertAllClose(np_ans, tf_ans)
+ else:
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def _testAll(self, x):
+ self._testCpu(x)
+ self._testGpu(x)
+
+ def testFloat(self):
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32))
+ self._testAll(
+ np.random.normal(size=30).reshape([2, 3, 5]).astype(np.float32))
+ self._testAll(np.empty((2, 0, 5)).astype(np.float32))
+
+ def testDouble(self):
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float64))
+ self._testAll(
+ np.random.normal(size=30).reshape([2, 3, 5]).astype(np.float64))
+ self._testAll(np.empty((2, 0, 5)).astype(np.float64))
+
+ def testInt32(self):
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.int32))
+ self._testAll(
+ (100 * np.random.normal(size=30)).reshape([2, 3, 5]).astype(np.int32))
+ self._testAll(np.empty((2, 0, 5)).astype(np.int32))
+
+ def testInt64(self):
+ self._testAll(np.arange(-15, 15).reshape([2, 3, 5]).astype(np.int64))
+ self._testAll(
+ (100 * np.random.normal(size=30)).reshape([2, 3, 5]).astype(np.int64))
+ self._testAll(np.empty((2, 0, 5)).astype(np.int64))
+
+ def testComplex64(self):
+ self._testAll(
+ np.complex(1, 2) * np.arange(-15, 15).reshape([2, 3, 5
+ ]).astype(np.complex64))
+ self._testAll(
+ np.complex(1, 2) * np.random.normal(size=30).reshape(
+ [2, 3, 5]).astype(np.complex64))
+ self._testAll(np.empty((2, 0, 5)).astype(np.complex64))
+
+ def testComplex128(self):
+ self._testAll(
+ np.complex(1, 2) * np.arange(-15, 15).reshape([2, 3, 5
+ ]).astype(np.complex128))
+ self._testAll(
+ np.complex(1, 2) * np.random.normal(size=30).reshape(
+ [2, 3, 5]).astype(np.complex128))
+ self._testAll(np.empty((2, 0, 5)).astype(np.complex128))
+
+ def testExplicitShapeNumPy(self):
+ c = constant_op.constant(
+ np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32),
+ shape=[2, 3, 5])
+ self.assertEqual(c.get_shape(), [2, 3, 5])
+
+ def testImplicitShapeNumPy(self):
+ c = constant_op.constant(
+ np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32))
+ self.assertEqual(c.get_shape(), [2, 3, 5])
+
+ def testExplicitShapeList(self):
+ c = constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[7])
+ self.assertEqual(c.get_shape(), [7])
+
+ def testExplicitShapeFill(self):
+ c = constant_op.constant(12, shape=[7])
+ self.assertEqual(c.get_shape(), [7])
+ self.assertAllEqual([12, 12, 12, 12, 12, 12, 12], c.numpy())
+
+ def testExplicitShapeReshape(self):
+ c = constant_op.constant(
+ np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32),
+ shape=[5, 2, 3])
+ self.assertEqual(c.get_shape(), [5, 2, 3])
+
+ def testImplicitShapeList(self):
+ c = constant_op.constant([1, 2, 3, 4, 5, 6, 7])
+ self.assertEqual(c.get_shape(), [7])
+
+ def testExplicitShapeNumber(self):
+ c = constant_op.constant(1, shape=[1])
+ self.assertEqual(c.get_shape(), [1])
+
+ def testImplicitShapeNumber(self):
+ c = constant_op.constant(1)
+ self.assertEqual(c.get_shape(), [])
+
+ def testShapeTooBig(self):
+ with self.assertRaises(TypeError):
+ constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[10])
+
+ def testShapeTooSmall(self):
+ with self.assertRaises(TypeError):
+ constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[5])
+
+ def testShapeWrong(self):
+ with self.assertRaisesRegexp(TypeError, None):
+ constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[5])
+
+ def testSparseValuesRaiseErrors(self):
+ with self.assertRaisesRegexp(ValueError,
+ "setting an array element with a sequence"):
+ constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, None):
+ constant_op.constant([[1, 2], [3]])
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, None):
+ constant_op.constant([[1, 2], [3], [4, 5]])
+
+
+class AsTensorTest(test.TestCase):
+
+ def testAsTensorForTensorInput(self):
+ t = constant_op.constant(10.0)
+ x = ops.convert_to_tensor(t)
+ self.assertIs(t, x)
+
+ def testAsTensorForNonTensorInput(self):
+ x = ops.convert_to_tensor(10.0)
+ self.assertTrue(isinstance(x, ops.EagerTensor))
+
+
+class ZerosTest(test.TestCase):
+
+ def _Zeros(self, shape):
+ ret = array_ops.zeros(shape)
+ self.assertEqual(shape, ret.get_shape())
+ return ret.numpy()
+
+ def testConst(self):
+ self.assertTrue(
+ np.array_equal(self._Zeros([2, 3]), np.array([[0] * 3] * 2)))
+
+ def testScalar(self):
+ self.assertEqual(0, self._Zeros([]))
+ self.assertEqual(0, self._Zeros(()))
+ scalar = array_ops.zeros(constant_op.constant([], dtype=dtypes_lib.int32))
+ self.assertEqual(0, scalar.numpy())
+
+ def testDynamicSizes(self):
+ np_ans = np.array([[0] * 3] * 2)
+ # Creates a tensor of 2 x 3.
+ d = array_ops.fill([2, 3], 12., name="fill")
+ # Constructs a tensor of zeros of the same dimensions as "d".
+ z = array_ops.zeros(array_ops.shape(d))
+ out = z.numpy()
+ self.assertAllEqual(np_ans, out)
+ self.assertShapeEqual(np_ans, d)
+ self.assertShapeEqual(np_ans, z)
+
+ def testDtype(self):
+ d = array_ops.fill([2, 3], 12., name="fill")
+ self.assertEqual(d.get_shape(), [2, 3])
+ # Test default type for both constant size and dynamic size
+ z = array_ops.zeros([2, 3])
+ self.assertEqual(z.dtype, dtypes_lib.float32)
+ self.assertEqual([2, 3], z.get_shape())
+ self.assertAllEqual(z.numpy(), np.zeros([2, 3]))
+ z = array_ops.zeros(array_ops.shape(d))
+ self.assertEqual(z.dtype, dtypes_lib.float32)
+ self.assertEqual([2, 3], z.get_shape())
+ self.assertAllEqual(z.numpy(), np.zeros([2, 3]))
+ # Test explicit type control
+ for dtype in [
+ dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
+ dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
+ dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64,
+ dtypes_lib.bool,
+ # TODO(josh11b): Support string type here.
+ # dtypes_lib.string
+ ]:
+ z = array_ops.zeros([2, 3], dtype=dtype)
+ self.assertEqual(z.dtype, dtype)
+ self.assertEqual([2, 3], z.get_shape())
+ z_value = z.numpy()
+ self.assertFalse(np.any(z_value))
+ self.assertEqual((2, 3), z_value.shape)
+ z = array_ops.zeros(array_ops.shape(d), dtype=dtype)
+ self.assertEqual(z.dtype, dtype)
+ self.assertEqual([2, 3], z.get_shape())
+ z_value = z.numpy()
+ self.assertFalse(np.any(z_value))
+ self.assertEqual((2, 3), z_value.shape)
+
+
+class ZerosLikeTest(test.TestCase):
+
+ def _compareZeros(self, dtype, use_gpu):
+ # Creates a tensor of non-zero values with shape 2 x 3.
+ # NOTE(kearnes): The default numpy dtype associated with tf.string is
+ # np.object (and can't be changed without breaking a lot things), which
+ # causes a TypeError in constant_op.constant below. Here we catch the
+ # special case of tf.string and set the numpy dtype appropriately.
+ if dtype == dtypes_lib.string:
+ numpy_dtype = np.string_
+ else:
+ numpy_dtype = dtype.as_numpy_dtype
+ d = constant_op.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype)
+ # Constructs a tensor of zeros of the same dimensions and type as "d".
+ z_var = array_ops.zeros_like(d)
+ # Test that the type is correct
+ self.assertEqual(z_var.dtype, dtype)
+ # Test that the shape is correct
+ self.assertEqual([2, 3], z_var.get_shape())
+
+ # Test that the value is correct
+ z_value = z_var.numpy()
+ self.assertFalse(np.any(z_value))
+ self.assertEqual((2, 3), z_value.shape)
+
+ def testZerosLikeCPU(self):
+ for dtype in [
+ dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
+ dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
+ dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64,
+ # TODO(josh11b): Support string type here.
+ # dtypes_lib.string
+ ]:
+ self._compareZeros(dtype, use_gpu=False)
+
+ def testZerosLikeGPU(self):
+ for dtype in [
+ dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
+ dtypes_lib.bool, dtypes_lib.int64,
+ # TODO(josh11b): Support string type here.
+ # dtypes_lib.string
+ ]:
+ self._compareZeros(dtype, use_gpu=True)
+
+ def testZerosLikeDtype(self):
+ # Make sure zeros_like works even for dtypes that cannot be cast between
+ shape = (3, 5)
+ dtypes = np.float32, np.complex64
+ for in_type in dtypes:
+ x = np.arange(15).astype(in_type).reshape(*shape)
+ for out_type in dtypes:
+ y = array_ops.zeros_like(x, dtype=out_type).numpy()
+ self.assertEqual(y.dtype, out_type)
+ self.assertEqual(y.shape, shape)
+ self.assertAllEqual(y, np.zeros(shape, dtype=out_type))
+
+
+class OnesTest(test.TestCase):
+
+ def _Ones(self, shape):
+ ret = array_ops.ones(shape)
+ self.assertEqual(shape, ret.get_shape())
+ return ret.numpy()
+
+ def testConst(self):
+ self.assertTrue(np.array_equal(self._Ones([2, 3]), np.array([[1] * 3] * 2)))
+
+ def testScalar(self):
+ self.assertEqual(1, self._Ones([]))
+ self.assertEqual(1, self._Ones(()))
+ scalar = array_ops.ones(constant_op.constant([], dtype=dtypes_lib.int32))
+ self.assertEqual(1, scalar.numpy())
+
+ def testDynamicSizes(self):
+ np_ans = np.array([[1] * 3] * 2)
+ # Creates a tensor of 2 x 3.
+ d = array_ops.fill([2, 3], 12., name="fill")
+ # Constructs a tensor of ones of the same dimensions as "d".
+ z = array_ops.ones(array_ops.shape(d))
+ out = z.numpy()
+ self.assertAllEqual(np_ans, out)
+ self.assertShapeEqual(np_ans, d)
+ self.assertShapeEqual(np_ans, z)
+
+ def testDtype(self):
+ d = array_ops.fill([2, 3], 12., name="fill")
+ self.assertEqual(d.get_shape(), [2, 3])
+ # Test default type for both constant size and dynamic size
+ z = array_ops.ones([2, 3])
+ self.assertEqual(z.dtype, dtypes_lib.float32)
+ self.assertEqual([2, 3], z.get_shape())
+ self.assertAllEqual(z.numpy(), np.ones([2, 3]))
+ z = array_ops.ones(array_ops.shape(d))
+ self.assertEqual(z.dtype, dtypes_lib.float32)
+ self.assertEqual([2, 3], z.get_shape())
+ self.assertAllEqual(z.numpy(), np.ones([2, 3]))
+ # Test explicit type control
+ for dtype in (dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
+ dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
+ dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64,
+ dtypes_lib.bool):
+ z = array_ops.ones([2, 3], dtype=dtype)
+ self.assertEqual(z.dtype, dtype)
+ self.assertEqual([2, 3], z.get_shape())
+ self.assertAllEqual(z.numpy(), np.ones([2, 3]))
+ z = array_ops.ones(array_ops.shape(d), dtype=dtype)
+ self.assertEqual(z.dtype, dtype)
+ self.assertEqual([2, 3], z.get_shape())
+ self.assertAllEqual(z.numpy(), np.ones([2, 3]))
+
+
+class OnesLikeTest(test.TestCase):
+
+ def testOnesLike(self):
+ for dtype in [
+ dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
+ dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8,
+ dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.int64
+ ]:
+ numpy_dtype = dtype.as_numpy_dtype
+ # Creates a tensor of non-zero values with shape 2 x 3.
+ d = constant_op.constant(np.ones((2, 3), dtype=numpy_dtype), dtype=dtype)
+ # Constructs a tensor of zeros of the same dimensions and type as "d".
+ z_var = array_ops.ones_like(d)
+ # Test that the type is correct
+ self.assertEqual(z_var.dtype, dtype)
+ z_value = z_var.numpy()
+
+ # Test that the value is correct
+ self.assertTrue(np.array_equal(z_value, np.array([[1] * 3] * 2)))
+ self.assertEqual([2, 3], z_var.get_shape())
+
+
+class FillTest(test.TestCase):
+
+ def _compare(self, dims, val, np_ans, use_gpu):
+ ctx = context.get_default_context()
+ device = "GPU:0" if (use_gpu and ctx.num_gpus()) else "CPU:0"
+ with ops.device(device):
+ tf_ans = array_ops.fill(dims, val, name="fill")
+ out = tf_ans.numpy()
+ self.assertAllClose(np_ans, out)
+
+ def _compareAll(self, dims, val, np_ans):
+ self._compare(dims, val, np_ans, False)
+ self._compare(dims, val, np_ans, True)
+
+ def testFillFloat(self):
+ np_ans = np.array([[3.1415] * 3] * 2).astype(np.float32)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
+
+ def testFillDouble(self):
+ np_ans = np.array([[3.1415] * 3] * 2).astype(np.float64)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
+
+ def testFillInt32(self):
+ np_ans = np.array([[42] * 3] * 2).astype(np.int32)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
+
+ def testFillInt64(self):
+ np_ans = np.array([[-42] * 3] * 2).astype(np.int64)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
+
+ def testFillComplex64(self):
+ np_ans = np.array([[0.15] * 3] * 2).astype(np.complex64)
+ self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False)
+
+ def testFillComplex128(self):
+ np_ans = np.array([[0.15] * 3] * 2).astype(np.complex128)
+ self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False)
+
+ def testFillString(self):
+ np_ans = np.array([[b"yolo"] * 3] * 2)
+ tf_ans = array_ops.fill([2, 3], np_ans[0][0], name="fill").numpy()
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def testFillNegative(self):
+ for shape in (-1,), (2, -1), (-1, 2), (-2), (-3):
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ array_ops.fill(shape, 7)
+
+ def testShapeFunctionEdgeCases(self):
+ # Non-vector dimensions.
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ array_ops.fill([[0, 1], [2, 3]], 1.0)
+
+ # Non-scalar value.
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ array_ops.fill([3, 2], [1.0, 2.0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 20c21eca4c..e31bd8d0ee 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -43,23 +44,31 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.assign_variable_op(
handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
with self.assertRaises(ValueError):
- resource_variable_ops.assign_variable_op(
- handle, constant_op.constant([0], dtype=dtypes.int32)).run()
- resource_variable_ops.assign_variable_op(
- handle, constant_op.constant(0, dtype=dtypes.int32)).run()
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [0],
+ dtype=dtypes.int32)).run()
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ 0,
+ dtype=dtypes.int32)).run()
def testDtypeSurvivesIdentity(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
id_handle = array_ops.identity(handle)
- resource_variable_ops.assign_variable_op(
- id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
+ resource_variable_ops.assign_variable_op(id_handle,
+ constant_op.constant(
+ 0,
+ dtype=dtypes.int32)).run()
def testCreateRead(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
- resource_variable_ops.assign_variable_op(
- handle, constant_op.constant(1, dtype=dtypes.int32)).run()
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ 1,
+ dtype=dtypes.int32)).run()
value = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32).eval()
self.assertAllEqual(1, value)
@@ -67,8 +76,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
def testManyAssigns(self):
with self.test_session() as session:
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
- create = resource_variable_ops.assign_variable_op(
- handle, constant_op.constant(1, dtype=dtypes.int32))
+ create = resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ 1,
+ dtype=dtypes.int32))
with ops.control_dependencies([create]):
first_read = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
@@ -85,8 +96,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
def testAssignAdd(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
- resource_variable_ops.assign_variable_op(
- handle, constant_op.constant(1, dtype=dtypes.int32)).run()
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ 1,
+ dtype=dtypes.int32)).run()
resource_variable_ops.assign_add_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)).run()
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
@@ -96,10 +109,14 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.test_session(use_gpu=True):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
- resource_variable_ops.assign_variable_op(
- handle, constant_op.constant([[1]], dtype=dtypes.int32)).run()
- resource_variable_ops.resource_scatter_add(
- handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)).run()
+ resource_variable_ops.assign_variable_op(handle,
+ constant_op.constant(
+ [[1]],
+ dtype=dtypes.int32)).run()
+ resource_variable_ops.resource_scatter_add(handle, [0],
+ constant_op.constant(
+ [[2]],
+ dtype=dtypes.int32)).run()
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(read.eval(), [[3]])
@@ -113,32 +130,31 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
sess.run(variables.global_variables_initializer())
self.assertEqual(
- resource_variable_ops.var_is_initialized_op(abc.handle).eval(),
- True)
+ resource_variable_ops.var_is_initialized_op(abc.handle).eval(), True)
print(sess.run(abc))
def testConstraintArg(self):
constraint = lambda x: x
- v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
- constraint=constraint)
+ v = resource_variable_ops.ResourceVariable(
+ initial_value=lambda: 1, constraint=constraint)
self.assertEqual(v.constraint, constraint)
constraint = 0
with self.assertRaises(ValueError):
- v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
- constraint=constraint)
+ v = resource_variable_ops.ResourceVariable(
+ initial_value=lambda: 1, constraint=constraint)
def testInitFn(self):
with self.test_session():
- v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
- dtype=dtypes.float32)
+ v = resource_variable_ops.ResourceVariable(
+ initial_value=lambda: 1, dtype=dtypes.float32)
self.assertEqual(v.handle.op.colocation_groups(),
v.initializer.inputs[1].op.colocation_groups())
def testInitFnDtype(self):
with self.test_session():
- v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
- dtype=dtypes.float32)
+ v = resource_variable_ops.ResourceVariable(
+ initial_value=lambda: 1, dtype=dtypes.float32)
self.assertEqual(dtypes.float32, v.value().dtype)
def testInitFnNoDtype(self):
@@ -158,7 +174,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.test_session():
v = resource_variable_ops.ResourceVariable(1.0)
variables.global_variables_initializer().run()
- self.assertEqual(2.0, (v+v).eval())
+ self.assertEqual(2.0, (v + v).eval())
def testAssignMethod(self):
with self.test_session():
@@ -217,8 +233,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
# Handle to a resource not actually created.
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
# Should raise no exception
- sess.run(resource_variable_ops.destroy_resource_op(
- handle, ignore_lookup_error=True))
+ sess.run(
+ resource_variable_ops.destroy_resource_op(
+ handle, ignore_lookup_error=True))
def testAssignDifferentShapes(self):
with self.test_session() as sess, variable_scope.variable_scope(
@@ -226,9 +243,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
var = variable_scope.get_variable("x", shape=[1, 1], dtype=dtypes.float32)
placeholder = array_ops.placeholder(dtypes.float32)
assign = var.assign(placeholder)
- sess.run([assign],
- feed_dict={placeholder: np.zeros(shape=[2, 2],
- dtype=np.float32)})
+ sess.run(
+ [assign],
+ feed_dict={placeholder: np.zeros(shape=[2, 2], dtype=np.float32)})
def testDtypeAfterFromProto(self):
v = resource_variable_ops.ResourceVariable(2.0)
@@ -256,19 +273,28 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
v.initializer.run()
- w = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype,
- shape=v.get_shape(),
- shared_name="var1")
+ w = resource_variable_ops.var_handle_op(
+ dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, w_read.eval())
- x = resource_variable_ops.var_handle_op(dtype=v.dtype.base_dtype,
- shape=v.get_shape(),
- shared_name="var1/")
+ x = resource_variable_ops.var_handle_op(
+ dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1/")
x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
with self.assertRaisesOpError("Resource .*/var1//.* does not exist"):
_ = x_read.eval()
+ def testSharedNameWithNamescope(self):
+ with self.test_session():
+ with ops.name_scope("foo"):
+ v = resource_variable_ops.ResourceVariable(300.0, name="var1")
+ v.initializer.run()
+
+ w = resource_variable_ops.var_handle_op(
+ dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1")
+ w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
+ self.assertEqual(300.0, w_read.eval())
+
def testShape(self):
with self.test_session():
v = resource_variable_ops.ResourceVariable(
@@ -291,6 +317,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
def testControlFlowInitialization(self):
"""Expects an error if an initializer is in a control-flow scope."""
+
def cond(i, _):
return i < 10
@@ -303,5 +330,61 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
control_flow_ops.while_loop(cond, body, [0, 0])
+# TODO(agarwal,apassos): Add more comprehensive tests and/or translate the above
+# tests to work in both GRAPH and EAGER modes.
+# TODO(agarwal): Add tests for sparse_read, scatter_sub
+class ResourceVariableOpsEagerTest(test_util.TensorFlowTestCase):
+
+ def testVariable(self):
+ with context.eager_mode():
+ init = array_ops.ones(shape=[10, 20, 35], dtype=dtypes.int32)
+ constraint = lambda x: x
+ with ops.name_scope("foo"):
+ v = resource_variable_ops.ResourceVariable(
+ name="var1",
+ initial_value=init,
+ caching_device="cpu:0",
+ constraint=constraint)
+ # Test properties
+ self.assertEqual(dtypes.int32, v.dtype)
+ self.assertEqual("foo/var1:0", v.name)
+ self.assertAllEqual([10, 20, 35], v.shape.as_list())
+ self.assertAllEqual(init.device, v.device)
+ self.assertTrue(isinstance(v.handle, ops.EagerTensor))
+ self.assertEqual(constraint, v.constraint)
+ self.assertAllEqual(init.numpy(), v.read_value().numpy())
+ self.assertAllEqual(init.numpy(), v.value().numpy())
+
+ # Callable init.
+ callable_init = lambda: init * 2
+ v2 = resource_variable_ops.ResourceVariable(
+ initial_value=callable_init, name="v2")
+ self.assertEqual("v2:0", v2.name)
+ self.assertAllEqual(2 * init.numpy(), v2.read_value().numpy())
+
+ # Test assign_add.
+ new_v2_val = v2.assign_add(v.read_value())
+ self.assertAllEqual(v.read_value().numpy() * 3, new_v2_val.numpy())
+
+ # Test assign_sub.
+ new_v2_val = v2.assign_sub(v.read_value())
+ self.assertAllEqual(v.read_value().numpy() * 2, new_v2_val.numpy())
+
+ # Test assign.
+ v2.assign(v.read_value())
+ self.assertAllEqual(v.read_value().numpy(), v2.read_value().numpy())
+
+ # Test load
+ v2.load(2 * v.read_value())
+ self.assertAllEqual(2 * v.read_value().numpy(), v2.read_value().numpy())
+
+ # Test convert_to_tensor
+ t = ops.convert_to_tensor(v)
+ self.assertAllEqual(t.numpy(), v.read_value().numpy())
+
+ # Test operations
+ self.assertAllEqual((v * 2).numpy(), (v + v).numpy())
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 7fbc12d1e0..9c9b852b76 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Support for manipulating tensors.
See the @{$python/array_ops} guide.
@@ -85,6 +84,7 @@ from __future__ import print_function
import sys
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -103,7 +103,6 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util.deprecation import deprecated
# pylint: enable=wildcard-import
-
# Used for slicing to specify a new 1 size dimension
newaxis = None
@@ -112,6 +111,23 @@ newaxis = None
_baseslice = slice
+def identity(input, name=None): # pylint: disable=redefined-builtin
+ r"""Return a tensor with the same shape and contents as input.
+
+ Args:
+ input: A `Tensor`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `input`.
+ """
+ if context.in_graph_mode():
+ return gen_array_ops.identity(input, name=name)
+ else:
+ # TODO(apassos): make sure this works ok with gradients.
+ return input._copy() # pylint: disable=protected-access
+
+
# pylint: disable=redefined-builtin,protected-access
def expand_dims(input, axis=None, name=None, dim=None):
"""Inserts a dimension of 1 into a tensor's shape.
@@ -168,25 +184,31 @@ def expand_dims(input, axis=None, name=None, dim=None):
raise ValueError("can't specify both 'dim' and 'axis'")
axis = dim
return gen_array_ops._expand_dims(input, axis, name)
+
+
# pylint: enable=redefined-builtin,protected-access
# Aliases for some automatically-generated names.
# pylint: disable=protected-access
-@deprecated(
- "2016-11-30",
- "This op will be removed after the deprecation date. "
- "Please switch to tf.setdiff1d().")
+@deprecated("2016-11-30", "This op will be removed after the deprecation date. "
+ "Please switch to tf.setdiff1d().")
def listdiff(x, y, out_idx=None, name=None):
return gen_array_ops._list_diff(x, y, out_idx, name)
+
+
listdiff.__doc__ = gen_array_ops._list_diff.__doc__ + "\n" + listdiff.__doc__
+
# pylint: enable=protected-access
# pylint: disable=undefined-variable,protected-access
def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
return gen_array_ops._list_diff(x, y, index_dtype, name)
+
+
setdiff1d.__doc__ = gen_array_ops._list_diff.__doc__
+
# pylint: enable=protected-access
@@ -262,8 +284,8 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
"""
with ops.name_scope(name, "Shape", [input]) as name:
- if isinstance(
- input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
+ if isinstance(input, (sparse_tensor.SparseTensor,
+ sparse_tensor.SparseTensorValue)):
return gen_math_ops.cast(input.dense_shape, out_type)
else:
input_tensor = ops.convert_to_tensor(input)
@@ -314,8 +336,8 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
A `Tensor` of type `out_type`.
"""
with ops.name_scope(name, "Size", [input]) as name:
- if isinstance(
- input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
+ if isinstance(input, (sparse_tensor.SparseTensor,
+ sparse_tensor.SparseTensorValue)):
return gen_math_ops._prod(
gen_math_ops.cast(input.dense_shape, out_type), 0, name=name)
else:
@@ -371,8 +393,8 @@ def rank_internal(input, name=None, optimize=True):
A `Tensor` of type `int32`.
"""
with ops.name_scope(name, "Rank", [input]) as name:
- if isinstance(
- input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
+ if isinstance(input, (sparse_tensor.SparseTensor,
+ sparse_tensor.SparseTensorValue)):
return gen_array_ops.size(input.dense_shape, name=name)
else:
input_tensor = ops.convert_to_tensor(input)
@@ -412,7 +434,8 @@ def _SliceHelper(tensor, slice_spec, var=None):
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[tf.newaxis, :, :].eval()) # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[:, tf.newaxis, :].eval()) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]]
- print(foo[:, :, tf.newaxis].eval()) # => [[[1],[2],[3]], [[4],[5],[6]], [[7],[8],[9]]]
+ print(foo[:, :, tf.newaxis].eval()) # => [[[1],[2],[3]], [[4],[5],[6]],
+ [[7],[8],[9]]]
# Ellipses (3 equivalent operations)
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
@@ -491,8 +514,8 @@ def _SliceHelper(tensor, slice_spec, var=None):
with ops.name_scope(None, "strided_slice",
[tensor] + begin + end + strides) as name:
if begin:
- packed_begin, packed_end, packed_strides = (
- stack(begin), stack(end), stack(strides))
+ packed_begin, packed_end, packed_strides = (stack(begin), stack(end),
+ stack(strides))
else:
var_empty = constant([], dtype=dtypes.int32)
packed_begin = packed_end = packed_strides = var_empty
@@ -748,6 +771,7 @@ def _SliceHelperVar(var, slice_spec):
return _SliceHelper(var._AsTensor(), slice_spec, var)
+
ops.Tensor._override_operator("__getitem__", _SliceHelper)
@@ -851,8 +875,8 @@ def stack(values, axis=0, name="stack"):
if value_shape.ndims is not None:
expanded_num_dims = value_shape.ndims + 1
if axis < -expanded_num_dims or axis >= expanded_num_dims:
- raise ValueError("axis = %d not in [%d, %d)" %
- (axis, -expanded_num_dims, expanded_num_dims))
+ raise ValueError("axis = %d not in [%d, %d)" % (axis, -expanded_num_dims,
+ expanded_num_dims))
return gen_array_ops._pack(values, axis=axis, name=name)
@@ -875,9 +899,9 @@ def _autopacking_helper(list_or_tuple, dtype, name):
for i, elem in enumerate(list_or_tuple):
if ops.is_dense_tensor_like(elem):
if dtype is not None and elem.dtype.base_dtype != dtype:
- raise TypeError(
- "Cannot convert a list containing a tensor of dtype "
- "%s to %s (Tensor is: %r)" % (elem.dtype, dtype, elem))
+ raise TypeError("Cannot convert a list containing a tensor of dtype "
+ "%s to %s (Tensor is: %r)" % (elem.dtype, dtype,
+ elem))
converted_elems.append(elem)
must_pack = True
elif isinstance(elem, (list, tuple)):
@@ -936,13 +960,14 @@ def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
if dtype is not None and dtype != inferred_dtype:
return NotImplemented
return _autopacking_helper(v, inferred_dtype, name or "packed")
-# pylint: enable=invalid-name
+# pylint: enable=invalid-name
+
# NOTE: Register this conversion function to run *before* one that
# assumes every element is a value.
-ops.register_tensor_conversion_function(
- (list, tuple), _autopacking_conversion_function, 99)
+ops.register_tensor_conversion_function((list, tuple),
+ _autopacking_conversion_function, 99)
def unstack(value, num=None, axis=0, name="unstack"):
@@ -1058,14 +1083,12 @@ def concat(values, axis, name="concat"):
# the returned tensor is a scalar.
# TODO(keveman): Implement a standalone type and shape checker.
with ops.name_scope(name) as scope:
- ops.convert_to_tensor(axis,
- name="concat_dim",
- dtype=dtypes.int32).get_shape(
- ).assert_is_compatible_with(tensor_shape.scalar())
+ ops.convert_to_tensor(
+ axis, name="concat_dim",
+ dtype=dtypes.int32).get_shape().assert_is_compatible_with(
+ tensor_shape.scalar())
return identity(values[0], name=scope)
- return gen_array_ops._concat_v2(values=values,
- axis=axis,
- name=name)
+ return gen_array_ops._concat_v2(values=values, axis=axis, name=name)
def boolean_mask(tensor, mask, name="boolean_mask"):
@@ -1104,6 +1127,7 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
boolean_mask(tensor, mask) ==> [[1, 2], [5, 6]]
```
"""
+
def _apply_mask_1d(reshaped_tensor, mask):
"""Mask tensor along dimension 0 with a 1-D mask."""
indices = squeeze(where(mask), squeeze_dims=[1])
@@ -1125,9 +1149,9 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
shape_tensor[:ndims_mask].assert_is_compatible_with(shape_mask)
leading_size = gen_math_ops._prod(shape(tensor)[:ndims_mask], [0])
- tensor = reshape(
- tensor,
- concat([[leading_size], shape(tensor)[ndims_mask:]], 0))
+ tensor = reshape(tensor,
+ concat([[leading_size],
+ shape(tensor)[ndims_mask:]], 0))
first_dim = shape_tensor[:ndims_mask].num_elements()
tensor.set_shape(
tensor_shape.as_shape([first_dim])
@@ -1363,10 +1387,12 @@ def matrix_transpose(a, name="matrix_transpose"):
perm = list(range(ndims - 2)) + [ndims - 1] + [ndims - 2]
else:
a_rank = rank(a)
- perm = concat(
- (gen_math_ops._range(0, a_rank - 2, 1), [a_rank - 1, a_rank - 2]), 0)
+ perm = concat((gen_math_ops._range(0, a_rank - 2, 1),
+ [a_rank - 1, a_rank - 2]), 0)
return transpose(a, perm=perm)
+
+
# pylint: enable=invalid-name
@@ -1383,7 +1409,8 @@ def zeros(shape, dtype=dtypes.float32, name=None):
```
Args:
- shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type `int32`.
+ shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type
+ `int32`.
dtype: The type of an element in the resulting `Tensor`.
name: A name for the operation (optional).
@@ -1442,8 +1469,8 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
return zeros(tensor.shape, dtype=dtype or tensor.dtype, name=name)
if dtype is not None and dtype != tensor.dtype:
- return zeros(shape_internal(tensor, optimize=optimize), dtype=dtype,
- name=name)
+ return zeros(
+ shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
else:
return gen_array_ops._zeros_like(tensor, name=name)
@@ -1480,7 +1507,8 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
if dtype is None:
dtype = tensor.dtype
ret = ones(ones_shape, dtype=dtype, name=name)
- ret.set_shape(tensor.get_shape())
+ if context.in_graph_mode():
+ ret.set_shape(tensor.get_shape())
return ret
@@ -1497,7 +1525,8 @@ def ones(shape, dtype=dtypes.float32, name=None):
```
Args:
- shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type `int32`.
+ shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type
+ `int32`.
dtype: The type of an element in the resulting `Tensor`.
name: A name for the operation (optional).
@@ -1553,7 +1582,8 @@ def placeholder(dtype, shape=None, name=None):
# pylint: disable=redefined-outer-name
def _normalize_sparse_shape(shape, name):
"""Returns a tuple of (Tensor or None, rank or None)."""
- if shape is None: return (None, None)
+ if shape is None:
+ return (None, None)
rank = shape.get_shape()[0] if isinstance(shape, ops.Tensor) else len(shape)
if not isinstance(shape, ops.Tensor) and None in shape:
return (None, rank)
@@ -1605,12 +1635,16 @@ def sparse_placeholder(dtype, shape=None, name=None):
shape = placeholder(dtypes.int64, shape=[rank], name=shape_name)
return sparse_tensor.SparseTensor(
values=placeholder(
- dtype, shape=[None],
+ dtype,
+ shape=[None],
name=(name + "/values") if name is not None else None),
indices=placeholder(
- dtypes.int64, shape=[None, None],
+ dtypes.int64,
+ shape=[None, None],
name=(name + "/indices") if name is not None else None),
dense_shape=shape)
+
+
# pylint: enable=redefined-outer-name
@@ -1764,15 +1798,15 @@ def meshgrid(*args, **kwargs):
# Prepare reshape by inserting dimensions with size 1 where needed
output = []
for i, x in enumerate(args):
- output.append(reshape(stack(x), (s0[:i] + (-1,) + s0[i + 1::])) )
+ output.append(reshape(stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
# Create parameters for broadcasting each tensor to the full size
shapes = [size(x) for x in args]
output_dtype = ops.convert_to_tensor(args[0]).dtype.base_dtype
if indexing == "xy" and ndim > 1:
- output[0] = reshape(output[0], (1, -1) + (1,)*(ndim - 2))
- output[1] = reshape(output[1], (-1, 1) + (1,)*(ndim - 2))
+ output[0] = reshape(output[0], (1, -1) + (1,) * (ndim - 2))
+ output[1] = reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
shapes[0], shapes[1] = shapes[1], shapes[0]
# TODO: improve performance with a broadcast
@@ -1903,23 +1937,22 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
Raises:
TypeError: If either `hypothesis` or `truth` are not a `SparseTensor`.
"""
- if not isinstance(
- hypothesis, (sparse_tensor.SparseTensor,
- sparse_tensor.SparseTensorValue)):
+ if not isinstance(hypothesis, (sparse_tensor.SparseTensor,
+ sparse_tensor.SparseTensorValue)):
raise TypeError("Hypothesis must be a SparseTensor.")
- if not isinstance(
- truth, (sparse_tensor.SparseTensor,
- sparse_tensor.SparseTensorValue)):
+ if not isinstance(truth, (sparse_tensor.SparseTensor,
+ sparse_tensor.SparseTensorValue)):
raise TypeError("Truth must be a SparseTensor.")
- return gen_array_ops._edit_distance(hypothesis.indices,
- hypothesis.values,
- hypothesis.dense_shape,
- truth.indices,
- truth.values,
- truth.dense_shape,
- normalize=normalize,
- name=name)
+ return gen_array_ops._edit_distance(
+ hypothesis.indices,
+ hypothesis.values,
+ hypothesis.dense_shape,
+ truth.indices,
+ truth.values,
+ truth.dense_shape,
+ normalize=normalize,
+ name=name)
@ops.RegisterGradient("FakeQuantWithMinMaxArgs")
@@ -1992,12 +2025,10 @@ def required_space_to_batch_paddings(input_shape,
"""
with ops.name_scope(name, "required_space_to_batch_paddings",
[input_shape, block_shape]):
- input_shape = ops.convert_to_tensor(input_shape,
- dtype=dtypes.int32,
- name="input_shape")
- block_shape = ops.convert_to_tensor(block_shape,
- dtype=dtypes.int32,
- name="block_shape")
+ input_shape = ops.convert_to_tensor(
+ input_shape, dtype=dtypes.int32, name="input_shape")
+ block_shape = ops.convert_to_tensor(
+ block_shape, dtype=dtypes.int32, name="block_shape")
block_shape.get_shape().assert_is_fully_defined()
block_shape.get_shape().assert_has_rank(1)
@@ -2008,9 +2039,8 @@ def required_space_to_batch_paddings(input_shape,
input_shape.get_shape().assert_is_compatible_with([num_block_dims])
if base_paddings is not None:
- base_paddings = ops.convert_to_tensor(base_paddings,
- dtype=dtypes.int32,
- name="base_paddings")
+ base_paddings = ops.convert_to_tensor(
+ base_paddings, dtype=dtypes.int32, name="base_paddings")
base_paddings.get_shape().assert_is_compatible_with([num_block_dims, 2])
else:
base_paddings = zeros([num_block_dims, 2], dtypes.int32)
@@ -2040,11 +2070,11 @@ def required_space_to_batch_paddings(input_shape,
def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=redefined-builtin
- result = space_to_batch_nd(input,
- paddings=paddings,
- block_shape=np.array([block_size, block_size],
- dtype=np.int64),
- name=name)
+ result = space_to_batch_nd(
+ input,
+ paddings=paddings,
+ block_shape=np.array([block_size, block_size], dtype=np.int64),
+ name=name)
result.set_shape(result.get_shape().with_rank(4))
return result
@@ -2053,11 +2083,11 @@ space_to_batch.__doc__ = gen_array_ops._space_to_batch.__doc__
def batch_to_space(input, crops, block_size, name=None): # pylint: disable=redefined-builtin
- result = batch_to_space_nd(input,
- crops=crops,
- block_shape=np.array([block_size, block_size],
- dtype=np.int64),
- name=name)
+ result = batch_to_space_nd(
+ input,
+ crops=crops,
+ block_shape=np.array([block_size, block_size], dtype=np.int64),
+ name=name)
result.set_shape(result.get_shape().with_rank(4))
return result
@@ -2065,8 +2095,13 @@ def batch_to_space(input, crops, block_size, name=None): # pylint: disable=rede
batch_to_space.__doc__ = gen_array_ops._batch_to_space.__doc__
-def one_hot(indices, depth, on_value=None, off_value=None,
- axis=None, dtype=None, name=None):
+def one_hot(indices,
+ depth,
+ on_value=None,
+ off_value=None,
+ axis=None,
+ dtype=None,
+ name=None):
"""Returns a one-hot tensor.
The locations represented by indices in `indices` take value `on_value`,
@@ -2190,8 +2225,9 @@ def one_hot(indices, depth, on_value=None, off_value=None,
TypeError: If dtype of either `on_value` or `off_value` don't match `dtype`
TypeError: If dtype of `on_value` and `off_value` don't match one another
"""
- with ops.name_scope(name, "one_hot", [indices, depth, on_value, off_value,
- axis, dtype]) as name:
+ with ops.name_scope(name, "one_hot",
+ [indices, depth, on_value, off_value, axis,
+ dtype]) as name:
on_exists = on_value is not None
off_exists = off_value is not None
@@ -2274,9 +2310,8 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
# to length as a matrix with 1 column: [[1], [3], [2]].
# Because of broadcasting on both arguments this comparison results
# in a matrix of size (len(lengths), maxlen)
- row_vector = gen_math_ops._range(constant(0, maxlen.dtype),
- maxlen,
- constant(1, maxlen.dtype))
+ row_vector = gen_math_ops._range(
+ constant(0, maxlen.dtype), maxlen, constant(1, maxlen.dtype))
# Since maxlen >= max(lengths), it is safe to use maxlen as a cast
# authoritative type. Whenever maxlen fits into tf.int32, so do the lengths.
matrix = gen_math_ops.cast(expand_dims(lengths, 1), maxlen.dtype)
@@ -2388,6 +2423,8 @@ def where(condition, x=None, y=None, name=None):
def reverse(tensor, axis, name=None):
return gen_array_ops.reverse_v2(tensor, axis, name)
+
+
reverse.__doc__ = gen_array_ops.reverse_v2.__doc__
@@ -2409,9 +2446,10 @@ def reverse_sequence(input,
seq_dim=seq_axis,
batch_dim=batch_axis,
name=name)
-# pylint: enable=redefined-builtin
+# pylint: enable=redefined-builtin
+
reverse_sequence.__doc__ = deprecation.rewrite_argument_docstring(
deprecation.rewrite_argument_docstring(
gen_array_ops.reverse_sequence.__doc__, "batch_dim", "batch_axis"),
@@ -2422,8 +2460,10 @@ def gather(params, indices, validate_indices=None, name=None, axis=0):
# TODO(rjryan): Remove "Gather" creation in favor of GatherV2 once the forward
# compatibility 3 week period has passed.
if axis == 0:
- return gen_array_ops.gather(params, indices,
- validate_indices=validate_indices, name=name)
+ return gen_array_ops.gather(
+ params, indices, validate_indices=validate_indices, name=name)
else:
return gen_array_ops.gather_v2(params, indices, axis, name=name)
+
+
gather.__doc__ = gen_array_ops.gather_v2.__doc__
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index e3a0a45f29..c31121b3eb 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -101,16 +102,20 @@ class ResourceVariable(variables.Variable):
if initial_value:
raise ValueError("variable_def and initial_value are mutually "
"exclusive.")
+ if not context.in_graph_mode():
+ raise ValueError("Creating ResourceVariable from variable_def"
+ " only supported in GRAPH mode.")
self._init_from_proto(variable_def, import_scope=import_scope)
else:
- self._init_from_args(initial_value=initial_value,
- trainable=trainable,
- collections=collections,
- validate_shape=validate_shape,
- caching_device=caching_device,
- name=name,
- dtype=dtype,
- constraint=constraint)
+ self._init_from_args(
+ initial_value=initial_value,
+ trainable=trainable,
+ collections=collections,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ dtype=dtype,
+ constraint=constraint)
# pylint: disable=unused-argument
def _init_from_args(self,
@@ -122,7 +127,6 @@ class ResourceVariable(variables.Variable):
name=None,
dtype=None,
constraint=None):
-
"""Creates a variable.
Args:
@@ -177,34 +181,48 @@ class ResourceVariable(variables.Variable):
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
self._save_slice_info = None
+ in_graph_mode = context.in_graph_mode()
with ops.control_dependencies(None):
- with ops.name_scope(name, "Variable", [] if init_from_fn else
- [initial_value]) as name:
+ with ops.name_scope(name, "Variable", []
+ if init_from_fn else [initial_value]) as name:
# pylint: disable=protected-access
- true_name = ops._name_from_scope_name(name)
+ handle_name = ops._name_from_scope_name(name)
if init_from_fn:
# Use attr_scope and device(None) to simulate the behavior of
# colocate_with when the variable we want to colocate with doesn't
# yet exist.
- attr = attr_value_pb2.AttrValue(
- list=attr_value_pb2.AttrValue.ListValue(
- s=[compat.as_bytes("loc:@%s" % true_name)]))
- with ops.get_default_graph()._attr_scope({"_class": attr}):
- with ops.name_scope("Initializer"), ops.device(None):
- self._initial_value = ops.convert_to_tensor(
- initial_value(), name="initial_value", dtype=dtype)
+ if in_graph_mode:
+ attr = attr_value_pb2.AttrValue(
+ list=attr_value_pb2.AttrValue.ListValue(
+ s=[compat.as_bytes("loc:@%s" % handle_name)]))
+ with ops.get_default_graph()._attr_scope({"_class": attr}):
+ with ops.name_scope("Initializer"), ops.device(None):
+ initial_value = ops.convert_to_tensor(
+ initial_value(), name="initial_value", dtype=dtype)
+ self._handle = gen_resource_variable_ops.var_handle_op(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=handle_name,
+ name=name)
+ else:
+ initial_value = initial_value()
self._handle = gen_resource_variable_ops.var_handle_op(
- shape=self._initial_value.get_shape(),
- dtype=self._initial_value.dtype.base_dtype,
- shared_name=true_name, name=name)
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=handle_name,
+ name=name,
+ container="")
# pylint: enable=protected-access
# Or get the initial value from a Tensor or Python object.
else:
- self._initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
+ with ops.name_scope("Initializer"):
+ initial_value = ops.convert_to_tensor(
+ initial_value, name="initial_value", dtype=dtype)
# pylint: disable=protected-access
- if self._initial_value.op._get_control_flow_context() is not None:
+ if (in_graph_mode and
+ initial_value is not None and
+ initial_value.op._get_control_flow_context() is not None):
raise ValueError(
"Initializer for variable %s is from inside a control-flow "
"construct, such as a loop or conditional. When creating a "
@@ -212,45 +230,65 @@ class ResourceVariable(variables.Variable):
"initializer." % name)
# pylint: enable=protected-access
self._handle = gen_resource_variable_ops.var_handle_op(
- shape=self._initial_value.get_shape(),
- dtype=self._initial_value.dtype.base_dtype,
- shared_name=true_name, name=name)
-
- self._dtype = self._initial_value.dtype.base_dtype
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=handle_name,
+ name=name,
+ container="")
+
+ self._initial_value = initial_value if in_graph_mode else None
+ self._handle_name = handle_name + ":0"
+ self._dtype = initial_value.dtype.base_dtype
self._constraint = constraint
- with ops.name_scope("IsInitialized"):
- self._is_initialized_op = (
- gen_resource_variable_ops.var_is_initialized_op(self._handle))
- if initial_value is not None:
- with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
- self._initializer_op = gen_resource_variable_ops.assign_variable_op(
- self._handle,
- self._build_initializer_expr(self._initial_value),
- name=n)
- with ops.name_scope("Read"), ops.colocate_with(self._handle):
- # Manually assign reads to the handle's device to avoid log messages.
- with ops.device(self._handle.device):
- value = gen_resource_variable_ops.read_variable_op(
- self._handle, dtype=self._dtype)
- self._graph_element = value
- if caching_device is not None:
- # Variables may be created in a tf.device() or ops.colocate_with()
- # context. At the same time, users would expect caching device to be
- # independent of this context, and/or would not expect the current
- # device context to be merged with the caching device spec.
- # Therefore we reset the colocation stack before creating the cached
- # value. Note that resetting the colocation stack will also reset
- # the device stack.
- with ops.colocate_with(None, ignore_existing=True):
- with ops.device(caching_device):
- self._cached_value = array_ops.identity(value)
+ if in_graph_mode:
+ with ops.name_scope("IsInitialized"):
+ self._is_initialized_op = (
+ gen_resource_variable_ops.var_is_initialized_op(self._handle))
+ if initial_value is not None:
+ with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
+ self._initializer_op = (
+ gen_resource_variable_ops.assign_variable_op(
+ self._handle,
+ self._build_initializer_expr(initial_value),
+ name=n))
+ with ops.name_scope("Read"), ops.colocate_with(self._handle):
+ # Manually assign reads to the handle's device to avoid log
+ # messages.
+ with ops.device(self._handle.device):
+ value = gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
+ self._graph_element = value
+ if caching_device is not None:
+ # Variables may be created in a tf.device() or ops.colocate_with()
+ # context. At the same time, users would expect caching device to
+ # be independent of this context, and/or would not expect the
+ # current device context to be merged with the caching device
+ # spec. Therefore we reset the colocation stack before creating
+ # the cached value. Note that resetting the colocation stack will
+ # also reset the device stack.
+ with ops.colocate_with(None, ignore_existing=True):
+ with ops.device(caching_device):
+ self._cached_value = array_ops.identity(value)
+ else:
+ self._cached_value = None
+ else:
+ gen_resource_variable_ops.assign_variable_op(self._handle,
+ initial_value)
+ self._is_initialized_op = None
+ self._initializer_op = None
+ self._graph_element = None
+ if caching_device:
+ with ops.device(caching_device):
+ self._cached_value = gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
else:
self._cached_value = None
- ops.add_to_collections(collections, self)
+ ops.add_to_collections(collections, self)
def _init_from_proto(self, variable_def, import_scope=None):
"""Initializes from `VariableDef` proto."""
+ assert context.in_graph_mode()
assert isinstance(variable_def, variable_pb2.VariableDef)
if not variable_def.is_resource:
raise ValueError("Trying to restore Variable as ResourceVariable.")
@@ -258,15 +296,16 @@ class ResourceVariable(variables.Variable):
# Create from variable_def.
g = ops.get_default_graph()
self._handle = g.as_graph_element(
- ops.prepend_name_scope(variable_def.variable_name,
- import_scope=import_scope))
+ ops.prepend_name_scope(
+ variable_def.variable_name, import_scope=import_scope))
+ self._handle_name = self._handle.name
self._initializer_op = g.as_graph_element(
- ops.prepend_name_scope(variable_def.initializer_name,
- import_scope=import_scope))
+ ops.prepend_name_scope(
+ variable_def.initializer_name, import_scope=import_scope))
if variable_def.snapshot_name:
self._cached_value = g.as_graph_element(
- ops.prepend_name_scope(variable_def.snapshot_name,
- import_scope=import_scope))
+ ops.prepend_name_scope(
+ variable_def.snapshot_name, import_scope=import_scope))
else:
self._cached_value = None
if variable_def.HasField("save_slice_info_def"):
@@ -297,16 +336,22 @@ class ResourceVariable(variables.Variable):
@property
def name(self):
"""The name of the handle for this variable."""
- return self._handle.name
+ return self._handle_name
@property
def shape(self):
"""The shape of this variable."""
- return tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
+ if context.in_graph_mode():
+ return tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
+ else:
+ # TODO(agarwal): avoid a call to value here.
+ return self.value().shape
@property
def create(self):
"""The op responsible for initializing this variable."""
+ if not context.in_graph_mode():
+ raise RuntimeError("Calling create in EAGER mode not supported.")
return self._initializer_op
@property
@@ -335,6 +380,8 @@ class ResourceVariable(variables.Variable):
@property
def initial_value(self):
"""Returns the Tensor used as the initial value for the variable."""
+ if context.in_eager_mode():
+ raise RuntimeError("initial_value not supported in EAGER mode.""")
return self._initial_value
@property
@@ -354,6 +401,8 @@ class ResourceVariable(variables.Variable):
def eval(self, session=None):
"""Evaluates and returns the value of this variable."""
+ if context.in_eager_mode():
+ raise RuntimeError("Trying to eval in EAGER mode")
return self._graph_element.eval(session=session)
def _set_save_slice_info(self, save_slice_info):
@@ -397,32 +446,38 @@ class ResourceVariable(variables.Variable):
Args:
export_scope: Optional `string`. Name scope to remove.
+ Raises:
+ RuntimeError: If run in EAGER mode.
+
Returns:
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
in the specified name scope.
"""
- if (export_scope is None or
- self.handle.name.startswith(export_scope)):
+ if context.in_eager_mode():
+ raise RuntimeError("to_proto not supported in EAGER mode.")
+ if export_scope is None or self.handle.name.startswith(export_scope):
var_def = variable_pb2.VariableDef()
- var_def.variable_name = ops.strip_name_scope(
- self.handle.name, export_scope)
- var_def.initializer_name = ops.strip_name_scope(
- self.initializer.name, export_scope)
+ var_def.variable_name = ops.strip_name_scope(self.handle.name,
+ export_scope)
+ var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
+ export_scope)
if self._cached_value is not None:
- var_def.snapshot_name = ops.strip_name_scope(
- self._cached_value.name, export_scope)
+ var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
+ export_scope)
var_def.is_resource = True
if self._save_slice_info:
- var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
- export_scope=export_scope))
+ var_def.save_slice_info_def.MergeFrom(
+ self._save_slice_info.to_proto(export_scope=export_scope))
return var_def
else:
return None
@staticmethod
def from_proto(variable_def, import_scope=None):
- return ResourceVariable(variable_def=variable_def,
- import_scope=import_scope)
+ if context.in_eager_mode():
+ raise RuntimeError("from_proto not supported in EAGER mode.")
+ return ResourceVariable(
+ variable_def=variable_def, import_scope=import_scope)
@staticmethod
def _OverloadAllOperators(): # pylint: disable=invalid-name
@@ -454,6 +509,7 @@ class ResourceVariable(variables.Variable):
def _run_op(a, *args):
# pylint: disable=protected-access
return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
+
# Propagate __doc__ to wrapper
try:
_run_op.__doc__ = getattr(ops.Tensor, operator).__doc__
@@ -468,61 +524,66 @@ class ResourceVariable(variables.Variable):
# TODO(apassos): this here and below is not atomic. Consider making it
# atomic if there's a way to do so without a performance cost for those who
# don't need it.
- with ops.control_dependencies(
- [gen_resource_variable_ops.assign_sub_variable_op(
+ with ops.control_dependencies([
+ gen_resource_variable_ops.assign_sub_variable_op(
self.handle,
- ops.convert_to_tensor(delta, dtype=self.dtype), name=name)]):
+ ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ ]):
return self.read_value()
def assign_add(self, delta, use_locking=None, name=None):
- with ops.control_dependencies(
- [gen_resource_variable_ops.assign_add_variable_op(
+ with ops.control_dependencies([
+ gen_resource_variable_ops.assign_add_variable_op(
self.handle,
- ops.convert_to_tensor(delta, dtype=self.dtype), name=name)]):
+ ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ ]):
return self.read_value()
def assign(self, value, use_locking=None, name=None):
- with ops.control_dependencies(
- [gen_resource_variable_ops.assign_variable_op(
+ with ops.control_dependencies([
+ gen_resource_variable_ops.assign_variable_op(
self.handle,
- ops.convert_to_tensor(value, dtype=self.dtype), name=name)]):
+ ops.convert_to_tensor(value, dtype=self.dtype),
+ name=name)
+ ]):
return self.read_value()
- def _strided_slice_assign(self,
- begin,
- end,
- strides,
- value,
- name,
- begin_mask,
- end_mask,
- ellipsis_mask,
- new_axis_mask,
+ def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
+ end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask):
- with ops.control_dependencies([gen_array_ops.resource_strided_slice_assign(
- ref=self.handle,
- begin=begin,
- end=end,
- strides=strides,
- value=value,
- name=name,
- begin_mask=begin_mask,
- end_mask=end_mask,
- ellipsis_mask=ellipsis_mask,
- new_axis_mask=new_axis_mask,
- shrink_axis_mask=shrink_axis_mask)]):
+ with ops.control_dependencies([
+ gen_array_ops.resource_strided_slice_assign(
+ ref=self.handle,
+ begin=begin,
+ end=end,
+ strides=strides,
+ value=value,
+ name=name,
+ begin_mask=begin_mask,
+ end_mask=end_mask,
+ ellipsis_mask=ellipsis_mask,
+ new_axis_mask=new_axis_mask,
+ shrink_axis_mask=shrink_axis_mask)
+ ]):
+ return self.value()
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ del name
+ if dtype is not None and dtype != self.value().dtype:
+ print("trying to switch the dtype to ", dtype, " from ",
+ self.value().dtype)
+ return NotImplemented
+ if as_ref:
+ return self.read_value().op.inputs[0]
+ else:
return self.value()
-# pylint: disable=unused-argument,protected-access
def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
- if dtype is not None and dtype != var.value().dtype:
- print("trying to switch the dtype to ", dtype, " from ", var.value().dtype)
- return NotImplemented
- if as_ref:
- return var.read_value().op.inputs[0]
- return var.value()
-# pylint: enable=unused-argument,protected-access
+ return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
+
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
@@ -531,8 +592,7 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
# otherwise lead to the wrong behavior.
ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
ops.register_tensor_conversion_function(
- variables.Variable,
- variables.Variable._TensorConversionFunction) # pylint: disable=protected-access
+ variables.Variable, variables.Variable._TensorConversionFunction) # pylint: disable=protected-access
# pylint: disable=protected-access
ResourceVariable._OverloadAllOperators()
@@ -551,6 +611,9 @@ def _GatherGrad(op, grad):
# Build appropriately shaped IndexedSlices
# Walk graph back until the original handle is found.
# TODO(apassos): more robust way of getting the shape.
+ # TODO(apassos): implement this for EAGER mode.
+ if context.in_eager_mode():
+ raise NotImplementedError("_GatherGrad not implemented for EAGER mode")
handle = op.inputs[0]
while handle.op.type != "VarHandleOp":
handle = handle.op.inputs[0]
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 7c12020263..a706193a40 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -187,7 +188,11 @@ class Variable(object):
ValueError: If both `variable_def` and initial_value are specified.
ValueError: If the initial value is not specified, or does not have a
shape and `validate_shape` is `True`.
+ RuntimeError: If created in EAGER mode.
"""
+ if not context.in_graph_mode():
+ raise RuntimeError("Variable not supported in Eager mode. "
+ "Please use ResourceVariable instead")
if variable_def:
# If variable_def is provided, recreates the variable from its fields.
if initial_value:
@@ -692,12 +697,15 @@ class Variable(object):
Raises:
ValueError: Session is not passed and no default session
"""
- session = session or ops.get_default_session()
- if session is None:
- raise ValueError(
- "Either session argument should be provided or default session "
- "should be established")
- session.run(self._initializer_op, {self._initializer_op.inputs[1]: value})
+ if context.in_graph_mode():
+ session = session or ops.get_default_session()
+ if session is None:
+ raise ValueError(
+ "Either session argument should be provided or default session "
+ "should be established")
+ session.run(self._initializer_op, {self._initializer_op.inputs[1]: value})
+ else:
+ self.assign(value)
# Conversion to tensor.
@staticmethod
@@ -1057,7 +1065,10 @@ class PartitionedVariable(object):
`partitions` is not a list.
ValueError: If `variable_list` is empty, or the `Variable` shape
information does not match `shape`, or `partitions` has invalid values.
+ RuntimeError: If created in EAGER mode.
"""
+ if not context.in_graph_mode():
+ raise RuntimeError("PartitionedVariable not supported in Eager mode.")
if not isinstance(variable_list, (list, tuple)):
raise TypeError(
"variable_list is not a list or tuple: %s" % variable_list)
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index e93c3c9c38..1f91259e7c 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -25,6 +25,7 @@ limitations under the License.
%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_ContextAddFunctionDef;
%rename("%s") TFE_TensorHandleDim;
+%rename("%s") TFE_TensorHandleDeviceName;
%rename("%s") TFE_TensorHandleCopyToDevice;
%rename("%s") TFE_NewOp;
%rename("%s") TFE_Py_TensorHandleToNumpy;