diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-07-22 13:41:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-22 13:45:07 -0700 |
commit | 012f97121441f936b5262b98e2ca488c0c92422f (patch) | |
tree | 82d866e4c71a3c0b7a0659b33adf306a67cb4a31 /tensorflow/contrib/layers | |
parent | 162304f9da4114f5ed3f0e4c27929413e7abc965 (diff) |
Add synchronization and aggregation arguments to variable creation methods in contrib/layers.
PiperOrigin-RevId: 205588849
Diffstat (limited to 'tensorflow/contrib/layers')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 33 |
1 files changed, 19 insertions, 14 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index beeabd6b65..dd602cf3a9 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1702,19 +1702,22 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): return utils.collect_named_outputs(output_collections, sc, flattened) -def _model_variable_getter(getter, - name, - shape=None, - dtype=None, - initializer=None, - regularizer=None, - trainable=True, - collections=None, - caching_device=None, - partitioner=None, - rename=None, - use_resource=None, - **_): +def _model_variable_getter( + getter, + name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + partitioner=None, + rename=None, + use_resource=None, + synchronization=tf_variables.VariableSynchronization.AUTO, + aggregation=tf_variables.VariableAggregation.NONE, + **_): """Getter that uses model_variable for compatibility with core layers.""" short_name = name.split('/')[-1] if rename and short_name in rename: @@ -1732,7 +1735,9 @@ def _model_variable_getter(getter, caching_device=caching_device, partitioner=partitioner, custom_getter=getter, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) def _build_variable_getter(rename=None): |