diff options
author | 2016-11-08 16:47:42 -0800 | |
---|---|---|
committer | 2016-11-08 17:01:56 -0800 | |
commit | bb829c60f578ab43438497deea4113e4446a67da (patch) | |
tree | 84eabf66606e51fa9cb20d68d89a8f8e3804e949 | |
parent | 265e61d0d6ad5a003d3ed13ab15fe7a8155a5e45 (diff) |
Cleanup: Consolidate specification of what constitutes a crossable column into
a single function.
Change: 138582455
-rw-r--r-- | tensorflow/contrib/layers/python/layers/feature_column.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index d9259dbaa1..ec94f7ace9 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -1544,9 +1544,13 @@ class _CrossedColumn(_FeatureColumn, """ @staticmethod - def _is_crossable(column): - return isinstance(column, - (_SparseColumn, _CrossedColumn, _BucketizedColumn)) + def _assert_is_crossable(column): + if isinstance(column, (_SparseColumn, _CrossedColumn, _BucketizedColumn)): + return + raise TypeError("columns must be a set of _SparseColumn, " + "_CrossedColumn, or _BucketizedColumn instances. " + "(column {} is a {})".format(column, + column.__class__.__name__)) def __new__(cls, columns, @@ -1556,10 +1560,7 @@ class _CrossedColumn(_FeatureColumn, ckpt_to_load_from=None, tensor_name_in_ckpt=None): for column in columns: - if not _CrossedColumn._is_crossable(column): - raise TypeError("columns must be a set of _SparseColumn, " - "_CrossedColumn, or _BucketizedColumn instances. " - "column: {}".format(column)) + _CrossedColumn._assert_is_crossable(column) if len(columns) < 2: raise ValueError("columns must contain at least 2 elements. " |