aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column/feature_column_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/feature_column/feature_column_v2.py')
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index aa66ed77e9..28c5c82d2c 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -385,6 +385,10 @@ class FeatureLayer(Layer):
'You can wrap a categorical column with an '
'embedding_column or indicator_column. Given: {}'.format(column))
+ @property
+ def _is_feature_layer(self):
+ return True
+
def build(self, _):
for column in sorted(self._feature_columns, key=lambda x: x.name):
if isinstance(column, SharedEmbeddingColumn):
@@ -409,7 +413,13 @@ class FeatureLayer(Layer):
A `Tensor` which represents input layer of a model. Its shape
is (batch_size, first_layer_dimension) and its dtype is `float32`.
first_layer_dimension is determined based on given `feature_columns`.
+
+ Raises:
+ ValueError: If features are not a dictionary.
"""
+ if not isinstance(features, dict):
+ raise ValueError('We expected a dictionary here. Instead we got: ',
+ features)
transformation_cache = FeatureTransformationCache(features)
output_tensors = []
ordered_columns = []
@@ -431,6 +441,12 @@ class FeatureLayer(Layer):
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
+ def compute_output_shape(self, input_shape):
+ total_elements = 0
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ total_elements += column.variable_shape.num_elements()
+ return (input_shape[0], total_elements)
+
def linear_model(features,
feature_columns,