diff options
-rw-r--r-- | tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py | 26 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py | 20 |
2 files changed, 45 insertions, 1 deletions
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 0ac3e3286b..42a4c1df2f 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import time +from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn import export_strategy from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import prediction_key @@ -345,7 +346,7 @@ def make_export_strategy(serving_input_fn, collection. Returns: - an ExportStrategy that can be passed to the Experiment constructor. + An ExportStrategy that can be passed to the Experiment constructor. """ def export_fn(estimator, export_dir_base): @@ -370,3 +371,26 @@ def make_export_strategy(serving_input_fn, return export_result return export_strategy.ExportStrategy('Servo', export_fn) + + +def make_parsing_export_strategy(feature_columns, exports_to_keep=5): + """Create an ExportStrategy for use with Experiment, using `FeatureColumn`s. + + Creates a SavedModel export that expects to be fed with a single string + Tensor containing serialized tf.Examples. At serving time, incoming + tf.Examples will be parsed according to the provided `FeatureColumn`s. + + Args: + feature_columns: An iterable of `FeatureColumn`s representing the features + that must be provided at serving time (excluding labels!). + exports_to_keep: Number of exports to keep. Older exports will be + garbage-collected. Defaults to 5. Set to None to disable garbage + collection. + + Returns: + An ExportStrategy that can be passed to the Experiment constructor. + """ + feature_spec = feature_column.create_feature_spec_for_parsing(feature_columns) + serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec) + return make_export_strategy(serving_input_fn, exports_to_keep=exports_to_keep) + diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index a78a7a453e..7eebb76475 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -29,6 +29,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import ctypes sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL) +from tensorflow.contrib.layers.python.layers import feature_column as fc from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import model_fn @@ -651,6 +652,25 @@ class SavedModelExportUtilsTest(test.TestCase): self.assertTrue( isinstance(export_strategy, export_strategy_lib.ExportStrategy)) + def test_make_parsing_export_strategy(self): + """Only tests that an ExportStrategy instance is created.""" + sparse_col = fc.sparse_column_with_hash_bucket( + "sparse_column", hash_bucket_size=100) + embedding_col = fc.embedding_column( + fc.sparse_column_with_hash_bucket( + "sparse_column_for_embedding", hash_bucket_size=10), + dimension=4) + real_valued_col1 = fc.real_valued_column("real_valued_column1") + bucketized_col1 = fc.bucketized_column( + fc.real_valued_column("real_valued_column_for_bucketization1"), [0, 4]) + feature_columns = [sparse_col, embedding_col, real_valued_col1, + bucketized_col1] + + export_strategy = saved_model_export_utils.make_parsing_export_strategy( + feature_columns=feature_columns) + self.assertTrue( + isinstance(export_strategy, export_strategy_lib.ExportStrategy)) + def _create_test_export_dir(export_dir_base): export_dir = saved_model_export_utils.get_timestamped_export_dir( |