diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/linear.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/linear.py | 23 |
1 files changed, 9 insertions, 14 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index faf78a3675..d7f1017a46 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -149,21 +149,16 @@ def _linear_model_fn(features, labels, mode, params, config=None): values=tuple(six.itervalues(features)), partitioner=partitioner) as scope: if joint_weights: - logits, _, _ = ( - layers.joint_weighted_sum_from_feature_columns( - columns_to_tensors=features, - feature_columns=feature_columns, - num_outputs=head.logits_dimension, - weight_collections=[parent_scope], - scope=scope)) + layer_fn = layers.joint_weighted_sum_from_feature_columns else: - logits, _, _ = ( - layers.weighted_sum_from_feature_columns( - columns_to_tensors=features, - feature_columns=feature_columns, - num_outputs=head.logits_dimension, - weight_collections=[parent_scope], - scope=scope)) + layer_fn = layers.weighted_sum_from_feature_columns + + logits, _, _ = layer_fn( + columns_to_tensors=features, + feature_columns=feature_columns, + num_outputs=head.logits_dimension, + weight_collections=[parent_scope], + scope=scope) def _train_op_fn(loss): global_step = contrib_variables.get_global_step() |