aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variable_scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r--tensorflow/python/ops/variable_scope.py47
1 files changed, 43 insertions, 4 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 6f46fd633e..fa7ac4eaef 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -526,9 +526,14 @@ class _VariableStore(object):
var_full_name = "%s/part_%d" % (name, i)
with ops.name_scope(var_full_name + "/PartitionedInitializer"):
+ # Create the tensor to initialize the variable with default value.
if initializer is None:
- init = init_ops.uniform_unit_scaling_initializer()
- init_shape = var_shape
+ init, initializing_from_value = self._get_default_initializer(
+ name=name, shape=shape, dtype=dtype)
+ if initializing_from_value:
+ init_shape = None
+ else:
+ init_shape = var_shape
elif callable(initializer):
init = initializer
init_shape = var_shape
@@ -653,9 +658,10 @@ class _VariableStore(object):
raise ValueError("Shape of a new variable (%s) must be fully defined, "
"but instead was %s." % (name, shape))
- # Create the tensor to initialize the variable.
+ # Create the tensor to initialize the variable with default value.
if initializer is None:
- initializer = init_ops.uniform_unit_scaling_initializer()
+ initializer, initializing_from_value = self._get_default_initializer(
+ name=name, shape=shape, dtype=dtype)
# Clear control dependencies while creating the initializer.
with ops.control_dependencies(None):
if initializing_from_value:
@@ -692,6 +698,39 @@ class _VariableStore(object):
return v
+ # Initialize variable when no initializer provided
+ def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
+ """Provide a default initializer and a corresponding value.
+
+ Args:
+ name: see get_variable.
+ shape: see get_variable.
+ dtype: see get_variable.
+
+ Returns:
+ initializer and initializing_from_value. See get_variable above.
+
+ Raises:
+ ValueError: When giving unsupported dtype.
+ """
+ # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
+ if dtype.is_floating:
+ initializer = init_ops.uniform_unit_scaling_initializer()
+ initializing_from_value = False
+ # If dtype is DT_INT/DT_UINT, provide a default value `zero`
+ # If dtype is DT_BOOL, provide a default value `FALSE`
+ elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
+ initializer = init_ops.zeros_initializer()(
+ shape=shape, dtype=dtype.base_dtype)
+ initializing_from_value = True
+ # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
+ else:
+ raise ValueError("An initializer for variable %s of %s is required"
+ % (name, dtype.base_dtype))
+
+ return initializer, initializing_from_value
+
+
# To stop regularization, use this regularizer
def no_regularizer(_):
"""Use this function to prevent regularization of variables."""