aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-24 06:24:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-24 06:27:14 -0700
commitb9e12bc69df65eca279a90045d045e661fdb8108 (patch)
treef3530c1fa75a3a805f3a655f5109408d2fbb8764 /tensorflow/contrib/framework
parentf62c472c470aee64147df58de584f0b8450b29ad (diff)
Make tf.contrib.framework.zero_initializer work with ResourceVariable
PiperOrigin-RevId: 194077027
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/BUILD1
-rw-r--r--tensorflow/contrib/framework/kernels/zero_initializer_op.cc71
-rw-r--r--tensorflow/contrib/framework/ops/variable_ops.cc29
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py8
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py26
5 files changed, 134 insertions, 1 deletions
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index b1c8ad49ea..f675cc0cf0 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -93,6 +93,7 @@ tf_kernel_library(
],
deps = [
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
alwayslink = 1,
diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc
index 5bf6b67529..6ab3f460b3 100644
--- a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc
+++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/resource_var.h"
namespace tensorflow {
@@ -85,4 +86,74 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_KERNELS
+template <typename Device, typename T>
+class ZeroVarInitializer : public OpKernel {
+ public:
+ explicit ZeroVarInitializer(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Var* variable = nullptr;
+ OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
+ ctx, HandleFromInput(ctx, 0), &variable,
+ [this, ctx](Var** var_ptr) {
+ *var_ptr = new Var(dtype_);
+ PersistentTensor unused;
+ Tensor* var_tensor = nullptr;
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+ dtype_, shape_, &unused, &var_tensor, attr));
+
+ functor::TensorSetZero<Device, T>()(
+ ctx->eigen_device<Device>(),
+ var_tensor->flat<T>());
+
+ *(*var_ptr)->tensor() = *var_tensor;
+
+ return Status::OK();
+ }));
+
+ core::ScopedUnref scoped(variable);
+ mutex_lock ml(*variable->mu());
+
+ OP_REQUIRES(ctx, !variable->is_initialized,
+ errors::InvalidArgument("input is already initialized"));
+
+ variable->is_initialized = true;
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
+ output->scalar<ResourceHandle>()() = HandleFromInput(ctx, 0);
+ }
+
+ private:
+ DataType dtype_;
+ TensorShape shape_;
+};
+
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("ZeroVarInitializer") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("dtype"), \
+ ZeroVarInitializer<Eigen::ThreadPoolDevice, type>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("ZeroVarInitializer") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("dtype") \
+ .HostMemory("var"), \
+ ZeroVarInitializer<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
+#endif // GOOGLE_CUDA
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/framework/ops/variable_ops.cc b/tensorflow/contrib/framework/ops/variable_ops.cc
index 706134ba9a..f6ee6cdb57 100644
--- a/tensorflow/contrib/framework/ops/variable_ops.cc
+++ b/tensorflow/contrib/framework/ops/variable_ops.cc
@@ -39,4 +39,33 @@ ref: Should be from a `Variable` node.
output_ref:= Same as "ref".
)doc");
+REGISTER_OP("ZeroVarInitializer")
+ .Input("var: resource")
+ .Output("output_var: resource")
+ .Attr("dtype: type")
+ .Attr("shape: shape")
+ .SetAllowsUninitializedInput()
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
+ PartialTensorShape p;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
+ c->set_output_handle_shapes_and_types(
+ 0, std::vector<shape_inference::ShapeAndType>{{s, t}});
+
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Initialize 'var' with all zeros. This op requires that the resource var is not
+initialized. The var will first be allocated memory, then be filled with all
+zeros. This op is intended to save memory during initialization,
+if you use this op, you should not run initializer of the var.
+
+var: Should be a ResourceVariable.
+output_var:= Same as "var".
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index 0754c3e0e3..40ae01bfcc 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import tf_logging as logging
@@ -82,7 +83,12 @@ def zero_initializer(ref, use_locking=True, name="zero_initializer"):
"""
loader.load_op_library(
resource_loader.get_path_to_datafile("_variable_ops.so"))
- return gen_variable_ops.zero_initializer(ref, name=name)
+ if resource_variable_ops.is_resource_variable(ref):
+ return gen_variable_ops.zero_var_initializer(
+ ref.handle, shape=ref.shape, dtype=ref.dtype, name=name)
+ else:
+ return gen_variable_ops.zero_initializer(ref, name=name)
+
@deprecated(None, "Please switch to tf.train.assert_global_step")
def assert_global_step(global_step_tensor):
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index 2f06df93ac..37ea6eb12a 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -1284,6 +1284,32 @@ class ZeroInitializerOpTest(test.TestCase):
[10, 20], dtype=dtype), use_init)
+class ZeroVarInitializerOpTest(test.TestCase):
+
+ def _testZeroVarInitializer(self, shape, initializer, use_init):
+ var = resource_variable_ops.ResourceVariable(initializer)
+ var_zero = variables_lib2.zero_initializer(var)
+
+ with self.test_session() as sess:
+ with self.assertRaisesOpError('Error while reading resource variable'):
+ var.eval()
+ if use_init:
+ sess.run(var.initializer)
+ with self.assertRaisesOpError('input is already initialized'):
+ var_zero.eval()
+ self.assertAllClose(np.ones(shape), var.eval())
+ else:
+ var_zero.eval()
+ self.assertAllClose(np.zeros(shape), var.eval())
+
+ def testZeroVarInitializer(self):
+ for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64):
+ for use_init in (False, True):
+ self._testZeroVarInitializer([10, 20],
+ array_ops.ones([10, 20], dtype=dtype),
+ use_init)
+
+
class FilterVariablesTest(test.TestCase):
def setUp(self):