diff options
Diffstat (limited to 'tensorflow/python/feature_column/feature_column.py')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 2246d2f3e9..9984379e9d 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -169,7 +169,8 @@ def _internal_input_layer(features, weight_collections=None, trainable=True, cols_to_vars=None, - scope=None): + scope=None, + cols_to_output_tensors=None): """See input_layer. `scope` is a name or variable scope to use.""" feature_columns = _normalize_feature_columns(feature_columns) @@ -202,14 +203,17 @@ def _internal_input_layer(features, trainable=trainable) num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access batch_size = array_ops.shape(tensor)[0] - output_tensors.append( - array_ops.reshape(tensor, shape=(batch_size, num_elements))) + output_tensor = array_ops.reshape( + tensor, shape=(batch_size, num_elements)) + output_tensors.append(output_tensor) if cols_to_vars is not None: # Retrieve any variables created (some _DenseColumn's don't create # variables, in which case an empty list is returned). cols_to_vars[column] = ops.get_collection( ops.GraphKeys.GLOBAL_VARIABLES, scope=variable_scope.get_variable_scope().name) + if cols_to_output_tensors is not None: + cols_to_output_tensors[column] = output_tensor _verify_static_batch_size_equality(output_tensors, ordered_columns) return array_ops.concat(output_tensors, 1) @@ -219,7 +223,8 @@ def input_layer(features, feature_columns, weight_collections=None, trainable=True, - cols_to_vars=None): + cols_to_vars=None, + cols_to_output_tensors=None): """Returns a dense `Tensor` as input layer based on given `feature_columns`. Generally a single example in training data is described with FeatureColumns. @@ -264,6 +269,9 @@ def input_layer(features, dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10), <tf.Variable 'some_variable:1' shape=(5, 10)]} If a column creates no variables, its value will be an empty list. + cols_to_output_tensors: If not `None`, must be a dictionary that will be + filled with a mapping from '_FeatureColumn' to the associated + output `Tensor`s. Returns: A `Tensor` which represents input layer of a model. Its shape @@ -273,8 +281,13 @@ def input_layer(features, Raises: ValueError: if an item in `feature_columns` is not a `_DenseColumn`. """ - return _internal_input_layer(features, feature_columns, weight_collections, - trainable, cols_to_vars) + return _internal_input_layer( + features, + feature_columns, + weight_collections=weight_collections, + trainable=trainable, + cols_to_vars=cols_to_vars, + cols_to_output_tensors=cols_to_output_tensors) # TODO(akshayka): InputLayer should be a subclass of Layer, and it |