diff options
-rw-r--r-- | tensorflow/python/ops/partitioned_variables.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py index fb2f51dcde..9d4d19668a 100644 --- a/tensorflow/python/ops/partitioned_variables.py +++ b/tensorflow/python/ops/partitioned_variables.py @@ -95,7 +95,7 @@ def _compute_slice_dim_and_shape(full_shape, slicing): def create_partitioned_variables( shape, slicing, initializer, dtype=dtypes.float32, - trainable=True, collections=None, name=None): + trainable=True, collections=None, name=None, reuse=None): """Create a list of partitioned variables according to the given `slicing`. Currently only one dimension of the full variable can be sliced, and the @@ -127,6 +127,9 @@ def create_partitioned_variables( Defaults to `[GraphKeys.VARIABLES]`. name: Optional name for the full variable. Defaults to `"PartitionedVariable"` and gets uniquified automatically. + reuse: Boolean or `None`; if `True` and name is set, it would reuse + previously created variables. if `False` it will create new variables. + if `None`, it would inherit the parent scope reuse. Returns: A list of Variables corresponding to the slicing. @@ -152,7 +155,8 @@ def create_partitioned_variables( num_slices_with_excess = full_shape[slice_dim] % num_slices with variable_scope.variable_op_scope([], name, - "PartitionedVariable") as scope: + "PartitionedVariable", + reuse=reuse) as scope: full_name = scope.name slice_offset = [0] * len(full_shape) for i in xrange(num_slices): |