aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-05 16:09:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-05 16:13:49 -0800
commit3e7401af39b6b94f779da193c477afe05fc9856f (patch)
tree826cbde630e691a365a0cadf4a20432243cd08b6 /tensorflow/contrib/model_pruning
parent35a068e7c29202f298575e51320c469f91f22f95 (diff)
Make compute_output_shape public in masked core layers
PiperOrigin-RevId: 180988293
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/core_layers.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
index 95dfd8f421..764ab620bc 100644
--- a/tensorflow/contrib/model_pruning/python/layers/core_layers.py
+++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
@@ -210,7 +210,7 @@ class _MaskedConv(base.Layer):
return self.activation(outputs)
return outputs
- def _compute_output_shape(self, input_shape):
+ def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_last':
space = input_shape[1:-1]
@@ -467,7 +467,7 @@ class MaskedFullyConnected(base.Layer):
return self.activation(outputs) # pylint: disable=not-callable
return outputs
- def _compute_output_shape(self, input_shape):
+ def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
input_shape = input_shape.with_rank_at_least(2)
if input_shape[-1].value is None: