aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-07-20 15:45:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 15:51:41 -0700
commit6c528feaf820bdde820833ad24e05167adb5daa7 (patch)
treecdffac07b9e343e03958b734ac9553102bbd4ccf /tensorflow/contrib/estimator
parent5e876a8c25819070d78aa96595943afa207a6671 (diff)
Automated rollback of commit 8257891f378027a1a7c0403ba6ba0aeb313496a0
PiperOrigin-RevId: 205466000
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/BUILD41
-rw-r--r--tensorflow/contrib/estimator/__init__.py5
-rw-r--r--tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py445
-rw-r--r--tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py369
4 files changed, 0 insertions, 860 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 349f48f7f7..1aa3df8d8d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -28,7 +28,6 @@ py_library(
":multi_head",
":replicate_model_fn",
":rnn",
- ":saved_model_estimator",
"//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -466,43 +465,3 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
-
-py_library(
- name = "saved_model_estimator",
- srcs = ["python/estimator/saved_model_estimator.py"],
- deps = [
- ":export",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:export",
- "//tensorflow/python/estimator:model_fn",
- "//tensorflow/python/saved_model",
- ],
-)
-
-py_test(
- name = "saved_model_estimator_test",
- size = "medium",
- srcs = ["python/estimator/saved_model_estimator_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":export",
- ":saved_model_estimator",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:platform",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:export_export",
- "//tensorflow/python/estimator:export_output",
- "//tensorflow/python/estimator:model_fn",
- ],
-)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index e1453ae1d0..09fcfd66a1 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -33,8 +33,6 @@ from tensorflow.contrib.estimator.python.estimator.logit_fns import *
from tensorflow.contrib.estimator.python.estimator.multi_head import *
from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import *
from tensorflow.contrib.estimator.python.estimator.rnn import *
-from tensorflow.contrib.estimator.python.estimator.saved_model_estimator import *
-from tensorflow.python.estimator.export.export import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
@@ -72,9 +70,6 @@ _allowed_symbols = [
'stop_if_higher_hook',
'stop_if_no_increase_hook',
'stop_if_no_decrease_hook',
- 'build_raw_supervised_input_receiver_fn',
- 'build_supervised_input_receiver_fn_from_input_fn',
- 'SavedModelEstimator'
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
deleted file mode 100644
index 22188fe663..0000000000
--- a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator.py
+++ /dev/null
@@ -1,445 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Class that creates an Estimator from a SavedModel."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import six
-
-from tensorflow.python.estimator import estimator as estimator_lib
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator.export import export as export_lib
-from tensorflow.python.estimator.export import export_output
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.saved_model import constants
-from tensorflow.python.saved_model import loader_impl
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import monitored_session
-from tensorflow.python.training import training_util
-
-
-class SavedModelEstimator(estimator_lib.Estimator):
- """Create an Estimator from a SavedModel.
-
- Only SavedModels exported with
- `tf.contrib.estimator.export_all_saved_models()` or
- `tf.estimator.Estimator.export_savedmodel()` are supported for this class.
-
- Example with `tf.estimator.DNNClassifier`:
-
- **Step 1: Create and train DNNClassifier.**
- ```python
- feature1 = tf.feature_column.embedding_column(
- tf.feature_column.categorical_column_with_vocabulary_list(
- key='feature1', vocabulary_list=('green', 'yellow')), dimension=1)
- feature2 = tf.feature_column.numeric_column(key='feature2', default_value=0.0)
-
- classifier = tf.estimator.DNNClassifier(
- hidden_units=[4,2], feature_columns=[feature1, feature2])
-
- def input_fn():
- features = {'feature1': tf.constant(['green', 'green', 'yellow']),
- 'feature2': tf.constant([3.5, 4.2, 6.1])}
- label = tf.constant([1., 0., 0.])
- return tf.data.Dataset.from_tensors((features, label)).repeat()
-
- classifier.train(input_fn=input_fn, steps=10)
- ```
-
- **Step 2: Export classifier.**
- First, build functions that specify the expected inputs.
- ```python
- # During train and evaluation, both the features and labels should be defined.
- supervised_input_receiver_fn = (
- tf.contrib.estimator.build_raw_supervised_input_receiver_fn(
- {'feature1': tf.placeholder(dtype=tf.string, shape=[None]),
- 'feature2': tf.placeholder(dtype=tf.float32, shape=[None])},
- tf.placeholder(dtype=tf.float32, shape=[None])))
-
- # During predict mode, expect to receive a `tf.Example` proto, so a parsing
- # function is used.
- serving_input_receiver_fn = (
- tf.estimator.export.build_parsing_serving_input_receiver_fn(
- tf.feature_column.make_parse_example_spec([feature1, feature2])))
- ```
-
- Next, export the model as a SavedModel. A timestamped directory will be
- created (for example `/tmp/export_all/1234567890`).
- ```python
- # Option 1: Save all modes (train, eval, predict)
- export_dir = tf.contrib.estimator.export_all_saved_models(
- classifier, '/tmp/export_all',
- {tf.estimator.ModeKeys.TRAIN: supervised_input_receiver_fn,
- tf.estimator.ModeKeys.EVAL: supervised_input_receiver_fn,
- tf.estimator.ModeKeys.PREDICT: serving_input_receiver_fn})
-
- # Option 2: Only export predict mode
- export_dir = classifier.export_savedmodel(
- '/tmp/export_predict', serving_input_receiver_fn)
- ```
-
- **Step 3: Create a SavedModelEstimator from the exported SavedModel.**
- ```python
- est = tf.contrib.estimator.SavedModelEstimator(export_dir)
-
- # If all modes were exported, you can immediately evaluate and predict, or
- # continue training. Otherwise only predict is available.
- eval_results = est.evaluate(input_fn=input_fn, steps=1)
- print(eval_results)
-
- est.train(input_fn=input_fn, steps=20)
-
- def predict_input_fn():
- example = example_pb2.Example()
- example.features.feature['feature1'].bytes_list.value.extend(['yellow'])
- example.features.feature['feature2'].float_list.value.extend([1.])
- return {'inputs':tf.constant([example.SerializeToString()])}
-
- predictions = est.predict(predict_input_fn)
- print(next(predictions))
- ```
- """
-
- def __init__(self, saved_model_dir, model_dir=None):
- """Initialize a SavedModelEstimator.
-
- The SavedModelEstimator loads its model function and variable values from
- the graphs defined in the SavedModel. There is no option to pass in
- `RunConfig` or `params` arguments, because the model function graph is
- defined statically in the SavedModel.
-
- Args:
- saved_model_dir: Directory containing SavedModel protobuf and subfolders.
- model_dir: Directory to save new checkpoints during training.
-
- Raises:
- NotImplementedError: If a DistributionStrategy is defined in the config.
- Unless the SavedModelEstimator is subclassed, this shouldn't happen.
- """
- checkpoint = estimator_lib._get_saved_model_ckpt(saved_model_dir) # pylint: disable=protected-access
- vars_to_warm_start = [name for name, _ in
- checkpoint_utils.list_variables(checkpoint)]
- warm_start_settings = estimator_lib.WarmStartSettings(
- ckpt_to_initialize_from=checkpoint,
- vars_to_warm_start=vars_to_warm_start)
-
- super(SavedModelEstimator, self).__init__(
- model_fn=self._model_fn_from_saved_model, model_dir=model_dir,
- warm_start_from=warm_start_settings)
- if self._distribution is not None:
- raise NotImplementedError(
- 'SavedModelEstimator currently does not support '
- 'DistributionStrategy.')
- self.saved_model_dir = saved_model_dir
- self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir)
- self._available_modes = self._extract_available_modes()
-
- def _extract_available_modes(self):
- """Return list of modes found in SavedModel."""
- available_modes = []
- logging.info('Checking available modes for SavedModelEstimator.')
- for mode in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
- model_fn_lib.ModeKeys.PREDICT]:
- try:
- self._get_meta_graph_def_for_mode(mode)
- except RuntimeError:
- logging.warning('%s mode not found in SavedModel.' % mode)
- continue
-
- if self._get_signature_def_for_mode(mode) is not None:
- available_modes.append(mode)
-
- logging.info('Available modes for Estimator: %s' % available_modes)
- return available_modes
-
- def _validate_mode(self, mode):
- """Make sure that mode can be run using the SavedModel."""
- if mode not in self._available_modes:
- raise RuntimeError('%s mode is not available in the SavedModel. Use '
- 'saved_model_cli to check that the Metagraph for this '
- 'mode has been exported.' % mode)
-
- def _get_meta_graph_def_for_mode(self, mode):
- tags = model_fn_lib.EXPORT_TAG_MAP[mode]
- return self.saved_model_loader.get_meta_graph_def_from_tags(tags)
-
- def _get_signature_def_for_mode(self, mode):
- meta_graph_def = self._get_meta_graph_def_for_mode(mode)
- sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- if mode == model_fn_lib.ModeKeys.PREDICT else mode)
- if sig_def_key not in meta_graph_def.signature_def:
- logging.warning('Metagraph for mode %s was found, but SignatureDef with'
- ' key \"%s\" is missing.' % (mode, sig_def_key))
- return None
- return meta_graph_def.signature_def[sig_def_key]
-
- def _create_and_assert_global_step(self, graph):
- # Do nothing here. The global step variable will be created/loaded from the
- # SavedModel. If a global step variable were created here, the result
- # will be two duplicate global step variables, causing issues during
- # the warm-start phase.
- # Due to the global variable being created in the model function, this may
- # cause issues when running DistributionStrategy. Thus, DistributionStrategy
- # is not yet supported with SavedModelEstimator.
- pass
-
- def _model_fn_from_saved_model(self, features, labels, mode):
- """Load a SavedModel graph and return an EstimatorSpec."""
- # TODO(kathywu): Model function loads placeholders from the graph. Calling
- # export_all_saved_models creates another placeholder for the inputs, on top
- # of the original placeholders. There should be a way to avoid this.
- self._validate_mode(mode)
-
- g = ops.get_default_graph()
- if training_util.get_global_step(g) is not None:
- raise RuntimeError(
- 'Graph must not contain a global step tensor before the SavedModel is'
- ' loaded. Please make sure that the input function does not create a '
- 'global step.')
-
- # Extract SignatureDef for information about the input and output tensors.
- signature_def = self._get_signature_def_for_mode(mode)
-
- # Generate input map for replacing the inputs in the SavedModel graph with
- # the provided features and labels.
- input_map = _generate_input_map(signature_def, features, labels)
-
- # Create a list of the names of output tensors. When the graph is loaded,
- # names of the output tensors may be remapped. This ensures that the correct
- # tensors are returned in the EstimatorSpec.
- output_tensor_names = [
- value.name for value in six.itervalues(signature_def.outputs)]
-
- # Load the graph. `output_tensors` contains output `Tensors` in the same
- # same order as the `output_tensor_names` list.
- tags = model_fn_lib.EXPORT_TAG_MAP[mode]
- _, output_tensors = self.saved_model_loader.load_graph(
- g, tags, input_map=input_map, return_elements=output_tensor_names)
-
- # Create a scaffold from the MetaGraphDef that contains ops to initialize
- # the graph. This should mirror the steps from _add_meta_graph_for_mode(),
- # which creates a MetaGraphDef from the EstimatorSpec's scaffold.
- scaffold = monitored_session.Scaffold(
- local_init_op=loader_impl._get_legacy_init_op_tensor( # pylint: disable=protected-access
- self._get_meta_graph_def_for_mode(mode)))
-
- # Ensure that a global step tensor has been created.
- global_step_tensor = training_util.get_global_step(g)
- training_util.assert_global_step(global_step_tensor)
-
- # Extract values to return in the EstimatorSpec.
- output_map = dict(zip(output_tensor_names, output_tensors))
- outputs = {key: output_map[value.name]
- for key, value in six.iteritems(signature_def.outputs)}
-
- loss, predictions, metrics = _validate_and_extract_outputs(
- mode, outputs, signature_def.method_name)
-
- train_op = ops.get_collection(constants.TRAIN_OP_KEY)
- if len(train_op) > 1:
- raise RuntimeError('Multiple ops found in the train_op collection.')
- train_op = None if not train_op else train_op[0]
-
- _clear_saved_model_collections()
- return model_fn_lib.EstimatorSpec(
- scaffold=scaffold,
- mode=mode,
- loss=loss,
- train_op=train_op,
- predictions=predictions,
- eval_metric_ops=metrics)
-
-
-def _clear_saved_model_collections():
- """Clear collections that are expected empty when exporting a SavedModel.
-
- The SavedModel builder uses these collections to track ops necessary to
- restore the graph state. These collections are expected to be empty before
- MetaGraphs are added to the builder.
- """
- del ops.get_collection_ref(constants.ASSETS_KEY)[:]
- del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:]
- del ops.get_collection_ref(constants.MAIN_OP_KEY)[:]
- del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:]
-
-
-def _generate_input_map(signature_def, features, labels):
- """Return dict mapping an input tensor name to a feature or label tensor.
-
- Args:
- signature_def: SignatureDef loaded from SavedModel
- features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
- `SparseTensor`, specifying the features to be passed to the model.
- labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
- `SparseTensor`, specifying the labels to be passed to the model. May be
- `None`.
-
- Returns:
- dict mapping string names of inputs to features or labels tensors
-
- Raises:
- ValueError: if SignatureDef inputs are not completely mapped by the input
- features and labels.
- """
- # pylint: disable=protected-access
- if not isinstance(features, dict):
- features = {export_lib._SINGLE_FEATURE_DEFAULT_NAME: features}
- if labels is not None and not isinstance(labels, dict):
- labels = {export_lib._SINGLE_LABEL_DEFAULT_NAME: labels}
- # pylint: enable=protected-access
-
- inputs = signature_def.inputs
- input_map = {}
- for key, tensor_info in six.iteritems(inputs):
- input_name = tensor_info.name
- if ':' in input_name:
- input_name = input_name[:input_name.find(':')]
-
- # When tensors are used as control inputs for operations, their names are
- # prepended with a '^' character in the GraphDef. To handle possible control
- # flow edge cases, control input names must be included in the input map.
- control_dependency_name = '^' + input_name
-
- if key in features:
- _check_same_dtype_and_shape(features[key], tensor_info, key)
- input_map[input_name] = input_map[control_dependency_name] = features[key]
- elif labels is not None and key in labels:
- _check_same_dtype_and_shape(labels[key], tensor_info, key)
- input_map[input_name] = input_map[control_dependency_name] = labels[key]
- else:
- raise ValueError(
- 'Key \"%s\" not found in features or labels passed in to the model '
- 'function. All required keys: %s' % (key, inputs.keys()))
-
- return input_map
-
-
-def _check_same_dtype_and_shape(tensor, tensor_info, name):
- """Validate that tensor has the same properties as the TensorInfo proto.
-
- Args:
- tensor: a `Tensor` object.
- tensor_info: a `TensorInfo` proto.
- name: Name of the input (to identify Tensor if an error is raised).
-
- Raises:
- ValueError: If the tensor shape or dtype don't match the TensorInfo
- """
- dtype_error = (tensor.dtype != dtypes.DType(tensor_info.dtype))
- shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape)
-
- if dtype_error or shape_error:
- msg = 'Tensor shape and/or dtype validation failed for input %s:' % name
- if dtype_error:
- msg += ('\n\tExpected dtype: %s, Got: %s'
- % (dtypes.DType(tensor_info.dtype), tensor.dtype))
- if shape_error:
- msg += ('\n\tExpected shape: %s, Got: %s'
- % (tensor_shape.TensorShape(tensor_info.tensor_shape),
- tensor.shape))
-
- raise ValueError(msg)
-
-
-def _extract_eval_metrics(output_dict):
- """Return a eval metric dict extracted from the output_dict.
-
- Eval metrics consist of a value tensor and an update op. Both must be in the
- passed-in tensor dictionary for an eval metric to be added to the returned
- dictionary.
-
- Args:
- output_dict: a dict that maps strings to tensors.
-
- Returns:
- dict mapping strings to (value, update_op) tuples.
- """
- # pylint: disable=protected-access
- metric_ops = {}
- separator_char = export_output._SupervisedOutput._SEPARATOR_CHAR
-
- for key, tensor in six.iteritems(output_dict):
- split_key = key.split(separator_char)
-
- # The metric name may contain the separator character, so recreate its name.
- metric_name = separator_char.join(split_key[:-1])
-
- if split_key[0] == export_output._SupervisedOutput.METRICS_NAME:
- # If the key ends with the value suffix, and there is a corresponding
- # key ending with the update_op suffix, then add tensors to metrics dict.
- if split_key[-1] == export_output._SupervisedOutput.METRIC_VALUE_SUFFIX:
- update_op = ''.join(
- [metric_name, separator_char,
- export_output._SupervisedOutput.METRIC_UPDATE_SUFFIX])
- if update_op in output_dict:
- update_op_tensor = output_dict[update_op]
- metric_ops[metric_name] = (tensor, update_op_tensor)
-
- # pylint: enable=protected-access
- return metric_ops
-
-
-def _validate_and_extract_outputs(mode, output_dict, method_name):
- """Extract values from SignatureDef output dictionary.
-
- Args:
- mode: One of the modes enumerated in `tf.estimator.ModeKeys`.
- output_dict: dict of string SignatureDef keys to `Tensor`.
- method_name: Method name of the SignatureDef as a string.
-
- Returns:
- Tuple of (
- loss: `Tensor` object,
- predictions: dictionary mapping string keys to `Tensor` objects,
- metrics: dictionary mapping string keys to a tuple of two `Tensor` objects
- )
-
- Raises:
- RuntimeError: raised if SignatureDef has an invalid method name for the mode
- """
- # pylint: disable=protected-access
- loss, predictions, metrics = None, None, None
-
- if mode == model_fn_lib.ModeKeys.PREDICT:
- predictions = output_dict
- else:
- # Validate that the SignatureDef's method name matches the expected name for
- # the given mode.
- expected_method_name = signature_constants.SUPERVISED_TRAIN_METHOD_NAME
- if mode == model_fn_lib.ModeKeys.EVAL:
- expected_method_name = signature_constants.SUPERVISED_EVAL_METHOD_NAME
- if method_name != expected_method_name:
- raise RuntimeError(
- 'Invalid SignatureDef method name for mode %s.\n\tExpected: %s\n\t'
- 'Got: %s\nPlease ensure that the SavedModel was exported with '
- '`tf.contrib.estimator.export_all_saved_models()`.' %
- (mode, expected_method_name, method_name))
-
- # Extract loss, metrics and predictions from the output dict.
- loss = output_dict[export_output._SupervisedOutput.LOSS_NAME]
- metrics = _extract_eval_metrics(output_dict)
- predictions = {
- key: value for key, value in six.iteritems(output_dict)
- if key.split(export_output._SupervisedOutput._SEPARATOR_CHAR)[0] == (
- export_output._SupervisedOutput.PREDICTIONS_NAME)}
-
- # pylint: enable=protected-access
- return loss, predictions, metrics
diff --git a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py b/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py
deleted file mode 100644
index 718da1367c..0000000000
--- a/tensorflow/contrib/estimator/python/estimator/saved_model_estimator_test.py
+++ /dev/null
@@ -1,369 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for SavedModelEstimator."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import shutil
-import tempfile
-
-from tensorflow.contrib.estimator.python.estimator import export as contrib_export
-from tensorflow.contrib.estimator.python.estimator import saved_model_estimator
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator.export import export
-from tensorflow.python.estimator.export import export_output
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import metrics as metrics_lib
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-from tensorflow.python.training import monitored_session
-from tensorflow.python.training import training
-
-
-def dummy_input_fn():
- return dataset_ops.Dataset.from_tensors((
- {'x': constant_op.constant([[1], [-2]], dtype=dtypes.int64)},
- constant_op.constant([[4], [-3]], dtype=dtypes.float32))).repeat()
-
-
-def dummy_input_fn_features_only():
- return dataset_ops.Dataset.from_tensors(
- {'x': constant_op.constant([[5], [6]], dtype=dtypes.int64)}).repeat()
-
-
-def dummy_supervised_receiver_fn():
- feature_spec = {
- 'x': array_ops.placeholder(
- dtype=dtypes.int64, shape=(2, 1), name='feature_x'),
- }
- label_spec = array_ops.placeholder(
- dtype=dtypes.float32, shape=[2, 1], name='truth')
- return export.build_raw_supervised_input_receiver_fn(
- feature_spec, label_spec)
-
-
-def dummy_serving_receiver_fn():
- feature_spec = {'x': array_ops.placeholder(
- dtype=dtypes.int64, shape=(2, 1), name='feature_x'),}
- return export.build_raw_serving_input_receiver_fn(feature_spec)
-
-
-def model_fn_diff_modes(features, labels, mode):
- _, _ = features, labels
- v = variables.Variable(21, name='some_var')
- train_op = None
- loss = constant_op.constant(104)
- if mode == model_fn_lib.ModeKeys.TRAIN:
- loss = constant_op.constant(105)
- predictions = constant_op.constant([501])
- train_op = control_flow_ops.group(
- state_ops.assign_add(training.get_global_step(), 1),
- state_ops.assign_add(v, 3))
- elif mode == model_fn_lib.ModeKeys.EVAL:
- loss = constant_op.constant(106)
- predictions = constant_op.constant([502])
- else:
- loss = constant_op.constant(107)
- predictions = constant_op.constant([503])
- return model_fn_lib.EstimatorSpec(
- mode,
- loss=loss,
- train_op=train_op,
- eval_metric_ops={
- 'abs_err': metrics_lib.mean_absolute_error(
- constant_op.constant(0), predictions)},
- predictions=predictions)
-
-
-class SavedModelEstimatorTest(test.TestCase):
-
- def setUp(self):
- self.tmpdirs = []
-
- def tearDown(self):
- for tmpdir in self.tmpdirs:
- # gfile.DeleteRecursively fails in the windows cmake test, so use shutil.
- shutil.rmtree(tmpdir, ignore_errors=True)
- self.tmpdirs = []
-
- def _get_tmp_dir(self):
- tmpdir = tempfile.mkdtemp()
- self.tmpdirs.append(tmpdir)
- return tmpdir
-
- def _export_estimator(self, train=True, evaluate=True, predict=True,
- model_fn=model_fn_diff_modes):
- est = estimator.Estimator(model_fn, self._get_tmp_dir())
- est.train(input_fn=dummy_input_fn, steps=10)
-
- input_receiver_fn_map = {}
- if train:
- input_receiver_fn_map[model_fn_lib.ModeKeys.TRAIN] = (
- dummy_supervised_receiver_fn())
- if evaluate:
- input_receiver_fn_map[model_fn_lib.ModeKeys.EVAL] = (
- dummy_supervised_receiver_fn())
- if predict:
- input_receiver_fn_map[model_fn_lib.ModeKeys.PREDICT] = (
- dummy_serving_receiver_fn())
-
- export_base_path = self._get_tmp_dir()
- export_dir = contrib_export.export_all_saved_models(
- est, export_base_path, input_receiver_fn_map)
- return export_dir
-
- def test_load_all_modes(self):
- sme = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(), self._get_tmp_dir())
- sme.train(input_fn=dummy_input_fn, steps=1)
- sme.train(input_fn=dummy_input_fn, steps=2)
- self.assertEqual(13, sme.get_variable_value('global_step'))
- self.assertEqual(60, sme.get_variable_value('some_var'))
-
- eval_results = sme.evaluate(dummy_input_fn, steps=5)
-
- self.assertEqual(13, eval_results['global_step'])
- self.assertEqual(106, eval_results['loss'])
- self.assertEqual(502, eval_results['metrics/abs_err'])
-
- predictions = next(sme.predict(dummy_input_fn_features_only))
- self.assertDictEqual({'output': 503}, predictions)
-
- def test_load_all_modes_no_train(self):
- """Ensure that all functions can be used without requiring a ckpt."""
- sme = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(), self._get_tmp_dir())
- eval_results = sme.evaluate(dummy_input_fn, steps=5)
- self.assertEqual(10, eval_results['global_step'])
- self.assertEqual(106, eval_results['loss'])
- self.assertEqual(502, eval_results['metrics/abs_err'])
-
- predictions = next(sme.predict(dummy_input_fn_features_only))
- self.assertDictEqual({'output': 503}, predictions)
-
- def test_partial_exported_estimator(self):
- sme1 = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(train=False, predict=False), self._get_tmp_dir())
- sme1.evaluate(dummy_input_fn, steps=5)
- with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'):
- sme1.train(input_fn=dummy_input_fn, steps=1)
- with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'):
- next(sme1.predict(dummy_input_fn_features_only))
-
- sme2 = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(evaluate=False), self._get_tmp_dir())
- sme2.train(input_fn=dummy_input_fn, steps=1)
- next(sme2.predict(dummy_input_fn_features_only))
- with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'):
- sme2.evaluate(dummy_input_fn, steps=5)
-
- def test_with_incorrect_input(self):
- sme = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(), self._get_tmp_dir())
-
- def bad_shape_input_fn():
- return dataset_ops.Dataset.from_tensors((
- {'x': constant_op.constant([1, 2], dtype=dtypes.int64)},
- constant_op.constant([1, 2], dtype=dtypes.float32)))
-
- with self.assertRaisesRegexp(ValueError, 'Expected shape'):
- sme.train(bad_shape_input_fn, steps=1)
-
- def bad_dtype_input_fn():
- return dataset_ops.Dataset.from_tensors((
- {'x': constant_op.constant([[1], [1]], dtype=dtypes.int32)},
- constant_op.constant([[1], [1]], dtype=dtypes.int64)))
-
- with self.assertRaisesRegexp(ValueError, 'Expected dtype'):
- sme.train(bad_dtype_input_fn, steps=1)
-
- def test_input_fn_with_global_step(self):
- sme = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(), self._get_tmp_dir())
-
- def bad_input_fn():
- training.get_or_create_global_step()
- return dataset_ops.Dataset.from_tensors((
- {'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)},
- constant_op.constant([[1], [1]], dtype=dtypes.float32)))
-
- with self.assertRaisesRegexp(RuntimeError,
- 'Graph must not contain a global step tensor'):
- sme.train(bad_input_fn, steps=1)
-
- def test_re_export_saved_model_serving_only(self):
- sme = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(), self._get_tmp_dir())
- sme.train(dummy_input_fn, steps=3)
- self.assertEqual(13, sme.get_variable_value('global_step'))
- self.assertEqual(60, sme.get_variable_value('some_var'))
-
- predictions = next(sme.predict(dummy_input_fn_features_only))
- self.assertDictEqual({'output': 503}, predictions)
-
- # Export SavedModel, and test that the variable and prediction values are
- # the same.
- sme_export_dir = sme.export_savedmodel(
- self._get_tmp_dir(), dummy_serving_receiver_fn())
-
- sme2 = saved_model_estimator.SavedModelEstimator(
- sme_export_dir, self._get_tmp_dir())
- self.assertEqual(60, sme.get_variable_value('some_var'))
- self.assertEqual(13, sme.get_variable_value('global_step'))
-
- predictions = next(sme2.predict(dummy_input_fn_features_only))
- self.assertDictEqual({'output': 503}, predictions)
-
- def test_re_export_saved_model(self):
- sme = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(), self._get_tmp_dir())
- self.assertDictEqual(
- {'loss': 106, 'metrics/abs_err': 502, 'global_step': 10},
- sme.evaluate(dummy_input_fn, steps=1))
-
- sme.train(dummy_input_fn, steps=3)
- self.assertDictEqual(
- {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
- sme.evaluate(dummy_input_fn, steps=1))
- self.assertEqual(60, sme.get_variable_value('some_var'))
-
- predictions = next(sme.predict(dummy_input_fn_features_only))
- self.assertDictEqual({'output': 503}, predictions)
-
- # Export SavedModel for all modes
- input_receiver_fn_map = {
- model_fn_lib.ModeKeys.TRAIN: dummy_supervised_receiver_fn(),
- model_fn_lib.ModeKeys.EVAL: dummy_supervised_receiver_fn(),
- model_fn_lib.ModeKeys.PREDICT: dummy_serving_receiver_fn()}
- sme_export_dir = contrib_export.export_all_saved_models(
- sme, self._get_tmp_dir(), input_receiver_fn_map)
-
- sme2 = saved_model_estimator.SavedModelEstimator(
- sme_export_dir, self._get_tmp_dir())
- self.assertDictEqual(
- {'loss': 106, 'metrics/abs_err': 502, 'global_step': 13},
- sme.evaluate(dummy_input_fn, steps=1))
- self.assertEqual(60, sme.get_variable_value('some_var'))
-
- sme.train(dummy_input_fn, steps=7)
- self.assertEqual(20, sme.get_variable_value('global_step'))
-
- predictions = next(sme2.predict(dummy_input_fn_features_only))
- self.assertDictEqual({'output': 503}, predictions)
-
- def test_load_saved_model_from_serving_only(self):
- def model_fn(features, labels, mode):
- _, _ = features, labels
- return model_fn_lib.EstimatorSpec(
- mode,
- loss=constant_op.constant([103]),
- train_op=state_ops.assign_add(training.get_global_step(), 1),
- predictions=constant_op.constant([502]),
- export_outputs={'test': export_output.ClassificationOutput(
- constant_op.constant([[32.]]))})
-
- est = estimator.Estimator(model_fn, self._get_tmp_dir())
- est.train(input_fn=dummy_input_fn, steps=10)
-
- def serving_input_receiver_fn():
- return export.ServingInputReceiver(
- {'test-features': constant_op.constant([[1], [1]])},
- array_ops.placeholder(dtype=dtypes.string))
-
- export_dir = est.export_savedmodel(
- self._get_tmp_dir(), serving_input_receiver_fn)
-
- sme = saved_model_estimator.SavedModelEstimator(
- export_dir, self._get_tmp_dir())
-
- def input_fn():
- return {'inputs': constant_op.constant('someinputstr')}
-
- prediction = next(sme.predict(input_fn))
- self.assertDictEqual({'scores': 32}, prediction)
-
- def test_with_local_init_op(self):
- def model_fn(features, labels, mode):
- _, _ = features, labels
- v = variables.Variable(21, name='some_var')
- scaffold = monitored_session.Scaffold(
- local_init_op=state_ops.assign_add(v, -3).op
- )
- return model_fn_lib.EstimatorSpec(
- mode,
- scaffold=scaffold,
- train_op=state_ops.assign_add(training.get_global_step(), 1),
- loss=array_ops.identity(v))
- export_dir = self._export_estimator(predict=False, model_fn=model_fn)
- sme = saved_model_estimator.SavedModelEstimator(
- export_dir, self._get_tmp_dir())
-
- eval_results1 = sme.evaluate(dummy_input_fn, steps=2)
- self.assertEqual(15, eval_results1['loss'])
-
- sme.train(dummy_input_fn, steps=1)
- self.assertEqual(15, sme.get_variable_value('some_var'))
-
- eval_results2 = sme.evaluate(dummy_input_fn, steps=5)
- self.assertEqual(12, eval_results2['loss'])
-
- def test_with_working_input_fn(self):
- def model_fn(features, labels, mode):
- loss = None
- if labels is not None:
- loss = labels[0][0] + labels[1][0]
- return model_fn_lib.EstimatorSpec(
- mode,
- loss=loss,
- train_op=state_ops.assign_add(training.get_global_step(), 1),
- predictions={'features_0': array_ops.identity([features['x'][0][0]]),
- 'features_1': array_ops.identity([features['x'][1][0]])})
-
- sme = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(model_fn=model_fn), self._get_tmp_dir())
- eval_results = sme.evaluate(dummy_input_fn, steps=1)
- self.assertEqual(1, eval_results['loss'])
-
- predictions = next(sme.predict(dummy_input_fn_features_only))
- self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions)
-
- def test_control_dependency(self):
- # Control dependencies are saved with "^" appended to the start of the input
- # name. The input map must include control dependencies as well.
- def model_fn(features, labels, mode):
- _ = labels
- with ops.control_dependencies([features['x']]):
- loss = features['x'][1][0]
- return model_fn_lib.EstimatorSpec(
- mode,
- loss=loss,
- train_op=state_ops.assign_add(training.get_global_step(), 1))
- sme = saved_model_estimator.SavedModelEstimator(
- self._export_estimator(train=False, predict=False, model_fn=model_fn),
- self._get_tmp_dir())
- sme.evaluate(dummy_input_fn, steps=1) # Should run without error
-
-
-if __name__ == '__main__':
- test.main()