aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/ops/partitioned_variables.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py
index fb2f51dcde..9d4d19668a 100644
--- a/tensorflow/python/ops/partitioned_variables.py
+++ b/tensorflow/python/ops/partitioned_variables.py
@@ -95,7 +95,7 @@ def _compute_slice_dim_and_shape(full_shape, slicing):
def create_partitioned_variables(
shape, slicing, initializer, dtype=dtypes.float32,
- trainable=True, collections=None, name=None):
+ trainable=True, collections=None, name=None, reuse=None):
"""Create a list of partitioned variables according to the given `slicing`.
Currently only one dimension of the full variable can be sliced, and the
@@ -127,6 +127,9 @@ def create_partitioned_variables(
Defaults to `[GraphKeys.VARIABLES]`.
name: Optional name for the full variable. Defaults to
`"PartitionedVariable"` and gets uniquified automatically.
+ reuse: Boolean or `None`; if `True` and name is set, it would reuse
+ previously created variables. if `False` it will create new variables.
+ if `None`, it would inherit the parent scope reuse.
Returns:
A list of Variables corresponding to the slicing.
@@ -152,7 +155,8 @@ def create_partitioned_variables(
num_slices_with_excess = full_shape[slice_dim] % num_slices
with variable_scope.variable_op_scope([], name,
- "PartitionedVariable") as scope:
+ "PartitionedVariable",
+ reuse=reuse) as scope:
full_name = scope.name
slice_offset = [0] * len(full_shape)
for i in xrange(num_slices):