aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-07-22 13:41:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-22 13:45:07 -0700
commit012f97121441f936b5262b98e2ca488c0c92422f (patch)
tree82d866e4c71a3c0b7a0659b33adf306a67cb4a31 /tensorflow/contrib/layers
parent162304f9da4114f5ed3f0e4c27929413e7abc965 (diff)
Add synchronization and aggregation arguments to variable creation methods in contrib/layers.
PiperOrigin-RevId: 205588849
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py33
1 files changed, 19 insertions, 14 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index beeabd6b65..dd602cf3a9 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -1702,19 +1702,22 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None):
return utils.collect_named_outputs(output_collections, sc, flattened)
-def _model_variable_getter(getter,
- name,
- shape=None,
- dtype=None,
- initializer=None,
- regularizer=None,
- trainable=True,
- collections=None,
- caching_device=None,
- partitioner=None,
- rename=None,
- use_resource=None,
- **_):
+def _model_variable_getter(
+ getter,
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ rename=None,
+ use_resource=None,
+ synchronization=tf_variables.VariableSynchronization.AUTO,
+ aggregation=tf_variables.VariableAggregation.NONE,
+ **_):
"""Getter that uses model_variable for compatibility with core layers."""
short_name = name.split('/')[-1]
if rename and short_name in rename:
@@ -1732,7 +1735,9 @@ def _model_variable_getter(getter,
caching_device=caching_device,
partitioner=partitioner,
custom_getter=getter,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
def _build_variable_getter(rename=None):