aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-13 13:22:37 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-13 13:22:37 +0800
commit04ddc2daf4c76bb4c99fdc6b582025e9a4ffba52 (patch)
tree4dd8424588dc21f2e4d23a591325bde7d3b63a66 /tensorflow/contrib/estimator
parentfd41d2c959372d7a068cb4474391362ef6a92fca (diff)
parent845aaec5ec2191f2708247a09d9bad37f012f536 (diff)
Merge branch 'master' into ENH/feature_importances_for_boosted_tree
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/BUILD57
-rw-r--r--tensorflow/contrib/estimator/__init__.py3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py434
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py611
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py14
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py41
6 files changed, 1143 insertions, 17 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 77f62df99d..6db311d52d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -18,6 +18,7 @@ py_library(
":boosted_trees",
":dnn",
":dnn_linear_combined",
+ ":dnn_with_layer_annotations",
":early_stopping",
":export",
":exporter",
@@ -127,6 +128,61 @@ py_test(
)
py_library(
+ name = "dnn_with_layer_annotations",
+ srcs = ["python/estimator/dnn_with_layer_annotations.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:partitioned_variables",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:head",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/ops/losses",
+ "//tensorflow/python/saved_model:utils",
+ ],
+)
+
+py_test(
+ name = "dnn_with_layer_annotations_test",
+ size = "medium",
+ srcs = ["python/estimator/dnn_with_layer_annotations_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "notsan", # b/67510291
+ ],
+ deps = [
+ ":dnn_with_layer_annotations",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator:dnn",
+ "//tensorflow/python/estimator:dnn_testing_utils",
+ "//tensorflow/python/estimator:export_export",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/estimator:pandas_io",
+ "//tensorflow/python/estimator:prediction_keys",
+ "//tensorflow/python/feature_column",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "dnn_linear_combined",
srcs = ["python/estimator/dnn_linear_combined.py"],
srcs_version = "PY2AND3",
@@ -446,6 +502,7 @@ py_library(
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 258860f263..78914ecaca 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.contrib.estimator.python.estimator.baseline import *
from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
from tensorflow.contrib.estimator.python.estimator.dnn import *
+from tensorflow.contrib.estimator.python.estimator.dnn_with_layer_annotations import *
from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
from tensorflow.contrib.estimator.python.estimator.early_stopping import *
from tensorflow.contrib.estimator.python.estimator.export import *
@@ -76,6 +77,8 @@ _allowed_symbols = [
'build_raw_supervised_input_receiver_fn',
'build_supervised_input_receiver_fn_from_input_fn',
'SavedModelEstimator'
+ 'DNNClassifierWithLayerAnnotations',
+ 'DNNRegressorWithLayerAnnotations',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
new file mode 100644
index 0000000000..152431d1b2
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -0,0 +1,434 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Deep Neural Network estimators with layer annotations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import pickle
+
+from google.protobuf.any_pb2 import Any
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.saved_model import utils as saved_model_utils
+
+
+class LayerAnnotationsCollectionNames(object):
+ """Names for the collections containing the annotations."""
+
+ UNPROCESSED_FEATURES = 'layer_annotations/unprocessed_features'
+ PROCESSED_FEATURES = 'layer_annotatons/processed_features'
+ FEATURE_COLUMNS = 'layer_annotations/feature_columns'
+
+ @classmethod
+ def keys(cls, collection_name):
+ return '%s/keys' % collection_name
+
+ @classmethod
+ def values(cls, collection_name):
+ return '%s/values' % collection_name
+
+
+def serialize_feature_column(feature_column):
+ if isinstance(feature_column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access
+ # We can't pickle nested functions, and we don't need the value of
+ # layer_creator in most cases anyway, so just discard its value.
+ args = feature_column._asdict()
+ args['layer_creator'] = None
+ temp = type(feature_column)(**args)
+ return pickle.dumps(temp)
+ return pickle.dumps(feature_column)
+
+
+def _to_any_wrapped_tensor_info(tensor):
+ """Converts a `Tensor` to a `TensorInfo` wrapped in a proto `Any`."""
+ any_buf = Any()
+ tensor_info = saved_model_utils.build_tensor_info(tensor)
+ any_buf.Pack(tensor_info)
+ return any_buf
+
+
+def make_input_layer_with_layer_annotations(original_input_layer, mode):
+ """Make an input_layer replacement function that adds layer annotations."""
+
+ def input_layer_with_layer_annotations(features,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None,
+ cols_to_output_tensors=None):
+ """Returns a dense `Tensor` as input layer based on given `feature_columns`.
+
+ Generally a single example in training data is described with
+ FeatureColumns.
+ At the first layer of the model, this column oriented data should be
+ converted
+ to a single `Tensor`.
+
+ This is like tf.feature_column.input_layer, except with added
+ Integrated-Gradient annotations.
+
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
+ on corresponding `_FeatureColumn`.
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `_DenseColumn` such as `numeric_column`, `embedding_column`,
+ `bucketized_column`, `indicator_column`. If you have categorical
+ features, you can wrap them with an `embedding_column` or
+ `indicator_column`.
+ weight_collections: A list of collection names to which the Variable will
+ be added. Note that variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ cols_to_vars: If not `None`, must be a dictionary that will be filled with
+ a mapping from `_FeatureColumn` to list of `Variable`s. For example,
+ after the call, we might have cols_to_vars = {_EmbeddingColumn(
+ categorical_column=_HashedCategoricalColumn( key='sparse_feature',
+ hash_bucket_size=5, dtype=tf.string), dimension=10): [<tf.Variable
+ 'some_variable:0' shape=(5, 10), <tf.Variable 'some_variable:1'
+ shape=(5, 10)]} If a column creates no variables, its value will be an
+ empty list.
+ cols_to_output_tensors: If not `None`, must be a dictionary that will be
+ filled with a mapping from '_FeatureColumn' to the associated output
+ `Tensor`s.
+
+ Returns:
+ A `Tensor` which represents input layer of a model. Its shape
+ is (batch_size, first_layer_dimension) and its dtype is `float32`.
+ first_layer_dimension is determined based on given `feature_columns`.
+
+ Raises:
+ ValueError: features and feature_columns have different lengths.
+ """
+
+ local_cols_to_output_tensors = {}
+ input_layer = original_input_layer(
+ features=features,
+ feature_columns=feature_columns,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ cols_to_vars=cols_to_vars,
+ cols_to_output_tensors=local_cols_to_output_tensors)
+
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors = local_cols_to_output_tensors
+
+ if mode and mode == model_fn.ModeKeys.PREDICT:
+ # Only annotate in PREDICT mode.
+
+ # Annotate features.
+ # These are the parsed Tensors, before embedding.
+
+ # Only annotate features used by FeatureColumns.
+ # We figure which ones are used by FeatureColumns by creating a parsing
+ # spec and looking at the keys.
+ spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ for key in spec.keys():
+ tensor = features[key]
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
+
+ # Annotate feature columns.
+ for column in feature_columns:
+ # TODO(cyfoo): Find a better way to serialize and deserialize
+ # _FeatureColumn.
+ ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
+ serialize_feature_column(column))
+
+ for column, tensor in local_cols_to_output_tensors.items():
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+ column.name)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
+
+ return input_layer
+
+ return input_layer_with_layer_annotations
+
+
+@contextlib.contextmanager
+def _monkey_patch(module, function, replacement):
+ old_function = getattr(module, function)
+ setattr(module, function, replacement)
+ yield
+ setattr(module, function, old_function)
+
+
+def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name
+ hidden_units,
+ feature_columns,
+ model_dir=None,
+ n_classes=2,
+ weight_column=None,
+ label_vocabulary=None,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None,
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM):
+ """A classifier for TensorFlow DNN models with layer annotations.
+
+ This classifier is fuctionally identical to estimator.DNNClassifier as far as
+ training and evaluating models is concerned. The key difference is that this
+ classifier adds additional layer annotations, which can be used for computing
+ Integrated Gradients.
+
+ Integrated Gradients is a method for attributing a classifier's predictions
+ to its input features (https://arxiv.org/pdf/1703.01365.pdf). Given an input
+ instance, the method assigns attribution scores to individual features in
+ proportion to the feature's importance to the classifier's prediction.
+
+ See estimator.DNNClassifer for example code for training and evaluating models
+ using this classifier.
+
+ This classifier is checkpoint-compatible with estimator.DNNClassifier and
+ therefore the following should work seamlessly:
+
+ # Instantiate ordinary estimator as usual.
+ estimator = tf.estimator.DNNClassifier(
+ config, feature_columns, hidden_units, ...)
+
+ # Train estimator, export checkpoint.
+ tf.estimator.train_and_evaluate(estimator, ...)
+
+ # Instantiate estimator with annotations with the same configuration as the
+ # ordinary estimator.
+ estimator_with_annotations = (
+ tf.contrib.estimator.DNNClassifierWithLayerAnnotations(
+ config, feature_columns, hidden_units, ...))
+
+ # Call export_savedmodel with the same arguments as the ordinary estimator,
+ # using the checkpoint produced for the ordinary estimator.
+ estimator_with_annotations.export_saved_model(
+ export_dir_base, serving_input_receiver, ...
+ checkpoint_path='/path/to/ordinary/estimator/checkpoint/model.ckpt-1234')
+
+ Args:
+ hidden_units: Iterable of number hidden units per layer. All layers are
+ fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second
+ one has 32.
+ feature_columns: An iterable containing all the feature columns used by the
+ model. All items in the set should be instances of classes derived from
+ `_FeatureColumn`.
+ model_dir: Directory to save model parameters, graph and etc. This can also
+ be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ n_classes: Number of label classes. Defaults to 2, namely binary
+ classification. Must be > 1.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
+ weight_column.normalizer_fn is applied on it to get weight tensor.
+ label_vocabulary: A list of strings represents possible label values. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. If it is not given, that means labels are already
+ encoded as integer or float within [0, 1] for `n_classes=2` and encoded as
+ integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there
+ will be errors if vocabulary is not provided and labels are string.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+ to Adagrad optimizer.
+ activation_fn: Activation function applied to each layer. If `None`, will
+ use `tf.nn.relu`.
+ dropout: When not `None`, the probability we will drop out a given
+ coordinate.
+ input_layer_partitioner: Optional. Partitioner for input layer. Defaults to
+ `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+ warm_start_from: A string filepath to a checkpoint to warm-start from, or a
+ `WarmStartSettings` object to fully configure warm-starting. If the
+ string filepath is provided instead of a `WarmStartSettings`, then all
+ weights are warm-started, and it is assumed that vocabularies and Tensor
+ names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
+ reduce training loss over batch. Defaults to `SUM`.
+
+ Returns:
+ DNNClassifier with layer annotations.
+ """
+
+ original = dnn.DNNClassifier(
+ hidden_units=hidden_units,
+ feature_columns=feature_columns,
+ model_dir=model_dir,
+ n_classes=n_classes,
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ optimizer=optimizer,
+ activation_fn=activation_fn,
+ dropout=dropout,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config,
+ warm_start_from=warm_start_from,
+ loss_reduction=loss_reduction)
+
+ def _model_fn(features, labels, mode, config):
+ with _monkey_patch(
+ feature_column_lib, 'input_layer',
+ make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
+ mode)):
+ return original.model_fn(features, labels, mode, config)
+
+ return estimator.Estimator(
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ warm_start_from=warm_start_from)
+
+
+def DNNRegressorWithLayerAnnotations( # pylint: disable=invalid-name
+ hidden_units,
+ feature_columns,
+ model_dir=None,
+ label_dimension=1,
+ weight_column=None,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None,
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM,
+):
+ """A regressor for TensorFlow DNN models with layer annotations.
+
+ This regressor is fuctionally identical to estimator.DNNRegressor as far as
+ training and evaluating models is concerned. The key difference is that this
+ classifier adds additional layer annotations, which can be used for computing
+ Integrated Gradients.
+
+ Integrated Gradients is a method for attributing a classifier's predictions
+ to its input features (https://arxiv.org/pdf/1703.01365.pdf). Given an input
+ instance, the method assigns attribution scores to individual features in
+ proportion to the feature's importance to the classifier's prediction.
+
+ See estimator.DNNRegressor for example code for training and evaluating models
+ using this regressor.
+
+ This regressor is checkpoint-compatible with estimator.DNNRegressor and
+ therefore the following should work seamlessly:
+
+ # Instantiate ordinary estimator as usual.
+ estimator = tf.estimator.DNNRegressor(
+ config, feature_columns, hidden_units, ...)
+
+ # Train estimator, export checkpoint.
+ tf.estimator.train_and_evaluate(estimator, ...)
+
+ # Instantiate estimator with annotations with the same configuration as the
+ # ordinary estimator.
+ estimator_with_annotations = (
+ tf.contrib.estimator.DNNRegressorWithLayerAnnotations(
+ config, feature_columns, hidden_units, ...))
+
+ # Call export_savedmodel with the same arguments as the ordinary estimator,
+ # using the checkpoint produced for the ordinary estimator.
+ estimator_with_annotations.export_saved_model(
+ export_dir_base, serving_input_receiver, ...
+ checkpoint_path='/path/to/ordinary/estimator/checkpoint/model.ckpt-1234')
+
+ Args:
+ hidden_units: Iterable of number hidden units per layer. All layers are
+ fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second
+ one has 32.
+ feature_columns: An iterable containing all the feature columns used by the
+ model. All items in the set should be instances of classes derived from
+ `_FeatureColumn`.
+ model_dir: Directory to save model parameters, graph and etc. This can also
+ be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ label_dimension: Number of regression targets per example. This is the size
+ of the last dimension of the labels and logits `Tensor` objects
+ (typically, these have shape `[batch_size, label_dimension]`).
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
+ weight_column.normalizer_fn is applied on it to get weight tensor.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+ to Adagrad optimizer.
+ activation_fn: Activation function applied to each layer. If `None`, will
+ use `tf.nn.relu`.
+ dropout: When not `None`, the probability we will drop out a given
+ coordinate.
+ input_layer_partitioner: Optional. Partitioner for input layer. Defaults to
+ `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+ warm_start_from: A string filepath to a checkpoint to warm-start from, or a
+ `WarmStartSettings` object to fully configure warm-starting. If the
+ string filepath is provided instead of a `WarmStartSettings`, then all
+ weights are warm-started, and it is assumed that vocabularies and Tensor
+ names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
+ reduce training loss over batch. Defaults to `SUM`.
+
+ Returns:
+ DNNRegressor with layer annotations.
+ """
+
+ original = dnn.DNNRegressor(
+ hidden_units=hidden_units,
+ feature_columns=feature_columns,
+ model_dir=model_dir,
+ label_dimension=label_dimension,
+ weight_column=weight_column,
+ optimizer=optimizer,
+ activation_fn=activation_fn,
+ dropout=dropout,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config,
+ warm_start_from=warm_start_from,
+ loss_reduction=loss_reduction,
+ )
+
+ def _model_fn(features, labels, mode, config):
+ with _monkey_patch(
+ feature_column_lib, 'input_layer',
+ make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
+ mode)):
+ return original.model_fn(features, labels, mode, config)
+
+ return estimator.Estimator(
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ warm_start_from=warm_start_from)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py
new file mode 100644
index 0000000000..2fe3d4c72e
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py
@@ -0,0 +1,611 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dnn_with_layer_annotations.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import dnn_with_layer_annotations
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.estimator.canned import dnn_testing_utils
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.estimator.inputs import pandas_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import input as input_lib
+from tensorflow.python.training import queue_runner
+
+try:
+ # pylint: disable=g-import-not-at-top
+ import pandas as pd
+ HAS_PANDAS = True
+except IOError:
+ # Pandas writes a temporary file during import. If it fails, don't use pandas.
+ HAS_PANDAS = False
+except ImportError:
+ HAS_PANDAS = False
+
+
+def _dnn_classifier_fn(*args, **kwargs):
+ return dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations(
+ *args, **kwargs)
+
+
+class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(self, _dnn_classifier_fn,
+ _dnn_regressor_fn)
+
+
+class DNNWithLayerAnnotationsClassifierEvaluateTest(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn)
+
+
+class DNNClassifierWithLayerAnnotationsPredictTest(
+ dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn)
+
+
+class DNNClassifierWithLayerAnnotationsTrainTest(
+ dnn_testing_utils.BaseDNNClassifierTrainTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn)
+
+
+def _dnn_regressor_fn(*args, **kwargs):
+ return dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations(
+ *args, **kwargs)
+
+
+class DNNWithLayerAnnotationsTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def _getLayerAnnotationCollection(self, graph, collection_name):
+ keys = graph.get_collection(
+ dnn_with_layer_annotations.LayerAnnotationsCollectionNames.keys(
+ collection_name))
+ values = graph.get_collection(
+ dnn_with_layer_annotations.LayerAnnotationsCollectionNames.values(
+ collection_name))
+ if len(keys) != len(values):
+ raise ValueError('keys and values should have same length. lengths were: '
+ '%d and %d, and elements were %s and %s' %
+ (len(keys), len(values), keys, values))
+ return dict(zip(keys, values))
+
+ def _testAnnotationsPresentForEstimator(self, estimator_class):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(1,)),
+ feature_column.embedding_column(
+ feature_column.categorical_column_with_vocabulary_list(
+ 'y', vocabulary_list=['a', 'b', 'c']),
+ dimension=3)
+ ]
+ estimator = estimator_class(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ model_dir=self._model_dir)
+ model_fn = estimator.model_fn
+
+ graph = ops.Graph()
+ with graph.as_default():
+ model_fn({
+ 'x': array_ops.constant([1.0]),
+ 'y': array_ops.constant(['a'])
+ }, {},
+ model_fn_lib.ModeKeys.PREDICT,
+ config=None)
+
+ unprocessed_features = self._getLayerAnnotationCollection(
+ graph, dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+ .UNPROCESSED_FEATURES)
+ processed_features = self._getLayerAnnotationCollection(
+ graph, dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+ .PROCESSED_FEATURES)
+ feature_columns = graph.get_collection(
+ dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+ .FEATURE_COLUMNS)
+
+ self.assertItemsEqual(unprocessed_features.keys(), ['x', 'y'])
+ self.assertEqual(2, len(processed_features.keys()))
+ self.assertEqual(2, len(feature_columns))
+
+ def testAnnotationsPresentForClassifier(self):
+ self._testAnnotationsPresentForEstimator(
+ dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations)
+
+ def testAnnotationsPresentForRegressor(self):
+ self._testAnnotationsPresentForEstimator(
+ dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations)
+
+ def _testCheckpointCompatibleWithNonAnnotatedEstimator(
+ self, train_input_fn, predict_input_fn, non_annotated_class,
+ annotated_class, prediction_key, estimator_args):
+ input_dimension = 2
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ estimator = non_annotated_class(
+ model_dir=self._model_dir,
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ **estimator_args)
+
+ estimator.train(train_input_fn, steps=10)
+
+ predictions = np.array(
+ [x[prediction_key] for x in estimator.predict(predict_input_fn)])
+
+ annotated_estimator = annotated_class(
+ model_dir=self._model_dir,
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ warm_start_from=self._model_dir,
+ **estimator_args)
+
+ annotated_predictions = np.array([
+ x[prediction_key] for x in annotated_estimator.predict(predict_input_fn)
+ ])
+
+ self.assertAllEqual(predictions.shape, annotated_predictions.shape)
+ for i, (a, b) in enumerate(
+ zip(predictions.flatten(), annotated_predictions.flatten())):
+ self.assertAlmostEqual(a, b, msg='index=%d' % i)
+
+ def testCheckpointCompatibleForClassifier(self):
+ n_classes = 2
+ input_dimension = 2
+ batch_size = 10
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ x_data = data.reshape(batch_size, input_dimension)
+ y_data = np.reshape(
+ np.rint(data[:batch_size]).astype(np.int64), (batch_size, 1))
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data},
+ y=y_data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, batch_size=batch_size, shuffle=False)
+
+ self._testCheckpointCompatibleWithNonAnnotatedEstimator(
+ train_input_fn,
+ predict_input_fn,
+ dnn.DNNClassifier,
+ dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations,
+ prediction_key=prediction_keys.PredictionKeys.PROBABILITIES,
+ estimator_args={'n_classes': n_classes})
+
+ def testCheckpointCompatibleForRegressor(self):
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, batch_size=batch_size, shuffle=False)
+
+ self._testCheckpointCompatibleWithNonAnnotatedEstimator(
+ train_input_fn,
+ predict_input_fn,
+ dnn.DNNRegressor,
+ dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations,
+ prediction_key=prediction_keys.PredictionKeys.PREDICTIONS,
+ estimator_args={'label_dimension': label_dimension})
+
+
+class DNNRegressorWithLayerAnnotationsEvaluateTest(
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn)
+
+
+class DNNRegressorWithLayerAnnotationsPredictTest(
+ dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn)
+
+
+class DNNRegressorWithLayerAnnotationsTrainTest(
+ dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn)
+
+
+def _queue_parsed_features(feature_map):
+ tensors_to_enqueue = []
+ keys = []
+ for key, tensor in six.iteritems(feature_map):
+ keys.append(key)
+ tensors_to_enqueue.append(tensor)
+ queue_dtypes = [x.dtype for x in tensors_to_enqueue]
+ input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes)
+ queue_runner.add_queue_runner(
+ queue_runner.QueueRunner(input_queue,
+ [input_queue.enqueue(tensors_to_enqueue)]))
+ dequeued_tensors = input_queue.dequeue()
+ return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
+
+
+class DNNRegressorWithLayerAnnotationsIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ est = dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ label_dimension=label_dimension,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array([
+ x[prediction_keys.PredictionKeys.PREDICTIONS]
+ for x in est.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, label_dimension), predictions.shape)
+
+ # EXPORT
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, y=data, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+ def test_pandas_input_fn(self):
+ """Tests complete flow with pandas_input_fn."""
+ if not HAS_PANDAS:
+ return
+ label_dimension = 1
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size, dtype=np.float32)
+ x = pd.DataFrame({'x': data})
+ y = pd.Series(data)
+ train_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+ eval_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, shuffle=False)
+ predict_input_fn = pandas_io.pandas_input_fn(
+ x=x, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+ def test_input_fn_from_parse_example(self):
+ """Tests complete flow with input_fn constructed from parse_example."""
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+
+ serialized_examples = []
+ for datum in data:
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'x':
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=datum)),
+ 'y':
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=datum)),
+ }))
+ serialized_examples.append(example.SerializeToString())
+
+ feature_spec = {
+ 'x': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+ }
+
+ def _train_input_fn():
+ feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _eval_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _predict_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ features.pop('y')
+ return features, None
+
+ self._test_complete_flow(
+ train_input_fn=_train_input_fn,
+ eval_input_fn=_eval_input_fn,
+ predict_input_fn=_predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+
+class DNNClassifierWithLayerAnnotationsIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _as_label(self, data_in_float):
+ return np.rint(data_in_float).astype(np.int64)
+
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ est = dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ # PREDICT
+ predicted_proba = np.array([
+ x[prediction_keys.PredictionKeys.PROBABILITIES]
+ for x in est.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
+
+ # EXPORT
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ n_classes = 3
+ input_dimension = 2
+ batch_size = 10
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ x_data = data.reshape(batch_size, input_dimension)
+ y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data},
+ y=y_data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+ def test_pandas_input_fn(self):
+ """Tests complete flow with pandas_input_fn."""
+ if not HAS_PANDAS:
+ return
+ input_dimension = 1
+ n_classes = 3
+ batch_size = 10
+ data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
+ x = pd.DataFrame({'x': data})
+ y = pd.Series(self._as_label(data))
+ train_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+ eval_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, shuffle=False)
+ predict_input_fn = pandas_io.pandas_input_fn(
+ x=x, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+ def test_input_fn_from_parse_example(self):
+ """Tests complete flow with input_fn constructed from parse_example."""
+ input_dimension = 2
+ n_classes = 3
+ batch_size = 10
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, input_dimension)
+
+ serialized_examples = []
+ for datum in data:
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'x':
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=datum)),
+ 'y':
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(
+ value=self._as_label(datum[:1]))),
+ }))
+ serialized_examples.append(example.SerializeToString())
+
+ feature_spec = {
+ 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
+ }
+
+ def _train_input_fn():
+ feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _eval_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _predict_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ features.pop('y')
+ return features, None
+
+ self._test_complete_flow(
+ train_input_fn=_train_input_fn,
+ eval_input_fn=_eval_input_fn,
+ predict_input_fn=_predict_input_fn,
+ input_dimension=input_dimension,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 7c49cd00d1..98660bb731 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import training_util
@@ -405,6 +406,7 @@ class RNNClassifier(estimator.Estimator):
weight_column=None,
label_vocabulary=None,
optimizer='Adagrad',
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
input_layer_partitioner=None,
config=None):
"""Initializes a `RNNClassifier` instance.
@@ -454,6 +456,8 @@ class RNNClassifier(estimator.Estimator):
string.
optimizer: An instance of `tf.Optimizer` or string specifying optimizer
type. Defaults to Adagrad optimizer.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
@@ -467,11 +471,15 @@ class RNNClassifier(estimator.Estimator):
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
- n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ n_classes,
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+
def _model_fn(features, labels, mode, config):
return _rnn_model_fn(
features=features,
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 959b40371a..1aebed348d 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -713,7 +713,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=1.119661)
+ mock_optimizer = self._mock_optimizer(expected_loss=0.559831)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -748,7 +748,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testMultiClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=2.662932)
+ mock_optimizer = self._mock_optimizer(expected_loss=1.331465)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -812,20 +812,32 @@ class RNNClassifierEvaluationTest(test.TestCase):
# probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]
# loss = -label * ln(p) - (1 - label) * ln(1 - p)
# = [[0.436326], [0.683335]]
+ # sum_over_batch_size = (0.436326 + 0.683335)/2
expected_metrics = {
- ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 1.119661,
- metric_keys.MetricKeys.LOSS_MEAN: 0.559831,
- metric_keys.MetricKeys.ACCURACY: 1.0,
- metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,
- metric_keys.MetricKeys.LABEL_MEAN: 0.5,
- metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+ ops.GraphKeys.GLOBAL_STEP:
+ global_step,
+ metric_keys.MetricKeys.LOSS:
+ 0.559831,
+ metric_keys.MetricKeys.LOSS_MEAN:
+ 0.559831,
+ metric_keys.MetricKeys.ACCURACY:
+ 1.0,
+ metric_keys.MetricKeys.PREDICTION_MEAN:
+ 0.429262,
+ metric_keys.MetricKeys.LABEL_MEAN:
+ 0.5,
+ metric_keys.MetricKeys.ACCURACY_BASELINE:
+ 0.5,
# With default threshold of 0.5, the model is a perfect classifier.
- metric_keys.MetricKeys.RECALL: 1.0,
- metric_keys.MetricKeys.PRECISION: 1.0,
+ metric_keys.MetricKeys.RECALL:
+ 1.0,
+ metric_keys.MetricKeys.PRECISION:
+ 1.0,
# Positive example is scored above negative, so AUC = 1.0.
- metric_keys.MetricKeys.AUC: 1.0,
- metric_keys.MetricKeys.AUC_PR: 1.0,
+ metric_keys.MetricKeys.AUC:
+ 1.0,
+ metric_keys.MetricKeys.AUC_PR:
+ 1.0,
}
self.assertAllClose(
sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
@@ -871,9 +883,10 @@ class RNNClassifierEvaluationTest(test.TestCase):
# [0.059494, 0.572639, 0.367866]]
# loss = -1. * log(softmax[label])
# = [[2.105432], [0.557500]]
+ # sum_over_batch_size = (2.105432 + 0.557500)/2
expected_metrics = {
ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 2.662932,
+ metric_keys.MetricKeys.LOSS: 1.331465,
metric_keys.MetricKeys.LOSS_MEAN: 1.331466,
metric_keys.MetricKeys.ACCURACY: 0.5,
}