diff options
author | 2017-06-20 16:18:12 -0700 | |
---|---|---|
committer | 2017-06-20 16:22:02 -0700 | |
commit | 5856f9ea6d84d041b882d19d3b357104c17563fc (patch) | |
tree | 7851650d989018720c09cfb4807e53631c9a6f82 | |
parent | 35af7113de0f15360246234f76e5dda5e927c556 (diff) |
Automated g4 rollback of changelist 159583264
PiperOrigin-RevId: 159630408
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 41 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 95 |
2 files changed, 8 insertions, 128 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index ff44914779..ab49f36e5e 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -357,7 +357,7 @@ class Estimator(object): } def _assert_members_are_not_overridden(self): - allowed_overrides = set(['_call_input_fn', '_create_global_step']) + allowed_overrides = set(['_create_global_step']) estimator_members = set([m for m in Estimator.__dict__.keys() if not m.startswith('__')]) subclass_members = set(self.__class__.__dict__.keys()) @@ -485,7 +485,7 @@ class Estimator(object): return export_dir def _get_features_from_input_fn(self, input_fn): - result = self._call_input_fn(input_fn) + result = input_fn() if not ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): logging.warning('Input graph does not contain a QueueRunner. ' 'That means predict yields forever. ' @@ -549,32 +549,6 @@ class Estimator(object): assert step.dtype.is_integer return step - def _call_input_fn(self, input_fn): - """Calls the input function. - - Args: - input_fn: The input function. - - Returns: - Either features or (features, labels) where features and labels are: - features - `Tensor` or dictionary of string feature name to `Tensor`. - labels - `Tensor` or dictionary of `Tensor` with labels. - - Raises: - ValueError: if input_fn takes invalid arguments. - """ - input_fn_args = _fn_args(input_fn) - for arg in input_fn_args: - if arg not in ('config', 'params'): - raise ValueError('input_fn should not include argument {}.'.format(arg)) - kwargs = {} - if 'params' in input_fn_args: - kwargs['params'] = self.params - if 'config' in input_fn_args: - kwargs['config'] = self.config - with ops.device('/cpu:0'): - return input_fn(**kwargs) - def _call_model_fn(self, features, labels, mode): """Calls model function. @@ -589,7 +563,7 @@ class Estimator(object): Raises: ValueError: if model_fn returns invalid objects. """ - model_fn_args = _fn_args(self._model_fn) + model_fn_args = _model_fn_args(self._model_fn) kwargs = {} if 'mode' in model_fn_args: kwargs['mode'] = mode @@ -610,7 +584,8 @@ class Estimator(object): with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) - features, labels = self._call_input_fn(input_fn) + with ops.device('/cpu:0'): + features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) @@ -691,7 +666,7 @@ class Estimator(object): with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) - features, labels = self._call_input_fn(input_fn) + features, labels = input_fn() estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL) @@ -774,7 +749,7 @@ def _get_replica_device_setter(config): return None -def _fn_args(fn): +def _model_fn_args(fn): """Get argument names for function-like object. Args: @@ -799,7 +774,7 @@ def _fn_args(fn): def _verify_model_fn_args(model_fn, params): """Verifies model fn arguments.""" - args = set(_fn_args(model_fn)) + args = set(_model_fn_args(model_fn)) if 'features' not in args: raise ValueError('model_fn (%s) must include features argument.' % model_fn) if 'labels' not in args: diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 338239ca23..b86afece43 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -120,9 +120,6 @@ class EstimatorInheritanceConstraintTest(test.TestCase): def __init__(self): super(_Estimator, self).__init__(model_fn=dummy_model_fn) - def _call_input_fn(self, input_fn): - return input_fn() - def _create_global_step(self, graph): pass @@ -328,48 +325,6 @@ def _make_input_fn(features, labels): class EstimatorTrainTest(test.TestCase): - def test_bad_input_fn_args(self): - expected_params = {'batch_size': 10} - expected_config = run_config.RunConfig().replace(tf_random_seed=4321) - - def _model_fn(features, labels, mode, params, config): - del params, config - return model_fn_global_step_incrementer(features, labels, mode) - - def _input_fn(params, config, not_allowed): - del not_allowed - self.assertEqual(expected_params, params) - self.assertEqual(4321, config.tf_random_seed) - return dummy_input_fn() - - est = estimator.Estimator(model_fn=_model_fn, - params=expected_params, - config=expected_config) - with self.assertRaisesRegexp(ValueError, 'should not include argument'): - est.train(_input_fn, steps=1) - - def test_input_fn_args(self): - expected_params = {'batch_size': 10} - expected_config = run_config.RunConfig().replace(tf_random_seed=4321) - input_fn_call_count = [0] - - def _model_fn(features, labels, mode, params, config): - del params, config - return model_fn_global_step_incrementer(features, labels, mode) - - def _input_fn(params, config): - input_fn_call_count[0] += 1 - self.assertEqual(expected_params, params) - self.assertEqual(4321, config.tf_random_seed) - return dummy_input_fn() - - est = estimator.Estimator(model_fn=_model_fn, - params=expected_params, - config=expected_config) - self.assertEqual(0, input_fn_call_count[0]) - est.train(_input_fn, steps=1) - self.assertEqual(1, input_fn_call_count[0]) - def test_minimal_model_fn_args(self): expected_features = {'x': 42., 'y': 43.} expected_labels = 44. @@ -710,29 +665,6 @@ class _StepCounterHook(session_run_hook.SessionRunHook): class EstimatorEvaluateTest(test.TestCase): - def test_input_fn_args(self): - expected_params = {'batch_size': 10} - expected_config = run_config.RunConfig().replace(tf_random_seed=4321) - input_fn_call_count = [0] - - def _model_fn(features, labels, mode, params, config): - del params, config - return model_fn_global_step_incrementer(features, labels, mode) - - def _input_fn(params, config): - input_fn_call_count[0] += 1 - self.assertEqual(expected_params, params) - self.assertEqual(4321, config.tf_random_seed) - return dummy_input_fn() - - est = estimator.Estimator(model_fn=_model_fn, - params=expected_params, - config=expected_config) - est.train(dummy_input_fn, steps=1) - self.assertEqual(0, input_fn_call_count[0]) - est.evaluate(_input_fn, steps=1) - self.assertEqual(1, input_fn_call_count[0]) - def test_model_fn_must_return_estimator_spec(self): def _model_fn(features, labels, mode): _, _ = features, labels @@ -934,33 +866,6 @@ class EstimatorEvaluateTest(test.TestCase): class EstimatorPredictTest(test.TestCase): - def test_input_fn_args(self): - expected_params = {'batch_size': 10} - expected_config = run_config.RunConfig().replace(tf_random_seed=4321) - input_fn_call_count = [0] - - def _model_fn(features, labels, mode, params, config): - del features, labels, params, config - return model_fn_lib.EstimatorSpec( - mode, - loss=constant_op.constant(0.), - train_op=state_ops.assign_add(training.get_global_step(), 1), - predictions=constant_op.constant([[10.]])) - - def _input_fn(params, config): - input_fn_call_count[0] += 1 - self.assertEqual(expected_params, params) - self.assertEqual(4321, config.tf_random_seed) - return dummy_input_fn() - - est = estimator.Estimator(model_fn=_model_fn, - params=expected_params, - config=expected_config) - est.train(dummy_input_fn, steps=1) - self.assertEqual(0, input_fn_call_count[0]) - next(est.predict(_input_fn)) - self.assertEqual(1, input_fn_call_count[0]) - def test_no_trained_model_in_model_dir(self): est = estimator.Estimator(model_fn=model_fn_global_step_incrementer) with self.assertRaisesRegexp(ValueError, |