aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-14 18:43:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-14 19:18:07 -0800
commit1e982bb8330c0c5a571559e8653c5e8b948a621e (patch)
treed96bdf9f112b9e52bf6e24d682dcd307e1382ec6
parentdd51f989b8ca738da8a04970857597ed68fa1a15 (diff)
Added rnn cell from Neural Architecture Search with Reinforcement Learning into
Tensorflow contrib. Change: 147549897
-rw-r--r--tensorflow/contrib/rnn/__init__.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py113
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py146
3 files changed, 263 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index 86dee63dbd..4411590244 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -38,6 +38,9 @@
@@CoupledInputForgetGateLSTMCell
@@TimeFreqLSTMCell
@@GridLSTMCell
+@@NASCell
+
+### RNNCell wrappers
@@AttentionCellWrapper
@@CompiledWrapper
@@static_rnn
@@ -78,3 +81,4 @@ from tensorflow.contrib.rnn.python.ops.rnn_cell import *
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__, ['core_rnn_cell'])
+
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 19b5788f2d..7df0f4cb14 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -636,6 +636,119 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(sess.run(output), expected_output)
self.assertAllClose(sess.run(state), expected_state)
+ def testNASCell(self):
+ num_units = 6
+ batch_size = 3
+ expected_output = np.array([[0.576751, 0.576751, 0.576751, 0.576751,
+ 0.576751, 0.576751],
+ [0.618936, 0.618936, 0.618936, 0.618936,
+ 0.618936, 0.618936],
+ [0.627393, 0.627393, 0.627393, 0.627393,
+ 0.627393, 0.627393]])
+ expected_state = np.array([[0.71579772, 0.71579772, 0.71579772, 0.71579772,
+ 0.71579772, 0.71579772, 0.57675087, 0.57675087,
+ 0.57675087, 0.57675087, 0.57675087, 0.57675087],
+ [0.78041625, 0.78041625, 0.78041625, 0.78041625,
+ 0.78041625, 0.78041625, 0.6189357, 0.6189357,
+ 0.61893570, 0.6189357, 0.6189357, 0.6189357],
+ [0.79457647, 0.79457647, 0.79457647, 0.79457647,
+ 0.79457653, 0.79457653, 0.62739348, 0.62739348,
+ 0.62739348, 0.62739348, 0.62739348, 0.62739348]
+ ])
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "nas_test",
+ initializer=init_ops.constant_initializer(0.5)):
+ cell = rnn_cell.NASCell(
+ num_units=num_units)
+ inputs = constant_op.constant(
+ np.array([[1., 1., 1., 1.],
+ [2., 2., 2., 2.],
+ [3., 3., 3., 3.]],
+ dtype=np.float32),
+ dtype=dtypes.float32)
+ state_value = constant_op.constant(
+ 0.1 * np.ones(
+ (batch_size, num_units), dtype=np.float32),
+ dtype=dtypes.float32)
+ init_state = core_rnn_cell_impl.LSTMStateTuple(state_value,
+ state_value)
+ output, state = cell(inputs, init_state)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([output, state])
+
+ # This is a smoke test: Only making sure expected values not change.
+ self.assertEqual(len(res), 2)
+ self.assertAllClose(res[0], expected_output)
+ # There should be 2 states in the tuple.
+ self.assertEqual(len(res[1]), 2)
+ # Checking the shape of each state to be batch_size * num_units
+ new_c, new_h = res[1]
+ self.assertEqual(new_c.shape[0], batch_size)
+ self.assertEqual(new_c.shape[1], num_units)
+ self.assertEqual(new_h.shape[0], batch_size)
+ self.assertEqual(new_h.shape[1], num_units)
+ self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
+
+ def testNASCellProj(self):
+ num_units = 6
+ batch_size = 3
+ num_proj = 5
+ expected_output = np.array([[1.697418, 1.697418, 1.697418, 1.697418,
+ 1.697418],
+ [1.840037, 1.840037, 1.840037, 1.840037,
+ 1.840037],
+ [1.873985, 1.873985, 1.873985, 1.873985,
+ 1.873985]])
+ expected_state = np.array([[0.69855207, 0.69855207, 0.69855207, 0.69855207,
+ 0.69855207, 0.69855207, 1.69741797, 1.69741797,
+ 1.69741797, 1.69741797, 1.69741797],
+ [0.77073824, 0.77073824, 0.77073824, 0.77073824,
+ 0.77073824, 0.77073824, 1.84003687, 1.84003687,
+ 1.84003687, 1.84003687, 1.84003687],
+ [0.78973997, 0.78973997, 0.78973997, 0.78973997,
+ 0.78973997, 0.78973997, 1.87398517, 1.87398517,
+ 1.87398517, 1.87398517, 1.87398517]])
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "nas_proj_test",
+ initializer=init_ops.constant_initializer(0.5)):
+ cell = rnn_cell.NASCell(
+ num_units=num_units,
+ num_proj=num_proj)
+ inputs = constant_op.constant(
+ np.array([[1., 1., 1., 1.],
+ [2., 2., 2., 2.],
+ [3., 3., 3., 3.]],
+ dtype=np.float32),
+ dtype=dtypes.float32)
+ state_value_c = constant_op.constant(
+ 0.1 * np.ones(
+ (batch_size, num_units), dtype=np.float32),
+ dtype=dtypes.float32)
+ state_value_h = constant_op.constant(
+ 0.1 * np.ones(
+ (batch_size, num_proj), dtype=np.float32),
+ dtype=dtypes.float32)
+ init_state = core_rnn_cell_impl.LSTMStateTuple(state_value_c,
+ state_value_h)
+ output, state = cell(inputs, init_state)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([output, state])
+
+ # This is a smoke test: Only making sure expected values not change.
+ self.assertEqual(len(res), 2)
+ self.assertAllClose(res[0], expected_output)
+ # There should be 2 states in the tuple.
+ self.assertEqual(len(res[1]), 2)
+ # Checking the shape of each state to be batch_size * num_units
+ new_c, new_h = res[1]
+ self.assertEqual(new_c.shape[0], batch_size)
+ self.assertEqual(new_c.shape[1], num_units)
+ self.assertEqual(new_h.shape[0], batch_size)
+ self.assertEqual(new_h.shape[1], num_proj)
+ self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
+
class LayerNormBasicLSTMCellTest(test.TestCase):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 0191ce8302..0ce3e8c897 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -1254,6 +1254,152 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
return new_h, new_state
+class NASCell(core_rnn_cell.RNNCell):
+ """Neural Architecture Search (NAS) recurrent network cell.
+
+ This implements the recurrent cell from the paper:
+
+ https://arxiv.org/abs/1611.01578
+
+ Barret Zoph and Quoc V. Le.
+ "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.
+
+ The class uses an optional projection layer.
+ """
+
+ def __init__(self, num_units, num_proj=None,
+ use_biases=False):
+ """Initialize the parameters for a NAS cell.
+
+ Args:
+ num_units: int, The number of units in the NAS cell
+ num_proj: (optional) int, The output dimensionality for the projection
+ matrices. If None, no projection is performed.
+ use_biases: (optional) bool, If True then use biases within the cell. This
+ is False by default.
+ """
+ self._num_units = num_units
+ self._num_proj = num_proj
+ self._use_biases = use_biases
+
+ if num_proj is not None:
+ self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
+ self._output_size = num_proj
+ else:
+ self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
+ self._output_size = num_units
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run one step of NAS Cell.
+
+ Args:
+ inputs: input Tensor, 2D, batch x num_units.
+ state: This must be a tuple of state Tensors, both `2-D`, with column
+ sizes `c_state` and `m_state`.
+ scope: VariableScope for the created subgraph; defaults to "nas_rnn".
+
+ Returns:
+ A tuple containing:
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
+ NAS Cell after reading `inputs` when previous state was `state`.
+ Here output_dim is:
+ num_proj if num_proj was set,
+ num_units otherwise.
+ - Tensor(s) representing the new state of NAS Cell after reading `inputs`
+ when the previous state was `state`. Same type and shape(s) as `state`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ sigmoid = math_ops.sigmoid
+ tanh = math_ops.tanh
+ relu = nn_ops.relu
+
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
+
+ (c_prev, m_prev) = state
+
+ dtype = inputs.dtype
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+ with vs.variable_scope(scope or "nas_rnn"):
+
+ # Variables for the NAS cell. W_m is all matrices multiplying the
+ # hiddenstate and W_inputs is all matrices multiplying the inputs.
+ concat_w_m = vs.get_variable(
+ "W_m", [num_proj, 8 * self._num_units],
+ dtype)
+ concat_w_inputs = vs.get_variable(
+ "W_inputs", [input_size.value, 8 * self._num_units],
+ dtype)
+
+ m_matrix = math_ops.matmul(m_prev, concat_w_m)
+ inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
+
+ if self._use_biases:
+ b = vs.get_variable(
+ "B",
+ shape=[8 * self._num_units],
+ initializer=init_ops.zeros_initializer(),
+ dtype=dtype)
+ m_matrix = nn_ops.bias_add(m_matrix, b)
+
+ # The NAS cell branches into 8 different splits for both the hiddenstate
+ # and the input
+ m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
+ value=m_matrix)
+ inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
+ value=inputs_matrix)
+
+ # First layer
+ layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
+ layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
+ layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
+ layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
+ layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
+ layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
+ layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
+ layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
+
+ # Second layer
+ l2_0 = tanh(layer1_0 * layer1_1)
+ l2_1 = tanh(layer1_2 + layer1_3)
+ l2_2 = tanh(layer1_4 * layer1_5)
+ l2_3 = sigmoid(layer1_6 + layer1_7)
+
+ # Inject the cell
+ l2_0 = tanh(l2_0 + c_prev)
+
+ # Third layer
+ l3_0_pre = l2_0 * l2_1
+ new_c = l3_0_pre # create new cell
+ l3_0 = l3_0_pre
+ l3_1 = tanh(l2_2 + l2_3)
+
+ # Final layer
+ new_m = tanh(l3_0 * l3_1)
+
+ # Projection layer if specified
+ if self._num_proj is not None:
+ concat_w_proj = vs.get_variable(
+ "W_P", [self._num_units, self._num_proj],
+ dtype)
+ new_m = math_ops.matmul(new_m, concat_w_proj)
+
+ new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
+ return new_m, new_state
+
+
_REGISTERED_OPS = None