aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/head.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/head.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index fdc266d582..699a92d38a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -25,8 +25,7 @@ import six
from tensorflow.contrib import framework as framework_lib
from tensorflow.contrib import layers as layers_lib
from tensorflow.contrib import lookup as lookup_lib
-# TODO(ptucker): Use tf.losses and tf.metrics.
-from tensorflow.contrib import losses as losses_lib
+# TODO(ptucker): Use tf.metrics.
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import model_fn
@@ -44,6 +43,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.ops.losses import losses as losses_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import training
@@ -1212,7 +1212,8 @@ class _BinarySvmHead(_SingleHead):
with ops.name_scope(None, "hinge_loss", (logits, labels)) as name:
with ops.control_dependencies((_assert_labels_rank(labels),)):
labels = array_ops.reshape(labels, shape=(-1, 1))
- loss = losses_lib.hinge_loss(logits=logits, labels=labels, scope=name)
+ loss = losses_lib.hinge_loss(labels=labels, logits=logits, scope=name,
+ reduction=losses_lib.Reduction.NONE)
return _compute_weighted_loss(loss, weights)
super(_BinarySvmHead, self).__init__(