diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-11-23 11:40:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-23 11:47:13 -0800 |
commit | aa4aadef0492005f60c03aa40789efad719752f0 (patch) | |
tree | 2e09b1c91ac3b05a699be4138d59988b2ae50478 | |
parent | 18af08feafad32b44fd9f1a20e143b716c82f21b (diff) |
Makes resource variables saveable/restorable.
Change: 140055398
-rw-r--r-- | tensorflow/core/framework/resource_mgr.h | 13 | ||||
-rw-r--r-- | tensorflow/core/kernels/resource_variable_ops.cc | 100 | ||||
-rw-r--r-- | tensorflow/core/ops/resource_variable_ops.cc | 64 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/resource_variable_ops_test.py | 16 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 47 | ||||
-rw-r--r-- | tensorflow/python/ops/state_ops.py | 22 | ||||
-rw-r--r-- | tensorflow/python/training/saver.py | 35 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 37 |
9 files changed, 182 insertions, 153 deletions
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index 2e22f3cdb3..ae4186ee71 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -177,6 +177,11 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value); template <typename T> Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); +// Looks up or creates a resource. +template <typename T> +Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value, std::function<Status(T**)> creator); + // Destroys a resource pointed by a given resource handle. template <typename T> Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); @@ -414,6 +419,14 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, } template <typename T> +Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value, std::function<Status(T**)> creator) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); + return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value, + creator); +} + +template <typename T> Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); return ctx->resource_manager()->Delete<T>(p.container(), p.name()); diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 0c4f2d2f84..602457de10 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -35,74 +35,6 @@ namespace tensorflow { REGISTER_RESOURCE_HANDLE_KERNEL(Var); template <typename Device, typename T> -class CreateVariableOp : public OpKernel { - public: - CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) { - OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); - OP_REQUIRES(c, DataTypeToEnum<T>::value == dtype_, - errors::InvalidArgument( - "Dtypes don't match; expected ", DataTypeString(dtype_), - " got ", DataTypeString(DataTypeToEnum<T>::value))); - } - - void Compute(OpKernelContext* context) override { - Var* var = new Var(dtype_); - AllocatorAttributes attr; - attr.set_gpu_compatible(true); - attr.set_nic_compatible(true); - PersistentTensor copy; - const Tensor& value = context->input(1); - - // TODO(apassos): allocating and copying is unnecessary if we are the last - // user of the value tensor. This should essentially always be the case, yet - // the refcount is usually 2 instead of 1. Figure out what needs to change - // in the code to make this not be the case, so we can safely take - // ownership. - Tensor* tmp_copy = nullptr; - OP_REQUIRES_OK(context, context->allocate_persistent( - dtype_, value.shape(), ©, &tmp_copy, attr)); - *var->tensor() = *tmp_copy; - functor::DenseUpdate<Device, T, ASSIGN> copy_functor; - copy_functor(context->eigen_device<Device>(), var->tensor()->flat<T>(), - value.flat<T>()); - Status s = CreateResource<Var>(context, HandleFromInput(context, 0), var); - OP_REQUIRES(context, s.ok() || errors::IsAlreadyExists(s), s); - } - - private: - DataType dtype_; -}; - -// TODO(apassos) register for the GPU as well. -#define REGISTER_KERNELS(type) \ - REGISTER_KERNEL_BUILDER(Name("CreateVariableOp") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("dtype"), \ - CreateVariableOp<Eigen::ThreadPoolDevice, type>); - -TF_CALL_ALL_TYPES(REGISTER_KERNELS); -TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); -#undef REGISTER_KERNELS - -#if GOOGLE_CUDA -#define REGISTER_GPU_KERNELS(type) \ - namespace functor { \ - template <> \ - void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \ - const GPUDevice& d, typename TTypes<type>::Flat lhs, \ - typename TTypes<type>::ConstFlat rhs); \ - extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \ - } \ - REGISTER_KERNEL_BUILDER(Name("CreateVariableOp") \ - .Device(DEVICE_GPU) \ - .TypeConstraint<type>("dtype"), \ - CreateVariableOp<GPUDevice, type>); - -TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); -#undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA - -template <typename Device, typename T> class ReadVariableOp : public OpKernel { public: ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {} @@ -137,6 +69,13 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); #if GOOGLE_CUDA #define REGISTER_GPU_KERNELS(type) \ + namespace functor { \ + template <> \ + void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \ + const GPUDevice& d, typename TTypes<type>::Flat lhs, \ + typename TTypes<type>::ConstFlat rhs); \ + extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \ + } \ REGISTER_KERNEL_BUILDER( \ Name("ReadVariableOp").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \ ReadVariableOp<GPUDevice, type>); @@ -148,12 +87,28 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); template <typename Device, typename T> class AssignVariableOp : public OpKernel { public: - AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {} + AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); + } void Compute(OpKernelContext* context) override { Var* variable = nullptr; - OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &variable)); + OP_REQUIRES_OK( + context, + LookupOrCreateResource<Var>( + context, HandleFromInput(context, 0), &variable, + [this, context](Var** ptr) { + *ptr = new Var(dtype_); + PersistentTensor unused; + Tensor* tmp; + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + TF_RETURN_IF_ERROR(context->allocate_persistent( + dtype_, context->input(1).shape(), &unused, &tmp, attr)); + *(*ptr)->tensor() = *tmp; + return Status::OK(); + })); core::ScopedUnref s(variable); // TODO(apassos): holding a lock and copying is unnecessary if we are the @@ -167,6 +122,9 @@ class AssignVariableOp : public OpKernel { copy_functor(context->eigen_device<Device>(), variable->tensor()->flat<T>(), value.flat<T>()); } + + private: + DataType dtype_; }; // TODO(apassos) register for the GPU as well. diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 205f888f6f..4b02790f7a 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -52,36 +52,6 @@ dtype: the type of this variable. Must agree with the dtypes shape: The (possibly partially specified) shape of this variable. )"); -Status CreateAssignShapeFn(InferenceContext* c) { - DataType handle_dtype = c->input_handle_dtype(0); - DataType value_dtype; - c->GetAttr("dtype", &value_dtype); - if (handle_dtype != value_dtype) { - return errors::InvalidArgument( - "Trying to initialize handle for variable with wrong dtype. " - "Expected ", - handle_dtype, " got ", value_dtype); - } - ShapeHandle s = c->input_handle_shape(0); - ShapeHandle value_shape = c->input(1); - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused)); - return Status::OK(); -} - -REGISTER_OP("CreateVariableOp") - .Input("resource: resource") - .Input("value: dtype") - .Attr("dtype: type") - .SetShapeFn(CreateAssignShapeFn) - .Doc(R"( -Creates a variable resource. - -resource: handle to the resource in which to store the variable. -value: the value to set the new tensor to use. -dtype: the dtype of the value. -)"); - REGISTER_OP("ReadVariableOp") .Input("resource: resource") .Output("value: dtype") @@ -113,6 +83,23 @@ resource: handle to the resource in which to store the variable. dtype: the dtype of the value. )"); +Status CreateAssignShapeFn(InferenceContext* c) { + DataType handle_dtype = c->input_handle_dtype(0); + DataType value_dtype; + c->GetAttr("dtype", &value_dtype); + if (handle_dtype != value_dtype) { + return errors::InvalidArgument( + "Trying to initialize handle for variable with wrong dtype. " + "Expected ", + handle_dtype, " got ", value_dtype); + } + ShapeHandle s = c->input_handle_shape(0); + ShapeHandle value_shape = c->input(1); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused)); + return Status::OK(); +} + REGISTER_OP("AssignVariableOp") .Input("resource: resource") .Input("value: dtype") @@ -133,22 +120,7 @@ REGISTER_OP("AssignAddVariableOp") .Input("resource: resource") .Input("value: dtype") .Attr("dtype: type") - .SetShapeFn([](InferenceContext* c) { - DataType handle_dtype = c->input_handle_dtype(0); - DataType value_dtype; - c->GetAttr("dtype", &value_dtype); - if (handle_dtype != value_dtype) { - return errors::InvalidArgument( - "Trying to initialize handle for variable with wrong dtype. " - "Expected ", - handle_dtype, " got ", value_dtype); - } - ShapeHandle s = c->input_handle_shape(0); - ShapeHandle value_shape = c->input(1); - ShapeHandle unused; - TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused)); - return Status::OK(); - }) + .SetShapeFn(CreateAssignShapeFn) .Doc(R"( Adds a value to the current value of a variable. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 0626895a35..9d0a1bd26e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1343,6 +1343,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":framework", + ":resource_variable_ops_gen", ":state_ops_gen", ], ) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index f8bf46be9f..b426719912 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -32,25 +32,25 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) with self.assertRaises(ValueError): - resource_variable_ops.create_variable_op( + resource_variable_ops.assign_variable_op( handle, constant_op.constant(0.0, dtype=dtypes.float32)).run() with self.assertRaises(ValueError): - resource_variable_ops.create_variable_op( + resource_variable_ops.assign_variable_op( handle, constant_op.constant([0], dtype=dtypes.int32)).run() - resource_variable_ops.create_variable_op( + resource_variable_ops.assign_variable_op( handle, constant_op.constant(0, dtype=dtypes.int32)).run() def testDtypeSurvivesIdentity(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) id_handle = array_ops.identity(handle) - resource_variable_ops.create_variable_op( + resource_variable_ops.assign_variable_op( id_handle, constant_op.constant(0, dtype=dtypes.int32)).run() def testCreateRead(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) - resource_variable_ops.create_variable_op( + resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() value = resource_variable_ops.read_variable_op( handle, dtype=dtypes.int32).eval() @@ -59,7 +59,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): def testManyAssigns(self): with self.test_session() as session: handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) - create = resource_variable_ops.create_variable_op( + create = resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)) with ops.control_dependencies([create]): first_read = resource_variable_ops.read_variable_op( @@ -77,7 +77,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): def testAssignAdd(self): with self.test_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) - resource_variable_ops.create_variable_op( + resource_variable_ops.assign_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() resource_variable_ops.assign_add_variable_op( handle, constant_op.constant(1, dtype=dtypes.int32)).run() @@ -88,7 +88,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.test_session(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1, 1]) - resource_variable_ops.create_variable_op( + resource_variable_ops.assign_variable_op( handle, constant_op.constant([[1]], dtype=dtypes.int32)).run() resource_variable_ops.resource_scatter_add( handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)).run() diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 8962fe7e4a..bdb777fddf 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -33,11 +33,10 @@ from tensorflow.python.ops.gen_resource_variable_ops import * def _register_variable_read(read, collections, trainable): """Helper function to put a read from a variable in the collections.""" if collections is None: - collections = [] - if (trainable and - ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES not in collections): - collections = (list(collections) + - [ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES]) + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + if (trainable and ops.GraphKeys.TRAINABLE_VARIABLES + not in collections): + collections = (list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]) ops.add_to_collections(collections, read) @@ -49,19 +48,23 @@ class ResourceVariable(object): """ + # pylint: disable=unused-argument def __init__(self, initial_value=None, name=None, + caching_device=None, trainable=True, collections=None, dtype=None, shape=None): + """Creates a variable. Args: initial_value: A `Tensor` or Python object convertible to a `Tensor` representing the initial value of this variable. name: The name of this variable. Automatically uniquified. + caching_device: device where the variable value's read by default. trainable: Whether the global read of this variable will be used for training. collections: Additional collections to which the `read` operation for @@ -73,6 +76,8 @@ class ResourceVariable(object): value but shape inference is desired. """ if initial_value is not None: + if callable(initial_value): + initial_value = initial_value() initial_value = ops.convert_to_tensor(initial_value) if dtype is None: assert initial_value is not None, ("Trying to create a resource variable " @@ -101,15 +106,22 @@ class ResourceVariable(object): gen_resource_variable_ops.var_is_initialized_op(self._handle)) if initial_value is not None: with ops.name_scope("Create"): - self._initialize_op = gen_resource_variable_ops.create_variable_op( + self._initialize_op = gen_resource_variable_ops.assign_variable_op( self._handle, initial_value) resources.register_resource(self._handle, self._initialize_op, self._is_initialized_op) with ops.name_scope("Read"): - self._value = gen_resource_variable_ops.read_variable_op( - self._handle, dtype=self._dtype) + if caching_device is not None: + with ops.device(caching_device): + self._value = gen_resource_variable_ops.read_variable_op( + self._handle, dtype=self._dtype) + else: + self._value = gen_resource_variable_ops.read_variable_op( + self._handle, dtype=self._dtype) + # TODO(apassos) this is terrible + self._value.initializer = self._initialize_op _register_variable_read( self._value, trainable=trainable, collections=collections) @@ -119,6 +131,15 @@ class ResourceVariable(object): return self._dtype @property + def name(self): + """The name of the handle for this variable.""" + return self._handle.name + + def get_shape(self): + """The shape of this variable.""" + return self._value.get_shape() + + @property def create(self): """The op responsible for initializing this variable.""" return self._initialize_op @@ -133,6 +154,15 @@ class ResourceVariable(object): """A cached operation which reads the value of this variable.""" return self._value + def _as_graph_element(self): + """Conversion function for Graph.as_graph_element().""" + return self._value + + @property + def initializer(self): + """The op responsible for initializing this variable.""" + return self._initialize_op + @property def op(self): """The op which reads the value of this variable.""" @@ -162,6 +192,7 @@ class ResourceVariable(object): return value def sparse_read(self, indices, collections=None, trainable=True, name=None): + """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name): value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype) diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index f2a201a609..4858453b52 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -127,6 +127,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import gen_state_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -201,3 +202,24 @@ def init_variable(v, init, name="init"): else: init = ops.convert_to_tensor(init, name="init") return gen_state_ops.assign(v, init, name=scope) + + +def is_variable_initialized(ref, name=None): + """Checks whether a tensor has been initialized. + + Outputs boolean scalar indicating whether the tensor has been initialized. + + Args: + ref: A mutable `Tensor`. + Should be from a `Variable` node. May be uninitialized. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `bool`. + """ + if ref.dtype._is_ref_dtype: + return gen_state_ops.is_variable_initialized(ref=ref, name=name) + # Handle resource variables. + if ref.op.type == "ReadVariableOp": + return gen_resource_variable_ops.var_is_initialized_op(ref.op.inputs[0], + name=name) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index f020c8cd0f..94a3813699 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import io_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables @@ -54,6 +55,13 @@ from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.util import compat +# Op names which identify variable reads which should be saved. +_VARIABLE_OPS = set(["Variable", + "AutoReloadVariable", + "ReadVariableOp", + "ResourceGather"]) + + class BaseSaverBuilder(object): """Base class for Savers. @@ -129,6 +137,23 @@ class BaseSaverBuilder(object): validate_shape=restored_shapes is None and self.op.get_shape().is_fully_defined()) + class ResourceVariableSaveable(SaveableObject): + """SaveableObject implementation that handles ResourceVariables.""" + + def __init__(self, var, slice_spec, name): + self.read_op = var + spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name) + super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__( + var, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + restored_tensor = restored_tensors[0] + if restored_shapes is not None: + restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) + return resource_variable_ops.assign_variable_op( + self.read_op.op.inputs[0], + restored_tensor) + def __init__(self, write_version=saver_pb2.SaverDef.V2): self._write_version = write_version @@ -406,8 +431,7 @@ class BaseSaverBuilder(object): @staticmethod def _IsVariable(v): - return isinstance(v, ops.Tensor) and (v.op.type == "Variable" or - v.op.type == "AutoReloadVariable") + return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS def _GroupByDevices(self, saveables): """Group Variable tensor slices per device. @@ -537,7 +561,12 @@ class BaseSaverBuilder(object): raise TypeError("names_to_saveables must be a dict mapping string " "names to Tensors/Variables. Not a variable: %s" % variable) - saveable = BaseSaverBuilder.VariableSaveable(variable, "", name) + if variable.op.type == "Variable": + saveable = BaseSaverBuilder.VariableSaveable(variable, "", name) + else: + # TODO(apassos): this assumes all non-variables are ResourceVariables. + saveable = BaseSaverBuilder.ResourceVariableSaveable( + variable, "", name) self._AddSaveable(saveables, seen_ops, saveable) return saveables diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index d27975b9d1..464446bb82 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -18,17 +18,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib import math import os.path -import time -import contextlib import random import shutil import tempfile +import time -import tensorflow as tf import numpy as np import six +import tensorflow as tf from google.protobuf.any_pb2 import Any @@ -38,8 +38,9 @@ from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import meta_graph -from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import gfile from tensorflow.python.training import saver as saver_module from tensorflow.python.util import compat @@ -110,13 +111,13 @@ class CheckpointedOp(object): class SaverTest(tf.test.TestCase): - def testBasics(self): + def basicSaveRestore(self, variable_op): save_path = os.path.join(self.get_temp_dir(), "basics") # Build a graph with 2 parameter nodes, and Save and # Restore nodes for them. - v0 = tf.Variable(10.0, name="v0") - v1 = tf.Variable(20.0, name="v1") + v0 = variable_op(10.0, name="v0") + v1 = variable_op(20.0, name="v1") v2 = CheckpointedOp(name="v2") v2_init = v2.insert("k1", 30.0) save = tf.train.Saver( @@ -143,17 +144,13 @@ class SaverTest(tf.test.TestCase): # Start a second session. In that session the parameter nodes # have not been initialized either. with self.test_session() as sess: - v0 = tf.Variable(-1.0, name="v0") - v1 = tf.Variable(-1.0, name="v1") + v0 = variable_op(-1.0, name="v0") + v1 = variable_op(-1.0, name="v1") v2 = CheckpointedOp(name="v2") save = tf.train.Saver({"v0": v0, "v1": v1, "v2": v2.saveable}) - with self.assertRaisesWithPredicateMatch( - tf.OpError, lambda e: "uninitialized value v0" in e.message): - sess.run(v0) - with self.assertRaisesWithPredicateMatch( - tf.OpError, lambda e: "uninitialized value v1" in e.message): - sess.run(v1) + # Assert that the variables are not initialized. + self.assertEqual(len(tf.report_uninitialized_variables().eval()), 2) self.assertEqual(0, len(v2.keys().eval())) self.assertEqual(0, len(v2.values().eval())) @@ -168,8 +165,8 @@ class SaverTest(tf.test.TestCase): # Build another graph with 2 nodes, initialized # differently, and a Restore node for them. with self.test_session() as sess: - v0_2 = tf.Variable(1000.0, name="v0") - v1_2 = tf.Variable(2000.0, name="v1") + v0_2 = variable_op(1000.0, name="v0") + v1_2 = variable_op(2000.0, name="v1") v2_2 = CheckpointedOp(name="v2") save2 = tf.train.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable}) v2_2.insert("k1000", 3000.0).run() @@ -188,6 +185,12 @@ class SaverTest(tf.test.TestCase): self.assertEqual(b"k1", v2_2.keys().eval()) self.assertEqual(30.0, v2_2.values().eval()) + def testBasic(self): + self.basicSaveRestore(tf.Variable) + + def testResourceBasic(self): + self.basicSaveRestore(resource_variable_ops.ResourceVariable) + def testInvalidPath(self): v0 = tf.Variable(0, name="v0") for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2): |