diff options
author | 2017-04-04 12:16:35 -0800 | |
---|---|---|
committer | 2017-04-04 13:29:29 -0700 | |
commit | 19629df86ef4f57ff869e53eae211dbd905fde1f (patch) | |
tree | 266304d0d8cb1cc43eea5064bed00edcb5be40f7 | |
parent | a88279b76fe1e1ed7b0223f1c4a5b554b2567049 (diff) |
Passes trainable flag to separable_conv2d biases.
Change: 152170239
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers_test.py | 14 |
2 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 65dcf8577f..0140f6d0d3 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1942,6 +1942,7 @@ def separable_convolution2d( dtype=dtype, initializer=biases_initializer, regularizer=biases_regularizer, + trainable=trainable, collections=biases_collections) outputs = nn.bias_add(outputs, biases) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 3bc31a2624..2b170e92ba 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -2979,6 +2979,20 @@ class SeparableConv2dTest(test.TestCase): sess.run(init_op) sess.run(net, feed_dict={images_placeholder: images}) + def testTrainableFlagIsPassedOn(self): + for trainable in [True, False]: + for num_filters in [None, 8]: + with ops.Graph().as_default(): + input_size = [5, 10, 12, 3] + + images = random_ops.random_uniform(input_size, seed=1) + layers_lib.separable_conv2d( + images, num_filters, [3, 3], 1, trainable=trainable) + model_variables = variables.get_model_variables() + trainable_variables = variables_lib.trainable_variables() + for model_variable in model_variables: + self.assertEqual(trainable, model_variable in trainable_variables) + class ScaleGradientTests(test.TestCase): """Simple tests of the scale_gradient function.""" |