aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework/python/ops/variables.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/framework/python/ops/variables.py')
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py30
1 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index e006add43a..2475a2fb21 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope
+from tensorflow.contrib.framework.python.ops import gen_variable_ops
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
@@ -29,8 +30,11 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops import gen_state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver
+from tensorflow.python.framework.load_library import load_op_library
+from tensorflow.python.platform import resource_loader
__all__ = ['add_model_variable',
@@ -53,9 +57,33 @@ __all__ = ['add_model_variable',
'local_variable',
'model_variable',
'variable',
- 'VariableDeviceChooser']
+ 'VariableDeviceChooser',
+ 'zero_initializer']
+def zero_initializer(ref, use_locking=True, name="zero_initializer"):
+ """Initialize 'ref' with all zeros, ref tensor should be uninitialized.
+ If already initialized, you will get ValueError. This op is intended to
+ save memory during initialization.
+ Args:
+ ref: ref of the tensor need to be zero initialized.
+ name: optional name for this operation.
+ Returns:
+ ref that initialized.
+ Raises:
+ ValueError: If ref tensor is initialized.
+ """
+ _variable_ops = load_op_library(resource_loader.get_path_to_datafile(
+ "_variable_ops.so"))
+ assert _variable_ops, "Could not load _variable_ops.so"
+ return gen_variable_ops.zero_initializer(ref, name=name)
+
+# shape function for _ZeroInitializerOp
+@ops.RegisterShape("ZeroInitializer")
+def _ZeroInitializerShape(op):
+ var_shape = op.inputs[0].get_shape()
+ return [var_shape]
+
def assert_global_step(global_step_tensor):
"""Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.