diff options
author | Zakaria Haque <zakaria@google.com> | 2017-03-17 09:09:44 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-17 10:47:14 -0700 |
commit | bb7cadbc5a886fded011a65776a42ea2b37498c3 (patch) | |
tree | 1186d8346f510eeb8820a9d9da4015a8da8c36be | |
parent | 5a95c76c8e0c8eb6e84707c484341cac14e989c9 (diff) |
Removes an unnecessary check that blocks using multihead with custom heads.
Change: 150452316
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head_test.py | 2 |
2 files changed, 0 insertions, 6 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 6d62ff95d7..e3092ca800 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -1441,10 +1441,6 @@ class _MultiHead(Head): """ self._logits_dimension = 0 for head in heads: - # TODO(ptucker): Change this, and add head_name to MultiHead, to support - # nested MultiHeads. - if not isinstance(head, _SingleHead): - raise ValueError("Members of MultiHead must be SingleHead.") if not head.head_name: raise ValueError("Members of MultiHead must have names.") self._logits_dimension += head.logits_dimension diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index a9bd011ac0..8bb6012887 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -1380,8 +1380,6 @@ class MultiHeadTest(test.TestCase): n_classes=4, label_name="label") with self.assertRaisesRegexp(ValueError, "must have names"): head_lib.multi_head((named_head, unnamed_head)) - with self.assertRaisesRegexp(ValueError, "must be SingleHead"): - head_lib.multi_head((named_head, head_lib.multi_head((named_head,)))) def testTrainWithNoneTrainOpFn(self): head1 = head_lib.multi_class_head( |