diff options
4 files changed, 42 insertions, 207 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 05ad9f6336..2f6963f6b8 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -535,6 +535,45 @@ class RNNTest(test.TestCase): self.assertAllClose(tf_out, k_out) self.assertAllClose(tf_state, k_state) + def testSimpleRNNCellAndBasicRNNCellComparison(self): + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 20 + (x_train, _), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + fix_weights_generator = keras.layers.SimpleRNNCell(output_shape) + fix_weights_generator.build((None, input_shape)) + # The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias + # The BasicRNNCell contains 2 weight: kernel and bias, where kernel is + # zipped [kernel, recurrent_kernel] in SimpleRNNCell. + keras_weights = fix_weights_generator.get_weights() + kernel, recurrent_kernel, bias = keras_weights + tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias] + + with self.test_session(graph=ops_lib.Graph()) as sess: + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + cell = keras.layers.SimpleRNNCell(output_shape) + k_out, k_state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + cell.set_weights(keras_weights) + [k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train}) + with self.test_session(graph=ops_lib.Graph()) as sess: + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + cell = rnn_cell_impl.BasicRNNCell(output_shape) + tf_out, tf_state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + cell.set_weights(tf_weights) + [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train}) + + self.assertAllClose(tf_out, k_out) + self.assertAllClose(tf_state, k_state) + def testBasicLSTMCellInterchangeWithLSTMCell(self): with self.session(graph=ops_lib.Graph()) as sess: basic_cell = rnn_cell_impl.BasicLSTMCell(1) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index c2751e529a..dd4f3d7a99 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -370,7 +370,7 @@ class LayerRNNCell(RNNCell): *args, **kwargs) -@tf_export("nn.rnn_cell.BasicRNNCell") +@tf_export(v1=["nn.rnn_cell.BasicRNNCell"]) class BasicRNNCell(LayerRNNCell): """The most basic RNN cell. @@ -393,6 +393,8 @@ class BasicRNNCell(LayerRNNCell): `trainable` etc when constructing the cell from configs of get_config(). """ + @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell," + " and will be replaced by that in Tensorflow 2.0.") def __init__(self, num_units, activation=None, diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt deleted file mode 100644 index a4483fefa2..0000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt +++ /dev/null @@ -1,202 +0,0 @@ -path: "tensorflow.nn.rnn_cell.BasicRNNCell" -tf_class { - is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicRNNCell\'>" - is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LayerRNNCell\'>" - is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>" - is_instance: "<class \'tensorflow.python.layers.base.Layer\'>" - is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" - is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>" - is_instance: "<type \'object\'>" - member { - name: "activity_regularizer" - mtype: "<type \'property\'>" - } - member { - name: "dtype" - mtype: "<type \'property\'>" - } - member { - name: "graph" - mtype: "<type \'property\'>" - } - member { - name: "inbound_nodes" - mtype: "<type \'property\'>" - } - member { - name: "input" - mtype: "<type \'property\'>" - } - member { - name: "input_mask" - mtype: "<type \'property\'>" - } - member { - name: "input_shape" - mtype: "<type \'property\'>" - } - member { - name: "losses" - mtype: "<type \'property\'>" - } - member { - name: "name" - mtype: "<type \'property\'>" - } - member { - name: "non_trainable_variables" - mtype: "<type \'property\'>" - } - member { - name: "non_trainable_weights" - mtype: "<type \'property\'>" - } - member { - name: "outbound_nodes" - mtype: "<type \'property\'>" - } - member { - name: "output" - mtype: "<type \'property\'>" - } - member { - name: "output_mask" - mtype: "<type \'property\'>" - } - member { - name: "output_shape" - mtype: "<type \'property\'>" - } - member { - name: "output_size" - mtype: "<type \'property\'>" - } - member { - name: "scope_name" - mtype: "<type \'property\'>" - } - member { - name: "state_size" - mtype: "<type \'property\'>" - } - member { - name: "trainable_variables" - mtype: "<type \'property\'>" - } - member { - name: "trainable_weights" - mtype: "<type \'property\'>" - } - member { - name: "updates" - mtype: "<type \'property\'>" - } - member { - name: "variables" - mtype: "<type \'property\'>" - } - member { - name: "weights" - mtype: "<type \'property\'>" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], " - } - member_method { - name: "add_loss" - argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "add_update" - argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "add_variable" - argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" - } - member_method { - name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], " - } - member_method { - name: "apply" - argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" - } - member_method { - name: "build" - argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "call" - argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "compute_mask" - argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "compute_output_shape" - argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "count_params" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "from_config" - argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_config" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_initial_state" - argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "get_input_at" - argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_input_mask_at" - argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_input_shape_at" - argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_losses_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_output_at" - argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_output_mask_at" - argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_output_shape_at" - argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_updates_for" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_weights" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "set_weights" - argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "zero_state" - argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt index 64697e8a02..24767e250f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt @@ -5,10 +5,6 @@ tf_module { mtype: "<type \'type\'>" } member { - name: "BasicRNNCell" - mtype: "<type \'type\'>" - } - member { name: "DeviceWrapper" mtype: "<type \'type\'>" } |