diff options
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 0d25a09852..62bc1ab15d 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2674,16 +2674,25 @@ def spatial_softmax(features, indexing='ij') pos_x = array_ops.reshape(pos_x, [height * width]) pos_y = array_ops.reshape(pos_y, [height * width]) + if temperature is None: - temperature_collections = utils.get_variable_collections( - variables_collections, 'temperature') - temperature = variables.model_variable( - 'temperature', - shape=(), - dtype=dtypes.float32, - initializer=init_ops.ones_initializer(), - collections=temperature_collections, - trainable=trainable) + temp_initializer = init_ops.ones_initializer() + else: + temp_initializer = init_ops.constant_initializer(temperature) + + if not trainable: + temp_collections = None + else: + temp_collections = utils.get_variable_collections( + variables_collections, 'temperature') + + temperature = variables.model_variable( + 'temperature', + shape=(), + dtype=dtypes.float32, + initializer=temp_initializer, + collections=temp_collections, + trainable=trainable) if data_format == 'NCHW': features = array_ops.reshape(features, [-1, height * width]) else: |