aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-06-09 12:04:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-09 12:08:38 -0700
commitd3cb8bb8fca4402a053db98f86fb5540690f7430 (patch)
tree48133398d625a481dcb4060b1aec09de6bccf57c
parent1700ac827237992143144a5763a72d56b2da7127 (diff)
Add string weight-column support to parsing spec util.
Update the doc. PiperOrigin-RevId: 158545612
-rw-r--r--tensorflow/python/estimator/canned/parsing_utils.py18
-rw-r--r--tensorflow/python/estimator/canned/parsing_utils_test.py17
2 files changed, 29 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/canned/parsing_utils.py b/tensorflow/python/estimator/canned/parsing_utils.py
index 82d9a15444..af584965bb 100644
--- a/tensorflow/python/estimator/canned/parsing_utils.py
+++ b/tensorflow/python/estimator/canned/parsing_utils.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
+
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import parsing_ops
@@ -70,24 +72,29 @@ def classifier_parse_example_spec(feature_columns,
estimator = DNNClassifier(
n_classes=1000,
feature_columns=feature_columns,
- label_key='my-label',
+ weight_column='example-weight',
label_vocabulary=['photos', 'keep', ...],
hidden_units=[256, 64, 16])
# This label configuration tells the classifier the following:
- # * label is retrieved with key 'my-label'.
+ # * weights are retrieved with key 'example-weight'
# * label is string and can be one of the following ['photos', 'keep', ...]
# * integer id for label 'photos' is 0, 'keep' is 1, ...
# Input builders
def input_fn_train(): # Returns a dictionary which also contains labels.
- return tf.contrib.learn.read_keyed_batch_features(
+ features = tf.contrib.learn.read_keyed_batch_features(
file_pattern=train_files,
batch_size=batch_size,
# creates parsing configuration for tf.parse_example
features=tf.estimator.classifier_parse_example_spec(
- feature_columns, label_key='my-label', label_dtype=tf.string),
+ feature_columns,
+ label_key='my-label',
+ label_dtype=tf.string,
+ weight_column='example-weight'),
reader=tf.RecordIOReader)
+ labels = features.pop('my-label')
+ return features, labels
estimator.train(input_fn=input_fn_train)
```
@@ -135,6 +142,9 @@ def classifier_parse_example_spec(feature_columns,
if weight_column is None:
return parsing_spec
+ if isinstance(weight_column, six.string_types):
+ weight_column = fc.numeric_column(weight_column)
+
if not isinstance(weight_column, fc._NumericColumn): # pylint: disable=protected-access
raise ValueError('weight_column should be an instance of '
'tf.feature_column.numeric_column. '
diff --git a/tensorflow/python/estimator/canned/parsing_utils_test.py b/tensorflow/python/estimator/canned/parsing_utils_test.py
index 607c47912b..c83823e750 100644
--- a/tensorflow/python/estimator/canned/parsing_utils_test.py
+++ b/tensorflow/python/estimator/canned/parsing_utils_test.py
@@ -63,7 +63,19 @@ class ClassifierParseExampleSpec(test.TestCase):
}
self.assertDictEqual(expected_spec, parsing_spec)
- def test_weight_column(self):
+ def test_weight_column_as_string(self):
+ parsing_spec = parsing_utils.classifier_parse_example_spec(
+ feature_columns=[fc.numeric_column('a')],
+ label_key='b',
+ weight_column='c')
+ expected_spec = {
+ 'a': parsing_ops.FixedLenFeature((1,), dtype=dtypes.float32),
+ 'b': parsing_ops.FixedLenFeature((1,), dtype=dtypes.int64),
+ 'c': parsing_ops.FixedLenFeature((1,), dtype=dtypes.float32),
+ }
+ self.assertDictEqual(expected_spec, parsing_spec)
+
+ def test_weight_column_as_numeric_column(self):
parsing_spec = parsing_utils.classifier_parse_example_spec(
feature_columns=[fc.numeric_column('a')],
label_key='b',
@@ -92,10 +104,11 @@ class ClassifierParseExampleSpec(test.TestCase):
def test_weight_column_should_be_a_numeric_column(self):
with self.assertRaisesRegexp(ValueError,
'tf.feature_column.numeric_column'):
+ not_a_numeric_column = 3
parsing_utils.classifier_parse_example_spec(
feature_columns=[fc.numeric_column('a')],
label_key='b',
- weight_column='NotANumericColumn')
+ weight_column=not_a_numeric_column)
if __name__ == '__main__':