aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-04 12:16:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-04 13:29:29 -0700
commit19629df86ef4f57ff869e53eae211dbd905fde1f (patch)
tree266304d0d8cb1cc43eea5064bed00edcb5be40f7
parenta88279b76fe1e1ed7b0223f1c4a5b554b2567049 (diff)
Passes trainable flag to separable_conv2d biases.
Change: 152170239
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py1
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py14
2 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 65dcf8577f..0140f6d0d3 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -1942,6 +1942,7 @@ def separable_convolution2d(
dtype=dtype,
initializer=biases_initializer,
regularizer=biases_regularizer,
+ trainable=trainable,
collections=biases_collections)
outputs = nn.bias_add(outputs, biases)
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 3bc31a2624..2b170e92ba 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -2979,6 +2979,20 @@ class SeparableConv2dTest(test.TestCase):
sess.run(init_op)
sess.run(net, feed_dict={images_placeholder: images})
+ def testTrainableFlagIsPassedOn(self):
+ for trainable in [True, False]:
+ for num_filters in [None, 8]:
+ with ops.Graph().as_default():
+ input_size = [5, 10, 12, 3]
+
+ images = random_ops.random_uniform(input_size, seed=1)
+ layers_lib.separable_conv2d(
+ images, num_filters, [3, 3], 1, trainable=trainable)
+ model_variables = variables.get_model_variables()
+ trainable_variables = variables_lib.trainable_variables()
+ for model_variable in model_variables:
+ self.assertEqual(trainable, model_variable in trainable_variables)
+
class ScaleGradientTests(test.TestCase):
"""Simple tests of the scale_gradient function."""