diff options
Diffstat (limited to 'tensorflow/contrib/framework/python/ops/variables.py')
-rw-r--r-- | tensorflow/contrib/framework/python/ops/variables.py | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index e006add43a..2475a2fb21 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope +from tensorflow.contrib.framework.python.ops import gen_variable_ops from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes @@ -29,8 +30,11 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops import gen_state_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver +from tensorflow.python.framework.load_library import load_op_library +from tensorflow.python.platform import resource_loader __all__ = ['add_model_variable', @@ -53,9 +57,33 @@ __all__ = ['add_model_variable', 'local_variable', 'model_variable', 'variable', - 'VariableDeviceChooser'] + 'VariableDeviceChooser', + 'zero_initializer'] +def zero_initializer(ref, use_locking=True, name="zero_initializer"): + """Initialize 'ref' with all zeros, ref tensor should be uninitialized. + If already initialized, you will get ValueError. This op is intended to + save memory during initialization. + Args: + ref: ref of the tensor need to be zero initialized. + name: optional name for this operation. + Returns: + ref that initialized. + Raises: + ValueError: If ref tensor is initialized. + """ + _variable_ops = load_op_library(resource_loader.get_path_to_datafile( + "_variable_ops.so")) + assert _variable_ops, "Could not load _variable_ops.so" + return gen_variable_ops.zero_initializer(ref, name=name) + +# shape function for _ZeroInitializerOp +@ops.RegisterShape("ZeroInitializer") +def _ZeroInitializerShape(op): + var_shape = op.inputs[0].get_shape() + return [var_shape] + def assert_global_step(global_step_tensor): """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`. |