From 1e982bb8330c0c5a571559e8653c5e8b948a621e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Feb 2017 18:43:30 -0800 Subject: Added rnn cell from Neural Architecture Search with Reinforcement Learning into Tensorflow contrib. Change: 147549897 --- tensorflow/contrib/rnn/__init__.py | 4 + .../rnn/python/kernel_tests/rnn_cell_test.py | 113 ++++++++++++++++ tensorflow/contrib/rnn/python/ops/rnn_cell.py | 146 +++++++++++++++++++++ 3 files changed, 263 insertions(+) diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index 86dee63dbd..4411590244 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -38,6 +38,9 @@ @@CoupledInputForgetGateLSTMCell @@TimeFreqLSTMCell @@GridLSTMCell +@@NASCell + +### RNNCell wrappers @@AttentionCellWrapper @@CompiledWrapper @@static_rnn @@ -78,3 +81,4 @@ from tensorflow.contrib.rnn.python.ops.rnn_cell import * from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__, ['core_rnn_cell']) + diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 19b5788f2d..7df0f4cb14 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -636,6 +636,119 @@ class RNNCellTest(test.TestCase): self.assertAllClose(sess.run(output), expected_output) self.assertAllClose(sess.run(state), expected_state) + def testNASCell(self): + num_units = 6 + batch_size = 3 + expected_output = np.array([[0.576751, 0.576751, 0.576751, 0.576751, + 0.576751, 0.576751], + [0.618936, 0.618936, 0.618936, 0.618936, + 0.618936, 0.618936], + [0.627393, 0.627393, 0.627393, 0.627393, + 0.627393, 0.627393]]) + expected_state = np.array([[0.71579772, 0.71579772, 0.71579772, 0.71579772, + 0.71579772, 0.71579772, 0.57675087, 0.57675087, + 0.57675087, 0.57675087, 0.57675087, 0.57675087], + [0.78041625, 0.78041625, 0.78041625, 0.78041625, + 0.78041625, 0.78041625, 0.6189357, 0.6189357, + 0.61893570, 0.6189357, 0.6189357, 0.6189357], + [0.79457647, 0.79457647, 0.79457647, 0.79457647, + 0.79457653, 0.79457653, 0.62739348, 0.62739348, + 0.62739348, 0.62739348, 0.62739348, 0.62739348] + ]) + with self.test_session() as sess: + with variable_scope.variable_scope( + "nas_test", + initializer=init_ops.constant_initializer(0.5)): + cell = rnn_cell.NASCell( + num_units=num_units) + inputs = constant_op.constant( + np.array([[1., 1., 1., 1.], + [2., 2., 2., 2.], + [3., 3., 3., 3.]], + dtype=np.float32), + dtype=dtypes.float32) + state_value = constant_op.constant( + 0.1 * np.ones( + (batch_size, num_units), dtype=np.float32), + dtype=dtypes.float32) + init_state = core_rnn_cell_impl.LSTMStateTuple(state_value, + state_value) + output, state = cell(inputs, init_state) + sess.run([variables.global_variables_initializer()]) + res = sess.run([output, state]) + + # This is a smoke test: Only making sure expected values not change. + self.assertEqual(len(res), 2) + self.assertAllClose(res[0], expected_output) + # There should be 2 states in the tuple. + self.assertEqual(len(res[1]), 2) + # Checking the shape of each state to be batch_size * num_units + new_c, new_h = res[1] + self.assertEqual(new_c.shape[0], batch_size) + self.assertEqual(new_c.shape[1], num_units) + self.assertEqual(new_h.shape[0], batch_size) + self.assertEqual(new_h.shape[1], num_units) + self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) + + def testNASCellProj(self): + num_units = 6 + batch_size = 3 + num_proj = 5 + expected_output = np.array([[1.697418, 1.697418, 1.697418, 1.697418, + 1.697418], + [1.840037, 1.840037, 1.840037, 1.840037, + 1.840037], + [1.873985, 1.873985, 1.873985, 1.873985, + 1.873985]]) + expected_state = np.array([[0.69855207, 0.69855207, 0.69855207, 0.69855207, + 0.69855207, 0.69855207, 1.69741797, 1.69741797, + 1.69741797, 1.69741797, 1.69741797], + [0.77073824, 0.77073824, 0.77073824, 0.77073824, + 0.77073824, 0.77073824, 1.84003687, 1.84003687, + 1.84003687, 1.84003687, 1.84003687], + [0.78973997, 0.78973997, 0.78973997, 0.78973997, + 0.78973997, 0.78973997, 1.87398517, 1.87398517, + 1.87398517, 1.87398517, 1.87398517]]) + with self.test_session() as sess: + with variable_scope.variable_scope( + "nas_proj_test", + initializer=init_ops.constant_initializer(0.5)): + cell = rnn_cell.NASCell( + num_units=num_units, + num_proj=num_proj) + inputs = constant_op.constant( + np.array([[1., 1., 1., 1.], + [2., 2., 2., 2.], + [3., 3., 3., 3.]], + dtype=np.float32), + dtype=dtypes.float32) + state_value_c = constant_op.constant( + 0.1 * np.ones( + (batch_size, num_units), dtype=np.float32), + dtype=dtypes.float32) + state_value_h = constant_op.constant( + 0.1 * np.ones( + (batch_size, num_proj), dtype=np.float32), + dtype=dtypes.float32) + init_state = core_rnn_cell_impl.LSTMStateTuple(state_value_c, + state_value_h) + output, state = cell(inputs, init_state) + sess.run([variables.global_variables_initializer()]) + res = sess.run([output, state]) + + # This is a smoke test: Only making sure expected values not change. + self.assertEqual(len(res), 2) + self.assertAllClose(res[0], expected_output) + # There should be 2 states in the tuple. + self.assertEqual(len(res[1]), 2) + # Checking the shape of each state to be batch_size * num_units + new_c, new_h = res[1] + self.assertEqual(new_c.shape[0], batch_size) + self.assertEqual(new_c.shape[1], num_units) + self.assertEqual(new_h.shape[0], batch_size) + self.assertEqual(new_h.shape[1], num_proj) + self.assertAllClose(np.concatenate(res[1], axis=1), expected_state) + class LayerNormBasicLSTMCellTest(test.TestCase): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 0191ce8302..0ce3e8c897 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1254,6 +1254,152 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell): return new_h, new_state +class NASCell(core_rnn_cell.RNNCell): + """Neural Architecture Search (NAS) recurrent network cell. + + This implements the recurrent cell from the paper: + + https://arxiv.org/abs/1611.01578 + + Barret Zoph and Quoc V. Le. + "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017. + + The class uses an optional projection layer. + """ + + def __init__(self, num_units, num_proj=None, + use_biases=False): + """Initialize the parameters for a NAS cell. + + Args: + num_units: int, The number of units in the NAS cell + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + use_biases: (optional) bool, If True then use biases within the cell. This + is False by default. + """ + self._num_units = num_units + self._num_proj = num_proj + self._use_biases = use_biases + + if num_proj is not None: + self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj) + self._output_size = num_proj + else: + self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def __call__(self, inputs, state, scope=None): + """Run one step of NAS Cell. + + Args: + inputs: input Tensor, 2D, batch x num_units. + state: This must be a tuple of state Tensors, both `2-D`, with column + sizes `c_state` and `m_state`. + scope: VariableScope for the created subgraph; defaults to "nas_rnn". + + Returns: + A tuple containing: + - A `2-D, [batch x output_dim]`, Tensor representing the output of the + NAS Cell after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of NAS Cell after reading `inputs` + when the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + sigmoid = math_ops.sigmoid + tanh = math_ops.tanh + relu = nn_ops.relu + + num_proj = self._num_units if self._num_proj is None else self._num_proj + + (c_prev, m_prev) = state + + dtype = inputs.dtype + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + with vs.variable_scope(scope or "nas_rnn"): + + # Variables for the NAS cell. W_m is all matrices multiplying the + # hiddenstate and W_inputs is all matrices multiplying the inputs. + concat_w_m = vs.get_variable( + "W_m", [num_proj, 8 * self._num_units], + dtype) + concat_w_inputs = vs.get_variable( + "W_inputs", [input_size.value, 8 * self._num_units], + dtype) + + m_matrix = math_ops.matmul(m_prev, concat_w_m) + inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) + + if self._use_biases: + b = vs.get_variable( + "B", + shape=[8 * self._num_units], + initializer=init_ops.zeros_initializer(), + dtype=dtype) + m_matrix = nn_ops.bias_add(m_matrix, b) + + # The NAS cell branches into 8 different splits for both the hiddenstate + # and the input + m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, + value=m_matrix) + inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, + value=inputs_matrix) + + # First layer + layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) + layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1]) + layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) + layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3]) + layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) + layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) + layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) + layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) + + # Second layer + l2_0 = tanh(layer1_0 * layer1_1) + l2_1 = tanh(layer1_2 + layer1_3) + l2_2 = tanh(layer1_4 * layer1_5) + l2_3 = sigmoid(layer1_6 + layer1_7) + + # Inject the cell + l2_0 = tanh(l2_0 + c_prev) + + # Third layer + l3_0_pre = l2_0 * l2_1 + new_c = l3_0_pre # create new cell + l3_0 = l3_0_pre + l3_1 = tanh(l2_2 + l2_3) + + # Final layer + new_m = tanh(l3_0 * l3_1) + + # Projection layer if specified + if self._num_proj is not None: + concat_w_proj = vs.get_variable( + "W_P", [self._num_units, self._num_proj], + dtype) + new_m = math_ops.matmul(new_m, concat_w_proj) + + new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m) + return new_m, new_state + + _REGISTERED_OPS = None -- cgit v1.2.3