diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-13 18:01:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-13 18:03:50 -0700 |
commit | 462a7e063169010899ce0fa9534f6d7c980f1116 (patch) | |
tree | bd55a540786567459d8a6d756bcb21fd40079839 /tensorflow/python/feature_column | |
parent | 1babacb30c63e7a5231c3aaaac79bc56f68bf3ec (diff) |
Add sequential functionality to _SharedEmbeddingColumn.
PiperOrigin-RevId: 200485876
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 46 |
1 files changed, 44 insertions, 2 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index af2ead9b84..f959b5e484 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -2553,7 +2553,7 @@ def _get_graph_for_variable(var): class _SharedEmbeddingColumn( - _DenseColumn, + _DenseColumn, _SequenceDenseColumn, collections.namedtuple( '_SharedEmbeddingColumn', ('categorical_column', 'dimension', 'combiner', 'initializer', @@ -2600,7 +2600,11 @@ class _SharedEmbeddingColumn( self._shape = tensor_shape.vector(self.dimension) return self._shape - def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + def _get_dense_tensor_internal(self, + inputs, + weight_collections=None, + trainable=None): + """Private method that follows the signature of _get_dense_tensor.""" # This method is called from a variable_scope with name _var_scope_name, # which is shared among all shared embeddings. Open a name_scope here, so # that the ops for different columns have distinct names. @@ -2641,6 +2645,44 @@ class _SharedEmbeddingColumn( name='%s_weights' % self.name, max_norm=self.max_norm) + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + if isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must not be of type _SequenceCategoricalColumn. ' + 'Suggested fix A: If you wish to use input_layer, use a ' + 'non-sequence categorical_column_with_*. ' + 'Suggested fix B: If you wish to create sequence input, use ' + 'sequence_input_layer instead of input_layer. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + return self._get_dense_tensor_internal( + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + + def _get_sequence_dense_tensor(self, + inputs, + weight_collections=None, + trainable=None): + if not isinstance(self.categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'In embedding_column: {}. ' + 'categorical_column must be of type _SequenceCategoricalColumn ' + 'to use sequence_input_layer. ' + 'Suggested fix: Use one of sequence_categorical_column_with_*. ' + 'Given (type {}): {}'.format(self.name, type(self.categorical_column), + self.categorical_column)) + dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access + sequence_length = _sequence_length_from_sparse_tensor( + sparse_tensors.id_tensor) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + def _create_tuple(shape, value): """Returns a tuple with given shape and filled with value.""" |