From 509898b471efdd7c66b3d5b8edc7e3b5be4adb6f Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Thu, 2 Mar 2017 16:15:48 -0800 Subject: Error out if a subclass of Estimator overrides a member of Estimator. Change: 149060568 --- tensorflow/python/estimator/estimator.py | 17 +++++++++++++- tensorflow/python/estimator/estimator_test.py | 32 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 5d4472556c..a52d3fb2e0 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -70,7 +70,7 @@ class Estimator(object): inspect it. The structure of params is therefore entirely up to the developer. """ - def __init__(self, model_fn=None, model_dir=None, config=None, params=None): + def __init__(self, model_fn, model_dir=None, config=None, params=None): """Constructs an `Estimator` instance. Args: @@ -109,7 +109,10 @@ class Estimator(object): Raises: ValueError: parameters of `model_fn` don't match `params`. + ValueError: if this is called via a subclass and if that class overrides + a member of `Estimator`. """ + self._assert_members_are_not_overridden() # Model directory. self._model_dir = model_dir if self._model_dir is None: @@ -312,6 +315,18 @@ class Estimator(object): for key, value in six.iteritems(preds_evaluated) } + def _assert_members_are_not_overridden(self): + estimator_members = set([m for m in Estimator.__dict__.keys() + if not m.startswith('__')]) + subclass_members = set(self.__class__.__dict__.keys()) + common_members = estimator_members & subclass_members + overriden_members = [m for m in common_members + if Estimator.__dict__[m] != self.__class__.__dict__[m]] + if overriden_members: + raise ValueError( + 'Subclasses of Estimator cannot override members of Estimator. ' + '{} does override {}'.format(self.__class__, overriden_members)) + def _get_features_from_input_fn(self, input_fn): result = input_fn() if not ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 1aadba0743..9748be6da5 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -33,6 +33,38 @@ from tensorflow.python.training import session_run_hook from tensorflow.python.training import training +def dummy_model_fn(features, labels, params): + _, _, _ = features, labels, params + + +class EstimatorInheritanceConstraintTest(test.TestCase): + """Tests that sub classes cannot override methods of Estimator.""" + + def test_override_a_method(self): + class _Estimator(estimator.Estimator): + + def __init__(self): + super(_Estimator, self).__init__(model_fn=dummy_model_fn) + + def predict(self, input_fn, predict_keys=None, hooks=None): + pass + + with self.assertRaisesRegexp( + ValueError, 'cannot override members of Estimator.*predict'): + _Estimator() + + def test_extension_of_api_is_ok(self): + class _Estimator(estimator.Estimator): + + def __init__(self): + super(_Estimator, self).__init__(model_fn=dummy_model_fn) + + def predict_proba(self, input_fn, predict_keys=None, hooks=None): + pass + + _Estimator() + + class EstimatorConstructorTest(test.TestCase): def test_config_must_be_a_run_config(self): -- cgit v1.2.3