diff options
author | Katherine Wu <kathywu@google.com> | 2018-07-20 15:45:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-20 15:51:41 -0700 |
commit | 6c528feaf820bdde820833ad24e05167adb5daa7 (patch) | |
tree | cdffac07b9e343e03958b734ac9553102bbd4ccf /tensorflow/contrib/estimator | |
parent | 5e876a8c25819070d78aa96595943afa207a6671 (diff) |
Automated rollback of commit 8257891f378027a1a7c0403ba6ba0aeb313496a0
PiperOrigin-RevId: 205466000
Diffstat (limited to 'tensorflow/contrib/estimator')
4 files changed, 0 insertions, 860 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 349f48f7f7..1aa3df8d8d 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -28,7 +28,6 @@ py_library( ":multi_head", ":replicate_model_fn", ":rnn", - ":saved_model_estimator", "//tensorflow:tensorflow_py_no_contrib", ], ) @@ -466,43 +465,3 @@ py_test( "@absl_py//absl/testing:parameterized", ], ) - -py_library( - name = "saved_model_estimator", - srcs = ["python/estimator/saved_model_estimator.py"], - deps = [ - ":export", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:export", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/saved_model", - ], -) - -py_test( - name = "saved_model_estimator_test", - size = "medium", - srcs = ["python/estimator/saved_model_estimator_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":export", - ":saved_model_estimator", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:export_export", - "//tensorflow/python/estimator:export_output", - "//tensorflow/python/estimator:model_fn", - ], -) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index e1453ae1d0..09fcfd66a1 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -33,8 +33,6 @@ from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * from tensorflow.contrib.estimator.python.estimator.rnn import * -from tensorflow.contrib.estimator.python.estimator.saved_model_estimator import * -from tensorflow.python.estimator.export.export import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -72,9 +70,6 @@ _allowed_symbols = [ 'stop_if_higher_hook', 'stop_if_no_increase_hook', 'stop_if_no_decrease_hook', - 'build_raw_supervised_input_receiver_fn', - 'build_supervised_input_receiver_fn_from_input_fn', - 'SavedModelEstimator' ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py deleted file mode 100644 index 22188fe663..0000000000 --- a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Class that creates an Estimator from a SavedModel.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import six - -from tensorflow.python.estimator import estimator as estimator_lib -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export as export_lib -from tensorflow.python.estimator.export import export_output -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.saved_model import constants -from tensorflow.python.saved_model import loader_impl -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import monitored_session -from tensorflow.python.training import training_util - - -class SavedModelEstimator(estimator_lib.Estimator): - """Create an Estimator from a SavedModel. - - Only SavedModels exported with - `tf.contrib.estimator.export_all_saved_models()` or - `tf.estimator.Estimator.export_savedmodel()` are supported for this class. - - Example with `tf.estimator.DNNClassifier`: - - **Step 1: Create and train DNNClassifier.** - ```python - feature1 = tf.feature_column.embedding_column( - tf.feature_column.categorical_column_with_vocabulary_list( - key='feature1', vocabulary_list=('green', 'yellow')), dimension=1) - feature2 = tf.feature_column.numeric_column(key='feature2', default_value=0.0) - - classifier = tf.estimator.DNNClassifier( - hidden_units=[4,2], feature_columns=[feature1, feature2]) - - def input_fn(): - features = {'feature1': tf.constant(['green', 'green', 'yellow']), - 'feature2': tf.constant([3.5, 4.2, 6.1])} - label = tf.constant([1., 0., 0.]) - return tf.data.Dataset.from_tensors((features, label)).repeat() - - classifier.train(input_fn=input_fn, steps=10) - ``` - - **Step 2: Export classifier.** - First, build functions that specify the expected inputs. - ```python - # During train and evaluation, both the features and labels should be defined. - supervised_input_receiver_fn = ( - tf.contrib.estimator.build_raw_supervised_input_receiver_fn( - {'feature1': tf.placeholder(dtype=tf.string, shape=[None]), - 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])}, - tf.placeholder(dtype=tf.float32, shape=[None]))) - - # During predict mode, expect to receive a `tf.Example` proto, so a parsing - # function is used. - serving_input_receiver_fn = ( - tf.estimator.export.build_parsing_serving_input_receiver_fn( - tf.feature_column.make_parse_example_spec([feature1, feature2]))) - ``` - - Next, export the model as a SavedModel. A timestamped directory will be - created (for example `/tmp/export_all/1234567890`). - ```python - # Option 1: Save all modes (train, eval, predict) - export_dir = tf.contrib.estimator.export_all_saved_models( - classifier, '/tmp/export_all', - {tf.estimator.ModeKeys.TRAIN: supervised_input_receiver_fn, - tf.estimator.ModeKeys.EVAL: supervised_input_receiver_fn, - tf.estimator.ModeKeys.PREDICT: serving_input_receiver_fn}) - - # Option 2: Only export predict mode - export_dir = classifier.export_savedmodel( - '/tmp/export_predict', serving_input_receiver_fn) - ``` - - **Step 3: Create a SavedModelEstimator from the exported SavedModel.** - ```python - est = tf.contrib.estimator.SavedModelEstimator(export_dir) - - # If all modes were exported, you can immediately evaluate and predict, or - # continue training. Otherwise only predict is available. - eval_results = est.evaluate(input_fn=input_fn, steps=1) - print(eval_results) - - est.train(input_fn=input_fn, steps=20) - - def predict_input_fn(): - example = example_pb2.Example() - example.features.feature['feature1'].bytes_list.value.extend(['yellow']) - example.features.feature['feature2'].float_list.value.extend([1.]) - return {'inputs':tf.constant([example.SerializeToString()])} - - predictions = est.predict(predict_input_fn) - print(next(predictions)) - ``` - """ - - def __init__(self, saved_model_dir, model_dir=None): - """Initialize a SavedModelEstimator. - - The SavedModelEstimator loads its model function and variable values from - the graphs defined in the SavedModel. There is no option to pass in - `RunConfig` or `params` arguments, because the model function graph is - defined statically in the SavedModel. - - Args: - saved_model_dir: Directory containing SavedModel protobuf and subfolders. - model_dir: Directory to save new checkpoints during training. - - Raises: - NotImplementedError: If a DistributionStrategy is defined in the config. - Unless the SavedModelEstimator is subclassed, this shouldn't happen. - """ - checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access - vars_to_warm_start = [name for name, _ in - checkpoint_utils.list_variables(checkpoint)] - warm_start_settings = estimator_lib.WarmStartSettings( - ckpt_to_initialize_from=checkpoint, - vars_to_warm_start=vars_to_warm_start) - - super(SavedModelEstimator, self).__init__( - model_fn=self._model_fn_from_saved_model, model_dir=model_dir, - warm_start_from=warm_start_settings) - if self._distribution is not None: - raise NotImplementedError( - 'SavedModelEstimator currently does not support ' - 'DistributionStrategy.') - self.saved_model_dir = saved_model_dir - self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir) - self._available_modes = self._extract_available_modes() - - def _extract_available_modes(self): - """Return list of modes found in SavedModel.""" - available_modes = [] - logging.info('Checking available modes for SavedModelEstimator.') - for mode in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, - model_fn_lib.ModeKeys.PREDICT]: - try: - self._get_meta_graph_def_for_mode(mode) - except RuntimeError: - logging.warning('%s mode not found in SavedModel.' % mode) - continue - - if self._get_signature_def_for_mode(mode) is not None: - available_modes.append(mode) - - logging.info('Available modes for Estimator: %s' % available_modes) - return available_modes - - def _validate_mode(self, mode): - """Make sure that mode can be run using the SavedModel.""" - if mode not in self._available_modes: - raise RuntimeError('%s mode is not available in the SavedModel. Use ' - 'saved_model_cli to check that the Metagraph for this ' - 'mode has been exported.' % mode) - - def _get_meta_graph_def_for_mode(self, mode): - tags = model_fn_lib.EXPORT_TAG_MAP[mode] - return self.saved_model_loader.get_meta_graph_def_from_tags(tags) - - def _get_signature_def_for_mode(self, mode): - meta_graph_def = self._get_meta_graph_def_for_mode(mode) - sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if mode == model_fn_lib.ModeKeys.PREDICT else mode) - if sig_def_key not in meta_graph_def.signature_def: - logging.warning('Metagraph for mode %s was found, but SignatureDef with' - ' key \"%s\" is missing.' % (mode, sig_def_key)) - return None - return meta_graph_def.signature_def[sig_def_key] - - def _create_and_assert_global_step(self, graph): - # Do nothing here. The global step variable will be created/loaded from the - # SavedModel. If a global step variable were created here, the result - # will be two duplicate global step variables, causing issues during - # the warm-start phase. - # Due to the global variable being created in the model function, this may - # cause issues when running DistributionStrategy. Thus, DistributionStrategy - # is not yet supported with SavedModelEstimator. - pass - - def _model_fn_from_saved_model(self, features, labels, mode): - """Load a SavedModel graph and return an EstimatorSpec.""" - # TODO(kathywu): Model function loads placeholders from the graph. Calling - # export_all_saved_models creates another placeholder for the inputs, on top - # of the original placeholders. There should be a way to avoid this. - self._validate_mode(mode) - - g = ops.get_default_graph() - if training_util.get_global_step(g) is not None: - raise RuntimeError( - 'Graph must not contain a global step tensor before the SavedModel is' - ' loaded. Please make sure that the input function does not create a ' - 'global step.') - - # Extract SignatureDef for information about the input and output tensors. - signature_def = self._get_signature_def_for_mode(mode) - - # Generate input map for replacing the inputs in the SavedModel graph with - # the provided features and labels. - input_map = _generate_input_map(signature_def, features, labels) - - # Create a list of the names of output tensors. When the graph is loaded, - # names of the output tensors may be remapped. This ensures that the correct - # tensors are returned in the EstimatorSpec. - output_tensor_names = [ - value.name for value in six.itervalues(signature_def.outputs)] - - # Load the graph. `output_tensors` contains output `Tensors` in the same - # same order as the `output_tensor_names` list. - tags = model_fn_lib.EXPORT_TAG_MAP[mode] - _, output_tensors = self.saved_model_loader.load_graph( - g, tags, input_map=input_map, return_elements=output_tensor_names) - - # Create a scaffold from the MetaGraphDef that contains ops to initialize - # the graph. This should mirror the steps from _add_meta_graph_for_mode(), - # which creates a MetaGraphDef from the EstimatorSpec's scaffold. - scaffold = monitored_session.Scaffold( - local_init_op=loader_impl._get_legacy_init_op_tensor( # pylint: disable=protected-access - self._get_meta_graph_def_for_mode(mode))) - - # Ensure that a global step tensor has been created. - global_step_tensor = training_util.get_global_step(g) - training_util.assert_global_step(global_step_tensor) - - # Extract values to return in the EstimatorSpec. - output_map = dict(zip(output_tensor_names, output_tensors)) - outputs = {key: output_map[value.name] - for key, value in six.iteritems(signature_def.outputs)} - - loss, predictions, metrics = _validate_and_extract_outputs( - mode, outputs, signature_def.method_name) - - train_op = ops.get_collection(constants.TRAIN_OP_KEY) - if len(train_op) > 1: - raise RuntimeError('Multiple ops found in the train_op collection.') - train_op = None if not train_op else train_op[0] - - _clear_saved_model_collections() - return model_fn_lib.EstimatorSpec( - scaffold=scaffold, - mode=mode, - loss=loss, - train_op=train_op, - predictions=predictions, - eval_metric_ops=metrics) - - -def _clear_saved_model_collections(): - """Clear collections that are expected empty when exporting a SavedModel. - - The SavedModel builder uses these collections to track ops necessary to - restore the graph state. These collections are expected to be empty before - MetaGraphs are added to the builder. - """ - del ops.get_collection_ref(constants.ASSETS_KEY)[:] - del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:] - del ops.get_collection_ref(constants.MAIN_OP_KEY)[:] - del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:] - - -def _generate_input_map(signature_def, features, labels): - """Return dict mapping an input tensor name to a feature or label tensor. - - Args: - signature_def: SignatureDef loaded from SavedModel - features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or - `SparseTensor`, specifying the features to be passed to the model. - labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or - `SparseTensor`, specifying the labels to be passed to the model. May be - `None`. - - Returns: - dict mapping string names of inputs to features or labels tensors - - Raises: - ValueError: if SignatureDef inputs are not completely mapped by the input - features and labels. - """ - # pylint: disable=protected-access - if not isinstance(features, dict): - features = {export_lib._SINGLE_FEATURE_DEFAULT_NAME: features} - if labels is not None and not isinstance(labels, dict): - labels = {export_lib._SINGLE_LABEL_DEFAULT_NAME: labels} - # pylint: enable=protected-access - - inputs = signature_def.inputs - input_map = {} - for key, tensor_info in six.iteritems(inputs): - input_name = tensor_info.name - if ':' in input_name: - input_name = input_name[:input_name.find(':')] - - # When tensors are used as control inputs for operations, their names are - # prepended with a '^' character in the GraphDef. To handle possible control - # flow edge cases, control input names must be included in the input map. - control_dependency_name = '^' + input_name - - if key in features: - _check_same_dtype_and_shape(features[key], tensor_info, key) - input_map[input_name] = input_map[control_dependency_name] = features[key] - elif labels is not None and key in labels: - _check_same_dtype_and_shape(labels[key], tensor_info, key) - input_map[input_name] = input_map[control_dependency_name] = labels[key] - else: - raise ValueError( - 'Key \"%s\" not found in features or labels passed in to the model ' - 'function. All required keys: %s' % (key, inputs.keys())) - - return input_map - - -def _check_same_dtype_and_shape(tensor, tensor_info, name): - """Validate that tensor has the same properties as the TensorInfo proto. - - Args: - tensor: a `Tensor` object. - tensor_info: a `TensorInfo` proto. - name: Name of the input (to identify Tensor if an error is raised). - - Raises: - ValueError: If the tensor shape or dtype don't match the TensorInfo - """ - dtype_error = (tensor.dtype != dtypes.DType(tensor_info.dtype)) - shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape) - - if dtype_error or shape_error: - msg = 'Tensor shape and/or dtype validation failed for input %s:' % name - if dtype_error: - msg += ('\n\tExpected dtype: %s, Got: %s' - % (dtypes.DType(tensor_info.dtype), tensor.dtype)) - if shape_error: - msg += ('\n\tExpected shape: %s, Got: %s' - % (tensor_shape.TensorShape(tensor_info.tensor_shape), - tensor.shape)) - - raise ValueError(msg) - - -def _extract_eval_metrics(output_dict): - """Return a eval metric dict extracted from the output_dict. - - Eval metrics consist of a value tensor and an update op. Both must be in the - passed-in tensor dictionary for an eval metric to be added to the returned - dictionary. - - Args: - output_dict: a dict that maps strings to tensors. - - Returns: - dict mapping strings to (value, update_op) tuples. - """ - # pylint: disable=protected-access - metric_ops = {} - separator_char = export_output._SupervisedOutput._SEPARATOR_CHAR - - for key, tensor in six.iteritems(output_dict): - split_key = key.split(separator_char) - - # The metric name may contain the separator character, so recreate its name. - metric_name = separator_char.join(split_key[:-1]) - - if split_key[0] == export_output._SupervisedOutput.METRICS_NAME: - # If the key ends with the value suffix, and there is a corresponding - # key ending with the update_op suffix, then add tensors to metrics dict. - if split_key[-1] == export_output._SupervisedOutput.METRIC_VALUE_SUFFIX: - update_op = ''.join( - [metric_name, separator_char, - export_output._SupervisedOutput.METRIC_UPDATE_SUFFIX]) - if update_op in output_dict: - update_op_tensor = output_dict[update_op] - metric_ops[metric_name] = (tensor, update_op_tensor) - - # pylint: enable=protected-access - return metric_ops - - -def _validate_and_extract_outputs(mode, output_dict, method_name): - """Extract values from SignatureDef output dictionary. - - Args: - mode: One of the modes enumerated in `tf.estimator.ModeKeys`. - output_dict: dict of string SignatureDef keys to `Tensor`. - method_name: Method name of the SignatureDef as a string. - - Returns: - Tuple of ( - loss: `Tensor` object, - predictions: dictionary mapping string keys to `Tensor` objects, - metrics: dictionary mapping string keys to a tuple of two `Tensor` objects - ) - - Raises: - RuntimeError: raised if SignatureDef has an invalid method name for the mode - """ - # pylint: disable=protected-access - loss, predictions, metrics = None, None, None - - if mode == model_fn_lib.ModeKeys.PREDICT: - predictions = output_dict - else: - # Validate that the SignatureDef's method name matches the expected name for - # the given mode. - expected_method_name = signature_constants.SUPERVISED_TRAIN_METHOD_NAME - if mode == model_fn_lib.ModeKeys.EVAL: - expected_method_name = signature_constants.SUPERVISED_EVAL_METHOD_NAME - if method_name != expected_method_name: - raise RuntimeError( - 'Invalid SignatureDef method name for mode %s.\n\tExpected: %s\n\t' - 'Got: %s\nPlease ensure that the SavedModel was exported with ' - '`tf.contrib.estimator.export_all_saved_models()`.' % - (mode, expected_method_name, method_name)) - - # Extract loss, metrics and predictions from the output dict. - loss = output_dict[export_output._SupervisedOutput.LOSS_NAME] - metrics = _extract_eval_metrics(output_dict) - predictions = { - key: value for key, value in six.iteritems(output_dict) - if key.split(export_output._SupervisedOutput._SEPARATOR_CHAR)[0] == ( - export_output._SupervisedOutput.PREDICTIONS_NAME)} - - # pylint: enable=protected-access - return loss, predictions, metrics diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py deleted file mode 100644 index 718da1367c..0000000000 --- a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py +++ /dev/null @@ -1,369 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for SavedModelEstimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import shutil -import tempfile - -from tensorflow.contrib.estimator.python.estimator import export as contrib_export -from tensorflow.contrib.estimator.python.estimator import saved_model_estimator -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.estimator import estimator -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.export import export -from tensorflow.python.estimator.export import export_output -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import metrics as metrics_lib -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.training import monitored_session -from tensorflow.python.training import training - - -def dummy_input_fn(): - return dataset_ops.Dataset.from_tensors(( - {'x': constant_op.constant([[1], [-2]], dtype=dtypes.int64)}, - constant_op.constant([[4], [-3]], dtype=dtypes.float32))).repeat() - - -def dummy_input_fn_features_only(): - return dataset_ops.Dataset.from_tensors( - {'x': constant_op.constant([[5], [6]], dtype=dtypes.int64)}).repeat() - - -def dummy_supervised_receiver_fn(): - feature_spec = { - 'x': array_ops.placeholder( - dtype=dtypes.int64, shape=(2, 1), name='feature_x'), - } - label_spec = array_ops.placeholder( - dtype=dtypes.float32, shape=[2, 1], name='truth') - return export.build_raw_supervised_input_receiver_fn( - feature_spec, label_spec) - - -def dummy_serving_receiver_fn(): - feature_spec = {'x': array_ops.placeholder( - dtype=dtypes.int64, shape=(2, 1), name='feature_x'),} - return export.build_raw_serving_input_receiver_fn(feature_spec) - - -def model_fn_diff_modes(features, labels, mode): - _, _ = features, labels - v = variables.Variable(21, name='some_var') - train_op = None - loss = constant_op.constant(104) - if mode == model_fn_lib.ModeKeys.TRAIN: - loss = constant_op.constant(105) - predictions = constant_op.constant([501]) - train_op = control_flow_ops.group( - state_ops.assign_add(training.get_global_step(), 1), - state_ops.assign_add(v, 3)) - elif mode == model_fn_lib.ModeKeys.EVAL: - loss = constant_op.constant(106) - predictions = constant_op.constant([502]) - else: - loss = constant_op.constant(107) - predictions = constant_op.constant([503]) - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - train_op=train_op, - eval_metric_ops={ - 'abs_err': metrics_lib.mean_absolute_error( - constant_op.constant(0), predictions)}, - predictions=predictions) - - -class SavedModelEstimatorTest(test.TestCase): - - def setUp(self): - self.tmpdirs = [] - - def tearDown(self): - for tmpdir in self.tmpdirs: - # gfile.DeleteRecursively fails in the windows cmake test, so use shutil. - shutil.rmtree(tmpdir, ignore_errors=True) - self.tmpdirs = [] - - def _get_tmp_dir(self): - tmpdir = tempfile.mkdtemp() - self.tmpdirs.append(tmpdir) - return tmpdir - - def _export_estimator(self, train=True, evaluate=True, predict=True, - model_fn=model_fn_diff_modes): - est = estimator.Estimator(model_fn, self._get_tmp_dir()) - est.train(input_fn=dummy_input_fn, steps=10) - - input_receiver_fn_map = {} - if train: - input_receiver_fn_map[model_fn_lib.ModeKeys.TRAIN] = ( - dummy_supervised_receiver_fn()) - if evaluate: - input_receiver_fn_map[model_fn_lib.ModeKeys.EVAL] = ( - dummy_supervised_receiver_fn()) - if predict: - input_receiver_fn_map[model_fn_lib.ModeKeys.PREDICT] = ( - dummy_serving_receiver_fn()) - - export_base_path = self._get_tmp_dir() - export_dir = contrib_export.export_all_saved_models( - est, export_base_path, input_receiver_fn_map) - return export_dir - - def test_load_all_modes(self): - sme = saved_model_estimator.SavedModelEstimator( - self._export_estimator(), self._get_tmp_dir()) - sme.train(input_fn=dummy_input_fn, steps=1) - sme.train(input_fn=dummy_input_fn, steps=2) - self.assertEqual(13, sme.get_variable_value('global_step')) - self.assertEqual(60, sme.get_variable_value('some_var')) - - eval_results = sme.evaluate(dummy_input_fn, steps=5) - - self.assertEqual(13, eval_results['global_step']) - self.assertEqual(106, eval_results['loss']) - self.assertEqual(502, eval_results['metrics/abs_err']) - - predictions = next(sme.predict(dummy_input_fn_features_only)) - self.assertDictEqual({'output': 503}, predictions) - - def test_load_all_modes_no_train(self): - """Ensure that all functions can be used without requiring a ckpt.""" - sme = saved_model_estimator.SavedModelEstimator( - self._export_estimator(), self._get_tmp_dir()) - eval_results = sme.evaluate(dummy_input_fn, steps=5) - self.assertEqual(10, eval_results['global_step']) - self.assertEqual(106, eval_results['loss']) - self.assertEqual(502, eval_results['metrics/abs_err']) - - predictions = next(sme.predict(dummy_input_fn_features_only)) - self.assertDictEqual({'output': 503}, predictions) - - def test_partial_exported_estimator(self): - sme1 = saved_model_estimator.SavedModelEstimator( - self._export_estimator(train=False, predict=False), self._get_tmp_dir()) - sme1.evaluate(dummy_input_fn, steps=5) - with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'): - sme1.train(input_fn=dummy_input_fn, steps=1) - with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'): - next(sme1.predict(dummy_input_fn_features_only)) - - sme2 = saved_model_estimator.SavedModelEstimator( - self._export_estimator(evaluate=False), self._get_tmp_dir()) - sme2.train(input_fn=dummy_input_fn, steps=1) - next(sme2.predict(dummy_input_fn_features_only)) - with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'): - sme2.evaluate(dummy_input_fn, steps=5) - - def test_with_incorrect_input(self): - sme = saved_model_estimator.SavedModelEstimator( - self._export_estimator(), self._get_tmp_dir()) - - def bad_shape_input_fn(): - return dataset_ops.Dataset.from_tensors(( - {'x': constant_op.constant([1, 2], dtype=dtypes.int64)}, - constant_op.constant([1, 2], dtype=dtypes.float32))) - - with self.assertRaisesRegexp(ValueError, 'Expected shape'): - sme.train(bad_shape_input_fn, steps=1) - - def bad_dtype_input_fn(): - return dataset_ops.Dataset.from_tensors(( - {'x': constant_op.constant([[1], [1]], dtype=dtypes.int32)}, - constant_op.constant([[1], [1]], dtype=dtypes.int64))) - - with self.assertRaisesRegexp(ValueError, 'Expected dtype'): - sme.train(bad_dtype_input_fn, steps=1) - - def test_input_fn_with_global_step(self): - sme = saved_model_estimator.SavedModelEstimator( - self._export_estimator(), self._get_tmp_dir()) - - def bad_input_fn(): - training.get_or_create_global_step() - return dataset_ops.Dataset.from_tensors(( - {'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)}, - constant_op.constant([[1], [1]], dtype=dtypes.float32))) - - with self.assertRaisesRegexp(RuntimeError, - 'Graph must not contain a global step tensor'): - sme.train(bad_input_fn, steps=1) - - def test_re_export_saved_model_serving_only(self): - sme = saved_model_estimator.SavedModelEstimator( - self._export_estimator(), self._get_tmp_dir()) - sme.train(dummy_input_fn, steps=3) - self.assertEqual(13, sme.get_variable_value('global_step')) - self.assertEqual(60, sme.get_variable_value('some_var')) - - predictions = next(sme.predict(dummy_input_fn_features_only)) - self.assertDictEqual({'output': 503}, predictions) - - # Export SavedModel, and test that the variable and prediction values are - # the same. - sme_export_dir = sme.export_savedmodel( - self._get_tmp_dir(), dummy_serving_receiver_fn()) - - sme2 = saved_model_estimator.SavedModelEstimator( - sme_export_dir, self._get_tmp_dir()) - self.assertEqual(60, sme.get_variable_value('some_var')) - self.assertEqual(13, sme.get_variable_value('global_step')) - - predictions = next(sme2.predict(dummy_input_fn_features_only)) - self.assertDictEqual({'output': 503}, predictions) - - def test_re_export_saved_model(self): - sme = saved_model_estimator.SavedModelEstimator( - self._export_estimator(), self._get_tmp_dir()) - self.assertDictEqual( - {'loss': 106, 'metrics/abs_err': 502, 'global_step': 10}, - sme.evaluate(dummy_input_fn, steps=1)) - - sme.train(dummy_input_fn, steps=3) - self.assertDictEqual( - {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13}, - sme.evaluate(dummy_input_fn, steps=1)) - self.assertEqual(60, sme.get_variable_value('some_var')) - - predictions = next(sme.predict(dummy_input_fn_features_only)) - self.assertDictEqual({'output': 503}, predictions) - - # Export SavedModel for all modes - input_receiver_fn_map = { - model_fn_lib.ModeKeys.TRAIN: dummy_supervised_receiver_fn(), - model_fn_lib.ModeKeys.EVAL: dummy_supervised_receiver_fn(), - model_fn_lib.ModeKeys.PREDICT: dummy_serving_receiver_fn()} - sme_export_dir = contrib_export.export_all_saved_models( - sme, self._get_tmp_dir(), input_receiver_fn_map) - - sme2 = saved_model_estimator.SavedModelEstimator( - sme_export_dir, self._get_tmp_dir()) - self.assertDictEqual( - {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13}, - sme.evaluate(dummy_input_fn, steps=1)) - self.assertEqual(60, sme.get_variable_value('some_var')) - - sme.train(dummy_input_fn, steps=7) - self.assertEqual(20, sme.get_variable_value('global_step')) - - predictions = next(sme2.predict(dummy_input_fn_features_only)) - self.assertDictEqual({'output': 503}, predictions) - - def test_load_saved_model_from_serving_only(self): - def model_fn(features, labels, mode): - _, _ = features, labels - return model_fn_lib.EstimatorSpec( - mode, - loss=constant_op.constant([103]), - train_op=state_ops.assign_add(training.get_global_step(), 1), - predictions=constant_op.constant([502]), - export_outputs={'test': export_output.ClassificationOutput( - constant_op.constant([[32.]]))}) - - est = estimator.Estimator(model_fn, self._get_tmp_dir()) - est.train(input_fn=dummy_input_fn, steps=10) - - def serving_input_receiver_fn(): - return export.ServingInputReceiver( - {'test-features': constant_op.constant([[1], [1]])}, - array_ops.placeholder(dtype=dtypes.string)) - - export_dir = est.export_savedmodel( - self._get_tmp_dir(), serving_input_receiver_fn) - - sme = saved_model_estimator.SavedModelEstimator( - export_dir, self._get_tmp_dir()) - - def input_fn(): - return {'inputs': constant_op.constant('someinputstr')} - - prediction = next(sme.predict(input_fn)) - self.assertDictEqual({'scores': 32}, prediction) - - def test_with_local_init_op(self): - def model_fn(features, labels, mode): - _, _ = features, labels - v = variables.Variable(21, name='some_var') - scaffold = monitored_session.Scaffold( - local_init_op=state_ops.assign_add(v, -3).op - ) - return model_fn_lib.EstimatorSpec( - mode, - scaffold=scaffold, - train_op=state_ops.assign_add(training.get_global_step(), 1), - loss=array_ops.identity(v)) - export_dir = self._export_estimator(predict=False, model_fn=model_fn) - sme = saved_model_estimator.SavedModelEstimator( - export_dir, self._get_tmp_dir()) - - eval_results1 = sme.evaluate(dummy_input_fn, steps=2) - self.assertEqual(15, eval_results1['loss']) - - sme.train(dummy_input_fn, steps=1) - self.assertEqual(15, sme.get_variable_value('some_var')) - - eval_results2 = sme.evaluate(dummy_input_fn, steps=5) - self.assertEqual(12, eval_results2['loss']) - - def test_with_working_input_fn(self): - def model_fn(features, labels, mode): - loss = None - if labels is not None: - loss = labels[0][0] + labels[1][0] - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - train_op=state_ops.assign_add(training.get_global_step(), 1), - predictions={'features_0': array_ops.identity([features['x'][0][0]]), - 'features_1': array_ops.identity([features['x'][1][0]])}) - - sme = saved_model_estimator.SavedModelEstimator( - self._export_estimator(model_fn=model_fn), self._get_tmp_dir()) - eval_results = sme.evaluate(dummy_input_fn, steps=1) - self.assertEqual(1, eval_results['loss']) - - predictions = next(sme.predict(dummy_input_fn_features_only)) - self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions) - - def test_control_dependency(self): - # Control dependencies are saved with "^" appended to the start of the input - # name. The input map must include control dependencies as well. - def model_fn(features, labels, mode): - _ = labels - with ops.control_dependencies([features['x']]): - loss = features['x'][1][0] - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - train_op=state_ops.assign_add(training.get_global_step(), 1)) - sme = saved_model_estimator.SavedModelEstimator( - self._export_estimator(train=False, predict=False, model_fn=model_fn), - self._get_tmp_dir()) - sme.evaluate(dummy_input_fn, steps=1) # Should run without error - - -if __name__ == '__main__': - test.main() |