aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-25 07:03:07 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-25 07:10:26 +0800
commitf8ee9799e6a72d4fe24f9fad76d6e6b1b3a01af1 (patch)
tree81269f73be522b541a1a13afbda83b7fb203d2ba /tensorflow/python/estimator/canned
parent407a64b773f15bfe67a2b5b1979134368464b6ff (diff)
ENH: raise exception if unsupported features/columns is given
Diffstat (limited to 'tensorflow/python/estimator/canned')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py9
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py97
2 files changed, 63 insertions, 43 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index f2a5b9178b..66784fad0c 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -204,6 +204,9 @@ def _generate_feature_name_mapping(sorted_feature_columns):
Returns:
feature_name_mapping: a list of feature names indexed by the feature ids.
+
+ Raises:
+ ValueError: when unsupported features/columns are tried.
"""
names = []
for column in sorted_feature_columns:
@@ -221,8 +224,12 @@ def _generate_feature_name_mapping(sorted_feature_columns):
else:
for num in range(categorical_column._num_buckets): # pylint:disable=protected-access
names.append('{}:{}'.format(column.name, num))
- else:
+ elif isinstance(column, feature_column_lib._BucketizedColumn):
names.append(column.name)
+ else:
+ raise ValueError(
+ 'For now, only bucketized_column and indicator_column is supported '
+ 'but got: {}'.format(column))
return names
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 7620f73425..14c05e024d 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -892,6 +892,49 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
'all empty or contain only a root node'):
est.experimental_feature_importances(normalize=True)
+ def testNegativeFeatureImportances(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ # In order to generate a negative feature importances,
+ # We assign an invalid value -1 to tree_weights here.
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ tree_weights: -1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ with self.assertRaisesRegexp(AssertionError, 'non-negative'):
+ est.experimental_feature_importances(normalize=False)
+
+ with self.assertRaisesRegexp(AssertionError, 'non-negative'):
+ est.experimental_feature_importances(normalize=True)
+
def testFeatureImportancesNamesForCategoricalColumn(self):
categorical = feature_column.categorical_column_with_vocabulary_list(
key='categorical', vocabulary_list=('bad', 'good', 'ok'))
@@ -1015,48 +1058,18 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertAllEqual(feature_names_expected, feature_names)
self.assertAllClose([0.5, 0.2, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0], importances)
- def testNegativeFeatureImportances(self):
- est = boosted_trees.BoostedTreesClassifier(
- feature_columns=self._feature_columns,
- n_batches_per_layer=1,
- n_trees=1,
- max_depth=5)
-
- # In order to generate a negative feature importances,
- # We assign an invalid value -1 to tree_weights here.
- tree_ensemble_text = """
- trees {
- nodes {
- bucketized_split {
- feature_id: 1
- left_id: 1
- right_id: 2
- }
- metadata {
- gain: 5.0
- }
- }
- nodes {
- leaf {
- scalar: -0.34
- }
- }
- nodes {
- leaf {
- scalar: 1.34
- }
- }
- }
- tree_weights: -1.0
- """
- self._create_fake_checkpoint_with_tree_ensemble_proto(
- est, tree_ensemble_text)
-
- with self.assertRaisesRegexp(AssertionError, 'non-negative'):
- est.experimental_feature_importances(normalize=False)
-
- with self.assertRaisesRegexp(AssertionError, 'non-negative'):
- est.experimental_feature_importances(normalize=True)
+ def testFeatureImportancesNamesForUnsupportedColumn(self):
+ numeric_col = feature_column.numeric_column(
+ 'continuous', dtype=dtypes.float32)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'only bucketized_column and indicator_column'):
+ _ = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[numeric_col],
+ n_batches_per_layer=1,
+ n_trees=2,
+ learning_rate=1.0,
+ max_depth=1)
class ModelFnTests(test_util.TensorFlowTestCase):