aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-09 09:30:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-09 10:33:18 -0700
commitb0ce8deae4f8b0b24c8d8e18c4f62c3b1927f9d8 (patch)
tree6f9808b7d47a26783386ad9eda3d28a1ae096f05
parentbd548874a53c5f934e480991e9ca730b8d73657a (diff)
Some cleanup of dynamic_rnn_estimator:
- Changed `logits` to `activations` to include regression as well as classification. - Added name_scope to several functions. Change: 132689120
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py321
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py127
2 files changed, 240 insertions, 208 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index 05cd63dbf6..f6e9d71263 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -54,60 +54,69 @@ def _padding_mask(sequence_lengths, padded_length):
array_ops.expand_dims(sequence_lengths, 1))
-def _mask_logits_and_targets(logits, targets, sequence_length):
- """Remove entries outside `sequence_length` and returned flattened results.
+def _mask_activations_and_targets(activations, targets, sequence_lengths):
+ """Remove entries outside `sequence_lengths` and returned flattened results.
Args:
- logits: output of the RNN, shape `[batch_size, padded_length, k]`.
+ activations: output of the RNN, shape `[batch_size, padded_length, k]`.
targets: target values, shape `[batch_size, padded_length]`.
- sequence_length: a `Tensor` of shape `[batch_size]` with the unpadded length
- of each sequence. If `None`, then each sequence is unpadded.
+ sequence_lengths: a `Tensor` of shape `[batch_size]` with the unpadded
+ length of each sequence. If `None`, then each sequence is unpadded.
Returns:
- logits_masked: `logit` values with those beyond `sequence_length` removed
- for each batch. Batches are then concatenated. Shape
- `[tf.sum(sequence_length), k]` if `sequence_length` is not `None` and
+ activations_masked: `logit` values with those beyond `sequence_lengths`
+ removed for each batch. Batches are then concatenated. Shape
+ `[tf.sum(sequence_lengths), k]` if `sequence_lengths` is not `None` and
shape `[batch_size * padded_length, k]` otherwise.
targets_masked: target values after removing unneeded entries. Shape
- `[tf.sum(sequence_length)]` if `sequence_length` is not `None` and shape
+ `[tf.sum(sequence_lengths)]` if `sequence_lengths` is not `None` and shape
`[batch_size * padded_length]` otherwise.
"""
- targets_shape = array_ops.shape(targets)
- batch_size = targets_shape[0]
- padded_length = targets_shape[1]
- if sequence_length is None:
- flattened_dimension = padded_length * batch_size
- logits_masked = array_ops.reshape(logits, [flattened_dimension, -1])
- targets_masked = array_ops.reshape(targets, [flattened_dimension])
- else:
- mask = _padding_mask(sequence_length, padded_length)
- logits_masked = array_ops.boolean_mask(logits, mask)
- targets_masked = array_ops.boolean_mask(targets, mask)
- return logits_masked, targets_masked
-
-
-def _select_last_logits(logits, sequence_lengths):
- """Selects the nth set of logits for each n in `sequence_length`.
+ with ops.name_scope('mask_activations_and_targets',
+ values=[activations, targets, sequence_lengths]):
+ targets_shape = array_ops.shape(targets)
+ batch_size = targets_shape[0]
+ padded_length = targets_shape[1]
+ if sequence_lengths is None:
+ flattened_dimension = padded_length * batch_size
+ activations_masked = array_ops.reshape(activations,
+ [flattened_dimension, -1])
+ targets_masked = array_ops.reshape(targets, [flattened_dimension])
+ else:
+ mask = _padding_mask(sequence_lengths, padded_length)
+ activations_masked = array_ops.boolean_mask(activations, mask)
+ targets_masked = array_ops.boolean_mask(targets, mask)
+ return activations_masked, targets_masked
+
+
+def _select_last_activations(activations, sequence_lengths):
+ """Selects the nth set of activations for each n in `sequence_length`.
Reuturns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not
- `None`, then `output[i, :] = logits[i, sequence_length[i], :]`. If
- `sequence_length` is `None`, then `output[i, :] = logits[i, -1, :]`.
+ `None`, then `output[i, :] = activations[i, sequence_length[i], :]`. If
+ `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`.
Args:
- logits: a `Tensor` with shape `[batch_size, padded_length, k]`.
+ activations: a `Tensor` with shape `[batch_size, padded_length, k]`.
sequence_lengths: a `Tensor` with shape `[batch_size]` or `None`.
Returns:
A `Tensor` of shape `[batch_size, k]`.
"""
- logits_shape = array_ops.shape(logits)
- batch_size = logits_shape[0]
- padded_length = logits_shape[1]
- num_classes = logits_shape[2]
- if sequence_lengths is None:
- sequence_lengths = padded_length
- reshaped_logits = array_ops.reshape(logits, [-1, num_classes])
- indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1
- return array_ops.gather(reshaped_logits, indices)
+ with ops.name_scope('select_last_activations',
+ values=[activations, sequence_lengths]):
+ activations_shape = array_ops.shape(activations)
+ batch_size = activations_shape[0]
+ padded_length = activations_shape[1]
+ num_label_columns = activations_shape[2]
+ if sequence_lengths is None:
+ sequence_lengths = padded_length
+ reshaped_activations = array_ops.reshape(activations,
+ [-1, num_label_columns])
+ indices = math_ops.range(batch_size) * padded_length + sequence_lengths - 1
+ last_activations = array_ops.gather(reshaped_activations, indices)
+ last_activations.set_shape(
+ [activations.get_shape()[0], activations.get_shape()[2]])
+ return last_activations
@six.add_metaclass(abc.ABCMeta)
@@ -126,7 +135,8 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
initial_state_key='initial_state',
dtype=None,
parallel_iterations=None,
- swap_memory=False):
+ swap_memory=False,
+ name=None):
"""Initialize `DynamicRNNEstimator`.
Args:
@@ -152,6 +162,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
swap_memory: Parameter passed ot `dynamic_rnn`. Transparently swap the
tensors produced in forward inference but needed for back prop from GPU
to CPU.
+ name: Optional name for the `Estimator`.
"""
super(_DynamicRNNEstimator, self).__init__(
model_dir=model_dir, config=config)
@@ -165,6 +176,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
self._dtype = dtype or dtypes.float32
self._parallel_iterations = parallel_iterations
self._swap_memory = swap_memory
+ self._name = name or 'DynamicRnnEstimator'
def _construct_rnn(self, features):
"""Apply an RNN to `features`.
@@ -173,7 +185,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
input should be a `Tensor` of shape `[batch_size, padded_length, k]`
where `k` is the dimension of the input for each element of a sequence.
- `logits` has shape `[batch_size, sequence_length, n]` where `n` is
+ `activations` has shape `[batch_size, sequence_length, n]` where `n` is
`self._target_column.num_label_columns`. In the case of a multiclass
classifier, `n` is the number of classes.
@@ -185,40 +197,41 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
initial state and information about sequence lengths.
Returns:
- logits: the output of the RNN, projected to the appropriate number of
+ activations: the output of the RNN, projected to the appropriate number of
dimensions.
final_state: the final state output by the RNN.
Raises:
KeyError: if `features` does not contain `self._inputs_key`.
"""
- inputs = features.get(self._inputs_key)
- if inputs is None:
- raise KeyError('features must contain the key {}'.format(
- self._inputs_key))
- if inputs.dtype != self._dtype:
- inputs = math_ops.cast(inputs, self._dtype)
- initial_state = features.get(self._initial_state_key)
- rnn_outputs, final_state = rnn.dynamic_rnn(
- cell=self._cell,
- inputs=inputs,
- initial_state=initial_state,
- dtype=self._dtype,
- parallel_iterations=self._parallel_iterations,
- swap_memory=self._swap_memory,
- time_major=False)
- logits = layers.fully_connected(
- inputs=rnn_outputs,
- num_outputs=self._target_column.num_label_columns,
- activation_fn=None,
- trainable=False)
- return logits, final_state
+ with ops.name_scope('RNN'):
+ inputs = features.get(self._inputs_key)
+ if inputs is None:
+ raise KeyError('features must contain the key {}'.format(
+ self._inputs_key))
+ if inputs.dtype != self._dtype:
+ inputs = math_ops.cast(inputs, self._dtype)
+ initial_state = features.get(self._initial_state_key)
+ rnn_outputs, final_state = rnn.dynamic_rnn(
+ cell=self._cell,
+ inputs=inputs,
+ initial_state=initial_state,
+ dtype=self._dtype,
+ parallel_iterations=self._parallel_iterations,
+ swap_memory=self._swap_memory,
+ time_major=False)
+ activations = layers.fully_connected(
+ inputs=rnn_outputs,
+ num_outputs=self._target_column.num_label_columns,
+ activation_fn=None,
+ trainable=False)
+ return activations, final_state
@abc.abstractmethod
- def _logits_to_loss(self, features, logits, targets):
- """Map `logits` and `targets` to a loss `Tensor`.
+ def _activations_to_loss(self, features, activations, targets):
+ """Map `activations` and `targets` to a loss `Tensor`.
- `logits` has shape `[batch_size, padded_length,
+ `activations` has shape `[batch_size, padded_length,
self._target_column.num_label_columns]`. It is the output of
`_construct_rnn`.
@@ -229,7 +242,7 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
features: a `dict` containing the input and (optionally) sequence length
information and initial state. This is the same `features` passed to
`_construct_rnn`.
- logits: a `Tensor` of logits representing the output of the RNN.
+ activations: a `Tensor` of activations representing the output of the RNN.
targets: a `Tensor` of target values.
Returns:
@@ -238,17 +251,17 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
raise NotImplementedError()
@abc.abstractmethod
- def _logits_to_predictions(self, features, logits):
- """Map `logits` to predictions.
+ def _activations_to_predictions(self, features, activations):
+ """Map `activations` to predictions.
- `logits` has shape [batch_size, time, num_labels]. `TargetColumn`s require
- shape [n, num_labels]. `logits` is flattened before being converted to
- labels. Afterwards, its shape is reconstituted.
+ `activations` has shape [batch_size, time, num_labels]. `TargetColumn`s
+ require shape [n, num_labels]. `activations` is flattened before being
+ converted to labels. Afterwards, its shape is reconstituted.
Args:
features: a `dict` containing the input and (optionally) sequence length
information and initial state.
- logits: logit values returned by `_construct_rnn`.
+ activations: logit values returned by `_construct_rnn`.
Returns:
A set of predictions. The type of prediction is dependent on
@@ -258,35 +271,37 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
def _process_gradients(self, gradients_vars):
"""Process gradients (e.g. clipping) before applying them to weights."""
- gradients, variables = zip(*gradients_vars)
- if self._gradient_clipping_norm is not None:
- gradients, _ = clip_ops.clip_by_global_norm(
- gradients, self._gradient_clipping_norm)
- return zip(gradients, variables)
+ with ops.name_scope('process_gradients'):
+ gradients, variables = zip(*gradients_vars)
+ if self._gradient_clipping_norm is not None:
+ gradients, _ = clip_ops.clip_by_global_norm(
+ gradients, self._gradient_clipping_norm)
+ return zip(gradients, variables)
def _loss_to_train_op(self, loss):
"""Map `loss` to a training op."""
- trainable_variables = ops.get_default_graph().get_collection(
- ops.GraphKeys.TRAINABLE_VARIABLES)
- global_step = contrib_framework.get_global_step()
- gradients = self._optimizer.compute_gradients(
- loss=loss, var_list=trainable_variables)
- processed_gradients = self._process_gradients(gradients)
- return self._optimizer.apply_gradients(
- processed_gradients, global_step=global_step)
+ with ops.name_scope('loss_to_train_op'):
+ trainable_variables = ops.get_default_graph().get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES)
+ global_step = contrib_framework.get_global_step()
+ gradients = self._optimizer.compute_gradients(
+ loss=loss, var_list=trainable_variables)
+ processed_gradients = self._process_gradients(gradients)
+ return self._optimizer.apply_gradients(
+ processed_gradients, global_step=global_step)
@abc.abstractmethod
- def _logits_to_eval_ops(self, features, logits, targets, metrics):
- """Map `logits` to eval operations.
+ def _activations_to_eval_ops(self, features, activations, targets, metrics):
+ """Map `activations` to eval operations.
- `logits` has shape [batch_size, time, num_labels]. `TargetColumn`s require
- shape [n, num_labels]. `logits` is flattened before being converted to
- labels. Afterwards, its shape is reconstituted.
+ `activations` has shape [batch_size, time, num_labels]. `TargetColumn`s
+ require shape [n, num_labels]. `activations` is flattened before being
+ converted to labels. Afterwards, its shape is reconstituted.
Args:
features: a `dict` containing the input and (optionally) sequence length
information and initial state.
- logits: logit values returned by `_construct_rnn`.
+ activations: logit values returned by `_construct_rnn`.
targets: a `Tensor` of target values.
metrics: a list of `Metric`s to evaluate. Possibly `None`.
@@ -296,78 +311,90 @@ class _DynamicRNNEstimator(estimator.BaseEstimator):
raise NotImplementedError()
def _get_train_ops(self, features, targets):
- if isinstance(features, ops.Tensor):
- features = {self._inputs_key: features}
- logits, _ = self._construct_rnn(features)
- loss = self._logits_to_loss(features, logits, targets)
- train_op = self._loss_to_train_op(loss)
- return train_op, loss
+ with ops.name_scope(self._name):
+ if isinstance(features, ops.Tensor):
+ features = {self._inputs_key: features}
+ activations, _ = self._construct_rnn(features)
+ loss = self._activations_to_loss(features, activations, targets)
+ train_op = self._loss_to_train_op(loss)
+ return train_op, loss
def _get_eval_ops(self, features, targets, metrics):
- if isinstance(features, ops.Tensor):
- features = {self._inputs_key: features}
- logits, _ = self._construct_rnn(features)
- return self._logits_to_eval_ops(features, logits, targets, metrics)
+ with ops.name_scope(self._name):
+ if isinstance(features, ops.Tensor):
+ features = {self._inputs_key: features}
+ activations, _ = self._construct_rnn(features)
+ return self._activations_to_eval_ops(features, activations, targets,
+ metrics)
def _get_predict_ops(self, features):
- if isinstance(features, ops.Tensor):
- features = {self._inputs_key: features}
- logits, state = self._construct_rnn(features)
- predictions = self._logits_to_predictions(features, logits)
- return {'predictions': predictions, 'state': state}
+ with ops.name_scope(self._name):
+ if isinstance(features, ops.Tensor):
+ features = {self._inputs_key: features}
+ activations, state = self._construct_rnn(features)
+ predictions = self._activations_to_predictions(features, activations)
+ return {'predictions': predictions, 'state': state}
class _MultiValueRNNEstimator(_DynamicRNNEstimator):
"""An `Estimator` that maps sequences of inputs to sequences of outputs."""
- def _logits_to_loss(self, features, logits, targets):
+ def _activations_to_loss(self, features, activations, targets):
sequence_length = features.get(self._sequence_length_key)
- # Mask the logits and targets past `sequence_length`. Note that the
- # `Tensor`s returned by `_mask_logits_and_targets` are flattened.
- logits_masked, targets_masked = _mask_logits_and_targets(logits, targets,
- sequence_length)
- return self._target_column.loss(logits_masked, targets_masked, features)
-
- def _logits_to_predictions(self, unused_features, logits):
- logit_shape = array_ops.shape(logits)
- flattened_logits = array_ops.reshape(logits, [-1, logit_shape[2]])
- predictions = self._target_column.logits_to_predictions(
- flattened_logits, proba=False)
- reshaped_predictions = array_ops.reshape(
- predictions, [logit_shape[0], logit_shape[1], -1])
- return array_ops.squeeze(reshaped_predictions, [2])
-
- def _logits_to_eval_ops(self, features, logits, targets, metrics):
- logits_masked, targets_masked = _mask_logits_and_targets(
- logits, targets, features.get(self._sequence_length_key))
-
- return self._target_column.get_eval_ops(features=features,
- logits=logits_masked,
- targets=targets_masked,
- metrics=metrics)
+ # Mask the activations and targets past `sequence_length`. Note that the
+ # `Tensor`s returned by `_mask_activations_and_targets` are flattened.
+ with ops.name_scope('activations_to_loss'):
+ activations_masked, targets_masked = _mask_activations_and_targets(
+ activations, targets, sequence_length)
+ return self._target_column.loss(activations_masked, targets_masked,
+ features)
+
+ def _activations_to_predictions(self, unused_features, activations):
+ with ops.name_scope('activations_to_predictions'):
+ activations_shape = array_ops.shape(activations)
+ flattened_activations = array_ops.reshape(activations,
+ [-1, activations_shape[2]])
+ predictions = self._target_column.activations_to_predictions(
+ flattened_activations, proba=False)
+ reshaped_predictions = array_ops.reshape(
+ predictions, [activations_shape[0], activations_shape[1], -1])
+ return array_ops.squeeze(reshaped_predictions, [2])
+
+ def _activations_to_eval_ops(self, features, activations, targets, metrics):
+ with ops.name_scope('activations_to_eval_ops'):
+ activations_masked, targets_masked = _mask_activations_and_targets(
+ activations, targets, features.get(self._sequence_length_key))
+
+ return self._target_column.get_eval_ops(features=features,
+ logits=activations_masked,
+ targets=targets_masked,
+ metrics=metrics)
class _SingleValueRNNEstimator(_DynamicRNNEstimator):
"""An `Estimator` that maps sequences of inputs to single outputs."""
- def _logits_to_loss(self, features, logits, targets):
- sequence_lengths = features.get(self._sequence_length_key)
- last_logits = _select_last_logits(logits, sequence_lengths)
- return self._target_column.loss(last_logits, targets, features)
-
- def _logits_to_predictions(self, features, logits):
- sequence_lengths = features.get(self._sequence_length_key)
- last_logits = _select_last_logits(logits, sequence_lengths)
- return self._target_column.logits_to_predictions(
- last_logits, proba=False)
-
- def _logits_to_eval_ops(self, features, logits, targets, metrics):
- sequence_lengths = features.get(self._sequence_length_key)
- last_logits = _select_last_logits(logits, sequence_lengths)
- return self._target_column.get_eval_ops(features=features,
- logits=last_logits,
- targets=targets,
- metrics=metrics)
+ def _activations_to_loss(self, features, activations, targets):
+ with ops.name_scope('activations_to_loss'):
+ sequence_lengths = features.get(self._sequence_length_key)
+ last_activations = _select_last_activations(activations, sequence_lengths)
+ return self._target_column.loss(last_activations, targets, features)
+
+ def _activations_to_predictions(self, features, activations):
+ with ops.name_scope('activations_to_predictions'):
+ sequence_lengths = features.get(self._sequence_length_key)
+ last_activations = _select_last_activations(activations, sequence_lengths)
+ return self._target_column.activations_to_predictions(
+ last_activations, proba=False)
+
+ def _activations_to_eval_ops(self, features, activations, targets, metrics):
+ with ops.name_scope('activations_to_eval_ops'):
+ sequence_lengths = features.get(self._sequence_length_key)
+ last_activations = _select_last_activations(activations, sequence_lengths)
+ return self._target_column.get_eval_ops(features=features,
+ logits=last_activations,
+ targets=targets,
+ metrics=metrics)
def _get_optimizer(optimizer_type, learning_rate, momentum):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index 37e829cf12..1ee3a8dd60 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -47,15 +47,15 @@ class MockTargetColumn(object):
def __init__(self):
self._num_label_columns = None
- def get_eval_ops(self, features, logits, targets, metrics):
+ def get_eval_ops(self, features, activations, targets, metrics):
raise NotImplementedError(
'MockTargetColumn.get_eval_ops called unexpectedly.')
- def logits_to_predictions(self, flattened_logits, proba=False):
+ def activations_to_predictions(self, flattened_activations, proba=False):
raise NotImplementedError(
- 'MockTargetColumn.logits_to_predictions called unexpectedly.')
+ 'MockTargetColumn.activations_to_predictions called unexpectedly.')
- def loss(self, logits, targets, features):
+ def loss(self, activations, targets, features):
raise NotImplementedError('MockTargetColumn.loss called unexpectedly.')
@property
@@ -120,26 +120,26 @@ class DynamicRnnEstimatorTest(tf.test.TestCase):
'sequence_length': tf.constant(
sequence_length, dtype=tf.int32)}
- # Map feature to logits with mocked linear layer.
+ # Map feature to activations with mocked linear layer.
with tf.test.mock.patch.object(dynamic_rnn_estimator,
'layers') as mock_layers:
mock_layers.fully_connected.return_value = tf.constant(
mock_linear_layer_output, dtype=tf.float32)
- logits_t, final_state_t = self._rnn_estimator._construct_rnn(
+ activations_t, final_state_t = self._rnn_estimator._construct_rnn(
features)
_, fully_connected_kwargs = mock_layers.fully_connected.call_args
linear_layer_inputs_t = fully_connected_kwargs['inputs']
linear_layer_output_dim = fully_connected_kwargs['num_outputs']
- # Obtain values of linear layer input, logits and final state.
+ # Obtain values of linear layer input, activations and final state.
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
- linear_layer_inputs, logits, final_state = sess.run(
- [linear_layer_inputs_t, logits_t, final_state_t])
+ linear_layer_inputs, activations, final_state = sess.run(
+ [linear_layer_inputs_t, activations_t, final_state_t])
np.testing.assert_equal(num_classes, linear_layer_output_dim)
np.testing.assert_almost_equal(inputs, linear_layer_inputs)
- np.testing.assert_almost_equal(mock_linear_layer_output, logits)
+ np.testing.assert_almost_equal(mock_linear_layer_output, activations)
np.testing.assert_almost_equal(
np.zeros([batch_size, self._rnn_cell.state_size], dtype=float),
final_state)
@@ -183,32 +183,33 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
'Mismatch on row {}. Got {}; expected {}.'.format(
i, actual_mask, expected_mask))
- def testMaskLogitsAndTargets(self):
- """Test `_mask_logits_and_targets`."""
+ def testMaskActivationsAndTargets(self):
+ """Test `_mask_activations_and_targets`."""
batch_size = 4
padded_length = 6
num_classes = 4
np.random.seed(1234)
sequence_length = np.random.randint(0, padded_length + 1, batch_size)
- logits = np.random.rand(batch_size, padded_length, num_classes)
+ activations = np.random.rand(batch_size, padded_length, num_classes)
targets = np.random.randint(0, num_classes, [batch_size, padded_length])
- (logits_masked_t,
- targets_masked_t) = dynamic_rnn_estimator._mask_logits_and_targets(
+ (activations_masked_t,
+ targets_masked_t) = dynamic_rnn_estimator._mask_activations_and_targets(
tf.constant(
- logits, dtype=tf.float32),
+ activations, dtype=tf.float32),
tf.constant(
targets, dtype=tf.int32),
tf.constant(
sequence_length, dtype=tf.int32))
with tf.Session() as sess:
- logits_masked, targets_masked = sess.run(
- [logits_masked_t, targets_masked_t])
+ activations_masked, targets_masked = sess.run(
+ [activations_masked_t, targets_masked_t])
- expected_logits_shape = [sum(sequence_length), num_classes]
- np.testing.assert_equal(expected_logits_shape, logits_masked.shape,
- 'Wrong logits shape. Expected {}; got {}.'.format(
- expected_logits_shape, logits_masked.shape))
+ expected_activations_shape = [sum(sequence_length), num_classes]
+ np.testing.assert_equal(
+ expected_activations_shape, activations_masked.shape,
+ 'Wrong activations shape. Expected {}; got {}.'.format(
+ expected_activations_shape, activations_masked.shape))
expected_targets_shape = [sum(sequence_length)]
np.testing.assert_equal(expected_targets_shape, targets_masked.shape,
@@ -217,14 +218,14 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
masked_index = 0
for i in range(batch_size):
for j in range(sequence_length[i]):
- actual_logits = logits_masked[masked_index]
- expected_logits = logits[i, j, :]
+ actual_activations = activations_masked[masked_index]
+ expected_activations = activations[i, j, :]
np.testing.assert_almost_equal(
- expected_logits,
- actual_logits,
+ expected_activations,
+ actual_activations,
err_msg='Unexpected logit value at index [{}, {}, :].'
- ' Expected {}; got {}.'.format(i, j, expected_logits,
- actual_logits))
+ ' Expected {}; got {}.'.format(i, j, expected_activations,
+ actual_activations))
actual_targets = targets_masked[masked_index]
expected_targets = targets[i, j]
@@ -236,31 +237,34 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
actual_targets))
masked_index += 1
- def testLogitsToPredictions(self):
- """Test `DynamicRNNEstimator._logits_to_predictions`."""
+ def testActivationsToPredictions(self):
+ """Test `DynamicRNNEstimator._activations_to_predictions`."""
batch_size = 8
sequence_length = 16
num_classes = 3
np.random.seed(10101)
- logits = np.random.rand(batch_size, sequence_length, num_classes)
- flattened_logits = np.reshape(logits, [-1, num_classes])
- flattened_argmax = np.argmax(flattened_logits, axis=1)
- expected_predictions = np.argmax(logits, axis=2)
-
- with tf.test.mock.patch.object(self._mock_target_column,
- 'logits_to_predictions',
- return_value=flattened_argmax,
- autospec=True) as mock_logits_to_predictions:
- predictions_t = self._seq_estimator._logits_to_predictions(
- None, tf.constant(logits, dtype=tf.float32))
- (target_column_input_logits_t,), _ = mock_logits_to_predictions.call_args
+ activations = np.random.rand(batch_size, sequence_length, num_classes)
+ flattened_activations = np.reshape(activations, [-1, num_classes])
+ flattened_argmax = np.argmax(flattened_activations, axis=1)
+ expected_predictions = np.argmax(activations, axis=2)
+
+ with tf.test.mock.patch.object(
+ self._mock_target_column,
+ 'activations_to_predictions',
+ return_value=flattened_argmax,
+ autospec=True) as mock_activations_to_predictions:
+ predictions_t = self._seq_estimator._activations_to_predictions(
+ None, tf.constant(activations, dtype=tf.float32))
+ (target_column_input_activations_t,
+ ), _ = mock_activations_to_predictions.call_args
with tf.Session() as sess:
- target_column_input_logits, predictions = sess.run(
- [target_column_input_logits_t, predictions_t])
+ target_column_input_activations, predictions = sess.run(
+ [target_column_input_activations_t, predictions_t])
- np.testing.assert_almost_equal(flattened_logits, target_column_input_logits)
+ np.testing.assert_almost_equal(flattened_activations,
+ target_column_input_activations)
np.testing.assert_equal(expected_predictions, predictions)
def testLearnSineFunction(self):
@@ -353,35 +357,36 @@ class MultiValueRNNEstimatorTest(tf.test.TestCase):
class SingleValueRNNEstimatorTest(tf.test.TestCase):
- def testSelectLastLogits(self):
- """Test `_select_last_logits`."""
+ def testSelectLastactivations(self):
+ """Test `_select_last_activations`."""
batch_size = 4
padded_length = 6
num_classes = 4
np.random.seed(4444)
sequence_length = np.random.randint(0, padded_length + 1, batch_size)
- logits = np.random.rand(batch_size, padded_length, num_classes)
- last_logits_t = dynamic_rnn_estimator._select_last_logits(
- tf.constant(logits, dtype=tf.float32),
+ activations = np.random.rand(batch_size, padded_length, num_classes)
+ last_activations_t = dynamic_rnn_estimator._select_last_activations(
+ tf.constant(activations, dtype=tf.float32),
tf.constant(sequence_length, dtype=tf.int32))
with tf.Session() as sess:
- last_logits = sess.run(last_logits_t)
+ last_activations = sess.run(last_activations_t)
- expected_logits_shape = [batch_size, num_classes]
- np.testing.assert_equal(expected_logits_shape, last_logits.shape,
- 'Wrong logits shape. Expected {}; got {}.'.format(
- expected_logits_shape, last_logits.shape))
+ expected_activations_shape = [batch_size, num_classes]
+ np.testing.assert_equal(
+ expected_activations_shape, last_activations.shape,
+ 'Wrong activations shape. Expected {}; got {}.'.format(
+ expected_activations_shape, last_activations.shape))
for i in range(batch_size):
- actual_logits = last_logits[i, :]
- expected_logits = logits[i, sequence_length[i] - 1, :]
+ actual_activations = last_activations[i, :]
+ expected_activations = activations[i, sequence_length[i] - 1, :]
np.testing.assert_almost_equal(
- expected_logits,
- actual_logits,
+ expected_activations,
+ actual_activations,
err_msg='Unexpected logit value at index [{}, :].'
- ' Expected {}; got {}.'.format(i, expected_logits,
- actual_logits))
+ ' Expected {}; got {}.'.format(i, expected_activations,
+ actual_activations))
def testLearnMean(self):
"""Test that `_SequenceRegressor` can learn to calculate a mean."""