diff options
author | 2018-05-30 16:25:00 -0700 | |
---|---|---|
committer | 2018-05-30 16:27:30 -0700 | |
commit | c9297e34f0ceef4afd970ee117aea9110bf8ae62 (patch) | |
tree | 3d5a9f8bd9b21ea42ec83f58753f5b44db6bec5e | |
parent | dff3875cdca6a8cf49ee5ce4c0c970eda550157f (diff) |
Add a convenience function, build_supervised_input_receiver_fn_from_input_fn,
that takes an Estimator input_fn and returns an input receiver function.
PiperOrigin-RevId: 198638593
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 4 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/estimator/BUILD | 20 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 55 | ||||
-rw-r--r-- | tensorflow/python/estimator/export/export.py | 36 | ||||
-rw-r--r-- | tensorflow/python/estimator/export/export_test.py | 35 | ||||
-rw-r--r-- | tensorflow/python/estimator/util.py | 57 | ||||
-rw-r--r-- | tensorflow/python/estimator/util_test.py | 102 |
8 files changed, 267 insertions, 43 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index aeb7ba536f..4465833f88 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -46,6 +46,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -2748,7 +2749,8 @@ class _Inputs(object): """ iterator = self._dataset.make_initializable_iterator() # pylint: disable=protected-access - hook = estimator_lib._DatasetInitializerHook(iterator) + hook = estimator_util._DatasetInitializerHook(iterator) + # pylint: enable=protected-access self._iterator = iterator return hook diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 679ef93229..0542c2fc91 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2699,7 +2699,6 @@ py_library( ":util", ":variables", "//tensorflow/python/eager:context", - "//tensorflow/python/estimator:util", "@six_archive//:six", ], ) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 0754041f9e..9c4d58b177 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -446,7 +446,26 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:platform", + "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/data", + ], +) + +py_test( + name = "util_test", + srcs = ["util_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/67510291 + deps = [ + ":util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:training", + "//tensorflow/python/data", + "//third_party/py/numpy", + "@six_archive//:six", ], ) @@ -598,6 +617,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":util", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 331ee7490e..cfbf7e2ce5 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -32,10 +32,10 @@ from tensorflow.core.framework import summary_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as tf_session -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import errors @@ -964,17 +964,9 @@ class Estimator(object): def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" result = self._call_input_fn(input_fn, mode) - input_hooks = [] - if isinstance(result, dataset_ops.Dataset): - iterator = result.make_initializable_iterator() - input_hooks.append(_DatasetInitializerHook(iterator)) - result = iterator.get_next() - if isinstance(result, (list, tuple)): - # Unconditionally drop the label (the second element of result). - result = result[0] - + result, _, hooks = estimator_util.parse_input_fn_result(result) self._validate_features_in_predict_input(result) - return result, input_hooks + return result, hooks def _validate_features_in_predict_input(self, result): if not _has_dataset_or_queue_runner(result): @@ -984,25 +976,13 @@ class Estimator(object): def _get_features_and_labels_from_input_fn(self, input_fn, mode): """Extracts the `features` and labels from return values of `input_fn`.""" - input_hooks = [] if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN: result = self._distribution.distribute_dataset( lambda: self._call_input_fn(input_fn, mode)) - iterator = result.make_initializable_iterator() - input_hooks.append(_DatasetInitializerHook(iterator)) - result = iterator.get_next() else: result = self._call_input_fn(input_fn, mode) - if isinstance(result, dataset_ops.Dataset): - iterator = result.make_initializable_iterator() - input_hooks.append(_DatasetInitializerHook(iterator)) - result = iterator.get_next() - if isinstance(result, (list, tuple)): - if len(result) != 2: - raise ValueError( - 'input_fn should return (features, labels) as a len 2 tuple.') - return result[0], result[1], input_hooks - return result, None, input_hooks + + return estimator_util.parse_input_fn_result(result) def _extract_batch_length(self, preds_evaluated): """Extracts batch length of predictions.""" @@ -1067,9 +1047,15 @@ class Estimator(object): mode: ModeKeys Returns: - Either features or (features, labels) where features and labels are: - features - `Tensor` or dictionary of string feature name to `Tensor`. - labels - `Tensor` or dictionary of `Tensor` with labels. + The return value of the passed input_fn, which should be one of: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. Raises: ValueError: if input_fn takes invalid arguments. @@ -1610,19 +1596,6 @@ def _has_dataset_or_queue_runner(maybe_tensor): # Now, check queue. return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS) - -class _DatasetInitializerHook(training.SessionRunHook): - - def __init__(self, iterator): - self._iterator = iterator - - def begin(self): - self._initializer = self._iterator.initializer - - def after_create_session(self, session, coord): - del coord - session.run(self._initializer) - VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo) diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 48ae8cd497..ff19a0a7f4 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -404,6 +404,42 @@ def build_raw_supervised_input_receiver_fn(features, return supervised_input_receiver_fn +def build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args): + """Get a function that returns a SupervisedInputReceiver matching an input_fn. + + Note that this function calls the input_fn in a local graph in order to + extract features and labels. Placeholders are then created from those + features and labels in the default graph. + + Args: + input_fn: An Estimator input_fn, which is a function that returns one of: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + **input_fn_args: set of kwargs to be passed to the input_fn. Note that + these will not be checked or validated here, and any errors raised by + the input_fn will be thrown to the top. + + Returns: + A function taking no arguments that, when called, returns a + SupervisedInputReceiver. This function can be passed in as part of the + input_receiver_map when exporting SavedModels from Estimator with multiple + modes. + """ + # Wrap the input_fn call in a graph to prevent sullying the default namespace + with ops.Graph().as_default(): + result = input_fn(**input_fn_args) + features, labels, _ = util.parse_input_fn_result(result) + # Placeholders are created back in the default graph. + return build_raw_supervised_input_receiver_fn(features, labels) + + ### Below utilities are specific to SavedModel exports. diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index 0af587f2a8..a7074712c2 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -459,6 +459,41 @@ class ExportTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): export.build_raw_supervised_input_receiver_fn(features, labels) + def test_build_supervised_input_receiver_fn_from_input_fn(self): + def dummy_input_fn(): + return ({"x": constant_op.constant([[1], [1]]), + "y": constant_op.constant(["hello", "goodbye"])}, + constant_op.constant([[1], [1]])) + + input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn( + dummy_input_fn) + + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["x", "y"]), + set(input_receiver.features.keys())) + self.assertIsInstance(input_receiver.labels, ops.Tensor) + self.assertEqual(set(["x", "y", "label"]), + set(input_receiver.receiver_tensors.keys())) + + def test_build_supervised_input_receiver_fn_from_input_fn_args(self): + def dummy_input_fn(feature_key="x"): + return ({feature_key: constant_op.constant([[1], [1]]), + "y": constant_op.constant(["hello", "goodbye"])}, + {"my_label": constant_op.constant([[1], [1]])}) + + input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn( + dummy_input_fn, feature_key="z") + + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["z", "y"]), + set(input_receiver.features.keys())) + self.assertEqual(set(["my_label"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["z", "y", "my_label"]), + set(input_receiver.receiver_tensors.keys())) + def test_build_all_signature_defs_without_receiver_alternatives(self): receiver_tensor = array_ops.placeholder(dtypes.string) output_1 = constant_op.constant([1.]) diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index e4e1d37f74..924ca309ff 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -24,6 +24,7 @@ import time from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import training from tensorflow.python.util import compat from tensorflow.python.util import function_utils @@ -72,3 +73,59 @@ def get_timestamped_dir(dir_base): result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) raise RuntimeError('Failed to obtain a unique export directory name after ' '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) + + +def parse_input_fn_result(result): + """Gets features, labels, and hooks from the result of an Estimator input_fn. + + Args: + result: output of an input_fn to an estimator, which should be one of: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + Returns: + Tuple of features, labels, and input_hooks, where features are as described + above, labels are as described above or None, and input_hooks are a list + of SessionRunHooks to be included when running. + + Raises: + ValueError: if the result is a list or tuple of length != 2. + """ + input_hooks = [] + try: + # We can't just check whether this is a tf.data.Dataset instance here, + # as this is plausibly a PerDeviceDataset. Try treating as a dataset first. + iterator = result.make_initializable_iterator() + except AttributeError: + # Not a dataset or dataset-like-object. Move along. + pass + else: + input_hooks.append(_DatasetInitializerHook(iterator)) + result = iterator.get_next() + + if isinstance(result, (list, tuple)): + if len(result) != 2: + raise ValueError( + 'input_fn should return (features, labels) as a len 2 tuple.') + return result[0], result[1], input_hooks + return result, None, input_hooks + + +class _DatasetInitializerHook(training.SessionRunHook): + """Creates a SessionRunHook that initializes the passed iterator.""" + + def __init__(self, iterator): + self._iterator = iterator + + def begin(self): + self._initializer = self._iterator.initializer + + def after_create_session(self, session, coord): + del coord + session.run(self._initializer) diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py new file mode 100644 index 0000000000..d7e0610779 --- /dev/null +++ b/tensorflow/python/estimator/util_test.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for util.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import util +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +class UtilTest(test.TestCase): + """Tests for miscellaneous Estimator utils.""" + + def test_parse_input_fn_result_tuple(self): + def _input_fn(): + features = constant_op.constant(np.arange(100)) + labels = constant_op.constant(np.arange(100, 200)) + return features, labels + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with self.test_session() as sess: + vals = sess.run([features, labels]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertAllEqual(vals[1], np.arange(100, 200)) + self.assertEqual(hooks, []) + + def test_parse_input_fn_result_dataset(self): + def _input_fn(): + features = np.expand_dims(np.arange(100), 0) + labels = np.expand_dims(np.arange(100, 200), 0) + return dataset_ops.Dataset.from_tensor_slices((features, labels)) + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with training.MonitoredSession(hooks=hooks) as sess: + vals = sess.run([features, labels]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertAllEqual(vals[1], np.arange(100, 200)) + self.assertIsInstance(hooks[0], util._DatasetInitializerHook) + + def test_parse_input_fn_result_features_only(self): + def _input_fn(): + return constant_op.constant(np.arange(100)) + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with self.test_session() as sess: + vals = sess.run([features]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertEqual(labels, None) + self.assertEqual(hooks, []) + + def test_parse_input_fn_result_features_only_dataset(self): + def _input_fn(): + features = np.expand_dims(np.arange(100), 0) + return dataset_ops.Dataset.from_tensor_slices(features) + + features, labels, hooks = util.parse_input_fn_result(_input_fn()) + + with training.MonitoredSession(hooks=hooks) as sess: + vals = sess.run([features]) + + self.assertAllEqual(vals[0], np.arange(100)) + self.assertEqual(labels, None) + self.assertIsInstance(hooks[0], util._DatasetInitializerHook) + + def test_parse_input_fn_result_invalid(self): + def _input_fn(): + features = np.expand_dims(np.arange(100), 0) + labels = np.expand_dims(np.arange(100, 200), 0) + return dataset_ops.Dataset.from_tensor_slices((features, labels, labels)) + + with self.assertRaisesRegexp(ValueError, 'input_fn should return'): + util.parse_input_fn_result(_input_fn()) + + +if __name__ == '__main__': + test.main() |