diff options
author | Patrick Nguyen <drpng@google.com> | 2017-12-28 16:04:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-28 16:08:58 -0800 |
commit | 20765b3e1ae3b718699592c98aa9805cb874b6d1 (patch) | |
tree | b429a74cd0046404644f34cc8fe6ff2cab78bb85 /tensorflow/contrib/kernel_methods | |
parent | 2e2715baa84720f786b38d1f9cb6887399020d6f (diff) |
Merge changes from github.
PiperOrigin-RevId: 180301735
Diffstat (limited to 'tensorflow/contrib/kernel_methods')
-rw-r--r-- | tensorflow/contrib/kernel_methods/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/kernel_methods/python/losses.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/kernel_methods/python/losses_test.py | 23 |
3 files changed, 28 insertions, 3 deletions
diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index a2f320ab11..eff7dfeb4c 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -83,9 +83,11 @@ py_test( srcs_version = "PY2AND3", deps = [ ":kernel_methods", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index 208b0e1c9d..f182fef067 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -73,13 +73,13 @@ def sparse_multiclass_hinge_loss( labels)) as scope: # Check logits Tensor has valid rank. - logits_shape = logits.get_shape() - logits_rank = logits_shape.ndims + logits_rank = logits.get_shape().ndims if logits_rank != 2: raise ValueError( 'logits should have rank 2 ([batch_size, num_classes]). Given rank is' ' {}'.format(logits_rank)) - batch_size, num_classes = logits_shape[0].value, logits_shape[1].value + logits_shape = array_ops.shape(logits) + batch_size, num_classes = logits_shape[0], logits_shape[1] logits = math_ops.to_float(logits) # Check labels have valid type. diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py index 8a1a5ffe56..d38d8041ce 100644 --- a/tensorflow/contrib/kernel_methods/python/losses_test.py +++ b/tensorflow/contrib/kernel_methods/python/losses_test.py @@ -18,10 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.kernel_methods.python import losses from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -114,6 +117,26 @@ class SparseMulticlassHingeLossTest(test.TestCase): loss = losses.sparse_multiclass_hinge_loss(labels, logits) self.assertAlmostEqual(loss.eval(), 0.0, 3) + def testUnknownShape(self): + """Result keeps same with `testZeroLossInt32Labels`""" + logits_np = np.array([[1.2, -1.4, -1.0], + [1.4, 1.8, 4.0], + [0.5, 1.8, -1.0]]) + labels_np = np.array([0, 2, 1], dtype=np.int32) + + logits_shapes = [[3, 3], # batch_size, num_classes + [None, 3], + [3, None], + [None, None]] + + for batch_size, num_classes in logits_shapes: + with self.test_session(): + logits = array_ops.placeholder(dtypes.float32, shape=(batch_size, num_classes)) + labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,)) + loss = losses.sparse_multiclass_hinge_loss(labels, logits) + result = loss.eval(feed_dict={logits: logits_np, labels: labels_np}) + self.assertAlmostEqual(result, 0.0, 3) + def testCorrectPredictionsSomeClassesInsideMargin(self): """Loss is > 0 even if true class logits are higher than other classes.""" with self.test_session(): |