aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-04-08 10:29:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-08 11:32:33 -0700
commitbbe01cbdfab2a588f030e6d1178ff67fd7b1d0ec (patch)
tree015c4dd97de7afcc983b7cf2e18df91a378d953f
parentf140e0942ba39e3923e9597ffe92ec4c2c403b2e (diff)
Allow reuse of partitioned_variables.
Change: 119392087
-rw-r--r--tensorflow/python/ops/partitioned_variables.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py
index fb2f51dcde..9d4d19668a 100644
--- a/tensorflow/python/ops/partitioned_variables.py
+++ b/tensorflow/python/ops/partitioned_variables.py
@@ -95,7 +95,7 @@ def _compute_slice_dim_and_shape(full_shape, slicing):
def create_partitioned_variables(
shape, slicing, initializer, dtype=dtypes.float32,
- trainable=True, collections=None, name=None):
+ trainable=True, collections=None, name=None, reuse=None):
"""Create a list of partitioned variables according to the given `slicing`.
Currently only one dimension of the full variable can be sliced, and the
@@ -127,6 +127,9 @@ def create_partitioned_variables(
Defaults to `[GraphKeys.VARIABLES]`.
name: Optional name for the full variable. Defaults to
`"PartitionedVariable"` and gets uniquified automatically.
+ reuse: Boolean or `None`; if `True` and name is set, it would reuse
+ previously created variables. if `False` it will create new variables.
+ if `None`, it would inherit the parent scope reuse.
Returns:
A list of Variables corresponding to the slicing.
@@ -152,7 +155,8 @@ def create_partitioned_variables(
num_slices_with_excess = full_shape[slice_dim] % num_slices
with variable_scope.variable_op_scope([], name,
- "PartitionedVariable") as scope:
+ "PartitionedVariable",
+ reuse=reuse) as scope:
full_name = scope.name
slice_offset = [0] * len(full_shape)
for i in xrange(num_slices):