aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 18:03:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 18:07:16 -0700
commitd125fb8a39bb4fca1be5421130ed66d673ee590f (patch)
treea061732b06825319596a2cb30225e14abb8e9d3c /tensorflow/contrib/estimator
parent8469e314dae2c177c116bd17e38991c9a32bf418 (diff)
Always add layer annotations, regardless of mode.
PiperOrigin-RevId: 214073179
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py79
1 files changed, 37 insertions, 42 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
index 152431d1b2..3fd9f12c61 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -24,7 +24,6 @@ import pickle
from google.protobuf.any_pb2 import Any
from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import dnn
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import ops
@@ -68,7 +67,7 @@ def _to_any_wrapped_tensor_info(tensor):
return any_buf
-def make_input_layer_with_layer_annotations(original_input_layer, mode):
+def make_input_layer_with_layer_annotations(original_input_layer):
"""Make an input_layer replacement function that adds layer annotations."""
def input_layer_with_layer_annotations(features,
@@ -137,42 +136,38 @@ def make_input_layer_with_layer_annotations(original_input_layer, mode):
if cols_to_output_tensors is not None:
cols_to_output_tensors = local_cols_to_output_tensors
- if mode and mode == model_fn.ModeKeys.PREDICT:
- # Only annotate in PREDICT mode.
-
- # Annotate features.
- # These are the parsed Tensors, before embedding.
-
- # Only annotate features used by FeatureColumns.
- # We figure which ones are used by FeatureColumns by creating a parsing
- # spec and looking at the keys.
- spec = feature_column_lib.make_parse_example_spec(feature_columns)
- for key in spec.keys():
- tensor = features[key]
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.keys(
- LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.values(
- LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
- _to_any_wrapped_tensor_info(tensor))
-
- # Annotate feature columns.
- for column in feature_columns:
- # TODO(cyfoo): Find a better way to serialize and deserialize
- # _FeatureColumn.
- ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
- serialize_feature_column(column))
-
- for column, tensor in local_cols_to_output_tensors.items():
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.keys(
- LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
- column.name)
- ops.add_to_collection(
- LayerAnnotationsCollectionNames.values(
- LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
- _to_any_wrapped_tensor_info(tensor))
+ # Annotate features.
+ # These are the parsed Tensors, before embedding.
+
+ # Only annotate features used by FeatureColumns.
+ # We figure which ones are used by FeatureColumns by creating a parsing
+ # spec and looking at the keys.
+ spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ for key in spec.keys():
+ tensor = ops.convert_to_tensor(features[key])
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
+
+ # Annotate feature columns.
+ for column in feature_columns:
+ # TODO(cyfoo): Find a better way to serialize and deserialize
+ # _FeatureColumn.
+ ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
+ serialize_feature_column(column))
+
+ for column, tensor in local_cols_to_output_tensors.items():
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES), column.name)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
return input_layer
@@ -302,8 +297,8 @@ def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name
def _model_fn(features, labels, mode, config):
with _monkey_patch(
feature_column_lib, 'input_layer',
- make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
- mode)):
+ make_input_layer_with_layer_annotations(
+ feature_column_lib.input_layer)):
return original.model_fn(features, labels, mode, config)
return estimator.Estimator(
@@ -423,8 +418,8 @@ def DNNRegressorWithLayerAnnotations( # pylint: disable=invalid-name
def _model_fn(features, labels, mode, config):
with _monkey_patch(
feature_column_lib, 'input_layer',
- make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
- mode)):
+ make_input_layer_with_layer_annotations(
+ feature_column_lib.input_layer)):
return original.model_fn(features, labels, mode, config)
return estimator.Estimator(