aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Karmel Allison <karmel@google.com>2018-05-30 16:25:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 16:27:30 -0700
commitc9297e34f0ceef4afd970ee117aea9110bf8ae62 (patch)
tree3d5a9f8bd9b21ea42ec83f58753f5b44db6bec5e
parentdff3875cdca6a8cf49ee5ce4c0c970eda550157f (diff)
Add a convenience function, build_supervised_input_receiver_fn_from_input_fn,
that takes an Estimator input_fn and returns an input receiver function. PiperOrigin-RevId: 198638593
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py4
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/estimator/BUILD20
-rw-r--r--tensorflow/python/estimator/estimator.py55
-rw-r--r--tensorflow/python/estimator/export/export.py36
-rw-r--r--tensorflow/python/estimator/export/export_test.py35
-rw-r--r--tensorflow/python/estimator/util.py57
-rw-r--r--tensorflow/python/estimator/util_test.py102
8 files changed, 267 insertions, 43 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index aeb7ba536f..4465833f88 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -46,6 +46,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -2748,7 +2749,8 @@ class _Inputs(object):
"""
iterator = self._dataset.make_initializable_iterator()
# pylint: disable=protected-access
- hook = estimator_lib._DatasetInitializerHook(iterator)
+ hook = estimator_util._DatasetInitializerHook(iterator)
+ # pylint: enable=protected-access
self._iterator = iterator
return hook
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 679ef93229..0542c2fc91 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2699,7 +2699,6 @@ py_library(
":util",
":variables",
"//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:util",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 0754041f9e..9c4d58b177 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -446,7 +446,26 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/data",
+ ],
+)
+
+py_test(
+ name = "util_test",
+ srcs = ["util_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"], # b/67510291
+ deps = [
+ ":util",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
],
)
@@ -598,6 +617,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 331ee7490e..cfbf7e2ce5 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -32,10 +32,10 @@ from tensorflow.core.framework import summary_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session as tf_session
-from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
+from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export as export_helpers
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import errors
@@ -964,17 +964,9 @@ class Estimator(object):
def _get_features_from_input_fn(self, input_fn, mode):
"""Extracts the `features` from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
- input_hooks = []
- if isinstance(result, dataset_ops.Dataset):
- iterator = result.make_initializable_iterator()
- input_hooks.append(_DatasetInitializerHook(iterator))
- result = iterator.get_next()
- if isinstance(result, (list, tuple)):
- # Unconditionally drop the label (the second element of result).
- result = result[0]
-
+ result, _, hooks = estimator_util.parse_input_fn_result(result)
self._validate_features_in_predict_input(result)
- return result, input_hooks
+ return result, hooks
def _validate_features_in_predict_input(self, result):
if not _has_dataset_or_queue_runner(result):
@@ -984,25 +976,13 @@ class Estimator(object):
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
"""Extracts the `features` and labels from return values of `input_fn`."""
- input_hooks = []
if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
result = self._distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
- iterator = result.make_initializable_iterator()
- input_hooks.append(_DatasetInitializerHook(iterator))
- result = iterator.get_next()
else:
result = self._call_input_fn(input_fn, mode)
- if isinstance(result, dataset_ops.Dataset):
- iterator = result.make_initializable_iterator()
- input_hooks.append(_DatasetInitializerHook(iterator))
- result = iterator.get_next()
- if isinstance(result, (list, tuple)):
- if len(result) != 2:
- raise ValueError(
- 'input_fn should return (features, labels) as a len 2 tuple.')
- return result[0], result[1], input_hooks
- return result, None, input_hooks
+
+ return estimator_util.parse_input_fn_result(result)
def _extract_batch_length(self, preds_evaluated):
"""Extracts batch length of predictions."""
@@ -1067,9 +1047,15 @@ class Estimator(object):
mode: ModeKeys
Returns:
- Either features or (features, labels) where features and labels are:
- features - `Tensor` or dictionary of string feature name to `Tensor`.
- labels - `Tensor` or dictionary of `Tensor` with labels.
+ The return value of the passed input_fn, which should be one of:
+
+ * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+ tuple (features, labels) with same constraints as below.
+ * A tuple (features, labels): Where `features` is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
Raises:
ValueError: if input_fn takes invalid arguments.
@@ -1610,19 +1596,6 @@ def _has_dataset_or_queue_runner(maybe_tensor):
# Now, check queue.
return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
-
-class _DatasetInitializerHook(training.SessionRunHook):
-
- def __init__(self, iterator):
- self._iterator = iterator
-
- def begin(self):
- self._initializer = self._iterator.initializer
-
- def after_create_session(self, session, coord):
- del coord
- session.run(self._initializer)
-
VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name
tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo)
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 48ae8cd497..ff19a0a7f4 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -404,6 +404,42 @@ def build_raw_supervised_input_receiver_fn(features,
return supervised_input_receiver_fn
+def build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args):
+ """Get a function that returns a SupervisedInputReceiver matching an input_fn.
+
+ Note that this function calls the input_fn in a local graph in order to
+ extract features and labels. Placeholders are then created from those
+ features and labels in the default graph.
+
+ Args:
+ input_fn: An Estimator input_fn, which is a function that returns one of:
+
+ * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+ tuple (features, labels) with same constraints as below.
+ * A tuple (features, labels): Where `features` is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
+
+ **input_fn_args: set of kwargs to be passed to the input_fn. Note that
+ these will not be checked or validated here, and any errors raised by
+ the input_fn will be thrown to the top.
+
+ Returns:
+ A function taking no arguments that, when called, returns a
+ SupervisedInputReceiver. This function can be passed in as part of the
+ input_receiver_map when exporting SavedModels from Estimator with multiple
+ modes.
+ """
+ # Wrap the input_fn call in a graph to prevent sullying the default namespace
+ with ops.Graph().as_default():
+ result = input_fn(**input_fn_args)
+ features, labels, _ = util.parse_input_fn_result(result)
+ # Placeholders are created back in the default graph.
+ return build_raw_supervised_input_receiver_fn(features, labels)
+
+
### Below utilities are specific to SavedModel exports.
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index 0af587f2a8..a7074712c2 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -459,6 +459,41 @@ class ExportTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
export.build_raw_supervised_input_receiver_fn(features, labels)
+ def test_build_supervised_input_receiver_fn_from_input_fn(self):
+ def dummy_input_fn():
+ return ({"x": constant_op.constant([[1], [1]]),
+ "y": constant_op.constant(["hello", "goodbye"])},
+ constant_op.constant([[1], [1]]))
+
+ input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
+ dummy_input_fn)
+
+ with ops.Graph().as_default():
+ input_receiver = input_receiver_fn()
+ self.assertEqual(set(["x", "y"]),
+ set(input_receiver.features.keys()))
+ self.assertIsInstance(input_receiver.labels, ops.Tensor)
+ self.assertEqual(set(["x", "y", "label"]),
+ set(input_receiver.receiver_tensors.keys()))
+
+ def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
+ def dummy_input_fn(feature_key="x"):
+ return ({feature_key: constant_op.constant([[1], [1]]),
+ "y": constant_op.constant(["hello", "goodbye"])},
+ {"my_label": constant_op.constant([[1], [1]])})
+
+ input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
+ dummy_input_fn, feature_key="z")
+
+ with ops.Graph().as_default():
+ input_receiver = input_receiver_fn()
+ self.assertEqual(set(["z", "y"]),
+ set(input_receiver.features.keys()))
+ self.assertEqual(set(["my_label"]),
+ set(input_receiver.labels.keys()))
+ self.assertEqual(set(["z", "y", "my_label"]),
+ set(input_receiver.receiver_tensors.keys()))
+
def test_build_all_signature_defs_without_receiver_alternatives(self):
receiver_tensor = array_ops.placeholder(dtypes.string)
output_1 = constant_op.constant([1.])
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index e4e1d37f74..924ca309ff 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -24,6 +24,7 @@ import time
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import training
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
@@ -72,3 +73,59 @@ def get_timestamped_dir(dir_base):
result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
raise RuntimeError('Failed to obtain a unique export directory name after '
'{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
+
+
+def parse_input_fn_result(result):
+ """Gets features, labels, and hooks from the result of an Estimator input_fn.
+
+ Args:
+ result: output of an input_fn to an estimator, which should be one of:
+
+ * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+ tuple (features, labels) with same constraints as below.
+ * A tuple (features, labels): Where `features` is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
+
+ Returns:
+ Tuple of features, labels, and input_hooks, where features are as described
+ above, labels are as described above or None, and input_hooks are a list
+ of SessionRunHooks to be included when running.
+
+ Raises:
+ ValueError: if the result is a list or tuple of length != 2.
+ """
+ input_hooks = []
+ try:
+ # We can't just check whether this is a tf.data.Dataset instance here,
+ # as this is plausibly a PerDeviceDataset. Try treating as a dataset first.
+ iterator = result.make_initializable_iterator()
+ except AttributeError:
+ # Not a dataset or dataset-like-object. Move along.
+ pass
+ else:
+ input_hooks.append(_DatasetInitializerHook(iterator))
+ result = iterator.get_next()
+
+ if isinstance(result, (list, tuple)):
+ if len(result) != 2:
+ raise ValueError(
+ 'input_fn should return (features, labels) as a len 2 tuple.')
+ return result[0], result[1], input_hooks
+ return result, None, input_hooks
+
+
+class _DatasetInitializerHook(training.SessionRunHook):
+ """Creates a SessionRunHook that initializes the passed iterator."""
+
+ def __init__(self, iterator):
+ self._iterator = iterator
+
+ def begin(self):
+ self._initializer = self._iterator.initializer
+
+ def after_create_session(self, session, coord):
+ del coord
+ session.run(self._initializer)
diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py
new file mode 100644
index 0000000000..d7e0610779
--- /dev/null
+++ b/tensorflow/python/estimator/util_test.py
@@ -0,0 +1,102 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for util.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import util
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+from tensorflow.python.training import training
+
+
+class UtilTest(test.TestCase):
+ """Tests for miscellaneous Estimator utils."""
+
+ def test_parse_input_fn_result_tuple(self):
+ def _input_fn():
+ features = constant_op.constant(np.arange(100))
+ labels = constant_op.constant(np.arange(100, 200))
+ return features, labels
+
+ features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+ with self.test_session() as sess:
+ vals = sess.run([features, labels])
+
+ self.assertAllEqual(vals[0], np.arange(100))
+ self.assertAllEqual(vals[1], np.arange(100, 200))
+ self.assertEqual(hooks, [])
+
+ def test_parse_input_fn_result_dataset(self):
+ def _input_fn():
+ features = np.expand_dims(np.arange(100), 0)
+ labels = np.expand_dims(np.arange(100, 200), 0)
+ return dataset_ops.Dataset.from_tensor_slices((features, labels))
+
+ features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+ with training.MonitoredSession(hooks=hooks) as sess:
+ vals = sess.run([features, labels])
+
+ self.assertAllEqual(vals[0], np.arange(100))
+ self.assertAllEqual(vals[1], np.arange(100, 200))
+ self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
+
+ def test_parse_input_fn_result_features_only(self):
+ def _input_fn():
+ return constant_op.constant(np.arange(100))
+
+ features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+ with self.test_session() as sess:
+ vals = sess.run([features])
+
+ self.assertAllEqual(vals[0], np.arange(100))
+ self.assertEqual(labels, None)
+ self.assertEqual(hooks, [])
+
+ def test_parse_input_fn_result_features_only_dataset(self):
+ def _input_fn():
+ features = np.expand_dims(np.arange(100), 0)
+ return dataset_ops.Dataset.from_tensor_slices(features)
+
+ features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+ with training.MonitoredSession(hooks=hooks) as sess:
+ vals = sess.run([features])
+
+ self.assertAllEqual(vals[0], np.arange(100))
+ self.assertEqual(labels, None)
+ self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
+
+ def test_parse_input_fn_result_invalid(self):
+ def _input_fn():
+ features = np.expand_dims(np.arange(100), 0)
+ labels = np.expand_dims(np.arange(100, 200), 0)
+ return dataset_ops.Dataset.from_tensor_slices((features, labels, labels))
+
+ with self.assertRaisesRegexp(ValueError, 'input_fn should return'):
+ util.parse_input_fn_result(_input_fn())
+
+
+if __name__ == '__main__':
+ test.main()