aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/framework/BUILD1
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/framework/meta_graph_test.py14
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py45
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py82
-rw-r--r--tensorflow/python/training/checkpoint_utils.py9
-rw-r--r--tensorflow/python/training/checkpoint_utils_test.py26
-rw-r--r--tensorflow/python/training/saver.py10
-rw-r--r--tensorflow/python/training/saver_test.py18
9 files changed, 171 insertions, 35 deletions
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index 50868c6d6c..ac043fda06 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -62,6 +62,7 @@ tf_custom_op_py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:smart_cond",
"//tensorflow/python:sparse_tensor",
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index fbdf15a69f..cb54cebf0f 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3954,6 +3954,7 @@ py_test(
":partitioned_variables",
":platform",
":pywrap_tensorflow",
+ ":resource_variable_ops",
":state_ops",
":training",
":variable_scope",
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 19dcd6a1b3..21963d0bee 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -905,20 +905,6 @@ class ExportImportAcrossScopesTest(test.TestCase):
with variable_scope.variable_scope("importA/keepA"):
graph_fn(use_resource=use_resource)
- if use_resource:
- # Bringing in collections that contain ResourceVariables will adds ops
- # to the graph the first time a variable is encountered, so mimic the
- # same behavior.
- seen_variables = set()
- for collection_key in sorted([
- ops.GraphKeys.GLOBAL_VARIABLES,
- ops.GraphKeys.TRAINABLE_VARIABLES,
- ]):
- for var in expected_graph.get_collection(collection_key):
- if var not in seen_variables:
- var._read_variable_op()
- seen_variables.add(var)
-
result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 8503f3e031..71699fe0ad 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -277,6 +277,20 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.evaluate(v.assign(2.0))
self.assertEqual(2.0, self.evaluate(v.value()))
+ # Tests for the 'read_value' argument:
+ assign_with_read = v.assign(3.0, read_value=True)
+ if context.in_graph_mode():
+ self.assertEqual(3.0, assign_with_read.eval())
+ else:
+ self.assertEqual(3.0, self.evaluate(assign_with_read))
+ assign_without_read = v.assign(4.0, read_value=False)
+ if context.in_graph_mode():
+ self.assertIsInstance(assign_without_read, ops.Operation)
+ else:
+ self.assertIsNone(assign_without_read)
+ self.evaluate(assign_without_read)
+ self.assertEqual(4.0, self.evaluate(v.value()))
+
@test_util.run_in_graph_and_eager_modes()
def testLoad(self):
v = resource_variable_ops.ResourceVariable(1.0, name="var0")
@@ -329,6 +343,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
w = resource_variable_ops.ResourceVariable.from_proto(v.to_proto())
self.assertEquals(2, math_ops.add(w, 1).eval())
+ self.assertEquals(v._handle, w._handle)
+ self.assertEquals(v._graph_element, w._graph_element)
+
@test_util.run_in_graph_and_eager_modes()
def testAssignAddMethod(self):
v = resource_variable_ops.ResourceVariable(1.0, name="var0")
@@ -336,6 +353,20 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.evaluate(v.assign_add(1.0))
self.assertEqual(2.0, self.evaluate(v.value()))
+ # Tests for the 'read_value' argument:
+ assign_with_read = v.assign_add(1.0, read_value=True)
+ if context.in_graph_mode():
+ self.assertEqual(3.0, assign_with_read.eval())
+ else:
+ self.assertEqual(3.0, self.evaluate(assign_with_read))
+ assign_without_read = v.assign_add(1.0, read_value=False)
+ if context.in_graph_mode():
+ self.assertIsInstance(assign_without_read, ops.Operation)
+ else:
+ self.assertIsNone(assign_without_read)
+ self.evaluate(assign_without_read)
+ self.assertEqual(4.0, self.evaluate(v.value()))
+
@test_util.run_in_graph_and_eager_modes()
def testAssignSubMethod(self):
v = resource_variable_ops.ResourceVariable(3.0, name="var0")
@@ -343,6 +374,20 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.evaluate(v.assign_sub(1.0))
self.assertEqual(2.0, self.evaluate(v.value()))
+ # Tests for the 'read_value' argument:
+ assign_with_read = v.assign_sub(1.0, read_value=True)
+ if context.in_graph_mode():
+ self.assertEqual(1.0, assign_with_read.eval())
+ else:
+ self.assertEqual(1.0, self.evaluate(assign_with_read))
+ assign_without_read = v.assign_sub(1.0, read_value=False)
+ if context.in_graph_mode():
+ self.assertIsInstance(assign_without_read, ops.Operation)
+ else:
+ self.assertIsNone(assign_without_read)
+ self.evaluate(assign_without_read)
+ self.assertEqual(0.0, self.evaluate(v.value()))
+
@test_util.run_in_graph_and_eager_modes()
def testDestroyResource(self):
v = resource_variable_ops.ResourceVariable(3.0, name="var0")
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 2d6d0672e0..bf186f1734 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -534,7 +534,8 @@ class ResourceVariable(variables.Variable):
self._save_slice_info = None
self._caching_device = None
self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
- self._graph_element = self.value()
+ self._graph_element = g.get_tensor_by_name(
+ self._handle.op.name + "/Read/ReadVariableOp:0")
self._constraint = None
def __nonzero__(self):
@@ -788,20 +789,52 @@ class ResourceVariable(variables.Variable):
__array_priority__ = 100
- def assign_sub(self, delta, use_locking=None, name=None):
+ def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
+ """Subtracts a value from this variable.
+
+ Args:
+ delta: A `Tensor`. The value to subtract from this variable.
+ use_locking: If `True`, use locking during the operation.
+ name: The name to use for the operation.
+ read_value: A `bool`. Whether to read and return the new value of the
+ variable or not.
+
+ Returns:
+ If `read_value` is `True`, this method will return the new value of the
+ variable after the assignment has completed. Otherwise, when in graph mode
+ it will return the `Operation` that does the assignment, and when in eager
+ mode it will return `None`.
+ """
# TODO(apassos): this here and below is not atomic. Consider making it
# atomic if there's a way to do so without a performance cost for those who
# don't need it.
- return self._lazy_read(gen_resource_variable_ops.assign_sub_variable_op(
- self.handle,
- ops.convert_to_tensor(delta, dtype=self.dtype),
- name=name))
+ assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
+ self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name)
+ if read_value:
+ return self._lazy_read(assign_sub_op)
+ return assign_sub_op
+
+ def assign_add(self, delta, use_locking=None, name=None, read_value=True):
+ """Adds a value to this variable.
+
+ Args:
+ delta: A `Tensor`. The value to add to this variable.
+ use_locking: If `True`, use locking during the operation.
+ name: The name to use for the operation.
+ read_value: A `bool`. Whether to read and return the new value of the
+ variable or not.
- def assign_add(self, delta, use_locking=None, name=None):
- return self._lazy_read(gen_resource_variable_ops.assign_add_variable_op(
- self.handle,
- ops.convert_to_tensor(delta, dtype=self.dtype),
- name=name))
+ Returns:
+ If `read_value` is `True`, this method will return the new value of the
+ variable after the assignment has completed. Otherwise, when in graph mode
+ it will return the `Operation` that does the assignment, and when in eager
+ mode it will return `None`.
+ """
+ assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
+ self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name)
+ if read_value:
+ return self._lazy_read(assign_add_op)
+ return assign_add_op
def _lazy_read(self, op):
if hasattr(self, "_trainable") and self._trainable:
@@ -811,14 +844,29 @@ class ResourceVariable(variables.Variable):
self._in_graph_mode,
self._handle_deleter if not self._in_graph_mode else None, op)
- def assign(self, value, use_locking=None, name=None):
+ def assign(self, value, use_locking=None, name=None, read_value=True):
+ """Assigns a new value to this variable.
+
+ Args:
+ value: A `Tensor`. The new value for this variable.
+ use_locking: If `True`, use locking during the assignment.
+ name: The name to use for the assignment.
+ read_value: A `bool`. Whether to read and return the new value of the
+ variable or not.
+
+ Returns:
+ If `read_value` is `True`, this method will return the new value of the
+ variable after the assignment has completed. Otherwise, when in graph mode
+ it will return the `Operation` that does the assignment, and when in eager
+ mode it will return `None`.
+ """
value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
self._shape.assert_is_compatible_with(value_tensor.shape)
- return self._lazy_read(
- gen_resource_variable_ops.assign_variable_op(
- self.handle,
- value_tensor,
- name=name))
+ assign_op = gen_resource_variable_ops.assign_variable_op(
+ self.handle, value_tensor, name=name)
+ if read_value:
+ return self._lazy_read(assign_op)
+ return assign_op
def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
end_mask, ellipsis_mask, new_axis_mask,
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 0af1cdecfa..52d092bc22 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -23,6 +23,7 @@ import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import ops
from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
@@ -289,10 +290,14 @@ def _set_checkpoint_initializer(variable,
name: Name of the operation.
"""
base_type = variable.dtype.base_dtype
- with ops.colocate_with(variable):
+ with ops.colocate_with(variable.op):
restore_op = io_ops.restore_v2(
ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
- variable._initializer_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access
+ if isinstance(variable, resource_variable_ops.ResourceVariable):
+ init_op = variable.assign(restore_op, read_value=False)
+ else:
+ init_op = state_ops.assign(variable, restore_op)
+ variable._initializer_op = init_op # pylint:disable=protected-access
restore_op.set_shape(variable.shape)
variable._initial_value = restore_op # pylint:disable=protected-access
diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py
index a461b24cbb..640bd665cb 100644
--- a/tensorflow/python/training/checkpoint_utils_test.py
+++ b/tensorflow/python/training/checkpoint_utils_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -362,6 +363,31 @@ class CheckpointsTest(test.TestCase):
checkpoint_utils.init_from_checkpoint(checkpoint_dir,
{"useful_scope": "some_scope/"})
+ def testNoAdditionalReadOpsForResourceVariables(self):
+ checkpoint_dir = self.get_temp_dir()
+ with self.test_session() as session:
+ v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
+
+ # New graph and session.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as session:
+ my1 = resource_variable_ops.ResourceVariable([[0.0] * 10], name="my1")
+
+ with ops.name_scope("init_from_checkpoint"):
+ checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})
+
+ # Basic sanity checks:
+ session.run(variables.global_variables_initializer())
+ self.assertAllEqual(session.run(my1), v1)
+
+ ops_in_init_from_checkpoint_scope = [
+ op for op in g.get_operations()
+ if (op.name.startswith("init_from_checkpoint/") and
+ not op.name.startswith("init_from_checkpoint/checkpoint_initializer"
+ ) and op.type != "AssignVariableOp")
+ ]
+ self.assertEqual(ops_in_init_from_checkpoint_scope, [])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index e8ea5abfbd..6c80562968 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -584,7 +584,10 @@ class BaseSaverBuilder(object):
else:
if context.in_graph_mode():
if convert_variable_to_tensor:
- var = ops.internal_convert_to_tensor(var, as_ref=True)
+ if isinstance(var, resource_variable_ops.ResourceVariable):
+ var = var._graph_element # pylint: disable=protected-access
+ else:
+ var = ops.internal_convert_to_tensor(var, as_ref=True)
if not BaseSaverBuilder._IsVariable(var):
raise TypeError("Variable to save is not a Variable: %s" % var)
if var.op.type == "ReadVariableOp":
@@ -674,7 +677,10 @@ class BaseSaverBuilder(object):
"mode is enabled, type: %s." % type(op))
saveable = BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
else:
- variable = ops.internal_convert_to_tensor(op, as_ref=True)
+ if isinstance(op, resource_variable_ops.ResourceVariable):
+ variable = op._graph_element # pylint: disable=protected-access
+ else:
+ variable = ops.internal_convert_to_tensor(op, as_ref=True)
if not BaseSaverBuilder._IsVariable(variable):
raise TypeError("names_to_saveables must be a dict mapping string "
"names to Tensors/Variables. Not a variable: %s" %
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index b758ceaab0..7947765449 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -262,6 +262,24 @@ class SaverTest(test.TestCase):
save2.restore(sess, save_path)
self.assertEquals(self.evaluate(v), [1])
+ def testNoAdditionalOpsAddedBySaverForResourceVariablesOutsideSaveScope(self):
+ with ops_lib.Graph().as_default() as g:
+ v = resource_variable_ops.ResourceVariable(1.0, name="v")
+ with ops_lib.name_scope("saver1"):
+ saver_module.Saver()
+ with ops_lib.name_scope("saver2"):
+ saver_module.Saver({"name": v})
+ ops_in_saver1_scope_but_not_save_scope = [
+ op for op in g.get_operations()
+ if (op.name.startswith("saver1/") and
+ not op.name.startswith("saver1/save/"))]
+ self.assertEqual(ops_in_saver1_scope_but_not_save_scope, [])
+ ops_in_saver2_scope_but_not_save_scope = [
+ op for op in g.get_operations()
+ if (op.name.startswith("saver2/") and
+ not op.name.startswith("saver2/save/"))]
+ self.assertEqual(ops_in_saver2_scope_but_not_save_scope, [])
+
def testSaveCopyRestoreWithSaveRelativePaths(self):
"""Save, copy checkpoint dir and restore from copied dir.