aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-02-01 13:32:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-01 13:50:35 -0800
commitda15e57fec882e1614fb9a45dffe76dd48d7ec2d (patch)
tree529a957fa6227d5436e78b9ecd3e3660dafd1e00
parent82ee4b74b0c34f3f0fbbdc3b99c3a59453519af2 (diff)
[TF:XLA] Add a placeholder implementation of Log1p (via log(1+x), which is not numerically accurate for x near 0).
Make some cleanups to unary_ops_test.py. Change: 146282294
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py161
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc39
-rw-r--r--tensorflow/compiler/tf2xla/op_registrations.cc4
3 files changed, 119 insertions, 85 deletions
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index ff565a9815..f0b80d1ffd 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
@@ -32,7 +33,19 @@ from tensorflow.python.platform import googletest
class UnaryOpsTest(XLATestCase):
"""Test cases for unary operators."""
- def _testUnary(self, op, inp, expected, equality_test=None):
+ def _assertOpOutputMatchesExpected(self, op, inp, expected,
+ equality_test=None, rtol=1e-3, atol=1e-5):
+ """Verifies that 'op' produces 'expected' when fed input 'inp' .
+
+ Args:
+ op: operator to test
+ inp: numpy input array to use as input to 'op'.
+ expected: numpy array representing the expected output of 'op'.
+ equality_test: either None, or a function that tests two numpy arrays for
+ equality. If None, self.assertAllClose is used.
+ rtol: relative tolerance for equality test.
+ atol: absolute tolerance for equality test.
+ """
with self.test_session() as session:
with self.test_scope():
pinp = array_ops.placeholder(
@@ -41,110 +54,117 @@ class UnaryOpsTest(XLATestCase):
result = session.run(output, {pinp: inp})
if equality_test is None:
equality_test = self.assertAllClose
- equality_test(result, expected, rtol=1e-3)
+ equality_test(result, expected, rtol=rtol, atol=atol)
- def ListsAreClose(self, result, expected, rtol):
+ def ListsAreClose(self, result, expected, rtol, atol):
"""Tests closeness of two lists of floats."""
self.assertEqual(len(result), len(expected))
- for i in range(len(result)):
- self.assertAllClose(result[i], expected[i], rtol)
+ for i in xrange(len(result)):
+ self.assertAllClose(result[i], expected[i], rtol, atol)
def testAllTypeOps(self):
for dtype in self.numeric_types:
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.diag,
np.array([1, 2, 3, 4], dtype=dtype),
np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.diag_part,
np.arange(36).reshape([2, 3, 2, 3]).astype(dtype),
np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.identity,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[-1, 1]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.matrix_diag,
np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.matrix_diag_part,
np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype),
np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.prevent_gradient,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[-1, 1]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.squeeze,
np.array([[[[[]]]]], dtype=dtype),
expected=np.array([], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.squeeze,
np.array([[[1], [2]]], dtype=dtype),
expected=np.array([1, 2], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.squeeze,
np.array([[[1]], [[2]]], dtype=dtype),
expected=np.array([1, 2], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.squeeze,
np.array([[[1, 2], [3, 4]]], dtype=dtype),
expected=np.array([[1, 2], [3, 4]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.stop_gradient,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[-1, 1]], dtype=dtype))
def testFloatOps(self):
for dtype in self.float_types:
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.ceil,
np.array([[-1.7, 1.2]], dtype=dtype),
expected=np.array([[-1, 2]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.exp,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[0.36787945, 2.7182817]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.floor,
np.array([[-1.7, 1.2]], dtype=dtype),
expected=np.array([[-2, 1]], dtype=dtype))
# Tests for tf.nn ops.
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0))
# TODO(b/31644876): enable this test case when fixed.
- # self._testUnary(tf.nn.l2_loss, dtype(4), dtype(10))
+ # self._assertOpOutputMatchesExpected(tf.nn.l2_loss, dtype(4), dtype(10))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.reciprocal,
np.array([[1, 2]], dtype=dtype),
expected=np.array([[1, 0.5]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.log,
np.array([[1, 2]], dtype=dtype),
expected=np.array([[0, 0.69314718]], dtype=dtype))
- self._testUnary(
+ # TODO(b/34703906): improve log1p implementation and make tolerance
+ # tighter.
+ self._assertOpOutputMatchesExpected(
+ math_ops.log1p,
+ np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
+ expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)))
+
+ self._assertOpOutputMatchesExpected(
math_ops.rsqrt,
np.array([[4, 16]], dtype=dtype),
expected=np.array([[0.5, 0.25]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.sigmoid,
np.array(
[[1, 1, 1, 1],
@@ -155,12 +175,12 @@ class UnaryOpsTest(XLATestCase):
[0.7310586, 0.880797, 0.95257413, 0.98201376]],
dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.sqrt,
np.array([[4, 9]], dtype=dtype),
expected=np.array([[2, 3]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.tanh,
np.array(
[[1, 1, 1, 1],
@@ -171,7 +191,7 @@ class UnaryOpsTest(XLATestCase):
[0.76159418, 0.96402758, 0.99505478, 0.99932933]],
dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
nn_ops.log_softmax,
np.array(
[[1, 1, 1, 1],
@@ -182,17 +202,17 @@ class UnaryOpsTest(XLATestCase):
[-3.4401896, -2.4401896, -1.4401897, -0.44018969]],
dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
nn_ops.relu,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[0, 1]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
nn_ops.relu6,
np.array([[-0.05, 6.05, 5]], dtype=dtype),
expected=np.array([[0, 6, 5]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
nn_ops.softmax,
np.array(
[[1, 1, 1, 1],
@@ -203,49 +223,50 @@ class UnaryOpsTest(XLATestCase):
[0.032058604, 0.087144323, 0.23688284, 0.64391428]],
dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
nn_ops.softplus,
np.array([[-2, 0, 8]], dtype=dtype),
expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
def testNumericOps(self):
for dtype in self.numeric_types:
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.abs,
np.array([[2, -1]], dtype=dtype),
expected=np.array([[2, 1]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.negative,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([[1, -1]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.square,
np.array([[-2, 3]], dtype=dtype),
expected=np.array([[4, 9]], dtype=dtype))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.zeros_like,
np.array([[4, 3], [2, 1]], dtype=dtype),
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
def testLogicalOps(self):
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
math_ops.logical_not,
np.array([[True, False], [False, True]], dtype=np.bool),
expected=np.array([[False, True], [True, False]], dtype=np.bool))
def testBiasAddGrad(self):
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
gen_nn_ops.bias_add_grad,
np.array([[1., 2.], [3., 4.]], dtype=np.float32),
expected=np.array([4., 6.], dtype=np.float32))
- self._testUnary(lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"),
- np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]],
- dtype=np.float32),
- expected=np.array([10., 26.], dtype=np.float32))
+ self._assertOpOutputMatchesExpected(
+ lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"),
+ np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]],
+ dtype=np.float32),
+ expected=np.array([10., 26.], dtype=np.float32))
def testCast(self):
shapes = [[], [4], [2, 3], [2, 0, 4]]
@@ -257,13 +278,13 @@ class UnaryOpsTest(XLATestCase):
src = src.reshape(shape)
dst = src.astype(dst_type.as_numpy_dtype)
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
src,
expected=dst)
def testInvertPermutation(self):
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.invert_permutation,
np.array([1, 2, 0], np.int32),
expected=np.array([2, 0, 1], dtype=np.int32))
@@ -271,17 +292,18 @@ class UnaryOpsTest(XLATestCase):
def testRank(self):
rank_op = lambda x: array_ops.rank_internal(x, optimize=False)
for dtype in self.numeric_types:
- self._testUnary(rank_op, dtype(7), expected=np.int32(0))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
+ rank_op, dtype(7), expected=np.int32(0))
+ self._assertOpOutputMatchesExpected(
rank_op, np.array(
[[], []], dtype=dtype), expected=np.int32(2))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
rank_op, np.array(
[-1, 1], dtype=dtype), expected=np.int32(1))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
rank_op, np.array(
[[-1, 1]], dtype=dtype), expected=np.int32(2))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
rank_op,
np.array([[-1], [1], [4]], dtype=dtype),
expected=np.int32(2))
@@ -289,20 +311,21 @@ class UnaryOpsTest(XLATestCase):
def testShape(self):
shape_op = lambda x: array_ops.shape_internal(x, optimize=False)
for dtype in self.numeric_types:
- self._testUnary(shape_op, dtype(7), expected=np.array([], dtype=np.int32))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
+ shape_op, dtype(7), expected=np.array([], dtype=np.int32))
+ self._assertOpOutputMatchesExpected(
shape_op,
np.array([[], []], dtype=dtype),
expected=np.array([2, 0], dtype=np.int32))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
shape_op,
np.array([-1, 1], dtype=dtype),
expected=np.array([2], dtype=np.int32))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
shape_op,
np.array([[-1, 1]], dtype=dtype),
expected=np.array([1, 2], dtype=np.int32))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
shape_op,
np.array([[-1], [1], [4]], dtype=dtype),
expected=np.array([3, 1], dtype=np.int32))
@@ -310,20 +333,21 @@ class UnaryOpsTest(XLATestCase):
def testSize(self):
size_op = lambda x: array_ops.size_internal(x, optimize=False)
for dtype in self.numeric_types:
- self._testUnary(size_op, dtype(7), expected=np.int32(1))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
+ size_op, dtype(7), expected=np.int32(1))
+ self._assertOpOutputMatchesExpected(
size_op, np.array([[], []], dtype=dtype), expected=np.int32(0))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2))
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
size_op,
np.array([[-1], [1], [4]], dtype=dtype),
expected=np.int32(3))
def testUnpack(self):
- self._testUnary(
+ self._assertOpOutputMatchesExpected(
array_ops.unstack,
np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
expected=[
@@ -333,13 +357,14 @@ class UnaryOpsTest(XLATestCase):
],
equality_test=self.ListsAreClose)
- self._testUnary(lambda x: array_ops.unstack(x, axis=1),
- np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
- expected=[
- np.array([1., 3., 5.], dtype=np.float32),
- np.array([2., 4., 6.], dtype=np.float32),
- ],
- equality_test=self.ListsAreClose)
+ self._assertOpOutputMatchesExpected(
+ lambda x: array_ops.unstack(x, axis=1),
+ np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
+ expected=[
+ np.array([1., 3., 5.], dtype=np.float32),
+ np.array([2., 4., 6.], dtype=np.float32),
+ ],
+ equality_test=self.ListsAreClose)
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index eced089b32..c3ba1a7a8b 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -33,7 +33,7 @@ namespace {
public: \
explicit Name##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \
void Compile(XlaOpKernelContext* ctx) { \
- xla::ComputationBuilder& b = *ctx->builder(); \
+ xla::ComputationBuilder* b = ctx->builder(); \
xla::ComputationDataHandle x = ctx->Input(0); \
xla::ComputationDataHandle y = COMPUTATION; \
ctx->SetOutput(0, y); \
@@ -42,27 +42,32 @@ namespace {
REGISTER_XLA_OP(#Name, Name##Op);
// Return x if x>0, otherwise -x.
-XLAJIT_MAKE_UNARY(Abs, b.Abs(x));
-XLAJIT_MAKE_UNARY(Ceil, b.Ceil(x));
-XLAJIT_MAKE_UNARY(Exp, b.Exp(x));
-XLAJIT_MAKE_UNARY(Floor, b.Floor(x));
+XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
+XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x));
+XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
+XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
-XLAJIT_MAKE_UNARY(Sign, b.Sign(x));
+XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
// Return 1/x
-XLAJIT_MAKE_UNARY(Inv, b.Div(XlaHelpers::One(&b, input_type(0)), x));
-XLAJIT_MAKE_UNARY(Reciprocal, b.Div(XlaHelpers::One(&b, input_type(0)), x));
-XLAJIT_MAKE_UNARY(Log, b.Log(x));
-XLAJIT_MAKE_UNARY(LogicalNot, b.LogicalNot(x));
-XLAJIT_MAKE_UNARY(Neg, b.Neg(x));
+XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x));
+XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x));
+XLAJIT_MAKE_UNARY(Log, b->Log(x));
+
+// TODO(b/34703906): use a more accurate implementation of log1p.
+XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x)));
+
+XLAJIT_MAKE_UNARY(LogicalNot, b->LogicalNot(x));
+XLAJIT_MAKE_UNARY(Neg, b->Neg(x));
XLAJIT_MAKE_UNARY(Rsqrt,
- b.Pow(x, XlaHelpers::FloatLiteral(&b, input_type(0), -0.5)));
-XLAJIT_MAKE_UNARY(Sigmoid, b.Map({x}, *ctx->GetOrCreateSigmoid(input_type(0))));
+ b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
+XLAJIT_MAKE_UNARY(Sigmoid,
+ b->Map({x}, *ctx->GetOrCreateSigmoid(input_type(0))));
XLAJIT_MAKE_UNARY(Softplus,
- b.Log(b.Add(b.Exp(x), XlaHelpers::One(&b, input_type(0)))));
+ b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0)))));
XLAJIT_MAKE_UNARY(Sqrt,
- b.Pow(x, XlaHelpers::FloatLiteral(&b, input_type(0), 0.5)));
-XLAJIT_MAKE_UNARY(Square, b.Mul(x, x));
-XLAJIT_MAKE_UNARY(Tanh, b.Tanh(x));
+ b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));
+XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x));
#undef XLAJIT_MAKE_UNARY
diff --git a/tensorflow/compiler/tf2xla/op_registrations.cc b/tensorflow/compiler/tf2xla/op_registrations.cc
index d1a7abb22c..e32070efa3 100644
--- a/tensorflow/compiler/tf2xla/op_registrations.cc
+++ b/tensorflow/compiler/tf2xla/op_registrations.cc
@@ -117,6 +117,8 @@ REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
Name("LinSpace").TypeConstraint("T", kCpuFloatTypes));
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
Name("Log").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Log1p").TypeConstraint("T", kCpuFloatTypes));
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalAnd"));
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalNot"));
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalOr"));
@@ -358,6 +360,8 @@ REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
Name("LinSpace").TypeConstraint("T", kGpuFloatTypes));
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
Name("Log").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Log1p").TypeConstraint("T", kGpuFloatTypes));
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalAnd"));
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalNot"));
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalOr"));