diff options
author | Mustafa Ispir <ispir@google.com> | 2017-05-04 08:56:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-04 10:15:42 -0700 |
commit | 641c9824d4c08b5d7c6ae4c3f26b0607f0dea619 (patch) | |
tree | 8e8da0f09a0d50841ee86a1a3c50b9c8758b0686 | |
parent | 65044bc25981e4e060ad5c34d9a520a0561775c3 (diff) |
Make contrib real_valued_column cross compatible with core feature_column builders.
Change: 155090692
-rw-r--r-- | tensorflow/contrib/layers/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/feature_column.py | 28 | ||||
-rw-r--r-- | tensorflow/contrib/layers/python/layers/feature_column_ops_test.py | 17 |
3 files changed, 43 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index aba8eabe10..fe661a5625 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -108,6 +108,7 @@ tf_custom_op_py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/feature_column", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index d6d5bf2294..04fe2370d1 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -136,8 +136,10 @@ from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.ops import bucketization_op from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops +from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor as sparse_tensor_py +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -1497,9 +1499,12 @@ def _real_valued_var_len_column(column_name, is_sparse) -class _RealValuedColumn(_FeatureColumn, collections.namedtuple( - "_RealValuedColumn", - ["column_name", "dimension", "default_value", "dtype", "normalizer"])): +class _RealValuedColumn( + _FeatureColumn, + fc_core._DenseColumn, # pylint: disable=protected-access + collections.namedtuple( + "_RealValuedColumn", + ["column_name", "dimension", "default_value", "dtype", "normalizer"])): """Represents a real valued feature column also known as continuous features. Instances of this class are immutable. The dictionary returned by InputBuilder @@ -1569,6 +1574,23 @@ class _RealValuedColumn(_FeatureColumn, collections.namedtuple( def _to_dense_tensor(self, input_tensor): return input_tensor + @property + def _variable_shape(self): + return tensor_shape.TensorShape((self.dimension)) + + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + del weight_collections + del trainable + return inputs.get(self) + + def _transform_feature(self, inputs): + return math_ops.to_float( + self._normalized_input_tensor(inputs.get(self.name))) + + @property + def _parse_example_config(self): + return self.config + def real_valued_column(column_name, dimension=1, diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index 632836fee4..b2dad0162e 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -27,6 +27,7 @@ from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -610,6 +611,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval()) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testRealValuedColumnWithMultiDimensions(self): real_valued = feature_column.real_valued_column("price", 2) @@ -620,6 +625,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval()) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testRealValuedColumnSparse(self): sparse_real_valued = feature_column._real_valued_var_len_column( @@ -640,6 +649,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval() - 2) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testRealValuedColumnWithMultiDimensionsAndNormalizer(self): real_valued = feature_column.real_valued_column( @@ -651,6 +664,10 @@ class CreateInputLayersForDNNsTest(test.TestCase): [real_valued]) with self.test_session(): self.assertAllClose(output.eval(), features["price"].eval() - 2) + # Verify cross compatibility: Core builder output should equal to contrib. + self.assertAllClose(output.eval(), + fc_core.make_input_layer(features, + [real_valued]).eval()) def testBucketizedColumnWithNormalizerSucceedsForDNN(self): bucket = feature_column.bucketized_column( |