aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-11-01 19:31:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-01 19:35:18 -0700
commit16fa134cfb576bfa690d7006864e555dc42c6b62 (patch)
tree90a98df106d3c806113660d3d2366987a74741d2 /tensorflow/contrib/model_pruning
parent53a4fcbdbad571e659203733f6a07ba82651d40b (diff)
Convert BasicRNNCell and GRUCell to proper layers.
PiperOrigin-RevId: 174272860
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/rnn_cells.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
index 18ba3d1327..a5b050d25d 100644
--- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
+++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
@@ -92,6 +92,8 @@ class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell):
# Call the build method of the parent class.
super(MaskedBasicLSTMCell, self).build(inputs_shape)
+ self.built = False
+
input_depth = inputs_shape[1].value
h_depth = self._num_units
self._mask = self.add_variable(
@@ -117,6 +119,8 @@ class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell):
ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
+ self.built = True
+
def call(self, inputs, state):
"""Long short-term memory cell (LSTM) with masks for pruning.
@@ -237,6 +241,8 @@ class MaskedLSTMCell(tf_rnn.LSTMCell):
# Call the build method of the parent class.
super(MaskedLSTMCell, self).build(inputs_shape)
+ self.built = False
+
input_depth = inputs_shape[1].value
h_depth = self._num_units
self._mask = self.add_variable(
@@ -262,6 +268,8 @@ class MaskedLSTMCell(tf_rnn.LSTMCell):
ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
+ self.built = True
+
def call(self, inputs, state):
"""Run one step of LSTM.