From 53b57715f5604a5d09a9ddc73bbbf54f1d1142ed Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Sep 2018 12:16:53 -0700 Subject: Create experimental DNN Estimators with support for Path-Integrated Gradients annotations. PiperOrigin-RevId: 212682657 --- tensorflow/contrib/estimator/BUILD | 56 ++ tensorflow/contrib/estimator/__init__.py | 3 + .../python/estimator/dnn_with_layer_annotations.py | 434 +++++++++++++++ .../estimator/dnn_with_layer_annotations_test.py | 611 +++++++++++++++++++++ 4 files changed, 1104 insertions(+) create mode 100644 tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py create mode 100644 tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 437b3d965d..6db311d52d 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -18,6 +18,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":dnn_with_layer_annotations", ":early_stopping", ":export", ":exporter", @@ -126,6 +127,61 @@ py_test( ], ) +py_library( + name = "dnn_with_layer_annotations", + srcs = ["python/estimator/dnn_with_layer_annotations.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:nn", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:summary", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:utils", + ], +) + +py_test( + name = "dnn_with_layer_annotations_test", + size = "medium", + srcs = ["python/estimator/dnn_with_layer_annotations_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", # b/67510291 + ], + deps = [ + ":dnn_with_layer_annotations", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:dnn", + "//tensorflow/python/estimator:dnn_testing_utils", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:pandas_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "@six_archive//:six", + ], +) + py_library( name = "dnn_linear_combined", srcs = ["python/estimator/dnn_linear_combined.py"], diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 258860f263..78914ecaca 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * +from tensorflow.contrib.estimator.python.estimator.dnn_with_layer_annotations import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.early_stopping import * from tensorflow.contrib.estimator.python.estimator.export import * @@ -76,6 +77,8 @@ _allowed_symbols = [ 'build_raw_supervised_input_receiver_fn', 'build_supervised_input_receiver_fn_from_input_fn', 'SavedModelEstimator' + 'DNNClassifierWithLayerAnnotations', + 'DNNRegressorWithLayerAnnotations', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py new file mode 100644 index 0000000000..152431d1b2 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py @@ -0,0 +1,434 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Deep Neural Network estimators with layer annotations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import pickle + +from google.protobuf.any_pb2 import Any + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.canned import dnn +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn +from tensorflow.python.ops.losses import losses +from tensorflow.python.saved_model import utils as saved_model_utils + + +class LayerAnnotationsCollectionNames(object): + """Names for the collections containing the annotations.""" + + UNPROCESSED_FEATURES = 'layer_annotations/unprocessed_features' + PROCESSED_FEATURES = 'layer_annotatons/processed_features' + FEATURE_COLUMNS = 'layer_annotations/feature_columns' + + @classmethod + def keys(cls, collection_name): + return '%s/keys' % collection_name + + @classmethod + def values(cls, collection_name): + return '%s/values' % collection_name + + +def serialize_feature_column(feature_column): + if isinstance(feature_column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access + # We can't pickle nested functions, and we don't need the value of + # layer_creator in most cases anyway, so just discard its value. + args = feature_column._asdict() + args['layer_creator'] = None + temp = type(feature_column)(**args) + return pickle.dumps(temp) + return pickle.dumps(feature_column) + + +def _to_any_wrapped_tensor_info(tensor): + """Converts a `Tensor` to a `TensorInfo` wrapped in a proto `Any`.""" + any_buf = Any() + tensor_info = saved_model_utils.build_tensor_info(tensor) + any_buf.Pack(tensor_info) + return any_buf + + +def make_input_layer_with_layer_annotations(original_input_layer, mode): + """Make an input_layer replacement function that adds layer annotations.""" + + def input_layer_with_layer_annotations(features, + feature_columns, + weight_collections=None, + trainable=True, + cols_to_vars=None, + cols_to_output_tensors=None): + """Returns a dense `Tensor` as input layer based on given `feature_columns`. + + Generally a single example in training data is described with + FeatureColumns. + At the first layer of the model, this column oriented data should be + converted + to a single `Tensor`. + + This is like tf.feature_column.input_layer, except with added + Integrated-Gradient annotations. + + Args: + features: A mapping from key to tensors. `_FeatureColumn`s look up via + these keys. For example `numeric_column('price')` will look at 'price' + key in this dict. Values can be a `SparseTensor` or a `Tensor` depends + on corresponding `_FeatureColumn`. + feature_columns: An iterable containing the FeatureColumns to use as + inputs to your model. All items should be instances of classes derived + from `_DenseColumn` such as `numeric_column`, `embedding_column`, + `bucketized_column`, `indicator_column`. If you have categorical + features, you can wrap them with an `embedding_column` or + `indicator_column`. + weight_collections: A list of collection names to which the Variable will + be added. Note that variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + cols_to_vars: If not `None`, must be a dictionary that will be filled with + a mapping from `_FeatureColumn` to list of `Variable`s. For example, + after the call, we might have cols_to_vars = {_EmbeddingColumn( + categorical_column=_HashedCategoricalColumn( key='sparse_feature', + hash_bucket_size=5, dtype=tf.string), dimension=10): [