aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-05-04 08:56:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-04 10:15:42 -0700
commit641c9824d4c08b5d7c6ae4c3f26b0607f0dea619 (patch)
tree8e8da0f09a0d50841ee86a1a3c50b9c8758b0686
parent65044bc25981e4e060ad5c34d9a520a0561775c3 (diff)
Make contrib real_valued_column cross compatible with core feature_column builders.
Change: 155090692
-rw-r--r--tensorflow/contrib/layers/BUILD1
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py28
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py17
3 files changed, 43 insertions, 3 deletions
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index aba8eabe10..fe661a5625 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -108,6 +108,7 @@ tf_custom_op_py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python/feature_column",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index d6d5bf2294..04fe2370d1 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -136,8 +136,10 @@ from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
+from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -1497,9 +1499,12 @@ def _real_valued_var_len_column(column_name,
is_sparse)
-class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
- "_RealValuedColumn",
- ["column_name", "dimension", "default_value", "dtype", "normalizer"])):
+class _RealValuedColumn(
+ _FeatureColumn,
+ fc_core._DenseColumn, # pylint: disable=protected-access
+ collections.namedtuple(
+ "_RealValuedColumn",
+ ["column_name", "dimension", "default_value", "dtype", "normalizer"])):
"""Represents a real valued feature column also known as continuous features.
Instances of this class are immutable. The dictionary returned by InputBuilder
@@ -1569,6 +1574,23 @@ class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
def _to_dense_tensor(self, input_tensor):
return input_tensor
+ @property
+ def _variable_shape(self):
+ return tensor_shape.TensorShape((self.dimension))
+
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ return inputs.get(self)
+
+ def _transform_feature(self, inputs):
+ return math_ops.to_float(
+ self._normalized_input_tensor(inputs.get(self.name)))
+
+ @property
+ def _parse_example_config(self):
+ return self.config
+
def real_valued_column(column_name,
dimension=1,
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index 632836fee4..b2dad0162e 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -610,6 +611,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval())
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllClose(output.eval(),
+ fc_core.make_input_layer(features,
+ [real_valued]).eval())
def testRealValuedColumnWithMultiDimensions(self):
real_valued = feature_column.real_valued_column("price", 2)
@@ -620,6 +625,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval())
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllClose(output.eval(),
+ fc_core.make_input_layer(features,
+ [real_valued]).eval())
def testRealValuedColumnSparse(self):
sparse_real_valued = feature_column._real_valued_var_len_column(
@@ -640,6 +649,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllClose(output.eval(),
+ fc_core.make_input_layer(features,
+ [real_valued]).eval())
def testRealValuedColumnWithMultiDimensionsAndNormalizer(self):
real_valued = feature_column.real_valued_column(
@@ -651,6 +664,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllClose(output.eval(),
+ fc_core.make_input_layer(features,
+ [real_valued]).eval())
def testBucketizedColumnWithNormalizerSucceedsForDNN(self):
bucket = feature_column.bucketized_column(