diff options
Diffstat (limited to 'tensorflow/python/feature_column/feature_column_v2.py')
-rw-r--r-- | tensorflow/python/feature_column/feature_column_v2.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index aa66ed77e9..28c5c82d2c 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -385,6 +385,10 @@ class FeatureLayer(Layer): 'You can wrap a categorical column with an ' 'embedding_column or indicator_column. Given: {}'.format(column)) + @property + def _is_feature_layer(self): + return True + def build(self, _): for column in sorted(self._feature_columns, key=lambda x: x.name): if isinstance(column, SharedEmbeddingColumn): @@ -409,7 +413,13 @@ class FeatureLayer(Layer): A `Tensor` which represents input layer of a model. Its shape is (batch_size, first_layer_dimension) and its dtype is `float32`. first_layer_dimension is determined based on given `feature_columns`. + + Raises: + ValueError: If features are not a dictionary. """ + if not isinstance(features, dict): + raise ValueError('We expected a dictionary here. Instead we got: ', + features) transformation_cache = FeatureTransformationCache(features) output_tensors = [] ordered_columns = [] @@ -431,6 +441,12 @@ class FeatureLayer(Layer): _verify_static_batch_size_equality(output_tensors, ordered_columns) return array_ops.concat(output_tensors, 1) + def compute_output_shape(self, input_shape): + total_elements = 0 + for column in sorted(self._feature_columns, key=lambda x: x.name): + total_elements += column.variable_shape.num_elements() + return (input_shape[0], total_elements) + def linear_model(features, feature_columns, |