aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-10-01 11:29:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 11:39:18 -0700
commitf0f301f05fb1f1965c966ef57cc390e48d966f12 (patch)
tree29be99009e738828ee13cbd6d72ebee0b8e772c9 /tensorflow
parenta9b01e8a31a02188bc81349c103f136095f322ac (diff)
Add deprecation notice for BasicRNNCell, which will be replaced by keras.SimpleRNNCell.
PiperOrigin-RevId: 215249611
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py39
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt202
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt4
4 files changed, 42 insertions, 207 deletions
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 05ad9f6336..2f6963f6b8 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -535,6 +535,45 @@ class RNNTest(test.TestCase):
self.assertAllClose(tf_out, k_out)
self.assertAllClose(tf_state, k_state)
+ def testSimpleRNNCellAndBasicRNNCellComparison(self):
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 20
+ (x_train, _), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ fix_weights_generator = keras.layers.SimpleRNNCell(output_shape)
+ fix_weights_generator.build((None, input_shape))
+ # The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias
+ # The BasicRNNCell contains 2 weight: kernel and bias, where kernel is
+ # zipped [kernel, recurrent_kernel] in SimpleRNNCell.
+ keras_weights = fix_weights_generator.get_weights()
+ kernel, recurrent_kernel, bias = keras_weights
+ tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias]
+
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = keras.layers.SimpleRNNCell(output_shape)
+ k_out, k_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(keras_weights)
+ [k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train})
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = rnn_cell_impl.BasicRNNCell(output_shape)
+ tf_out, tf_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(tf_weights)
+ [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
+
+ self.assertAllClose(tf_out, k_out)
+ self.assertAllClose(tf_state, k_state)
+
def testBasicLSTMCellInterchangeWithLSTMCell(self):
with self.session(graph=ops_lib.Graph()) as sess:
basic_cell = rnn_cell_impl.BasicLSTMCell(1)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index c2751e529a..dd4f3d7a99 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -370,7 +370,7 @@ class LayerRNNCell(RNNCell):
*args, **kwargs)
-@tf_export("nn.rnn_cell.BasicRNNCell")
+@tf_export(v1=["nn.rnn_cell.BasicRNNCell"])
class BasicRNNCell(LayerRNNCell):
"""The most basic RNN cell.
@@ -393,6 +393,8 @@ class BasicRNNCell(LayerRNNCell):
`trainable` etc when constructing the cell from configs of get_config().
"""
+ @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell,"
+ " and will be replaced by that in Tensorflow 2.0.")
def __init__(self,
num_units,
activation=None,
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
deleted file mode 100644
index a4483fefa2..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ /dev/null
@@ -1,202 +0,0 @@
-path: "tensorflow.nn.rnn_cell.BasicRNNCell"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LayerRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
index 64697e8a02..24767e250f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
@@ -5,10 +5,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "BasicRNNCell"
- mtype: "<type \'type\'>"
- }
- member {
name: "DeviceWrapper"
mtype: "<type \'type\'>"
}