aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-06 21:00:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 21:05:58 -0700
commitdebd66dae1c9a49d36ea006c97facf06b4ac25cb (patch)
treedcf98ad097d74a0404c4280fe20a7d65c2fb00f4 /tensorflow/contrib/estimator
parent38214447164a80f0c4a2ac42817ca383bd615fc0 (diff)
Added functionality of passing loss reduction as argument for RNNClassifier with default changed to SUM_OVER_BATCH_SIZE
This would involve making changes to all existing uses of RNNClassifier to set the loss reduction argument explicitly to SUM (previous default was SUM) PiperOrigin-RevId: 211917502
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/BUILD1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py14
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py41
3 files changed, 39 insertions, 17 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 77f62df99d..437b3d965d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -446,6 +446,7 @@ py_library(
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 7c49cd00d1..98660bb731 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import training_util
@@ -405,6 +406,7 @@ class RNNClassifier(estimator.Estimator):
weight_column=None,
label_vocabulary=None,
optimizer='Adagrad',
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
input_layer_partitioner=None,
config=None):
"""Initializes a `RNNClassifier` instance.
@@ -454,6 +456,8 @@ class RNNClassifier(estimator.Estimator):
string.
optimizer: An instance of `tf.Optimizer` or string specifying optimizer
type. Defaults to Adagrad optimizer.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
@@ -467,11 +471,15 @@ class RNNClassifier(estimator.Estimator):
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
- n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ n_classes,
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+
def _model_fn(features, labels, mode, config):
return _rnn_model_fn(
features=features,
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 959b40371a..1aebed348d 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -713,7 +713,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=1.119661)
+ mock_optimizer = self._mock_optimizer(expected_loss=0.559831)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -748,7 +748,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testMultiClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=2.662932)
+ mock_optimizer = self._mock_optimizer(expected_loss=1.331465)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -812,20 +812,32 @@ class RNNClassifierEvaluationTest(test.TestCase):
# probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]
# loss = -label * ln(p) - (1 - label) * ln(1 - p)
# = [[0.436326], [0.683335]]
+ # sum_over_batch_size = (0.436326 + 0.683335)/2
expected_metrics = {
- ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 1.119661,
- metric_keys.MetricKeys.LOSS_MEAN: 0.559831,
- metric_keys.MetricKeys.ACCURACY: 1.0,
- metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,
- metric_keys.MetricKeys.LABEL_MEAN: 0.5,
- metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+ ops.GraphKeys.GLOBAL_STEP:
+ global_step,
+ metric_keys.MetricKeys.LOSS:
+ 0.559831,
+ metric_keys.MetricKeys.LOSS_MEAN:
+ 0.559831,
+ metric_keys.MetricKeys.ACCURACY:
+ 1.0,
+ metric_keys.MetricKeys.PREDICTION_MEAN:
+ 0.429262,
+ metric_keys.MetricKeys.LABEL_MEAN:
+ 0.5,
+ metric_keys.MetricKeys.ACCURACY_BASELINE:
+ 0.5,
# With default threshold of 0.5, the model is a perfect classifier.
- metric_keys.MetricKeys.RECALL: 1.0,
- metric_keys.MetricKeys.PRECISION: 1.0,
+ metric_keys.MetricKeys.RECALL:
+ 1.0,
+ metric_keys.MetricKeys.PRECISION:
+ 1.0,
# Positive example is scored above negative, so AUC = 1.0.
- metric_keys.MetricKeys.AUC: 1.0,
- metric_keys.MetricKeys.AUC_PR: 1.0,
+ metric_keys.MetricKeys.AUC:
+ 1.0,
+ metric_keys.MetricKeys.AUC_PR:
+ 1.0,
}
self.assertAllClose(
sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
@@ -871,9 +883,10 @@ class RNNClassifierEvaluationTest(test.TestCase):
# [0.059494, 0.572639, 0.367866]]
# loss = -1. * log(softmax[label])
# = [[2.105432], [0.557500]]
+ # sum_over_batch_size = (2.105432 + 0.557500)/2
expected_metrics = {
ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 2.662932,
+ metric_keys.MetricKeys.LOSS: 1.331465,
metric_keys.MetricKeys.LOSS_MEAN: 1.331466,
metric_keys.MetricKeys.ACCURACY: 0.5,
}