diff options
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/rnn.py')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/rnn.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index 7c49cd00d1..98660bb731 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import training_util @@ -405,6 +406,7 @@ class RNNClassifier(estimator.Estimator): weight_column=None, label_vocabulary=None, optimizer='Adagrad', + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, input_layer_partitioner=None, config=None): """Initializes a `RNNClassifier` instance. @@ -454,6 +456,8 @@ class RNNClassifier(estimator.Estimator): string. optimizer: An instance of `tf.Optimizer` or string specifying optimizer type. Defaults to Adagrad optimizer. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`. input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. @@ -467,11 +471,15 @@ class RNNClassifier(estimator.Estimator): if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) else: head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access - n_classes, weight_column=weight_column, - label_vocabulary=label_vocabulary) + n_classes, + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + def _model_fn(features, labels, mode, config): return _rnn_model_fn( features=features, |