diff options
author | 2018-01-05 16:09:47 -0800 | |
---|---|---|
committer | 2018-01-05 16:13:49 -0800 | |
commit | 3e7401af39b6b94f779da193c477afe05fc9856f (patch) | |
tree | 826cbde630e691a365a0cadf4a20432243cd08b6 /tensorflow/contrib/model_pruning | |
parent | 35a068e7c29202f298575e51320c469f91f22f95 (diff) |
Make compute_output_shape public in masked core layers
PiperOrigin-RevId: 180988293
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r-- | tensorflow/contrib/model_pruning/python/layers/core_layers.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py index 95dfd8f421..764ab620bc 100644 --- a/tensorflow/contrib/model_pruning/python/layers/core_layers.py +++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py @@ -210,7 +210,7 @@ class _MaskedConv(base.Layer): return self.activation(outputs) return outputs - def _compute_output_shape(self, input_shape): + def compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() if self.data_format == 'channels_last': space = input_shape[1:-1] @@ -467,7 +467,7 @@ class MaskedFullyConnected(base.Layer): return self.activation(outputs) # pylint: disable=not-callable return outputs - def _compute_output_shape(self, input_shape): + def compute_output_shape(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) input_shape = input_shape.with_rank_at_least(2) if input_shape[-1].value is None: |