aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/rnn_cells.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
index 18ba3d1327..a5b050d25d 100644
--- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
+++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py
@@ -92,6 +92,8 @@ class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell):
# Call the build method of the parent class.
super(MaskedBasicLSTMCell, self).build(inputs_shape)
+ self.built = False
+
input_depth = inputs_shape[1].value
h_depth = self._num_units
self._mask = self.add_variable(
@@ -117,6 +119,8 @@ class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell):
ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
+ self.built = True
+
def call(self, inputs, state):
"""Long short-term memory cell (LSTM) with masks for pruning.
@@ -237,6 +241,8 @@ class MaskedLSTMCell(tf_rnn.LSTMCell):
# Call the build method of the parent class.
super(MaskedLSTMCell, self).build(inputs_shape)
+ self.built = False
+
input_depth = inputs_shape[1].value
h_depth = self._num_units
self._mask = self.add_variable(
@@ -262,6 +268,8 @@ class MaskedLSTMCell(tf_rnn.LSTMCell):
ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)
+ self.built = True
+
def call(self, inputs, state):
"""Run one step of LSTM.