diff options
author | Rohan Jain <rohanj@google.com> | 2018-09-06 17:25:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-06 17:29:48 -0700 |
commit | 25f93ba1f880e8b092be611d9a343b18136a267b (patch) | |
tree | 14c1282420d20ac814f831a4264b600768ae22c8 /tensorflow/python/feature_column | |
parent | d57cac9d95c8a10650e98f38ca9572c7bd6c6548 (diff) |
Adding support for FeatureColumn input in Keras models. Modifies the Model.fit() function to support taking in dictionaries of features in.
Support for functional models coming in a subsequent change.
PiperOrigin-RevId: 211897153
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_v2.py | 16 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_v2_test.py | 15 |
3 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index 1017d4ba47..ac53a84eef 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -12,6 +12,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_column", + ":feature_column_v2", "//tensorflow/python:util", ], ) diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index aa66ed77e9..28c5c82d2c 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -385,6 +385,10 @@ class FeatureLayer(Layer): 'You can wrap a categorical column with an ' 'embedding_column or indicator_column. Given: {}'.format(column)) + @property + def _is_feature_layer(self): + return True + def build(self, _): for column in sorted(self._feature_columns, key=lambda x: x.name): if isinstance(column, SharedEmbeddingColumn): @@ -409,7 +413,13 @@ class FeatureLayer(Layer): A `Tensor` which represents input layer of a model. Its shape is (batch_size, first_layer_dimension) and its dtype is `float32`. first_layer_dimension is determined based on given `feature_columns`. + + Raises: + ValueError: If features are not a dictionary. """ + if not isinstance(features, dict): + raise ValueError('We expected a dictionary here. Instead we got: ', + features) transformation_cache = FeatureTransformationCache(features) output_tensors = [] ordered_columns = [] @@ -431,6 +441,12 @@ class FeatureLayer(Layer): _verify_static_batch_size_equality(output_tensors, ordered_columns) return array_ops.concat(output_tensors, 1) + def compute_output_shape(self, input_shape): + total_elements = 0 + for column in sorted(self._feature_columns, key=lambda x: x.name): + total_elements += column.variable_shape.num_elements() + return (input_shape[0], total_elements) + def linear_model(features, feature_columns, diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 6b343ecf3e..58168e0f9e 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -2786,6 +2786,21 @@ class FeatureLayerTest(test.TestCase): with _initialized_session(): self.assertAllClose([[1., 2.], [5., 6.]], net.eval()) + def test_compute_output_shape(self): + price1 = fc.numeric_column('price1', shape=2) + price2 = fc.numeric_column('price2', shape=4) + with ops.Graph().as_default(): + features = { + 'price1': [[1., 2.], [5., 6.]], + 'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]] + } + feature_layer = FeatureLayer([price1, price2]) + self.assertEqual((None, 6), feature_layer.compute_output_shape((None,))) + net = feature_layer(features) + with _initialized_session(): + self.assertAllClose( + [[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]], net.eval()) + def test_raises_if_shape_mismatch(self): price = fc.numeric_column('price', shape=2) with ops.Graph().as_default(): |