From d3cb8bb8fca4402a053db98f86fb5540690f7430 Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Fri, 9 Jun 2017 12:04:50 -0700 Subject: Add string weight-column support to parsing spec util. Update the doc. PiperOrigin-RevId: 158545612 --- tensorflow/python/estimator/canned/parsing_utils.py | 18 ++++++++++++++---- .../python/estimator/canned/parsing_utils_test.py | 17 +++++++++++++++-- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/estimator/canned/parsing_utils.py b/tensorflow/python/estimator/canned/parsing_utils.py index 82d9a15444..af584965bb 100644 --- a/tensorflow/python/estimator/canned/parsing_utils.py +++ b/tensorflow/python/estimator/canned/parsing_utils.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import dtypes from tensorflow.python.ops import parsing_ops @@ -70,24 +72,29 @@ def classifier_parse_example_spec(feature_columns, estimator = DNNClassifier( n_classes=1000, feature_columns=feature_columns, - label_key='my-label', + weight_column='example-weight', label_vocabulary=['photos', 'keep', ...], hidden_units=[256, 64, 16]) # This label configuration tells the classifier the following: - # * label is retrieved with key 'my-label'. + # * weights are retrieved with key 'example-weight' # * label is string and can be one of the following ['photos', 'keep', ...] # * integer id for label 'photos' is 0, 'keep' is 1, ... # Input builders def input_fn_train(): # Returns a dictionary which also contains labels. - return tf.contrib.learn.read_keyed_batch_features( + features = tf.contrib.learn.read_keyed_batch_features( file_pattern=train_files, batch_size=batch_size, # creates parsing configuration for tf.parse_example features=tf.estimator.classifier_parse_example_spec( - feature_columns, label_key='my-label', label_dtype=tf.string), + feature_columns, + label_key='my-label', + label_dtype=tf.string, + weight_column='example-weight'), reader=tf.RecordIOReader) + labels = features.pop('my-label') + return features, labels estimator.train(input_fn=input_fn_train) ``` @@ -135,6 +142,9 @@ def classifier_parse_example_spec(feature_columns, if weight_column is None: return parsing_spec + if isinstance(weight_column, six.string_types): + weight_column = fc.numeric_column(weight_column) + if not isinstance(weight_column, fc._NumericColumn): # pylint: disable=protected-access raise ValueError('weight_column should be an instance of ' 'tf.feature_column.numeric_column. ' diff --git a/tensorflow/python/estimator/canned/parsing_utils_test.py b/tensorflow/python/estimator/canned/parsing_utils_test.py index 607c47912b..c83823e750 100644 --- a/tensorflow/python/estimator/canned/parsing_utils_test.py +++ b/tensorflow/python/estimator/canned/parsing_utils_test.py @@ -63,7 +63,19 @@ class ClassifierParseExampleSpec(test.TestCase): } self.assertDictEqual(expected_spec, parsing_spec) - def test_weight_column(self): + def test_weight_column_as_string(self): + parsing_spec = parsing_utils.classifier_parse_example_spec( + feature_columns=[fc.numeric_column('a')], + label_key='b', + weight_column='c') + expected_spec = { + 'a': parsing_ops.FixedLenFeature((1,), dtype=dtypes.float32), + 'b': parsing_ops.FixedLenFeature((1,), dtype=dtypes.int64), + 'c': parsing_ops.FixedLenFeature((1,), dtype=dtypes.float32), + } + self.assertDictEqual(expected_spec, parsing_spec) + + def test_weight_column_as_numeric_column(self): parsing_spec = parsing_utils.classifier_parse_example_spec( feature_columns=[fc.numeric_column('a')], label_key='b', @@ -92,10 +104,11 @@ class ClassifierParseExampleSpec(test.TestCase): def test_weight_column_should_be_a_numeric_column(self): with self.assertRaisesRegexp(ValueError, 'tf.feature_column.numeric_column'): + not_a_numeric_column = 3 parsing_utils.classifier_parse_example_spec( feature_columns=[fc.numeric_column('a')], label_key='b', - weight_column='NotANumericColumn') + weight_column=not_a_numeric_column) if __name__ == '__main__': -- cgit v1.2.3