diff options
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 47 |
1 files changed, 43 insertions, 4 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 6f46fd633e..fa7ac4eaef 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -526,9 +526,14 @@ class _VariableStore(object): var_full_name = "%s/part_%d" % (name, i) with ops.name_scope(var_full_name + "/PartitionedInitializer"): + # Create the tensor to initialize the variable with default value. if initializer is None: - init = init_ops.uniform_unit_scaling_initializer() - init_shape = var_shape + init, initializing_from_value = self._get_default_initializer( + name=name, shape=shape, dtype=dtype) + if initializing_from_value: + init_shape = None + else: + init_shape = var_shape elif callable(initializer): init = initializer init_shape = var_shape @@ -653,9 +658,10 @@ class _VariableStore(object): raise ValueError("Shape of a new variable (%s) must be fully defined, " "but instead was %s." % (name, shape)) - # Create the tensor to initialize the variable. + # Create the tensor to initialize the variable with default value. if initializer is None: - initializer = init_ops.uniform_unit_scaling_initializer() + initializer, initializing_from_value = self._get_default_initializer( + name=name, shape=shape, dtype=dtype) # Clear control dependencies while creating the initializer. with ops.control_dependencies(None): if initializing_from_value: @@ -692,6 +698,39 @@ class _VariableStore(object): return v + # Initialize variable when no initializer provided + def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): + """Provide a default initializer and a corresponding value. + + Args: + name: see get_variable. + shape: see get_variable. + dtype: see get_variable. + + Returns: + initializer and initializing_from_value. See get_variable above. + + Raises: + ValueError: When giving unsupported dtype. + """ + # If dtype is DT_FLOAT, provide a uniform unit scaling initializer + if dtype.is_floating: + initializer = init_ops.uniform_unit_scaling_initializer() + initializing_from_value = False + # If dtype is DT_INT/DT_UINT, provide a default value `zero` + # If dtype is DT_BOOL, provide a default value `FALSE` + elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: + initializer = init_ops.zeros_initializer()( + shape=shape, dtype=dtype.base_dtype) + initializing_from_value = True + # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? + else: + raise ValueError("An initializer for variable %s of %s is required" + % (name, dtype.base_dtype)) + + return initializer, initializing_from_value + + # To stop regularization, use this regularizer def no_regularizer(_): """Use this function to prevent regularization of variables.""" |