aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-07-20 13:59:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 14:04:59 -0700
commit8257891f378027a1a7c0403ba6ba0aeb313496a0 (patch)
treedea706c76084ce75011938bbedec95738e524e46 /tensorflow/contrib/estimator
parente542062aa1613dc01b82b6378675563160fe0abf (diff)
Add estimator in contrib that loads its model function from a SavedModel.
PiperOrigin-RevId: 205449314
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/BUILD41
-rw-r--r--tensorflow/contrib/estimator/__init__.py5
-rw-r--r--tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py445
-rw-r--r--tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py369
4 files changed, 860 insertions, 0 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 1aa3df8d8d..349f48f7f7 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -28,6 +28,7 @@ py_library(
":multi_head",
":replicate_model_fn",
":rnn",
+ ":saved_model_estimator",
"//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -465,3 +466,43 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
+
+py_library(
+ name = "saved_model_estimator",
+ srcs = ["python/estimator/saved_model_estimator.py"],
+ deps = [
+ ":export",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:export",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/saved_model",
+ ],
+)
+
+py_test(
+ name = "saved_model_estimator_test",
+ size = "medium",
+ srcs = ["python/estimator/saved_model_estimator_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":export",
+ ":saved_model_estimator",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:export_export",
+ "//tensorflow/python/estimator:export_output",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 09fcfd66a1..e1453ae1d0 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -33,6 +33,8 @@ from tensorflow.contrib.estimator.python.estimator.logit_fns import *
from tensorflow.contrib.estimator.python.estimator.multi_head import *
from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import *
from tensorflow.contrib.estimator.python.estimator.rnn import *
+from tensorflow.contrib.estimator.python.estimator.saved_model_estimator import *
+from tensorflow.python.estimator.export.export import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
@@ -70,6 +72,9 @@ _allowed_symbols = [
'stop_if_higher_hook',
'stop_if_no_increase_hook',
'stop_if_no_decrease_hook',
+ 'build_raw_supervised_input_receiver_fn',
+ 'build_supervised_input_receiver_fn_from_input_fn',
+ 'SavedModelEstimator'
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
new file mode 100644
index 0000000000..22188fe663
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
@@ -0,0 +1,445 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class that creates an Estimator from a SavedModel."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.python.estimator import estimator as estimator_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export as export_lib
+from tensorflow.python.estimator.export import export_output
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import monitored_session
+from tensorflow.python.training import training_util
+
+
+class SavedModelEstimator(estimator_lib.Estimator):
+ """Create an Estimator from a SavedModel.
+
+ Only SavedModels exported with
+ `tf.contrib.estimator.export_all_saved_models()` or
+ `tf.estimator.Estimator.export_savedmodel()` are supported for this class.
+
+ Example with `tf.estimator.DNNClassifier`:
+
+ **Step 1: Create and train DNNClassifier.**
+ ```python
+ feature1 = tf.feature_column.embedding_column(
+ tf.feature_column.categorical_column_with_vocabulary_list(
+ key='feature1', vocabulary_list=('green', 'yellow')), dimension=1)
+ feature2 = tf.feature_column.numeric_column(key='feature2', default_value=0.0)
+
+ classifier = tf.estimator.DNNClassifier(
+ hidden_units=[4,2], feature_columns=[feature1, feature2])
+
+ def input_fn():
+ features = {'feature1': tf.constant(['green', 'green', 'yellow']),
+ 'feature2': tf.constant([3.5, 4.2, 6.1])}
+ label = tf.constant([1., 0., 0.])
+ return tf.data.Dataset.from_tensors((features, label)).repeat()
+
+ classifier.train(input_fn=input_fn, steps=10)
+ ```
+
+ **Step 2: Export classifier.**
+ First, build functions that specify the expected inputs.
+ ```python
+ # During train and evaluation, both the features and labels should be defined.
+ supervised_input_receiver_fn = (
+ tf.contrib.estimator.build_raw_supervised_input_receiver_fn(
+ {'feature1': tf.placeholder(dtype=tf.string, shape=[None]),
+ 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])},
+ tf.placeholder(dtype=tf.float32, shape=[None])))
+
+ # During predict mode, expect to receive a `tf.Example` proto, so a parsing
+ # function is used.
+ serving_input_receiver_fn = (
+ tf.estimator.export.build_parsing_serving_input_receiver_fn(
+ tf.feature_column.make_parse_example_spec([feature1, feature2])))
+ ```
+
+ Next, export the model as a SavedModel. A timestamped directory will be
+ created (for example `/tmp/export_all/1234567890`).
+ ```python
+ # Option 1: Save all modes (train, eval, predict)
+ export_dir = tf.contrib.estimator.export_all_saved_models(
+ classifier, '/tmp/export_all',
+ {tf.estimator.ModeKeys.TRAIN: supervised_input_receiver_fn,
+ tf.estimator.ModeKeys.EVAL: supervised_input_receiver_fn,
+ tf.estimator.ModeKeys.PREDICT: serving_input_receiver_fn})
+
+ # Option 2: Only export predict mode
+ export_dir = classifier.export_savedmodel(
+ '/tmp/export_predict', serving_input_receiver_fn)
+ ```
+
+ **Step 3: Create a SavedModelEstimator from the exported SavedModel.**
+ ```python
+ est = tf.contrib.estimator.SavedModelEstimator(export_dir)
+
+ # If all modes were exported, you can immediately evaluate and predict, or
+ # continue training. Otherwise only predict is available.
+ eval_results = est.evaluate(input_fn=input_fn, steps=1)
+ print(eval_results)
+
+ est.train(input_fn=input_fn, steps=20)
+
+ def predict_input_fn():
+ example = example_pb2.Example()
+ example.features.feature['feature1'].bytes_list.value.extend(['yellow'])
+ example.features.feature['feature2'].float_list.value.extend([1.])
+ return {'inputs':tf.constant([example.SerializeToString()])}
+
+ predictions = est.predict(predict_input_fn)
+ print(next(predictions))
+ ```
+ """
+
+ def __init__(self, saved_model_dir, model_dir=None):
+ """Initialize a SavedModelEstimator.
+
+ The SavedModelEstimator loads its model function and variable values from
+ the graphs defined in the SavedModel. There is no option to pass in
+ `RunConfig` or `params` arguments, because the model function graph is
+ defined statically in the SavedModel.
+
+ Args:
+ saved_model_dir: Directory containing SavedModel protobuf and subfolders.
+ model_dir: Directory to save new checkpoints during training.
+
+ Raises:
+ NotImplementedError: If a DistributionStrategy is defined in the config.
+ Unless the SavedModelEstimator is subclassed, this shouldn't happen.
+ """
+ checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access
+ vars_to_warm_start = [name for name, _ in
+ checkpoint_utils.list_variables(checkpoint)]
+ warm_start_settings = estimator_lib.WarmStartSettings(
+ ckpt_to_initialize_from=checkpoint,
+ vars_to_warm_start=vars_to_warm_start)
+
+ super(SavedModelEstimator, self).__init__(
+ model_fn=self._model_fn_from_saved_model, model_dir=model_dir,
+ warm_start_from=warm_start_settings)
+ if self._distribution is not None:
+ raise NotImplementedError(
+ 'SavedModelEstimator currently does not support '
+ 'DistributionStrategy.')
+ self.saved_model_dir = saved_model_dir
+ self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir)
+ self._available_modes = self._extract_available_modes()
+
+ def _extract_available_modes(self):
+ """Return list of modes found in SavedModel."""
+ available_modes = []
+ logging.info('Checking available modes for SavedModelEstimator.')
+ for mode in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
+ model_fn_lib.ModeKeys.PREDICT]:
+ try:
+ self._get_meta_graph_def_for_mode(mode)
+ except RuntimeError:
+ logging.warning('%s mode not found in SavedModel.' % mode)
+ continue
+
+ if self._get_signature_def_for_mode(mode) is not None:
+ available_modes.append(mode)
+
+ logging.info('Available modes for Estimator: %s' % available_modes)
+ return available_modes
+
+ def _validate_mode(self, mode):
+ """Make sure that mode can be run using the SavedModel."""
+ if mode not in self._available_modes:
+ raise RuntimeError('%s mode is not available in the SavedModel. Use '
+ 'saved_model_cli to check that the Metagraph for this '
+ 'mode has been exported.' % mode)
+
+ def _get_meta_graph_def_for_mode(self, mode):
+ tags = model_fn_lib.EXPORT_TAG_MAP[mode]
+ return self.saved_model_loader.get_meta_graph_def_from_tags(tags)
+
+ def _get_signature_def_for_mode(self, mode):
+ meta_graph_def = self._get_meta_graph_def_for_mode(mode)
+ sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ if mode == model_fn_lib.ModeKeys.PREDICT else mode)
+ if sig_def_key not in meta_graph_def.signature_def:
+ logging.warning('Metagraph for mode %s was found, but SignatureDef with'
+ ' key \"%s\" is missing.' % (mode, sig_def_key))
+ return None
+ return meta_graph_def.signature_def[sig_def_key]
+
+ def _create_and_assert_global_step(self, graph):
+ # Do nothing here. The global step variable will be created/loaded from the
+ # SavedModel. If a global step variable were created here, the result
+ # will be two duplicate global step variables, causing issues during
+ # the warm-start phase.
+ # Due to the global variable being created in the model function, this may
+ # cause issues when running DistributionStrategy. Thus, DistributionStrategy
+ # is not yet supported with SavedModelEstimator.
+ pass
+
+ def _model_fn_from_saved_model(self, features, labels, mode):
+ """Load a SavedModel graph and return an EstimatorSpec."""
+ # TODO(kathywu): Model function loads placeholders from the graph. Calling
+ # export_all_saved_models creates another placeholder for the inputs, on top
+ # of the original placeholders. There should be a way to avoid this.
+ self._validate_mode(mode)
+
+ g = ops.get_default_graph()
+ if training_util.get_global_step(g) is not None:
+ raise RuntimeError(
+ 'Graph must not contain a global step tensor before the SavedModel is'
+ ' loaded. Please make sure that the input function does not create a '
+ 'global step.')
+
+ # Extract SignatureDef for information about the input and output tensors.
+ signature_def = self._get_signature_def_for_mode(mode)
+
+ # Generate input map for replacing the inputs in the SavedModel graph with
+ # the provided features and labels.
+ input_map = _generate_input_map(signature_def, features, labels)
+
+ # Create a list of the names of output tensors. When the graph is loaded,
+ # names of the output tensors may be remapped. This ensures that the correct
+ # tensors are returned in the EstimatorSpec.
+ output_tensor_names = [
+ value.name for value in six.itervalues(signature_def.outputs)]
+
+ # Load the graph. `output_tensors` contains output `Tensors` in the same
+ # same order as the `output_tensor_names` list.
+ tags = model_fn_lib.EXPORT_TAG_MAP[mode]
+ _, output_tensors = self.saved_model_loader.load_graph(
+ g, tags, input_map=input_map, return_elements=output_tensor_names)
+
+ # Create a scaffold from the MetaGraphDef that contains ops to initialize
+ # the graph. This should mirror the steps from _add_meta_graph_for_mode(),
+ # which creates a MetaGraphDef from the EstimatorSpec's scaffold.
+ scaffold = monitored_session.Scaffold(
+ local_init_op=loader_impl._get_legacy_init_op_tensor( # pylint: disable=protected-access
+ self._get_meta_graph_def_for_mode(mode)))
+
+ # Ensure that a global step tensor has been created.
+ global_step_tensor = training_util.get_global_step(g)
+ training_util.assert_global_step(global_step_tensor)
+
+ # Extract values to return in the EstimatorSpec.
+ output_map = dict(zip(output_tensor_names, output_tensors))
+ outputs = {key: output_map[value.name]
+ for key, value in six.iteritems(signature_def.outputs)}
+
+ loss, predictions, metrics = _validate_and_extract_outputs(
+ mode, outputs, signature_def.method_name)
+
+ train_op = ops.get_collection(constants.TRAIN_OP_KEY)
+ if len(train_op) > 1:
+ raise RuntimeError('Multiple ops found in the train_op collection.')
+ train_op = None if not train_op else train_op[0]
+
+ _clear_saved_model_collections()
+ return model_fn_lib.EstimatorSpec(
+ scaffold=scaffold,
+ mode=mode,
+ loss=loss,
+ train_op=train_op,
+ predictions=predictions,
+ eval_metric_ops=metrics)
+
+
+def _clear_saved_model_collections():
+ """Clear collections that are expected empty when exporting a SavedModel.
+
+ The SavedModel builder uses these collections to track ops necessary to
+ restore the graph state. These collections are expected to be empty before
+ MetaGraphs are added to the builder.
+ """
+ del ops.get_collection_ref(constants.ASSETS_KEY)[:]
+ del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:]
+ del ops.get_collection_ref(constants.MAIN_OP_KEY)[:]
+ del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:]
+
+
+def _generate_input_map(signature_def, features, labels):
+ """Return dict mapping an input tensor name to a feature or label tensor.
+
+ Args:
+ signature_def: SignatureDef loaded from SavedModel
+ features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
+ `SparseTensor`, specifying the features to be passed to the model.
+ labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
+ `SparseTensor`, specifying the labels to be passed to the model. May be
+ `None`.
+
+ Returns:
+ dict mapping string names of inputs to features or labels tensors
+
+ Raises:
+ ValueError: if SignatureDef inputs are not completely mapped by the input
+ features and labels.
+ """
+ # pylint: disable=protected-access
+ if not isinstance(features, dict):
+ features = {export_lib._SINGLE_FEATURE_DEFAULT_NAME: features}
+ if labels is not None and not isinstance(labels, dict):
+ labels = {export_lib._SINGLE_LABEL_DEFAULT_NAME: labels}
+ # pylint: enable=protected-access
+
+ inputs = signature_def.inputs
+ input_map = {}
+ for key, tensor_info in six.iteritems(inputs):
+ input_name = tensor_info.name
+ if ':' in input_name:
+ input_name = input_name[:input_name.find(':')]
+
+ # When tensors are used as control inputs for operations, their names are
+ # prepended with a '^' character in the GraphDef. To handle possible control
+ # flow edge cases, control input names must be included in the input map.
+ control_dependency_name = '^' + input_name
+
+ if key in features:
+ _check_same_dtype_and_shape(features[key], tensor_info, key)
+ input_map[input_name] = input_map[control_dependency_name] = features[key]
+ elif labels is not None and key in labels:
+ _check_same_dtype_and_shape(labels[key], tensor_info, key)
+ input_map[input_name] = input_map[control_dependency_name] = labels[key]
+ else:
+ raise ValueError(
+ 'Key \"%s\" not found in features or labels passed in to the model '
+ 'function. All required keys: %s' % (key, inputs.keys()))
+
+ return input_map
+
+
+def _check_same_dtype_and_shape(tensor, tensor_info, name):
+ """Validate that tensor has the same properties as the TensorInfo proto.
+
+ Args:
+ tensor: a `Tensor` object.
+ tensor_info: a `TensorInfo` proto.
+ name: Name of the input (to identify Tensor if an error is raised).
+
+ Raises:
+ ValueError: If the tensor shape or dtype don't match the TensorInfo
+ """
+ dtype_error = (tensor.dtype != dtypes.DType(tensor_info.dtype))
+ shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape)
+
+ if dtype_error or shape_error:
+ msg = 'Tensor shape and/or dtype validation failed for input %s:' % name
+ if dtype_error:
+ msg += ('\n\tExpected dtype: %s, Got: %s'
+ % (dtypes.DType(tensor_info.dtype), tensor.dtype))
+ if shape_error:
+ msg += ('\n\tExpected shape: %s, Got: %s'
+ % (tensor_shape.TensorShape(tensor_info.tensor_shape),
+ tensor.shape))
+
+ raise ValueError(msg)
+
+
+def _extract_eval_metrics(output_dict):
+ """Return a eval metric dict extracted from the output_dict.
+
+ Eval metrics consist of a value tensor and an update op. Both must be in the
+ passed-in tensor dictionary for an eval metric to be added to the returned
+ dictionary.
+
+ Args:
+ output_dict: a dict that maps strings to tensors.
+
+ Returns:
+ dict mapping strings to (value, update_op) tuples.
+ """
+ # pylint: disable=protected-access
+ metric_ops = {}
+ separator_char = export_output._SupervisedOutput._SEPARATOR_CHAR
+
+ for key, tensor in six.iteritems(output_dict):
+ split_key = key.split(separator_char)
+
+ # The metric name may contain the separator character, so recreate its name.
+ metric_name = separator_char.join(split_key[:-1])
+
+ if split_key[0] == export_output._SupervisedOutput.METRICS_NAME:
+ # If the key ends with the value suffix, and there is a corresponding
+ # key ending with the update_op suffix, then add tensors to metrics dict.
+ if split_key[-1] == export_output._SupervisedOutput.METRIC_VALUE_SUFFIX:
+ update_op = ''.join(
+ [metric_name, separator_char,
+ export_output._SupervisedOutput.METRIC_UPDATE_SUFFIX])
+ if update_op in output_dict:
+ update_op_tensor = output_dict[update_op]
+ metric_ops[metric_name] = (tensor, update_op_tensor)
+
+ # pylint: enable=protected-access
+ return metric_ops
+
+
+def _validate_and_extract_outputs(mode, output_dict, method_name):
+ """Extract values from SignatureDef output dictionary.
+
+ Args:
+ mode: One of the modes enumerated in `tf.estimator.ModeKeys`.
+ output_dict: dict of string SignatureDef keys to `Tensor`.
+ method_name: Method name of the SignatureDef as a string.
+
+ Returns:
+ Tuple of (
+ loss: `Tensor` object,
+ predictions: dictionary mapping string keys to `Tensor` objects,
+ metrics: dictionary mapping string keys to a tuple of two `Tensor` objects
+ )
+
+ Raises:
+ RuntimeError: raised if SignatureDef has an invalid method name for the mode
+ """
+ # pylint: disable=protected-access
+ loss, predictions, metrics = None, None, None
+
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ predictions = output_dict
+ else:
+ # Validate that the SignatureDef's method name matches the expected name for
+ # the given mode.
+ expected_method_name = signature_constants.SUPERVISED_TRAIN_METHOD_NAME
+ if mode == model_fn_lib.ModeKeys.EVAL:
+ expected_method_name = signature_constants.SUPERVISED_EVAL_METHOD_NAME
+ if method_name != expected_method_name:
+ raise RuntimeError(
+ 'Invalid SignatureDef method name for mode %s.\n\tExpected: %s\n\t'
+ 'Got: %s\nPlease ensure that the SavedModel was exported with '
+ '`tf.contrib.estimator.export_all_saved_models()`.' %
+ (mode, expected_method_name, method_name))
+
+ # Extract loss, metrics and predictions from the output dict.
+ loss = output_dict[export_output._SupervisedOutput.LOSS_NAME]
+ metrics = _extract_eval_metrics(output_dict)
+ predictions = {
+ key: value for key, value in six.iteritems(output_dict)
+ if key.split(export_output._SupervisedOutput._SEPARATOR_CHAR)[0] == (
+ export_output._SupervisedOutput.PREDICTIONS_NAME)}
+
+ # pylint: enable=protected-access
+ return loss, predictions, metrics
diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py
new file mode 100644
index 0000000000..718da1367c
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py
@@ -0,0 +1,369 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SavedModelEstimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+from tensorflow.contrib.estimator.python.estimator import export as contrib_export
+from tensorflow.contrib.estimator.python.estimator import saved_model_estimator
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.export import export_output
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+from tensorflow.python.training import training
+
+
+def dummy_input_fn():
+ return dataset_ops.Dataset.from_tensors((
+ {'x': constant_op.constant([[1], [-2]], dtype=dtypes.int64)},
+ constant_op.constant([[4], [-3]], dtype=dtypes.float32))).repeat()
+
+
+def dummy_input_fn_features_only():
+ return dataset_ops.Dataset.from_tensors(
+ {'x': constant_op.constant([[5], [6]], dtype=dtypes.int64)}).repeat()
+
+
+def dummy_supervised_receiver_fn():
+ feature_spec = {
+ 'x': array_ops.placeholder(
+ dtype=dtypes.int64, shape=(2, 1), name='feature_x'),
+ }
+ label_spec = array_ops.placeholder(
+ dtype=dtypes.float32, shape=[2, 1], name='truth')
+ return export.build_raw_supervised_input_receiver_fn(
+ feature_spec, label_spec)
+
+
+def dummy_serving_receiver_fn():
+ feature_spec = {'x': array_ops.placeholder(
+ dtype=dtypes.int64, shape=(2, 1), name='feature_x'),}
+ return export.build_raw_serving_input_receiver_fn(feature_spec)
+
+
+def model_fn_diff_modes(features, labels, mode):
+ _, _ = features, labels
+ v = variables.Variable(21, name='some_var')
+ train_op = None
+ loss = constant_op.constant(104)
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ loss = constant_op.constant(105)
+ predictions = constant_op.constant([501])
+ train_op = control_flow_ops.group(
+ state_ops.assign_add(training.get_global_step(), 1),
+ state_ops.assign_add(v, 3))
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ loss = constant_op.constant(106)
+ predictions = constant_op.constant([502])
+ else:
+ loss = constant_op.constant(107)
+ predictions = constant_op.constant([503])
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=loss,
+ train_op=train_op,
+ eval_metric_ops={
+ 'abs_err': metrics_lib.mean_absolute_error(
+ constant_op.constant(0), predictions)},
+ predictions=predictions)
+
+
+class SavedModelEstimatorTest(test.TestCase):
+
+ def setUp(self):
+ self.tmpdirs = []
+
+ def tearDown(self):
+ for tmpdir in self.tmpdirs:
+ # gfile.DeleteRecursively fails in the windows cmake test, so use shutil.
+ shutil.rmtree(tmpdir, ignore_errors=True)
+ self.tmpdirs = []
+
+ def _get_tmp_dir(self):
+ tmpdir = tempfile.mkdtemp()
+ self.tmpdirs.append(tmpdir)
+ return tmpdir
+
+ def _export_estimator(self, train=True, evaluate=True, predict=True,
+ model_fn=model_fn_diff_modes):
+ est = estimator.Estimator(model_fn, self._get_tmp_dir())
+ est.train(input_fn=dummy_input_fn, steps=10)
+
+ input_receiver_fn_map = {}
+ if train:
+ input_receiver_fn_map[model_fn_lib.ModeKeys.TRAIN] = (
+ dummy_supervised_receiver_fn())
+ if evaluate:
+ input_receiver_fn_map[model_fn_lib.ModeKeys.EVAL] = (
+ dummy_supervised_receiver_fn())
+ if predict:
+ input_receiver_fn_map[model_fn_lib.ModeKeys.PREDICT] = (
+ dummy_serving_receiver_fn())
+
+ export_base_path = self._get_tmp_dir()
+ export_dir = contrib_export.export_all_saved_models(
+ est, export_base_path, input_receiver_fn_map)
+ return export_dir
+
+ def test_load_all_modes(self):
+ sme = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(), self._get_tmp_dir())
+ sme.train(input_fn=dummy_input_fn, steps=1)
+ sme.train(input_fn=dummy_input_fn, steps=2)
+ self.assertEqual(13, sme.get_variable_value('global_step'))
+ self.assertEqual(60, sme.get_variable_value('some_var'))
+
+ eval_results = sme.evaluate(dummy_input_fn, steps=5)
+
+ self.assertEqual(13, eval_results['global_step'])
+ self.assertEqual(106, eval_results['loss'])
+ self.assertEqual(502, eval_results['metrics/abs_err'])
+
+ predictions = next(sme.predict(dummy_input_fn_features_only))
+ self.assertDictEqual({'output': 503}, predictions)
+
+ def test_load_all_modes_no_train(self):
+ """Ensure that all functions can be used without requiring a ckpt."""
+ sme = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(), self._get_tmp_dir())
+ eval_results = sme.evaluate(dummy_input_fn, steps=5)
+ self.assertEqual(10, eval_results['global_step'])
+ self.assertEqual(106, eval_results['loss'])
+ self.assertEqual(502, eval_results['metrics/abs_err'])
+
+ predictions = next(sme.predict(dummy_input_fn_features_only))
+ self.assertDictEqual({'output': 503}, predictions)
+
+ def test_partial_exported_estimator(self):
+ sme1 = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(train=False, predict=False), self._get_tmp_dir())
+ sme1.evaluate(dummy_input_fn, steps=5)
+ with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'):
+ sme1.train(input_fn=dummy_input_fn, steps=1)
+ with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'):
+ next(sme1.predict(dummy_input_fn_features_only))
+
+ sme2 = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(evaluate=False), self._get_tmp_dir())
+ sme2.train(input_fn=dummy_input_fn, steps=1)
+ next(sme2.predict(dummy_input_fn_features_only))
+ with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'):
+ sme2.evaluate(dummy_input_fn, steps=5)
+
+ def test_with_incorrect_input(self):
+ sme = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(), self._get_tmp_dir())
+
+ def bad_shape_input_fn():
+ return dataset_ops.Dataset.from_tensors((
+ {'x': constant_op.constant([1, 2], dtype=dtypes.int64)},
+ constant_op.constant([1, 2], dtype=dtypes.float32)))
+
+ with self.assertRaisesRegexp(ValueError, 'Expected shape'):
+ sme.train(bad_shape_input_fn, steps=1)
+
+ def bad_dtype_input_fn():
+ return dataset_ops.Dataset.from_tensors((
+ {'x': constant_op.constant([[1], [1]], dtype=dtypes.int32)},
+ constant_op.constant([[1], [1]], dtype=dtypes.int64)))
+
+ with self.assertRaisesRegexp(ValueError, 'Expected dtype'):
+ sme.train(bad_dtype_input_fn, steps=1)
+
+ def test_input_fn_with_global_step(self):
+ sme = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(), self._get_tmp_dir())
+
+ def bad_input_fn():
+ training.get_or_create_global_step()
+ return dataset_ops.Dataset.from_tensors((
+ {'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)},
+ constant_op.constant([[1], [1]], dtype=dtypes.float32)))
+
+ with self.assertRaisesRegexp(RuntimeError,
+ 'Graph must not contain a global step tensor'):
+ sme.train(bad_input_fn, steps=1)
+
+ def test_re_export_saved_model_serving_only(self):
+ sme = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(), self._get_tmp_dir())
+ sme.train(dummy_input_fn, steps=3)
+ self.assertEqual(13, sme.get_variable_value('global_step'))
+ self.assertEqual(60, sme.get_variable_value('some_var'))
+
+ predictions = next(sme.predict(dummy_input_fn_features_only))
+ self.assertDictEqual({'output': 503}, predictions)
+
+ # Export SavedModel, and test that the variable and prediction values are
+ # the same.
+ sme_export_dir = sme.export_savedmodel(
+ self._get_tmp_dir(), dummy_serving_receiver_fn())
+
+ sme2 = saved_model_estimator.SavedModelEstimator(
+ sme_export_dir, self._get_tmp_dir())
+ self.assertEqual(60, sme.get_variable_value('some_var'))
+ self.assertEqual(13, sme.get_variable_value('global_step'))
+
+ predictions = next(sme2.predict(dummy_input_fn_features_only))
+ self.assertDictEqual({'output': 503}, predictions)
+
+ def test_re_export_saved_model(self):
+ sme = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(), self._get_tmp_dir())
+ self.assertDictEqual(
+ {'loss': 106, 'metrics/abs_err': 502, 'global_step': 10},
+ sme.evaluate(dummy_input_fn, steps=1))
+
+ sme.train(dummy_input_fn, steps=3)
+ self.assertDictEqual(
+ {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
+ sme.evaluate(dummy_input_fn, steps=1))
+ self.assertEqual(60, sme.get_variable_value('some_var'))
+
+ predictions = next(sme.predict(dummy_input_fn_features_only))
+ self.assertDictEqual({'output': 503}, predictions)
+
+ # Export SavedModel for all modes
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.TRAIN: dummy_supervised_receiver_fn(),
+ model_fn_lib.ModeKeys.EVAL: dummy_supervised_receiver_fn(),
+ model_fn_lib.ModeKeys.PREDICT: dummy_serving_receiver_fn()}
+ sme_export_dir = contrib_export.export_all_saved_models(
+ sme, self._get_tmp_dir(), input_receiver_fn_map)
+
+ sme2 = saved_model_estimator.SavedModelEstimator(
+ sme_export_dir, self._get_tmp_dir())
+ self.assertDictEqual(
+ {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
+ sme.evaluate(dummy_input_fn, steps=1))
+ self.assertEqual(60, sme.get_variable_value('some_var'))
+
+ sme.train(dummy_input_fn, steps=7)
+ self.assertEqual(20, sme.get_variable_value('global_step'))
+
+ predictions = next(sme2.predict(dummy_input_fn_features_only))
+ self.assertDictEqual({'output': 503}, predictions)
+
+ def test_load_saved_model_from_serving_only(self):
+ def model_fn(features, labels, mode):
+ _, _ = features, labels
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant([103]),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ predictions=constant_op.constant([502]),
+ export_outputs={'test': export_output.ClassificationOutput(
+ constant_op.constant([[32.]]))})
+
+ est = estimator.Estimator(model_fn, self._get_tmp_dir())
+ est.train(input_fn=dummy_input_fn, steps=10)
+
+ def serving_input_receiver_fn():
+ return export.ServingInputReceiver(
+ {'test-features': constant_op.constant([[1], [1]])},
+ array_ops.placeholder(dtype=dtypes.string))
+
+ export_dir = est.export_savedmodel(
+ self._get_tmp_dir(), serving_input_receiver_fn)
+
+ sme = saved_model_estimator.SavedModelEstimator(
+ export_dir, self._get_tmp_dir())
+
+ def input_fn():
+ return {'inputs': constant_op.constant('someinputstr')}
+
+ prediction = next(sme.predict(input_fn))
+ self.assertDictEqual({'scores': 32}, prediction)
+
+ def test_with_local_init_op(self):
+ def model_fn(features, labels, mode):
+ _, _ = features, labels
+ v = variables.Variable(21, name='some_var')
+ scaffold = monitored_session.Scaffold(
+ local_init_op=state_ops.assign_add(v, -3).op
+ )
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ scaffold=scaffold,
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ loss=array_ops.identity(v))
+ export_dir = self._export_estimator(predict=False, model_fn=model_fn)
+ sme = saved_model_estimator.SavedModelEstimator(
+ export_dir, self._get_tmp_dir())
+
+ eval_results1 = sme.evaluate(dummy_input_fn, steps=2)
+ self.assertEqual(15, eval_results1['loss'])
+
+ sme.train(dummy_input_fn, steps=1)
+ self.assertEqual(15, sme.get_variable_value('some_var'))
+
+ eval_results2 = sme.evaluate(dummy_input_fn, steps=5)
+ self.assertEqual(12, eval_results2['loss'])
+
+ def test_with_working_input_fn(self):
+ def model_fn(features, labels, mode):
+ loss = None
+ if labels is not None:
+ loss = labels[0][0] + labels[1][0]
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=loss,
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ predictions={'features_0': array_ops.identity([features['x'][0][0]]),
+ 'features_1': array_ops.identity([features['x'][1][0]])})
+
+ sme = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(model_fn=model_fn), self._get_tmp_dir())
+ eval_results = sme.evaluate(dummy_input_fn, steps=1)
+ self.assertEqual(1, eval_results['loss'])
+
+ predictions = next(sme.predict(dummy_input_fn_features_only))
+ self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions)
+
+ def test_control_dependency(self):
+ # Control dependencies are saved with "^" appended to the start of the input
+ # name. The input map must include control dependencies as well.
+ def model_fn(features, labels, mode):
+ _ = labels
+ with ops.control_dependencies([features['x']]):
+ loss = features['x'][1][0]
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=loss,
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+ sme = saved_model_estimator.SavedModelEstimator(
+ self._export_estimator(train=False, predict=False, model_fn=model_fn),
+ self._get_tmp_dir())
+ sme.evaluate(dummy_input_fn, steps=1) # Should run without error
+
+
+if __name__ == '__main__':
+ test.main()