diff options
author | 2017-11-01 19:31:24 -0700 | |
---|---|---|
committer | 2017-11-01 19:35:18 -0700 | |
commit | 16fa134cfb576bfa690d7006864e555dc42c6b62 (patch) | |
tree | 90a98df106d3c806113660d3d2366987a74741d2 /tensorflow/contrib/model_pruning | |
parent | 53a4fcbdbad571e659203733f6a07ba82651d40b (diff) |
Convert BasicRNNCell and GRUCell to proper layers.
PiperOrigin-RevId: 174272860
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r-- | tensorflow/contrib/model_pruning/python/layers/rnn_cells.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py index 18ba3d1327..a5b050d25d 100644 --- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py +++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py @@ -92,6 +92,8 @@ class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell): # Call the build method of the parent class. super(MaskedBasicLSTMCell, self).build(inputs_shape) + self.built = False + input_depth = inputs_shape[1].value h_depth = self._num_units self._mask = self.add_variable( @@ -117,6 +119,8 @@ class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell): ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold) ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel) + self.built = True + def call(self, inputs, state): """Long short-term memory cell (LSTM) with masks for pruning. @@ -237,6 +241,8 @@ class MaskedLSTMCell(tf_rnn.LSTMCell): # Call the build method of the parent class. super(MaskedLSTMCell, self).build(inputs_shape) + self.built = False + input_depth = inputs_shape[1].value h_depth = self._num_units self._mask = self.add_variable( @@ -262,6 +268,8 @@ class MaskedLSTMCell(tf_rnn.LSTMCell): ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold) ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel) + self.built = True + def call(self, inputs, state): """Run one step of LSTM. |