aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Zakaria Haque <zakaria@google.com>2017-03-17 09:09:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-17 10:47:14 -0700
commitbb7cadbc5a886fded011a65776a42ea2b37498c3 (patch)
tree1186d8346f510eeb8820a9d9da4015a8da8c36be
parent5a95c76c8e0c8eb6e84707c484341cac14e989c9 (diff)
Removes an unnecessary check that blocks using multihead with custom heads.
Change: 150452316
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py2
2 files changed, 0 insertions, 6 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 6d62ff95d7..e3092ca800 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -1441,10 +1441,6 @@ class _MultiHead(Head):
"""
self._logits_dimension = 0
for head in heads:
- # TODO(ptucker): Change this, and add head_name to MultiHead, to support
- # nested MultiHeads.
- if not isinstance(head, _SingleHead):
- raise ValueError("Members of MultiHead must be SingleHead.")
if not head.head_name:
raise ValueError("Members of MultiHead must have names.")
self._logits_dimension += head.logits_dimension
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index a9bd011ac0..8bb6012887 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -1380,8 +1380,6 @@ class MultiHeadTest(test.TestCase):
n_classes=4, label_name="label")
with self.assertRaisesRegexp(ValueError, "must have names"):
head_lib.multi_head((named_head, unnamed_head))
- with self.assertRaisesRegexp(ValueError, "must be SingleHead"):
- head_lib.multi_head((named_head, head_lib.multi_head((named_head,))))
def testTrainWithNoneTrainOpFn(self):
head1 = head_lib.multi_class_head(