diff options
author | 2018-09-21 18:03:48 -0700 | |
---|---|---|
committer | 2018-09-21 18:07:16 -0700 | |
commit | d125fb8a39bb4fca1be5421130ed66d673ee590f (patch) | |
tree | a061732b06825319596a2cb30225e14abb8e9d3c /tensorflow/contrib/estimator | |
parent | 8469e314dae2c177c116bd17e38991c9a32bf418 (diff) |
Always add layer annotations, regardless of mode.
PiperOrigin-RevId: 214073179
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py | 79 |
1 files changed, 37 insertions, 42 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py index 152431d1b2..3fd9f12c61 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py @@ -24,7 +24,6 @@ import pickle from google.protobuf.any_pb2 import Any from tensorflow.python.estimator import estimator -from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import dnn from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops @@ -68,7 +67,7 @@ def _to_any_wrapped_tensor_info(tensor): return any_buf -def make_input_layer_with_layer_annotations(original_input_layer, mode): +def make_input_layer_with_layer_annotations(original_input_layer): """Make an input_layer replacement function that adds layer annotations.""" def input_layer_with_layer_annotations(features, @@ -137,42 +136,38 @@ def make_input_layer_with_layer_annotations(original_input_layer, mode): if cols_to_output_tensors is not None: cols_to_output_tensors = local_cols_to_output_tensors - if mode and mode == model_fn.ModeKeys.PREDICT: - # Only annotate in PREDICT mode. - - # Annotate features. - # These are the parsed Tensors, before embedding. - - # Only annotate features used by FeatureColumns. - # We figure which ones are used by FeatureColumns by creating a parsing - # spec and looking at the keys. - spec = feature_column_lib.make_parse_example_spec(feature_columns) - for key in spec.keys(): - tensor = features[key] - ops.add_to_collection( - LayerAnnotationsCollectionNames.keys( - LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key) - ops.add_to_collection( - LayerAnnotationsCollectionNames.values( - LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), - _to_any_wrapped_tensor_info(tensor)) - - # Annotate feature columns. - for column in feature_columns: - # TODO(cyfoo): Find a better way to serialize and deserialize - # _FeatureColumn. - ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS, - serialize_feature_column(column)) - - for column, tensor in local_cols_to_output_tensors.items(): - ops.add_to_collection( - LayerAnnotationsCollectionNames.keys( - LayerAnnotationsCollectionNames.PROCESSED_FEATURES), - column.name) - ops.add_to_collection( - LayerAnnotationsCollectionNames.values( - LayerAnnotationsCollectionNames.PROCESSED_FEATURES), - _to_any_wrapped_tensor_info(tensor)) + # Annotate features. + # These are the parsed Tensors, before embedding. + + # Only annotate features used by FeatureColumns. + # We figure which ones are used by FeatureColumns by creating a parsing + # spec and looking at the keys. + spec = feature_column_lib.make_parse_example_spec(feature_columns) + for key in spec.keys(): + tensor = ops.convert_to_tensor(features[key]) + ops.add_to_collection( + LayerAnnotationsCollectionNames.keys( + LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key) + ops.add_to_collection( + LayerAnnotationsCollectionNames.values( + LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), + _to_any_wrapped_tensor_info(tensor)) + + # Annotate feature columns. + for column in feature_columns: + # TODO(cyfoo): Find a better way to serialize and deserialize + # _FeatureColumn. + ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS, + serialize_feature_column(column)) + + for column, tensor in local_cols_to_output_tensors.items(): + ops.add_to_collection( + LayerAnnotationsCollectionNames.keys( + LayerAnnotationsCollectionNames.PROCESSED_FEATURES), column.name) + ops.add_to_collection( + LayerAnnotationsCollectionNames.values( + LayerAnnotationsCollectionNames.PROCESSED_FEATURES), + _to_any_wrapped_tensor_info(tensor)) return input_layer @@ -302,8 +297,8 @@ def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name def _model_fn(features, labels, mode, config): with _monkey_patch( feature_column_lib, 'input_layer', - make_input_layer_with_layer_annotations(feature_column_lib.input_layer, - mode)): + make_input_layer_with_layer_annotations( + feature_column_lib.input_layer)): return original.model_fn(features, labels, mode, config) return estimator.Estimator( @@ -423,8 +418,8 @@ def DNNRegressorWithLayerAnnotations( # pylint: disable=invalid-name def _model_fn(features, labels, mode, config): with _monkey_patch( feature_column_lib, 'input_layer', - make_input_layer_with_layer_annotations(feature_column_lib.input_layer, - mode)): + make_input_layer_with_layer_annotations( + feature_column_lib.input_layer)): return original.model_fn(features, labels, mode, config) return estimator.Estimator( |