aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/linear.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/linear.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py23
1 files changed, 9 insertions, 14 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index faf78a3675..d7f1017a46 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -149,21 +149,16 @@ def _linear_model_fn(features, labels, mode, params, config=None):
values=tuple(six.itervalues(features)),
partitioner=partitioner) as scope:
if joint_weights:
- logits, _, _ = (
- layers.joint_weighted_sum_from_feature_columns(
- columns_to_tensors=features,
- feature_columns=feature_columns,
- num_outputs=head.logits_dimension,
- weight_collections=[parent_scope],
- scope=scope))
+ layer_fn = layers.joint_weighted_sum_from_feature_columns
else:
- logits, _, _ = (
- layers.weighted_sum_from_feature_columns(
- columns_to_tensors=features,
- feature_columns=feature_columns,
- num_outputs=head.logits_dimension,
- weight_collections=[parent_scope],
- scope=scope))
+ layer_fn = layers.weighted_sum_from_feature_columns
+
+ logits, _, _ = layer_fn(
+ columns_to_tensors=features,
+ feature_columns=feature_columns,
+ num_outputs=head.logits_dimension,
+ weight_collections=[parent_scope],
+ scope=scope)
def _train_op_fn(loss):
global_step = contrib_variables.get_global_step()