aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-03-02 16:15:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-02 16:26:28 -0800
commit509898b471efdd7c66b3d5b8edc7e3b5be4adb6f (patch)
tree36dc3661ae1d48a46fc381b7b1e3b64263bfc5a7
parent06fd46a9c0faa054629b2747ca31ede86ce41deb (diff)
Error out if a subclass of Estimator overrides a member of Estimator.
Change: 149060568
-rw-r--r--tensorflow/python/estimator/estimator.py17
-rw-r--r--tensorflow/python/estimator/estimator_test.py32
2 files changed, 48 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 5d4472556c..a52d3fb2e0 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -70,7 +70,7 @@ class Estimator(object):
inspect it. The structure of params is therefore entirely up to the developer.
"""
- def __init__(self, model_fn=None, model_dir=None, config=None, params=None):
+ def __init__(self, model_fn, model_dir=None, config=None, params=None):
"""Constructs an `Estimator` instance.
Args:
@@ -109,7 +109,10 @@ class Estimator(object):
Raises:
ValueError: parameters of `model_fn` don't match `params`.
+ ValueError: if this is called via a subclass and if that class overrides
+ a member of `Estimator`.
"""
+ self._assert_members_are_not_overridden()
# Model directory.
self._model_dir = model_dir
if self._model_dir is None:
@@ -312,6 +315,18 @@ class Estimator(object):
for key, value in six.iteritems(preds_evaluated)
}
+ def _assert_members_are_not_overridden(self):
+ estimator_members = set([m for m in Estimator.__dict__.keys()
+ if not m.startswith('__')])
+ subclass_members = set(self.__class__.__dict__.keys())
+ common_members = estimator_members & subclass_members
+ overriden_members = [m for m in common_members
+ if Estimator.__dict__[m] != self.__class__.__dict__[m]]
+ if overriden_members:
+ raise ValueError(
+ 'Subclasses of Estimator cannot override members of Estimator. '
+ '{} does override {}'.format(self.__class__, overriden_members))
+
def _get_features_from_input_fn(self, input_fn):
result = input_fn()
if not ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 1aadba0743..9748be6da5 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -33,6 +33,38 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
+def dummy_model_fn(features, labels, params):
+ _, _, _ = features, labels, params
+
+
+class EstimatorInheritanceConstraintTest(test.TestCase):
+ """Tests that sub classes cannot override methods of Estimator."""
+
+ def test_override_a_method(self):
+ class _Estimator(estimator.Estimator):
+
+ def __init__(self):
+ super(_Estimator, self).__init__(model_fn=dummy_model_fn)
+
+ def predict(self, input_fn, predict_keys=None, hooks=None):
+ pass
+
+ with self.assertRaisesRegexp(
+ ValueError, 'cannot override members of Estimator.*predict'):
+ _Estimator()
+
+ def test_extension_of_api_is_ok(self):
+ class _Estimator(estimator.Estimator):
+
+ def __init__(self):
+ super(_Estimator, self).__init__(model_fn=dummy_model_fn)
+
+ def predict_proba(self, input_fn, predict_keys=None, hooks=None):
+ pass
+
+ _Estimator()
+
+
class EstimatorConstructorTest(test.TestCase):
def test_config_must_be_a_run_config(self):