aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py27
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 0d25a09852..62bc1ab15d 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -2674,16 +2674,25 @@ def spatial_softmax(features,
indexing='ij')
pos_x = array_ops.reshape(pos_x, [height * width])
pos_y = array_ops.reshape(pos_y, [height * width])
+
if temperature is None:
- temperature_collections = utils.get_variable_collections(
- variables_collections, 'temperature')
- temperature = variables.model_variable(
- 'temperature',
- shape=(),
- dtype=dtypes.float32,
- initializer=init_ops.ones_initializer(),
- collections=temperature_collections,
- trainable=trainable)
+ temp_initializer = init_ops.ones_initializer()
+ else:
+ temp_initializer = init_ops.constant_initializer(temperature)
+
+ if not trainable:
+ temp_collections = None
+ else:
+ temp_collections = utils.get_variable_collections(
+ variables_collections, 'temperature')
+
+ temperature = variables.model_variable(
+ 'temperature',
+ shape=(),
+ dtype=dtypes.float32,
+ initializer=temp_initializer,
+ collections=temp_collections,
+ trainable=trainable)
if data_format == 'NCHW':
features = array_ops.reshape(features, [-1, height * width])
else: