aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column/feature_column.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/feature_column/feature_column.py')
-rw-r--r--tensorflow/python/feature_column/feature_column.py25
1 files changed, 19 insertions, 6 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 2246d2f3e9..9984379e9d 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -169,7 +169,8 @@ def _internal_input_layer(features,
weight_collections=None,
trainable=True,
cols_to_vars=None,
- scope=None):
+ scope=None,
+ cols_to_output_tensors=None):
"""See input_layer. `scope` is a name or variable scope to use."""
feature_columns = _normalize_feature_columns(feature_columns)
@@ -202,14 +203,17 @@ def _internal_input_layer(features,
trainable=trainable)
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0]
- output_tensors.append(
- array_ops.reshape(tensor, shape=(batch_size, num_elements)))
+ output_tensor = array_ops.reshape(
+ tensor, shape=(batch_size, num_elements))
+ output_tensors.append(output_tensor)
if cols_to_vars is not None:
# Retrieve any variables created (some _DenseColumn's don't create
# variables, in which case an empty list is returned).
cols_to_vars[column] = ops.get_collection(
ops.GraphKeys.GLOBAL_VARIABLES,
scope=variable_scope.get_variable_scope().name)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = output_tensor
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
@@ -219,7 +223,8 @@ def input_layer(features,
feature_columns,
weight_collections=None,
trainable=True,
- cols_to_vars=None):
+ cols_to_vars=None,
+ cols_to_output_tensors=None):
"""Returns a dense `Tensor` as input layer based on given `feature_columns`.
Generally a single example in training data is described with FeatureColumns.
@@ -264,6 +269,9 @@ def input_layer(features,
dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
<tf.Variable 'some_variable:1' shape=(5, 10)]}
If a column creates no variables, its value will be an empty list.
+ cols_to_output_tensors: If not `None`, must be a dictionary that will be
+ filled with a mapping from '_FeatureColumn' to the associated
+ output `Tensor`s.
Returns:
A `Tensor` which represents input layer of a model. Its shape
@@ -273,8 +281,13 @@ def input_layer(features,
Raises:
ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
"""
- return _internal_input_layer(features, feature_columns, weight_collections,
- trainable, cols_to_vars)
+ return _internal_input_layer(
+ features,
+ feature_columns,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ cols_to_vars=cols_to_vars,
+ cols_to_output_tensors=cols_to_output_tensors)
# TODO(akshayka): InputLayer should be a subclass of Layer, and it