diff options
author | 2018-08-25 07:03:07 +0800 | |
---|---|---|
committer | 2018-08-25 07:10:26 +0800 | |
commit | f8ee9799e6a72d4fe24f9fad76d6e6b1b3a01af1 (patch) | |
tree | 81269f73be522b541a1a13afbda83b7fb203d2ba /tensorflow/python/estimator/canned | |
parent | 407a64b773f15bfe67a2b5b1979134368464b6ff (diff) |
ENH: raise exception if unsupported features/columns is given
Diffstat (limited to 'tensorflow/python/estimator/canned')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees.py | 9 | ||||
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees_test.py | 97 |
2 files changed, 63 insertions, 43 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index f2a5b9178b..66784fad0c 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -204,6 +204,9 @@ def _generate_feature_name_mapping(sorted_feature_columns): Returns: feature_name_mapping: a list of feature names indexed by the feature ids. + + Raises: + ValueError: when unsupported features/columns are tried. """ names = [] for column in sorted_feature_columns: @@ -221,8 +224,12 @@ def _generate_feature_name_mapping(sorted_feature_columns): else: for num in range(categorical_column._num_buckets): # pylint:disable=protected-access names.append('{}:{}'.format(column.name, num)) - else: + elif isinstance(column, feature_column_lib._BucketizedColumn): names.append(column.name) + else: + raise ValueError( + 'For now, only bucketized_column and indicator_column is supported ' + 'but got: {}'.format(column)) return names diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 7620f73425..14c05e024d 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -892,6 +892,49 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): 'all empty or contain only a root node'): est.experimental_feature_importances(normalize=True) + def testNegativeFeatureImportances(self): + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + + # In order to generate a negative feature importances, + # We assign an invalid value -1 to tree_weights here. + tree_ensemble_text = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 5.0 + } + } + nodes { + leaf { + scalar: -0.34 + } + } + nodes { + leaf { + scalar: 1.34 + } + } + } + tree_weights: -1.0 + """ + self._create_fake_checkpoint_with_tree_ensemble_proto( + est, tree_ensemble_text) + + with self.assertRaisesRegexp(AssertionError, 'non-negative'): + est.experimental_feature_importances(normalize=False) + + with self.assertRaisesRegexp(AssertionError, 'non-negative'): + est.experimental_feature_importances(normalize=True) + def testFeatureImportancesNamesForCategoricalColumn(self): categorical = feature_column.categorical_column_with_vocabulary_list( key='categorical', vocabulary_list=('bad', 'good', 'ok')) @@ -1015,48 +1058,18 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertAllEqual(feature_names_expected, feature_names) self.assertAllClose([0.5, 0.2, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0], importances) - def testNegativeFeatureImportances(self): - est = boosted_trees.BoostedTreesClassifier( - feature_columns=self._feature_columns, - n_batches_per_layer=1, - n_trees=1, - max_depth=5) - - # In order to generate a negative feature importances, - # We assign an invalid value -1 to tree_weights here. - tree_ensemble_text = """ - trees { - nodes { - bucketized_split { - feature_id: 1 - left_id: 1 - right_id: 2 - } - metadata { - gain: 5.0 - } - } - nodes { - leaf { - scalar: -0.34 - } - } - nodes { - leaf { - scalar: 1.34 - } - } - } - tree_weights: -1.0 - """ - self._create_fake_checkpoint_with_tree_ensemble_proto( - est, tree_ensemble_text) - - with self.assertRaisesRegexp(AssertionError, 'non-negative'): - est.experimental_feature_importances(normalize=False) - - with self.assertRaisesRegexp(AssertionError, 'non-negative'): - est.experimental_feature_importances(normalize=True) + def testFeatureImportancesNamesForUnsupportedColumn(self): + numeric_col = feature_column.numeric_column( + 'continuous', dtype=dtypes.float32) + + with self.assertRaisesRegexp(ValueError, + 'only bucketized_column and indicator_column'): + _ = boosted_trees.BoostedTreesRegressor( + feature_columns=[numeric_col], + n_batches_per_layer=1, + n_trees=2, + learning_rate=1.0, + max_depth=1) class ModelFnTests(test_util.TensorFlowTestCase): |