From bb7cadbc5a886fded011a65776a42ea2b37498c3 Mon Sep 17 00:00:00 2001 From: Zakaria Haque Date: Fri, 17 Mar 2017 09:09:44 -0800 Subject: Removes an unnecessary check that blocks using multihead with custom heads. Change: 150452316 --- tensorflow/contrib/learn/python/learn/estimators/head.py | 4 ---- tensorflow/contrib/learn/python/learn/estimators/head_test.py | 2 -- 2 files changed, 6 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 6d62ff95d7..e3092ca800 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -1441,10 +1441,6 @@ class _MultiHead(Head): """ self._logits_dimension = 0 for head in heads: - # TODO(ptucker): Change this, and add head_name to MultiHead, to support - # nested MultiHeads. - if not isinstance(head, _SingleHead): - raise ValueError("Members of MultiHead must be SingleHead.") if not head.head_name: raise ValueError("Members of MultiHead must have names.") self._logits_dimension += head.logits_dimension diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index a9bd011ac0..8bb6012887 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -1380,8 +1380,6 @@ class MultiHeadTest(test.TestCase): n_classes=4, label_name="label") with self.assertRaisesRegexp(ValueError, "must have names"): head_lib.multi_head((named_head, unnamed_head)) - with self.assertRaisesRegexp(ValueError, "must be SingleHead"): - head_lib.multi_head((named_head, head_lib.multi_head((named_head,)))) def testTrainWithNoneTrainOpFn(self): head1 = head_lib.multi_class_head( -- cgit v1.2.3