diff options
author | 2016-09-09 09:30:10 -0800 | |
---|---|---|
committer | 2016-09-09 10:33:18 -0700 | |
commit | b0ce8deae4f8b0b24c8d8e18c4f62c3b1927f9d8 (patch) | |
tree | 6f9808b7d47a26783386ad9eda3d28a1ae096f05 | |
parent | bd548874a53c5f934e480991e9ca730b8d73657a (diff) |
Some cleanup of dynamic_rnn_estimator:
- Changed `logits` to `activations` to include regression as well as classification.
- Added name_scope to several functions.
Change: 132689120
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py | 321 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py | 127 |
2 files changed, 240 insertions, 208 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 05cd63dbf6..f6e9d71263 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -54,60 +54,69 @@ def _padding_mask(sequence_lengths, padded_length): array_ops.expand_dims(sequence_lengths, 1)) -def _mask_logits_and_targets(logits, targets, sequence_length): - """Remove entries outside `sequence_length` and returned flattened results. +def _mask_activations_and_targets(activations, targets, sequence_lengths): + """Remove entries outside `sequence_lengths` and returned flattened results. Args: - logits: output of the RNN, shape `[batch_size, padded_length, k]`. + activations: output of the RNN, shape `[batch_size, padded_length, k]`. targets: target values, shape `[batch_size, padded_length]`. - sequence_length: a `Tensor` of shape `[batch_size]` with the unpadded length - of each sequence. If `None`, then each sequence is unpadded. + sequence_lengths: a `Tensor` of shape `[batch_size]` with the unpadded + length of each sequence. If `None`, then each sequence is unpadded. Returns: - logits_masked: `logit` values with those beyond `sequence_length` removed - for each batch. Batches are then concatenated. Shape - `[tf.sum(sequence_length), k]` if `sequence_length` is not `None` and + activations_masked: `logit` values with those beyond `sequence_lengths` + removed for each batch. Batches are then concatenated. Shape + `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and shape `[batch_size * padded_length, k]` otherwise. targets_masked: target values after removing unneeded entries. Shape - `[tf.sum(sequence_length)]` if `sequence_length` is not `None` and shape + `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape `[batch_size * padded_length]` otherwise. """ - targets_shape = array_ops.shape(targets) - batch_size = targets_shape[0] - padded_length = targets_shape[1] - if sequence_length is None: - flattened_dimension = padded_length * batch_size - logits_masked = array_ops.reshape(logits, [flattened_dimension, -1]) - targets_masked = array_ops.reshape(targets, [flattened_dimension]) - else: - mask = _padding_mask(sequence_length, padded_length) - logits_masked = array_ops.boolean_mask(logits, mask) - targets_masked = array_ops.boolean_mask(targets, mask) - return logits_masked, targets_masked - - -def _select_last_logits(logits, sequence_lengths): - """Selects the nth set of logits for each n in `sequence_length`. + with ops.name_scope('mask_activations_and_targets', + values=[activations, targets, sequence_lengths]): + targets_shape = array_ops.shape(targets) + batch_size = targets_shape[0] + padded_length = targets_shape[1] + if sequence_lengths is None: + flattened_dimension = padded_length * batch_size + activations_masked = array_ops.reshape(activations, + [flattened_dimension, -1]) + targets_masked = array_ops.reshape(targets, [flattened_dimension]) + else: + mask = _padding_mask(sequence_lengths, padded_length) + activations_masked = array_ops.boolean_mask(activations, mask) + targets_masked = array_ops.boolean_mask(targets, mask) + return activations_masked, targets_masked + + +def _select_last_activations(activations, sequence_lengths): + """Selects the nth set of activations for each n in `sequence_length`. Reuturns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not - `None`, then `output[i, :] = logits[i, sequence_length[i], :]`. If - `sequence_length` is `None`, then `output[i, :] = logits[i, -1, :]`. + `None`, then `output[i, :] = activations[i, sequence_length[i], :]`. If + `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`. Args: - logits: a `Tensor` with shape `[batch_size, padded_length, k]`. + activations: a `Tensor` with shape `[batch_size, padded_length, k]`. sequence_lengths: a `Tensor` with shape `[batch_size]` or `None`. Returns: A `Tensor` of shape `[batch_size, k]`. """ - logits_shape = array_ops.shape(logits) - batch_size = logits_shape[0] - padded_length = logits_shape[1] - num_classes = logits_shape[2] - if sequence_lengths is None: - sequence_lengths = padded_length - reshaped_logits = array_ops.reshape(logits, [-1, num_classes]) - indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1 - return array_ops.gather(reshaped_logits, indices) + with ops.name_scope('select_last_activations', + values=[activations, sequence_lengths]): + activations_shape = array_ops.shape(activations) + batch_size = activations_shape[0] + padded_length = activations_shape[1] + num_label_columns = activations_shape[2] + if sequence_lengths is None: + sequence_lengths = padded_length + reshaped_activations = array_ops.reshape(activations, + [-1, num_label_columns]) + indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1 + last_activations = array_ops.gather(reshaped_activations, indices) + last_activations.set_shape( + [activations.get_shape()[0], activations.get_shape()[2]]) + return last_activations @six.add_metaclass(abc.ABCMeta) @@ -126,7 +135,8 @@ class _DynamicRNNEstimator(estimator.BaseEstimator): initial_state_key='initial_state', dtype=None, parallel_iterations=None, - swap_memory=False): + swap_memory=False, + name=None): """Initialize `DynamicRNNEstimator`. Args: @@ -152,6 +162,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator): swap_memory: Parameter passed ot `dynamic_rnn`. Transparently swap the tensors produced in forward inference but needed for back prop from GPU to CPU. + name: Optional name for the `Estimator`. """ super(_DynamicRNNEstimator, self).__init__( model_dir=model_dir, config=config) @@ -165,6 +176,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator): self._dtype = dtype or dtypes.float32 self._parallel_iterations = parallel_iterations self._swap_memory = swap_memory + self._name = name or 'DynamicRnnEstimator' def _construct_rnn(self, features): """Apply an RNN to `features`. @@ -173,7 +185,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator): input should be a `Tensor` of shape `[batch_size, padded_length, k]` where `k` is the dimension of the input for each element of a sequence. - `logits` has shape `[batch_size, sequence_length, n]` where `n` is + `activations` has shape `[batch_size, sequence_length, n]` where `n` is `self._target_column.num_label_columns`. In the case of a multiclass classifier, `n` is the number of classes. @@ -185,40 +197,41 @@ class _DynamicRNNEstimator(estimator.BaseEstimator): initial state and information about sequence lengths. Returns: - logits: the output of the RNN, projected to the appropriate number of + activations: the output of the RNN, projected to the appropriate number of dimensions. final_state: the final state output by the RNN. Raises: KeyError: if `features` does not contain `self._inputs_key`. """ - inputs = features.get(self._inputs_key) - if inputs is None: - raise KeyError('features must contain the key {}'.format( - self._inputs_key)) - if inputs.dtype != self._dtype: - inputs = math_ops.cast(inputs, self._dtype) - initial_state = features.get(self._initial_state_key) - rnn_outputs, final_state = rnn.dynamic_rnn( - cell=self._cell, - inputs=inputs, - initial_state=initial_state, - dtype=self._dtype, - parallel_iterations=self._parallel_iterations, - swap_memory=self._swap_memory, - time_major=False) - logits = layers.fully_connected( - inputs=rnn_outputs, - num_outputs=self._target_column.num_label_columns, - activation_fn=None, - trainable=False) - return logits, final_state + with ops.name_scope('RNN'): + inputs = features.get(self._inputs_key) + if inputs is None: + raise KeyError('features must contain the key {}'.format( + self._inputs_key)) + if inputs.dtype != self._dtype: + inputs = math_ops.cast(inputs, self._dtype) + initial_state = features.get(self._initial_state_key) + rnn_outputs, final_state = rnn.dynamic_rnn( + cell=self._cell, + inputs=inputs, + initial_state=initial_state, + dtype=self._dtype, + parallel_iterations=self._parallel_iterations, + swap_memory=self._swap_memory, + time_major=False) + activations = layers.fully_connected( + inputs=rnn_outputs, + num_outputs=self._target_column.num_label_columns, + activation_fn=None, + trainable=False) + return activations, final_state @abc.abstractmethod - def _logits_to_loss(self, features, logits, targets): - """Map `logits` and `targets` to a loss `Tensor`. + def _activations_to_loss(self, features, activations, targets): + """Map `activations` and `targets` to a loss `Tensor`. - `logits` has shape `[batch_size, padded_length, + `activations` has shape `[batch_size, padded_length, self._target_column.num_label_columns]`. It is the output of `_construct_rnn`. @@ -229,7 +242,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator): features: a `dict` containing the input and (optionally) sequence length information and initial state. This is the same `features` passed to `_construct_rnn`. - logits: a `Tensor` of logits representing the output of the RNN. + activations: a `Tensor` of activations representing the output of the RNN. targets: a `Tensor` of target values. Returns: @@ -238,17 +251,17 @@ class _DynamicRNNEstimator(estimator.BaseEstimator): raise NotImplementedError() @abc.abstractmethod - def _logits_to_predictions(self, features, logits): - """Map `logits` to predictions. + def _activations_to_predictions(self, features, activations): + """Map `activations` to predictions. - `logits` has shape [batch_size, time, num_labels]. `TargetColumn`s require - shape [n, num_labels]. `logits` is flattened before being converted to - labels. Afterwards, its shape is reconstituted. + `activations` has shape [batch_size, time, num_labels]. `TargetColumn`s + require shape [n, num_labels]. `activations` is flattened before being + converted to labels. Afterwards, its shape is reconstituted. Args: features: a `dict` containing the input and (optionally) sequence length information and initial state. - logits: logit values returned by `_construct_rnn`. + activations: logit values returned by `_construct_rnn`. Returns: A set of predictions. The type of prediction is dependent on @@ -258,35 +271,37 @@ class _DynamicRNNEstimator(estimator.BaseEstimator): def _process_gradients(self, gradients_vars): """Process gradients (e.g. clipping) before applying them to weights.""" - gradients, variables = zip(*gradients_vars) - if self._gradient_clipping_norm is not None: - gradients, _ = clip_ops.clip_by_global_norm( - gradients, self._gradient_clipping_norm) - return zip(gradients, variables) + with ops.name_scope('process_gradients'): + gradients, variables = zip(*gradients_vars) + if self._gradient_clipping_norm is not None: + gradients, _ = clip_ops.clip_by_global_norm( + gradients, self._gradient_clipping_norm) + return zip(gradients, variables) def _loss_to_train_op(self, loss): """Map `loss` to a training op.""" - trainable_variables = ops.get_default_graph().get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES) - global_step = contrib_framework.get_global_step() - gradients = self._optimizer.compute_gradients( - loss=loss, var_list=trainable_variables) - processed_gradients = self._process_gradients(gradients) - return self._optimizer.apply_gradients( - processed_gradients, global_step=global_step) + with ops.name_scope('loss_to_train_op'): + trainable_variables = ops.get_default_graph().get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + global_step = contrib_framework.get_global_step() + gradients = self._optimizer.compute_gradients( + loss=loss, var_list=trainable_variables) + processed_gradients = self._process_gradients(gradients) + return self._optimizer.apply_gradients( + processed_gradients, global_step=global_step) @abc.abstractmethod - def _logits_to_eval_ops(self, features, logits, targets, metrics): - """Map `logits` to eval operations. + def _activations_to_eval_ops(self, features, activations, targets, metrics): + """Map `activations` to eval operations. - `logits` has shape [batch_size, time, num_labels]. `TargetColumn`s require - shape [n, num_labels]. `logits` is flattened before being converted to - labels. Afterwards, its shape is reconstituted. + `activations` has shape [batch_size, time, num_labels]. `TargetColumn`s + require shape [n, num_labels]. `activations` is flattened before being + converted to labels. Afterwards, its shape is reconstituted. Args: features: a `dict` containing the input and (optionally) sequence length information and initial state. - logits: logit values returned by `_construct_rnn`. + activations: logit values returned by `_construct_rnn`. targets: a `Tensor` of target values. metrics: a list of `Metric`s to evaluate. Possibly `None`. @@ -296,78 +311,90 @@ class _DynamicRNNEstimator(estimator.BaseEstimator): raise NotImplementedError() def _get_train_ops(self, features, targets): - if isinstance(features, ops.Tensor): - features = {self._inputs_key: features} - logits, _ = self._construct_rnn(features) - loss = self._logits_to_loss(features, logits, targets) - train_op = self._loss_to_train_op(loss) - return train_op, loss + with ops.name_scope(self._name): + if isinstance(features, ops.Tensor): + features = {self._inputs_key: features} + activations, _ = self._construct_rnn(features) + loss = self._activations_to_loss(features, activations, targets) + train_op = self._loss_to_train_op(loss) + return train_op, loss def _get_eval_ops(self, features, targets, metrics): - if isinstance(features, ops.Tensor): - features = {self._inputs_key: features} - logits, _ = self._construct_rnn(features) - return self._logits_to_eval_ops(features, logits, targets, metrics) + with ops.name_scope(self._name): + if isinstance(features, ops.Tensor): + features = {self._inputs_key: features} + activations, _ = self._construct_rnn(features) + return self._activations_to_eval_ops(features, activations, targets, + metrics) def _get_predict_ops(self, features): - if isinstance(features, ops.Tensor): - features = {self._inputs_key: features} - logits, state = self._construct_rnn(features) - predictions = self._logits_to_predictions(features, logits) - return {'predictions': predictions, 'state': state} + with ops.name_scope(self._name): + if isinstance(features, ops.Tensor): + features = {self._inputs_key: features} + activations, state = self._construct_rnn(features) + predictions = self._activations_to_predictions(features, activations) + return {'predictions': predictions, 'state': state} class _MultiValueRNNEstimator(_DynamicRNNEstimator): """An `Estimator` that maps sequences of inputs to sequences of outputs.""" - def _logits_to_loss(self, features, logits, targets): + def _activations_to_loss(self, features, activations, targets): sequence_length = features.get(self._sequence_length_key) - # Mask the logits and targets past `sequence_length`. Note that the - # `Tensor`s returned by `_mask_logits_and_targets` are flattened. - logits_masked, targets_masked = _mask_logits_and_targets(logits, targets, - sequence_length) - return self._target_column.loss(logits_masked, targets_masked, features) - - def _logits_to_predictions(self, unused_features, logits): - logit_shape = array_ops.shape(logits) - flattened_logits = array_ops.reshape(logits, [-1, logit_shape[2]]) - predictions = self._target_column.logits_to_predictions( - flattened_logits, proba=False) - reshaped_predictions = array_ops.reshape( - predictions, [logit_shape[0], logit_shape[1], -1]) - return array_ops.squeeze(reshaped_predictions, [2]) - - def _logits_to_eval_ops(self, features, logits, targets, metrics): - logits_masked, targets_masked = _mask_logits_and_targets( - logits, targets, features.get(self._sequence_length_key)) - - return self._target_column.get_eval_ops(features=features, - logits=logits_masked, - targets=targets_masked, - metrics=metrics) + # Mask the activations and targets past `sequence_length`. Note that the + # `Tensor`s returned by `_mask_activations_and_targets` are flattened. + with ops.name_scope('activations_to_loss'): + activations_masked, targets_masked = _mask_activations_and_targets( + activations, targets, sequence_length) + return self._target_column.loss(activations_masked, targets_masked, + features) + + def _activations_to_predictions(self, unused_features, activations): + with ops.name_scope('activations_to_predictions'): + activations_shape = array_ops.shape(activations) + flattened_activations = array_ops.reshape(activations, + [-1, activations_shape[2]]) + predictions = self._target_column.activations_to_predictions( + flattened_activations, proba=False) + reshaped_predictions = array_ops.reshape( + predictions, [activations_shape[0], activations_shape[1], -1]) + return array_ops.squeeze(reshaped_predictions, [2]) + + def _activations_to_eval_ops(self, features, activations, targets, metrics): + with ops.name_scope('activations_to_eval_ops'): + activations_masked, targets_masked = _mask_activations_and_targets( + activations, targets, features.get(self._sequence_length_key)) + + return self._target_column.get_eval_ops(features=features, + logits=activations_masked, + targets=targets_masked, + metrics=metrics) class _SingleValueRNNEstimator(_DynamicRNNEstimator): """An `Estimator` that maps sequences of inputs to single outputs.""" - def _logits_to_loss(self, features, logits, targets): - sequence_lengths = features.get(self._sequence_length_key) - last_logits = _select_last_logits(logits, sequence_lengths) - return self._target_column.loss(last_logits, targets, features) - - def _logits_to_predictions(self, features, logits): - sequence_lengths = features.get(self._sequence_length_key) - last_logits = _select_last_logits(logits, sequence_lengths) - return self._target_column.logits_to_predictions( - last_logits, proba=False) - - def _logits_to_eval_ops(self, features, logits, targets, metrics): - sequence_lengths = features.get(self._sequence_length_key) - last_logits = _select_last_logits(logits, sequence_lengths) - return self._target_column.get_eval_ops(features=features, - logits=last_logits, - targets=targets, - metrics=metrics) + def _activations_to_loss(self, features, activations, targets): + with ops.name_scope('activations_to_loss'): + sequence_lengths = features.get(self._sequence_length_key) + last_activations = _select_last_activations(activations, sequence_lengths) + return self._target_column.loss(last_activations, targets, features) + + def _activations_to_predictions(self, features, activations): + with ops.name_scope('activations_to_predictions'): + sequence_lengths = features.get(self._sequence_length_key) + last_activations = _select_last_activations(activations, sequence_lengths) + return self._target_column.activations_to_predictions( + last_activations, proba=False) + + def _activations_to_eval_ops(self, features, activations, targets, metrics): + with ops.name_scope('activations_to_eval_ops'): + sequence_lengths = features.get(self._sequence_length_key) + last_activations = _select_last_activations(activations, sequence_lengths) + return self._target_column.get_eval_ops(features=features, + logits=last_activations, + targets=targets, + metrics=metrics) def _get_optimizer(optimizer_type, learning_rate, momentum): diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index 37e829cf12..1ee3a8dd60 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -47,15 +47,15 @@ class MockTargetColumn(object): def __init__(self): self._num_label_columns = None - def get_eval_ops(self, features, logits, targets, metrics): + def get_eval_ops(self, features, activations, targets, metrics): raise NotImplementedError( 'MockTargetColumn.get_eval_ops called unexpectedly.') - def logits_to_predictions(self, flattened_logits, proba=False): + def activations_to_predictions(self, flattened_activations, proba=False): raise NotImplementedError( - 'MockTargetColumn.logits_to_predictions called unexpectedly.') + 'MockTargetColumn.activations_to_predictions called unexpectedly.') - def loss(self, logits, targets, features): + def loss(self, activations, targets, features): raise NotImplementedError('MockTargetColumn.loss called unexpectedly.') @property @@ -120,26 +120,26 @@ class DynamicRnnEstimatorTest(tf.test.TestCase): 'sequence_length': tf.constant( sequence_length, dtype=tf.int32)} - # Map feature to logits with mocked linear layer. + # Map feature to activations with mocked linear layer. with tf.test.mock.patch.object(dynamic_rnn_estimator, 'layers') as mock_layers: mock_layers.fully_connected.return_value = tf.constant( mock_linear_layer_output, dtype=tf.float32) - logits_t, final_state_t = self._rnn_estimator._construct_rnn( + activations_t, final_state_t = self._rnn_estimator._construct_rnn( features) _, fully_connected_kwargs = mock_layers.fully_connected.call_args linear_layer_inputs_t = fully_connected_kwargs['inputs'] linear_layer_output_dim = fully_connected_kwargs['num_outputs'] - # Obtain values of linear layer input, logits and final state. + # Obtain values of linear layer input, activations and final state. with tf.Session() as sess: sess.run(tf.initialize_all_variables()) - linear_layer_inputs, logits, final_state = sess.run( - [linear_layer_inputs_t, logits_t, final_state_t]) + linear_layer_inputs, activations, final_state = sess.run( + [linear_layer_inputs_t, activations_t, final_state_t]) np.testing.assert_equal(num_classes, linear_layer_output_dim) np.testing.assert_almost_equal(inputs, linear_layer_inputs) - np.testing.assert_almost_equal(mock_linear_layer_output, logits) + np.testing.assert_almost_equal(mock_linear_layer_output, activations) np.testing.assert_almost_equal( np.zeros([batch_size, self._rnn_cell.state_size], dtype=float), final_state) @@ -183,32 +183,33 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase): 'Mismatch on row {}. Got {}; expected {}.'.format( i, actual_mask, expected_mask)) - def testMaskLogitsAndTargets(self): - """Test `_mask_logits_and_targets`.""" + def testMaskActivationsAndTargets(self): + """Test `_mask_activations_and_targets`.""" batch_size = 4 padded_length = 6 num_classes = 4 np.random.seed(1234) sequence_length = np.random.randint(0, padded_length + 1, batch_size) - logits = np.random.rand(batch_size, padded_length, num_classes) + activations = np.random.rand(batch_size, padded_length, num_classes) targets = np.random.randint(0, num_classes, [batch_size, padded_length]) - (logits_masked_t, - targets_masked_t) = dynamic_rnn_estimator._mask_logits_and_targets( + (activations_masked_t, + targets_masked_t) = dynamic_rnn_estimator._mask_activations_and_targets( tf.constant( - logits, dtype=tf.float32), + activations, dtype=tf.float32), tf.constant( targets, dtype=tf.int32), tf.constant( sequence_length, dtype=tf.int32)) with tf.Session() as sess: - logits_masked, targets_masked = sess.run( - [logits_masked_t, targets_masked_t]) + activations_masked, targets_masked = sess.run( + [activations_masked_t, targets_masked_t]) - expected_logits_shape = [sum(sequence_length), num_classes] - np.testing.assert_equal(expected_logits_shape, logits_masked.shape, - 'Wrong logits shape. Expected {}; got {}.'.format( - expected_logits_shape, logits_masked.shape)) + expected_activations_shape = [sum(sequence_length), num_classes] + np.testing.assert_equal( + expected_activations_shape, activations_masked.shape, + 'Wrong activations shape. Expected {}; got {}.'.format( + expected_activations_shape, activations_masked.shape)) expected_targets_shape = [sum(sequence_length)] np.testing.assert_equal(expected_targets_shape, targets_masked.shape, @@ -217,14 +218,14 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase): masked_index = 0 for i in range(batch_size): for j in range(sequence_length[i]): - actual_logits = logits_masked[masked_index] - expected_logits = logits[i, j, :] + actual_activations = activations_masked[masked_index] + expected_activations = activations[i, j, :] np.testing.assert_almost_equal( - expected_logits, - actual_logits, + expected_activations, + actual_activations, err_msg='Unexpected logit value at index [{}, {}, :].' - ' Expected {}; got {}.'.format(i, j, expected_logits, - actual_logits)) + ' Expected {}; got {}.'.format(i, j, expected_activations, + actual_activations)) actual_targets = targets_masked[masked_index] expected_targets = targets[i, j] @@ -236,31 +237,34 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase): actual_targets)) masked_index += 1 - def testLogitsToPredictions(self): - """Test `DynamicRNNEstimator._logits_to_predictions`.""" + def testActivationsToPredictions(self): + """Test `DynamicRNNEstimator._activations_to_predictions`.""" batch_size = 8 sequence_length = 16 num_classes = 3 np.random.seed(10101) - logits = np.random.rand(batch_size, sequence_length, num_classes) - flattened_logits = np.reshape(logits, [-1, num_classes]) - flattened_argmax = np.argmax(flattened_logits, axis=1) - expected_predictions = np.argmax(logits, axis=2) - - with tf.test.mock.patch.object(self._mock_target_column, - 'logits_to_predictions', - return_value=flattened_argmax, - autospec=True) as mock_logits_to_predictions: - predictions_t = self._seq_estimator._logits_to_predictions( - None, tf.constant(logits, dtype=tf.float32)) - (target_column_input_logits_t,), _ = mock_logits_to_predictions.call_args + activations = np.random.rand(batch_size, sequence_length, num_classes) + flattened_activations = np.reshape(activations, [-1, num_classes]) + flattened_argmax = np.argmax(flattened_activations, axis=1) + expected_predictions = np.argmax(activations, axis=2) + + with tf.test.mock.patch.object( + self._mock_target_column, + 'activations_to_predictions', + return_value=flattened_argmax, + autospec=True) as mock_activations_to_predictions: + predictions_t = self._seq_estimator._activations_to_predictions( + None, tf.constant(activations, dtype=tf.float32)) + (target_column_input_activations_t, + ), _ = mock_activations_to_predictions.call_args with tf.Session() as sess: - target_column_input_logits, predictions = sess.run( - [target_column_input_logits_t, predictions_t]) + target_column_input_activations, predictions = sess.run( + [target_column_input_activations_t, predictions_t]) - np.testing.assert_almost_equal(flattened_logits, target_column_input_logits) + np.testing.assert_almost_equal(flattened_activations, + target_column_input_activations) np.testing.assert_equal(expected_predictions, predictions) def testLearnSineFunction(self): @@ -353,35 +357,36 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase): class SingleValueRNNEstimatorTest(tf.test.TestCase): - def testSelectLastLogits(self): - """Test `_select_last_logits`.""" + def testSelectLastactivations(self): + """Test `_select_last_activations`.""" batch_size = 4 padded_length = 6 num_classes = 4 np.random.seed(4444) sequence_length = np.random.randint(0, padded_length + 1, batch_size) - logits = np.random.rand(batch_size, padded_length, num_classes) - last_logits_t = dynamic_rnn_estimator._select_last_logits( - tf.constant(logits, dtype=tf.float32), + activations = np.random.rand(batch_size, padded_length, num_classes) + last_activations_t = dynamic_rnn_estimator._select_last_activations( + tf.constant(activations, dtype=tf.float32), tf.constant(sequence_length, dtype=tf.int32)) with tf.Session() as sess: - last_logits = sess.run(last_logits_t) + last_activations = sess.run(last_activations_t) - expected_logits_shape = [batch_size, num_classes] - np.testing.assert_equal(expected_logits_shape, last_logits.shape, - 'Wrong logits shape. Expected {}; got {}.'.format( - expected_logits_shape, last_logits.shape)) + expected_activations_shape = [batch_size, num_classes] + np.testing.assert_equal( + expected_activations_shape, last_activations.shape, + 'Wrong activations shape. Expected {}; got {}.'.format( + expected_activations_shape, last_activations.shape)) for i in range(batch_size): - actual_logits = last_logits[i, :] - expected_logits = logits[i, sequence_length[i] - 1, :] + actual_activations = last_activations[i, :] + expected_activations = activations[i, sequence_length[i] - 1, :] np.testing.assert_almost_equal( - expected_logits, - actual_logits, + expected_activations, + actual_activations, err_msg='Unexpected logit value at index [{}, :].' - ' Expected {}; got {}.'.format(i, expected_logits, - actual_logits)) + ' Expected {}; got {}.'.format(i, expected_activations, + actual_activations)) def testLearnMean(self): """Test that `_SequenceRegressor` can learn to calculate a mean.""" |