aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-08 16:47:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 17:01:56 -0800
commitbb829c60f578ab43438497deea4113e4446a67da (patch)
tree84eabf66606e51fa9cb20d68d89a8f8e3804e949
parent265e61d0d6ad5a003d3ed13ab15fe7a8155a5e45 (diff)
Cleanup: Consolidate specification of what constitutes a crossable column into
a single function. Change: 138582455
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index d9259dbaa1..ec94f7ace9 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -1544,9 +1544,13 @@ class _CrossedColumn(_FeatureColumn,
"""
@staticmethod
- def _is_crossable(column):
- return isinstance(column,
- (_SparseColumn, _CrossedColumn, _BucketizedColumn))
+ def _assert_is_crossable(column):
+ if isinstance(column, (_SparseColumn, _CrossedColumn, _BucketizedColumn)):
+ return
+ raise TypeError("columns must be a set of _SparseColumn, "
+ "_CrossedColumn, or _BucketizedColumn instances. "
+ "(column {} is a {})".format(column,
+ column.__class__.__name__))
def __new__(cls,
columns,
@@ -1556,10 +1560,7 @@ class _CrossedColumn(_FeatureColumn,
ckpt_to_load_from=None,
tensor_name_in_ckpt=None):
for column in columns:
- if not _CrossedColumn._is_crossable(column):
- raise TypeError("columns must be a set of _SparseColumn, "
- "_CrossedColumn, or _BucketizedColumn instances. "
- "column: {}".format(column))
+ _CrossedColumn._assert_is_crossable(column)
if len(columns) < 2:
raise ValueError("columns must contain at least 2 elements. "