aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-01 17:11:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-01 17:14:57 -0700
commit83621c7ec59a400d83de0dd3e7b45ec670c02893 (patch)
treeec565daec06e599f3e384f4d3acf698e931308dd /tensorflow/contrib/model_pruning
parent88b8f4b5382aaf3a6ff39f48d8c518ba8927aefe (diff)
Bug fix: Expose get_pruning_hparams function
PiperOrigin-RevId: 174260120
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/__init__.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/model_pruning/__init__.py b/tensorflow/contrib/model_pruning/__init__.py
index aaeb2238a4..d32bedbcd6 100644
--- a/tensorflow/contrib/model_pruning/__init__.py
+++ b/tensorflow/contrib/model_pruning/__init__.py
@@ -28,6 +28,7 @@ from tensorflow.contrib.model_pruning.python.learning import train
from tensorflow.contrib.model_pruning.python.pruning import apply_mask
from tensorflow.contrib.model_pruning.python.pruning import get_masked_weights
from tensorflow.contrib.model_pruning.python.pruning import get_masks
+from tensorflow.contrib.model_pruning.python.pruning import get_pruning_hparams
from tensorflow.contrib.model_pruning.python.pruning import get_thresholds
from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity
from tensorflow.contrib.model_pruning.python.pruning import get_weights
@@ -39,8 +40,8 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'masked_convolution', 'masked_conv2d', 'masked_fully_connected',
'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask',
- 'get_masked_weights', 'get_masks', 'get_thresholds', 'get_weights',
- 'get_weight_sparsity', 'Pruning'
+ 'get_masked_weights', 'get_masks', 'get_pruning_hparams', 'get_thresholds',
+ 'get_weights', 'get_weight_sparsity', 'Pruning'
]
remove_undocumented(__name__, _allowed_symbols)