aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2017-06-20 16:18:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-20 16:22:02 -0700
commit5856f9ea6d84d041b882d19d3b357104c17563fc (patch)
tree7851650d989018720c09cfb4807e53631c9a6f82
parent35af7113de0f15360246234f76e5dda5e927c556 (diff)
Automated g4 rollback of changelist 159583264
PiperOrigin-RevId: 159630408
-rw-r--r--tensorflow/python/estimator/estimator.py41
-rw-r--r--tensorflow/python/estimator/estimator_test.py95
2 files changed, 8 insertions, 128 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index ff44914779..ab49f36e5e 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -357,7 +357,7 @@ class Estimator(object):
}
def _assert_members_are_not_overridden(self):
- allowed_overrides = set(['_call_input_fn', '_create_global_step'])
+ allowed_overrides = set(['_create_global_step'])
estimator_members = set([m for m in Estimator.__dict__.keys()
if not m.startswith('__')])
subclass_members = set(self.__class__.__dict__.keys())
@@ -485,7 +485,7 @@ class Estimator(object):
return export_dir
def _get_features_from_input_fn(self, input_fn):
- result = self._call_input_fn(input_fn)
+ result = input_fn()
if not ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
logging.warning('Input graph does not contain a QueueRunner. '
'That means predict yields forever. '
@@ -549,32 +549,6 @@ class Estimator(object):
assert step.dtype.is_integer
return step
- def _call_input_fn(self, input_fn):
- """Calls the input function.
-
- Args:
- input_fn: The input function.
-
- Returns:
- Either features or (features, labels) where features and labels are:
- features - `Tensor` or dictionary of string feature name to `Tensor`.
- labels - `Tensor` or dictionary of `Tensor` with labels.
-
- Raises:
- ValueError: if input_fn takes invalid arguments.
- """
- input_fn_args = _fn_args(input_fn)
- for arg in input_fn_args:
- if arg not in ('config', 'params'):
- raise ValueError('input_fn should not include argument {}.'.format(arg))
- kwargs = {}
- if 'params' in input_fn_args:
- kwargs['params'] = self.params
- if 'config' in input_fn_args:
- kwargs['config'] = self.config
- with ops.device('/cpu:0'):
- return input_fn(**kwargs)
-
def _call_model_fn(self, features, labels, mode):
"""Calls model function.
@@ -589,7 +563,7 @@ class Estimator(object):
Raises:
ValueError: if model_fn returns invalid objects.
"""
- model_fn_args = _fn_args(self._model_fn)
+ model_fn_args = _model_fn_args(self._model_fn)
kwargs = {}
if 'mode' in model_fn_args:
kwargs['mode'] = mode
@@ -610,7 +584,8 @@ class Estimator(object):
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
- features, labels = self._call_input_fn(input_fn)
+ with ops.device('/cpu:0'):
+ features, labels = input_fn()
estimator_spec = self._call_model_fn(features, labels,
model_fn_lib.ModeKeys.TRAIN)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
@@ -691,7 +666,7 @@ class Estimator(object):
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
- features, labels = self._call_input_fn(input_fn)
+ features, labels = input_fn()
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL)
@@ -774,7 +749,7 @@ def _get_replica_device_setter(config):
return None
-def _fn_args(fn):
+def _model_fn_args(fn):
"""Get argument names for function-like object.
Args:
@@ -799,7 +774,7 @@ def _fn_args(fn):
def _verify_model_fn_args(model_fn, params):
"""Verifies model fn arguments."""
- args = set(_fn_args(model_fn))
+ args = set(_model_fn_args(model_fn))
if 'features' not in args:
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
if 'labels' not in args:
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 338239ca23..b86afece43 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -120,9 +120,6 @@ class EstimatorInheritanceConstraintTest(test.TestCase):
def __init__(self):
super(_Estimator, self).__init__(model_fn=dummy_model_fn)
- def _call_input_fn(self, input_fn):
- return input_fn()
-
def _create_global_step(self, graph):
pass
@@ -328,48 +325,6 @@ def _make_input_fn(features, labels):
class EstimatorTrainTest(test.TestCase):
- def test_bad_input_fn_args(self):
- expected_params = {'batch_size': 10}
- expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
-
- def _model_fn(features, labels, mode, params, config):
- del params, config
- return model_fn_global_step_incrementer(features, labels, mode)
-
- def _input_fn(params, config, not_allowed):
- del not_allowed
- self.assertEqual(expected_params, params)
- self.assertEqual(4321, config.tf_random_seed)
- return dummy_input_fn()
-
- est = estimator.Estimator(model_fn=_model_fn,
- params=expected_params,
- config=expected_config)
- with self.assertRaisesRegexp(ValueError, 'should not include argument'):
- est.train(_input_fn, steps=1)
-
- def test_input_fn_args(self):
- expected_params = {'batch_size': 10}
- expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
- input_fn_call_count = [0]
-
- def _model_fn(features, labels, mode, params, config):
- del params, config
- return model_fn_global_step_incrementer(features, labels, mode)
-
- def _input_fn(params, config):
- input_fn_call_count[0] += 1
- self.assertEqual(expected_params, params)
- self.assertEqual(4321, config.tf_random_seed)
- return dummy_input_fn()
-
- est = estimator.Estimator(model_fn=_model_fn,
- params=expected_params,
- config=expected_config)
- self.assertEqual(0, input_fn_call_count[0])
- est.train(_input_fn, steps=1)
- self.assertEqual(1, input_fn_call_count[0])
-
def test_minimal_model_fn_args(self):
expected_features = {'x': 42., 'y': 43.}
expected_labels = 44.
@@ -710,29 +665,6 @@ class _StepCounterHook(session_run_hook.SessionRunHook):
class EstimatorEvaluateTest(test.TestCase):
- def test_input_fn_args(self):
- expected_params = {'batch_size': 10}
- expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
- input_fn_call_count = [0]
-
- def _model_fn(features, labels, mode, params, config):
- del params, config
- return model_fn_global_step_incrementer(features, labels, mode)
-
- def _input_fn(params, config):
- input_fn_call_count[0] += 1
- self.assertEqual(expected_params, params)
- self.assertEqual(4321, config.tf_random_seed)
- return dummy_input_fn()
-
- est = estimator.Estimator(model_fn=_model_fn,
- params=expected_params,
- config=expected_config)
- est.train(dummy_input_fn, steps=1)
- self.assertEqual(0, input_fn_call_count[0])
- est.evaluate(_input_fn, steps=1)
- self.assertEqual(1, input_fn_call_count[0])
-
def test_model_fn_must_return_estimator_spec(self):
def _model_fn(features, labels, mode):
_, _ = features, labels
@@ -934,33 +866,6 @@ class EstimatorEvaluateTest(test.TestCase):
class EstimatorPredictTest(test.TestCase):
- def test_input_fn_args(self):
- expected_params = {'batch_size': 10}
- expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
- input_fn_call_count = [0]
-
- def _model_fn(features, labels, mode, params, config):
- del features, labels, params, config
- return model_fn_lib.EstimatorSpec(
- mode,
- loss=constant_op.constant(0.),
- train_op=state_ops.assign_add(training.get_global_step(), 1),
- predictions=constant_op.constant([[10.]]))
-
- def _input_fn(params, config):
- input_fn_call_count[0] += 1
- self.assertEqual(expected_params, params)
- self.assertEqual(4321, config.tf_random_seed)
- return dummy_input_fn()
-
- est = estimator.Estimator(model_fn=_model_fn,
- params=expected_params,
- config=expected_config)
- est.train(dummy_input_fn, steps=1)
- self.assertEqual(0, input_fn_call_count[0])
- next(est.predict(_input_fn))
- self.assertEqual(1, input_fn_call_count[0])
-
def test_no_trained_model_in_model_dir(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
with self.assertRaisesRegexp(ValueError,