aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-13 18:01:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 18:03:50 -0700
commit462a7e063169010899ce0fa9534f6d7c980f1116 (patch)
treebd55a540786567459d8a6d756bcb21fd40079839 /tensorflow/python/feature_column
parent1babacb30c63e7a5231c3aaaac79bc56f68bf3ec (diff)
Add sequential functionality to _SharedEmbeddingColumn.
PiperOrigin-RevId: 200485876
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py46
1 files changed, 44 insertions, 2 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index af2ead9b84..f959b5e484 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -2553,7 +2553,7 @@ def _get_graph_for_variable(var):
class _SharedEmbeddingColumn(
- _DenseColumn,
+ _DenseColumn, _SequenceDenseColumn,
collections.namedtuple(
'_SharedEmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
@@ -2600,7 +2600,11 @@ class _SharedEmbeddingColumn(
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _get_dense_tensor_internal(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Private method that follows the signature of _get_dense_tensor."""
# This method is called from a variable_scope with name _var_scope_name,
# which is shared among all shared embeddings. Open a name_scope here, so
# that the ops for different columns have distinct names.
@@ -2641,6 +2645,44 @@ class _SharedEmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ if isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ return self._get_dense_tensor_internal(
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable)
+
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ if not isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return _SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
def _create_tuple(shape, value):
"""Returns a tuple with given shape and filled with value."""