aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kernel_methods
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2017-12-28 16:04:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-28 16:08:58 -0800
commit20765b3e1ae3b718699592c98aa9805cb874b6d1 (patch)
treeb429a74cd0046404644f34cc8fe6ff2cab78bb85 /tensorflow/contrib/kernel_methods
parent2e2715baa84720f786b38d1f9cb6887399020d6f (diff)
Merge changes from github.
PiperOrigin-RevId: 180301735
Diffstat (limited to 'tensorflow/contrib/kernel_methods')
-rw-r--r--tensorflow/contrib/kernel_methods/BUILD2
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses.py6
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses_test.py23
3 files changed, 28 insertions, 3 deletions
diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD
index a2f320ab11..eff7dfeb4c 100644
--- a/tensorflow/contrib/kernel_methods/BUILD
+++ b/tensorflow/contrib/kernel_methods/BUILD
@@ -83,9 +83,11 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":kernel_methods",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py
index 208b0e1c9d..f182fef067 100644
--- a/tensorflow/contrib/kernel_methods/python/losses.py
+++ b/tensorflow/contrib/kernel_methods/python/losses.py
@@ -73,13 +73,13 @@ def sparse_multiclass_hinge_loss(
labels)) as scope:
# Check logits Tensor has valid rank.
- logits_shape = logits.get_shape()
- logits_rank = logits_shape.ndims
+ logits_rank = logits.get_shape().ndims
if logits_rank != 2:
raise ValueError(
'logits should have rank 2 ([batch_size, num_classes]). Given rank is'
' {}'.format(logits_rank))
- batch_size, num_classes = logits_shape[0].value, logits_shape[1].value
+ logits_shape = array_ops.shape(logits)
+ batch_size, num_classes = logits_shape[0], logits_shape[1]
logits = math_ops.to_float(logits)
# Check labels have valid type.
diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py
index 8a1a5ffe56..d38d8041ce 100644
--- a/tensorflow/contrib/kernel_methods/python/losses_test.py
+++ b/tensorflow/contrib/kernel_methods/python/losses_test.py
@@ -18,10 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.kernel_methods.python import losses
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -114,6 +117,26 @@ class SparseMulticlassHingeLossTest(test.TestCase):
loss = losses.sparse_multiclass_hinge_loss(labels, logits)
self.assertAlmostEqual(loss.eval(), 0.0, 3)
+ def testUnknownShape(self):
+ """Result keeps same with `testZeroLossInt32Labels`"""
+ logits_np = np.array([[1.2, -1.4, -1.0],
+ [1.4, 1.8, 4.0],
+ [0.5, 1.8, -1.0]])
+ labels_np = np.array([0, 2, 1], dtype=np.int32)
+
+ logits_shapes = [[3, 3], # batch_size, num_classes
+ [None, 3],
+ [3, None],
+ [None, None]]
+
+ for batch_size, num_classes in logits_shapes:
+ with self.test_session():
+ logits = array_ops.placeholder(dtypes.float32, shape=(batch_size, num_classes))
+ labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,))
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits)
+ result = loss.eval(feed_dict={logits: logits_np, labels: labels_np})
+ self.assertAlmostEqual(result, 0.0, 3)
+
def testCorrectPredictionsSomeClassesInsideMargin(self):
"""Loss is > 0 even if true class logits are higher than other classes."""
with self.test_session():