aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py26
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py20
2 files changed, 45 insertions, 1 deletions
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
index 0ac3e3286b..42a4c1df2f 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import os
import time
+from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.learn.python.learn import export_strategy
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
@@ -345,7 +346,7 @@ def make_export_strategy(serving_input_fn,
collection.
Returns:
- an ExportStrategy that can be passed to the Experiment constructor.
+ An ExportStrategy that can be passed to the Experiment constructor.
"""
def export_fn(estimator, export_dir_base):
@@ -370,3 +371,26 @@ def make_export_strategy(serving_input_fn,
return export_result
return export_strategy.ExportStrategy('Servo', export_fn)
+
+
+def make_parsing_export_strategy(feature_columns, exports_to_keep=5):
+ """Create an ExportStrategy for use with Experiment, using `FeatureColumn`s.
+
+ Creates a SavedModel export that expects to be fed with a single string
+ Tensor containing serialized tf.Examples. At serving time, incoming
+ tf.Examples will be parsed according to the provided `FeatureColumn`s.
+
+ Args:
+ feature_columns: An iterable of `FeatureColumn`s representing the features
+ that must be provided at serving time (excluding labels!).
+ exports_to_keep: Number of exports to keep. Older exports will be
+ garbage-collected. Defaults to 5. Set to None to disable garbage
+ collection.
+
+ Returns:
+ An ExportStrategy that can be passed to the Experiment constructor.
+ """
+ feature_spec = feature_column.create_feature_spec_for_parsing(feature_columns)
+ serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
+ return make_export_strategy(serving_input_fn, exports_to_keep=exports_to_keep)
+
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
index a78a7a453e..7eebb76475 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
@@ -29,6 +29,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
import ctypes
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)
+from tensorflow.contrib.layers.python.layers import feature_column as fc
from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
@@ -651,6 +652,25 @@ class SavedModelExportUtilsTest(test.TestCase):
self.assertTrue(
isinstance(export_strategy, export_strategy_lib.ExportStrategy))
+ def test_make_parsing_export_strategy(self):
+ """Only tests that an ExportStrategy instance is created."""
+ sparse_col = fc.sparse_column_with_hash_bucket(
+ "sparse_column", hash_bucket_size=100)
+ embedding_col = fc.embedding_column(
+ fc.sparse_column_with_hash_bucket(
+ "sparse_column_for_embedding", hash_bucket_size=10),
+ dimension=4)
+ real_valued_col1 = fc.real_valued_column("real_valued_column1")
+ bucketized_col1 = fc.bucketized_column(
+ fc.real_valued_column("real_valued_column_for_bucketization1"), [0, 4])
+ feature_columns = [sparse_col, embedding_col, real_valued_col1,
+ bucketized_col1]
+
+ export_strategy = saved_model_export_utils.make_parsing_export_strategy(
+ feature_columns=feature_columns)
+ self.assertTrue(
+ isinstance(export_strategy, export_strategy_lib.ExportStrategy))
+
def _create_test_export_dir(export_dir_base):
export_dir = saved_model_export_utils.get_timestamped_export_dir(