aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-23 11:40:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-23 11:47:13 -0800
commitaa4aadef0492005f60c03aa40789efad719752f0 (patch)
tree2e09b1c91ac3b05a699be4138d59988b2ae50478
parent18af08feafad32b44fd9f1a20e143b716c82f21b (diff)
Makes resource variables saveable/restorable.
Change: 140055398
-rw-r--r--tensorflow/core/framework/resource_mgr.h13
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc100
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc64
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py16
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py47
-rw-r--r--tensorflow/python/ops/state_ops.py22
-rw-r--r--tensorflow/python/training/saver.py35
-rw-r--r--tensorflow/python/training/saver_test.py37
9 files changed, 182 insertions, 153 deletions
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 2e22f3cdb3..ae4186ee71 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -177,6 +177,11 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
+// Looks up or creates a resource.
+template <typename T>
+Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
+ T** value, std::function<Status(T**)> creator);
+
// Destroys a resource pointed by a given resource handle.
template <typename T>
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
@@ -414,6 +419,14 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
}
template <typename T>
+Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
+ T** value, std::function<Status(T**)> creator) {
+ TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
+ return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
+ creator);
+}
+
+template <typename T>
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
return ctx->resource_manager()->Delete<T>(p.container(), p.name());
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 0c4f2d2f84..602457de10 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -35,74 +35,6 @@ namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL(Var);
template <typename Device, typename T>
-class CreateVariableOp : public OpKernel {
- public:
- CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) {
- OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
- OP_REQUIRES(c, DataTypeToEnum<T>::value == dtype_,
- errors::InvalidArgument(
- "Dtypes don't match; expected ", DataTypeString(dtype_),
- " got ", DataTypeString(DataTypeToEnum<T>::value)));
- }
-
- void Compute(OpKernelContext* context) override {
- Var* var = new Var(dtype_);
- AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- PersistentTensor copy;
- const Tensor& value = context->input(1);
-
- // TODO(apassos): allocating and copying is unnecessary if we are the last
- // user of the value tensor. This should essentially always be the case, yet
- // the refcount is usually 2 instead of 1. Figure out what needs to change
- // in the code to make this not be the case, so we can safely take
- // ownership.
- Tensor* tmp_copy = nullptr;
- OP_REQUIRES_OK(context, context->allocate_persistent(
- dtype_, value.shape(), &copy, &tmp_copy, attr));
- *var->tensor() = *tmp_copy;
- functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
- copy_functor(context->eigen_device<Device>(), var->tensor()->flat<T>(),
- value.flat<T>());
- Status s = CreateResource<Var>(context, HandleFromInput(context, 0), var);
- OP_REQUIRES(context, s.ok() || errors::IsAlreadyExists(s), s);
- }
-
- private:
- DataType dtype_;
-};
-
-// TODO(apassos) register for the GPU as well.
-#define REGISTER_KERNELS(type) \
- REGISTER_KERNEL_BUILDER(Name("CreateVariableOp") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("dtype"), \
- CreateVariableOp<Eigen::ThreadPoolDevice, type>);
-
-TF_CALL_ALL_TYPES(REGISTER_KERNELS);
-TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
-#undef REGISTER_KERNELS
-
-#if GOOGLE_CUDA
-#define REGISTER_GPU_KERNELS(type) \
- namespace functor { \
- template <> \
- void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
- const GPUDevice& d, typename TTypes<type>::Flat lhs, \
- typename TTypes<type>::ConstFlat rhs); \
- extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
- } \
- REGISTER_KERNEL_BUILDER(Name("CreateVariableOp") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("dtype"), \
- CreateVariableOp<GPUDevice, type>);
-
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
-#undef REGISTER_GPU_KERNELS
-#endif // GOOGLE_CUDA
-
-template <typename Device, typename T>
class ReadVariableOp : public OpKernel {
public:
ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
@@ -137,6 +69,13 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
#if GOOGLE_CUDA
#define REGISTER_GPU_KERNELS(type) \
+ namespace functor { \
+ template <> \
+ void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
+ const GPUDevice& d, typename TTypes<type>::Flat lhs, \
+ typename TTypes<type>::ConstFlat rhs); \
+ extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
+ } \
REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
ReadVariableOp<GPUDevice, type>);
@@ -148,12 +87,28 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
template <typename Device, typename T>
class AssignVariableOp : public OpKernel {
public:
- AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
+ AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
+ }
void Compute(OpKernelContext* context) override {
Var* variable = nullptr;
- OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
- &variable));
+ OP_REQUIRES_OK(
+ context,
+ LookupOrCreateResource<Var>(
+ context, HandleFromInput(context, 0), &variable,
+ [this, context](Var** ptr) {
+ *ptr = new Var(dtype_);
+ PersistentTensor unused;
+ Tensor* tmp;
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ TF_RETURN_IF_ERROR(context->allocate_persistent(
+ dtype_, context->input(1).shape(), &unused, &tmp, attr));
+ *(*ptr)->tensor() = *tmp;
+ return Status::OK();
+ }));
core::ScopedUnref s(variable);
// TODO(apassos): holding a lock and copying is unnecessary if we are the
@@ -167,6 +122,9 @@ class AssignVariableOp : public OpKernel {
copy_functor(context->eigen_device<Device>(), variable->tensor()->flat<T>(),
value.flat<T>());
}
+
+ private:
+ DataType dtype_;
};
// TODO(apassos) register for the GPU as well.
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 205f888f6f..4b02790f7a 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -52,36 +52,6 @@ dtype: the type of this variable. Must agree with the dtypes
shape: The (possibly partially specified) shape of this variable.
)");
-Status CreateAssignShapeFn(InferenceContext* c) {
- DataType handle_dtype = c->input_handle_dtype(0);
- DataType value_dtype;
- c->GetAttr("dtype", &value_dtype);
- if (handle_dtype != value_dtype) {
- return errors::InvalidArgument(
- "Trying to initialize handle for variable with wrong dtype. "
- "Expected ",
- handle_dtype, " got ", value_dtype);
- }
- ShapeHandle s = c->input_handle_shape(0);
- ShapeHandle value_shape = c->input(1);
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
- return Status::OK();
-}
-
-REGISTER_OP("CreateVariableOp")
- .Input("resource: resource")
- .Input("value: dtype")
- .Attr("dtype: type")
- .SetShapeFn(CreateAssignShapeFn)
- .Doc(R"(
-Creates a variable resource.
-
-resource: handle to the resource in which to store the variable.
-value: the value to set the new tensor to use.
-dtype: the dtype of the value.
-)");
-
REGISTER_OP("ReadVariableOp")
.Input("resource: resource")
.Output("value: dtype")
@@ -113,6 +83,23 @@ resource: handle to the resource in which to store the variable.
dtype: the dtype of the value.
)");
+Status CreateAssignShapeFn(InferenceContext* c) {
+ DataType handle_dtype = c->input_handle_dtype(0);
+ DataType value_dtype;
+ c->GetAttr("dtype", &value_dtype);
+ if (handle_dtype != value_dtype) {
+ return errors::InvalidArgument(
+ "Trying to initialize handle for variable with wrong dtype. "
+ "Expected ",
+ handle_dtype, " got ", value_dtype);
+ }
+ ShapeHandle s = c->input_handle_shape(0);
+ ShapeHandle value_shape = c->input(1);
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
+ return Status::OK();
+}
+
REGISTER_OP("AssignVariableOp")
.Input("resource: resource")
.Input("value: dtype")
@@ -133,22 +120,7 @@ REGISTER_OP("AssignAddVariableOp")
.Input("resource: resource")
.Input("value: dtype")
.Attr("dtype: type")
- .SetShapeFn([](InferenceContext* c) {
- DataType handle_dtype = c->input_handle_dtype(0);
- DataType value_dtype;
- c->GetAttr("dtype", &value_dtype);
- if (handle_dtype != value_dtype) {
- return errors::InvalidArgument(
- "Trying to initialize handle for variable with wrong dtype. "
- "Expected ",
- handle_dtype, " got ", value_dtype);
- }
- ShapeHandle s = c->input_handle_shape(0);
- ShapeHandle value_shape = c->input(1);
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused));
- return Status::OK();
- })
+ .SetShapeFn(CreateAssignShapeFn)
.Doc(R"(
Adds a value to the current value of a variable.
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 0626895a35..9d0a1bd26e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1343,6 +1343,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":framework",
+ ":resource_variable_ops_gen",
":state_ops_gen",
],
)
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index f8bf46be9f..b426719912 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -32,25 +32,25 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
with self.assertRaises(ValueError):
- resource_variable_ops.create_variable_op(
+ resource_variable_ops.assign_variable_op(
handle, constant_op.constant(0.0, dtype=dtypes.float32)).run()
with self.assertRaises(ValueError):
- resource_variable_ops.create_variable_op(
+ resource_variable_ops.assign_variable_op(
handle, constant_op.constant([0], dtype=dtypes.int32)).run()
- resource_variable_ops.create_variable_op(
+ resource_variable_ops.assign_variable_op(
handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testDtypeSurvivesIdentity(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
id_handle = array_ops.identity(handle)
- resource_variable_ops.create_variable_op(
+ resource_variable_ops.assign_variable_op(
id_handle, constant_op.constant(0, dtype=dtypes.int32)).run()
def testCreateRead(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
- resource_variable_ops.create_variable_op(
+ resource_variable_ops.assign_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)).run()
value = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32).eval()
@@ -59,7 +59,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
def testManyAssigns(self):
with self.test_session() as session:
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
- create = resource_variable_ops.create_variable_op(
+ create = resource_variable_ops.assign_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32))
with ops.control_dependencies([create]):
first_read = resource_variable_ops.read_variable_op(
@@ -77,7 +77,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
def testAssignAdd(self):
with self.test_session():
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
- resource_variable_ops.create_variable_op(
+ resource_variable_ops.assign_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)).run()
resource_variable_ops.assign_add_variable_op(
handle, constant_op.constant(1, dtype=dtypes.int32)).run()
@@ -88,7 +88,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.test_session():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
- resource_variable_ops.create_variable_op(
+ resource_variable_ops.assign_variable_op(
handle, constant_op.constant([[1]], dtype=dtypes.int32)).run()
resource_variable_ops.resource_scatter_add(
handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)).run()
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 8962fe7e4a..bdb777fddf 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -33,11 +33,10 @@ from tensorflow.python.ops.gen_resource_variable_ops import *
def _register_variable_read(read, collections, trainable):
"""Helper function to put a read from a variable in the collections."""
if collections is None:
- collections = []
- if (trainable and
- ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES not in collections):
- collections = (list(collections) +
- [ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES])
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ if (trainable and ops.GraphKeys.TRAINABLE_VARIABLES
+ not in collections):
+ collections = (list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES])
ops.add_to_collections(collections, read)
@@ -49,19 +48,23 @@ class ResourceVariable(object):
"""
+ # pylint: disable=unused-argument
def __init__(self,
initial_value=None,
name=None,
+ caching_device=None,
trainable=True,
collections=None,
dtype=None,
shape=None):
+
"""Creates a variable.
Args:
initial_value: A `Tensor` or Python object convertible to a `Tensor`
representing the initial value of this variable.
name: The name of this variable. Automatically uniquified.
+ caching_device: device where the variable value's read by default.
trainable: Whether the global read of this variable will be used for
training.
collections: Additional collections to which the `read` operation for
@@ -73,6 +76,8 @@ class ResourceVariable(object):
value but shape inference is desired.
"""
if initial_value is not None:
+ if callable(initial_value):
+ initial_value = initial_value()
initial_value = ops.convert_to_tensor(initial_value)
if dtype is None:
assert initial_value is not None, ("Trying to create a resource variable "
@@ -101,15 +106,22 @@ class ResourceVariable(object):
gen_resource_variable_ops.var_is_initialized_op(self._handle))
if initial_value is not None:
with ops.name_scope("Create"):
- self._initialize_op = gen_resource_variable_ops.create_variable_op(
+ self._initialize_op = gen_resource_variable_ops.assign_variable_op(
self._handle, initial_value)
resources.register_resource(self._handle,
self._initialize_op,
self._is_initialized_op)
with ops.name_scope("Read"):
- self._value = gen_resource_variable_ops.read_variable_op(
- self._handle, dtype=self._dtype)
+ if caching_device is not None:
+ with ops.device(caching_device):
+ self._value = gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
+ else:
+ self._value = gen_resource_variable_ops.read_variable_op(
+ self._handle, dtype=self._dtype)
+ # TODO(apassos) this is terrible
+ self._value.initializer = self._initialize_op
_register_variable_read(
self._value, trainable=trainable, collections=collections)
@@ -119,6 +131,15 @@ class ResourceVariable(object):
return self._dtype
@property
+ def name(self):
+ """The name of the handle for this variable."""
+ return self._handle.name
+
+ def get_shape(self):
+ """The shape of this variable."""
+ return self._value.get_shape()
+
+ @property
def create(self):
"""The op responsible for initializing this variable."""
return self._initialize_op
@@ -133,6 +154,15 @@ class ResourceVariable(object):
"""A cached operation which reads the value of this variable."""
return self._value
+ def _as_graph_element(self):
+ """Conversion function for Graph.as_graph_element()."""
+ return self._value
+
+ @property
+ def initializer(self):
+ """The op responsible for initializing this variable."""
+ return self._initialize_op
+
@property
def op(self):
"""The op which reads the value of this variable."""
@@ -162,6 +192,7 @@ class ResourceVariable(object):
return value
def sparse_read(self, indices, collections=None, trainable=True, name=None):
+ """Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name):
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype)
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index f2a201a609..4858453b52 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -127,6 +127,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_state_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
@@ -201,3 +202,24 @@ def init_variable(v, init, name="init"):
else:
init = ops.convert_to_tensor(init, name="init")
return gen_state_ops.assign(v, init, name=scope)
+
+
+def is_variable_initialized(ref, name=None):
+ """Checks whether a tensor has been initialized.
+
+ Outputs boolean scalar indicating whether the tensor has been initialized.
+
+ Args:
+ ref: A mutable `Tensor`.
+ Should be from a `Variable` node. May be uninitialized.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `bool`.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.is_variable_initialized(ref=ref, name=name)
+ # Handle resource variables.
+ if ref.op.type == "ReadVariableOp":
+ return gen_resource_variable_ops.var_is_initialized_op(ref.op.inputs[0],
+ name=name)
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index f020c8cd0f..94a3813699 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
@@ -54,6 +55,13 @@ from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.util import compat
+# Op names which identify variable reads which should be saved.
+_VARIABLE_OPS = set(["Variable",
+ "AutoReloadVariable",
+ "ReadVariableOp",
+ "ResourceGather"])
+
+
class BaseSaverBuilder(object):
"""Base class for Savers.
@@ -129,6 +137,23 @@ class BaseSaverBuilder(object):
validate_shape=restored_shapes is None and
self.op.get_shape().is_fully_defined())
+ class ResourceVariableSaveable(SaveableObject):
+ """SaveableObject implementation that handles ResourceVariables."""
+
+ def __init__(self, var, slice_spec, name):
+ self.read_op = var
+ spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name)
+ super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__(
+ var, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ restored_tensor = restored_tensors[0]
+ if restored_shapes is not None:
+ restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
+ return resource_variable_ops.assign_variable_op(
+ self.read_op.op.inputs[0],
+ restored_tensor)
+
def __init__(self, write_version=saver_pb2.SaverDef.V2):
self._write_version = write_version
@@ -406,8 +431,7 @@ class BaseSaverBuilder(object):
@staticmethod
def _IsVariable(v):
- return isinstance(v, ops.Tensor) and (v.op.type == "Variable" or
- v.op.type == "AutoReloadVariable")
+ return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
def _GroupByDevices(self, saveables):
"""Group Variable tensor slices per device.
@@ -537,7 +561,12 @@ class BaseSaverBuilder(object):
raise TypeError("names_to_saveables must be a dict mapping string "
"names to Tensors/Variables. Not a variable: %s" %
variable)
- saveable = BaseSaverBuilder.VariableSaveable(variable, "", name)
+ if variable.op.type == "Variable":
+ saveable = BaseSaverBuilder.VariableSaveable(variable, "", name)
+ else:
+ # TODO(apassos): this assumes all non-variables are ResourceVariables.
+ saveable = BaseSaverBuilder.ResourceVariableSaveable(
+ variable, "", name)
self._AddSaveable(saveables, seen_ops, saveable)
return saveables
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index d27975b9d1..464446bb82 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -18,17 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import contextlib
import math
import os.path
-import time
-import contextlib
import random
import shutil
import tempfile
+import time
-import tensorflow as tf
import numpy as np
import six
+import tensorflow as tf
from google.protobuf.any_pb2 import Any
@@ -38,8 +38,9 @@ from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import meta_graph
-from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_module
from tensorflow.python.util import compat
@@ -110,13 +111,13 @@ class CheckpointedOp(object):
class SaverTest(tf.test.TestCase):
- def testBasics(self):
+ def basicSaveRestore(self, variable_op):
save_path = os.path.join(self.get_temp_dir(), "basics")
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = tf.Variable(10.0, name="v0")
- v1 = tf.Variable(20.0, name="v1")
+ v0 = variable_op(10.0, name="v0")
+ v1 = variable_op(20.0, name="v1")
v2 = CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = tf.train.Saver(
@@ -143,17 +144,13 @@ class SaverTest(tf.test.TestCase):
# Start a second session. In that session the parameter nodes
# have not been initialized either.
with self.test_session() as sess:
- v0 = tf.Variable(-1.0, name="v0")
- v1 = tf.Variable(-1.0, name="v1")
+ v0 = variable_op(-1.0, name="v0")
+ v1 = variable_op(-1.0, name="v1")
v2 = CheckpointedOp(name="v2")
save = tf.train.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
- with self.assertRaisesWithPredicateMatch(
- tf.OpError, lambda e: "uninitialized value v0" in e.message):
- sess.run(v0)
- with self.assertRaisesWithPredicateMatch(
- tf.OpError, lambda e: "uninitialized value v1" in e.message):
- sess.run(v1)
+ # Assert that the variables are not initialized.
+ self.assertEqual(len(tf.report_uninitialized_variables().eval()), 2)
self.assertEqual(0, len(v2.keys().eval()))
self.assertEqual(0, len(v2.values().eval()))
@@ -168,8 +165,8 @@ class SaverTest(tf.test.TestCase):
# Build another graph with 2 nodes, initialized
# differently, and a Restore node for them.
with self.test_session() as sess:
- v0_2 = tf.Variable(1000.0, name="v0")
- v1_2 = tf.Variable(2000.0, name="v1")
+ v0_2 = variable_op(1000.0, name="v0")
+ v1_2 = variable_op(2000.0, name="v1")
v2_2 = CheckpointedOp(name="v2")
save2 = tf.train.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
v2_2.insert("k1000", 3000.0).run()
@@ -188,6 +185,12 @@ class SaverTest(tf.test.TestCase):
self.assertEqual(b"k1", v2_2.keys().eval())
self.assertEqual(30.0, v2_2.values().eval())
+ def testBasic(self):
+ self.basicSaveRestore(tf.Variable)
+
+ def testResourceBasic(self):
+ self.basicSaveRestore(resource_variable_ops.ResourceVariable)
+
def testInvalidPath(self):
v0 = tf.Variable(0, name="v0")
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):