aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-06 17:25:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 17:29:48 -0700
commit25f93ba1f880e8b092be611d9a343b18136a267b (patch)
tree14c1282420d20ac814f831a4264b600768ae22c8 /tensorflow/python/feature_column
parentd57cac9d95c8a10650e98f38ca9572c7bd6c6548 (diff)
Adding support for FeatureColumn input in Keras models. Modifies the Model.fit() function to support taking in dictionaries of features in.
Support for functional models coming in a subsequent change. PiperOrigin-RevId: 211897153
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/BUILD1
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py16
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py15
3 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 1017d4ba47..ac53a84eef 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -12,6 +12,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":feature_column",
+ ":feature_column_v2",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index aa66ed77e9..28c5c82d2c 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -385,6 +385,10 @@ class FeatureLayer(Layer):
'You can wrap a categorical column with an '
'embedding_column or indicator_column. Given: {}'.format(column))
+ @property
+ def _is_feature_layer(self):
+ return True
+
def build(self, _):
for column in sorted(self._feature_columns, key=lambda x: x.name):
if isinstance(column, SharedEmbeddingColumn):
@@ -409,7 +413,13 @@ class FeatureLayer(Layer):
A `Tensor` which represents input layer of a model. Its shape
is (batch_size, first_layer_dimension) and its dtype is `float32`.
first_layer_dimension is determined based on given `feature_columns`.
+
+ Raises:
+ ValueError: If features are not a dictionary.
"""
+ if not isinstance(features, dict):
+ raise ValueError('We expected a dictionary here. Instead we got: ',
+ features)
transformation_cache = FeatureTransformationCache(features)
output_tensors = []
ordered_columns = []
@@ -431,6 +441,12 @@ class FeatureLayer(Layer):
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
+ def compute_output_shape(self, input_shape):
+ total_elements = 0
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ total_elements += column.variable_shape.num_elements()
+ return (input_shape[0], total_elements)
+
def linear_model(features,
feature_columns,
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 6b343ecf3e..58168e0f9e 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -2786,6 +2786,21 @@ class FeatureLayerTest(test.TestCase):
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
+ def test_compute_output_shape(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2', shape=4)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [5., 6.]],
+ 'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]]
+ }
+ feature_layer = FeatureLayer([price1, price2])
+ self.assertEqual((None, 6), feature_layer.compute_output_shape((None,)))
+ net = feature_layer(features)
+ with _initialized_session():
+ self.assertAllClose(
+ [[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]], net.eval())
+
def test_raises_if_shape_mismatch(self):
price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():