From da15e57fec882e1614fb9a45dffe76dd48d7ec2d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 1 Feb 2017 13:32:45 -0800 Subject: [TF:XLA] Add a placeholder implementation of Log1p (via log(1+x), which is not numerically accurate for x near 0). Make some cleanups to unary_ops_test.py. Change: 146282294 --- tensorflow/compiler/tests/unary_ops_test.py | 161 ++++++++++++++---------- tensorflow/compiler/tf2xla/kernels/unary_ops.cc | 39 +++--- tensorflow/compiler/tf2xla/op_registrations.cc | 4 + 3 files changed, 119 insertions(+), 85 deletions(-) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ff565a9815..f0b80d1ffd 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes @@ -32,7 +33,19 @@ from tensorflow.python.platform import googletest class UnaryOpsTest(XLATestCase): """Test cases for unary operators.""" - def _testUnary(self, op, inp, expected, equality_test=None): + def _assertOpOutputMatchesExpected(self, op, inp, expected, + equality_test=None, rtol=1e-3, atol=1e-5): + """Verifies that 'op' produces 'expected' when fed input 'inp' . + + Args: + op: operator to test + inp: numpy input array to use as input to 'op'. + expected: numpy array representing the expected output of 'op'. + equality_test: either None, or a function that tests two numpy arrays for + equality. If None, self.assertAllClose is used. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ with self.test_session() as session: with self.test_scope(): pinp = array_ops.placeholder( @@ -41,110 +54,117 @@ class UnaryOpsTest(XLATestCase): result = session.run(output, {pinp: inp}) if equality_test is None: equality_test = self.assertAllClose - equality_test(result, expected, rtol=1e-3) + equality_test(result, expected, rtol=rtol, atol=atol) - def ListsAreClose(self, result, expected, rtol): + def ListsAreClose(self, result, expected, rtol, atol): """Tests closeness of two lists of floats.""" self.assertEqual(len(result), len(expected)) - for i in range(len(result)): - self.assertAllClose(result[i], expected[i], rtol) + for i in xrange(len(result)): + self.assertAllClose(result[i], expected[i], rtol, atol) def testAllTypeOps(self): for dtype in self.numeric_types: - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.diag_part, np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.identity, np.array([[-1, 1]], dtype=dtype), expected=np.array([[-1, 1]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.matrix_diag_part, np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype), np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.prevent_gradient, np.array([[-1, 1]], dtype=dtype), expected=np.array([[-1, 1]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.squeeze, np.array([[[[[]]]]], dtype=dtype), expected=np.array([], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.squeeze, np.array([[[1], [2]]], dtype=dtype), expected=np.array([1, 2], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.squeeze, np.array([[[1]], [[2]]], dtype=dtype), expected=np.array([1, 2], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.squeeze, np.array([[[1, 2], [3, 4]]], dtype=dtype), expected=np.array([[1, 2], [3, 4]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.stop_gradient, np.array([[-1, 1]], dtype=dtype), expected=np.array([[-1, 1]], dtype=dtype)) def testFloatOps(self): for dtype in self.float_types: - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.ceil, np.array([[-1.7, 1.2]], dtype=dtype), expected=np.array([[-1, 2]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.exp, np.array([[-1, 1]], dtype=dtype), expected=np.array([[0.36787945, 2.7182817]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.floor, np.array([[-1.7, 1.2]], dtype=dtype), expected=np.array([[-2, 1]], dtype=dtype)) # Tests for tf.nn ops. - self._testUnary( + self._assertOpOutputMatchesExpected( nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0)) # TODO(b/31644876): enable this test case when fixed. - # self._testUnary(tf.nn.l2_loss, dtype(4), dtype(10)) + # self._assertOpOutputMatchesExpected(tf.nn.l2_loss, dtype(4), dtype(10)) - self._testUnary( + self._assertOpOutputMatchesExpected( nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10)) - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.reciprocal, np.array([[1, 2]], dtype=dtype), expected=np.array([[1, 0.5]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.log, np.array([[1, 2]], dtype=dtype), expected=np.array([[0, 0.69314718]], dtype=dtype)) - self._testUnary( + # TODO(b/34703906): improve log1p implementation and make tolerance + # tighter. + self._assertOpOutputMatchesExpected( + math_ops.log1p, + np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype))) + + self._assertOpOutputMatchesExpected( math_ops.rsqrt, np.array([[4, 16]], dtype=dtype), expected=np.array([[0.5, 0.25]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.sigmoid, np.array( [[1, 1, 1, 1], @@ -155,12 +175,12 @@ class UnaryOpsTest(XLATestCase): [0.7310586, 0.880797, 0.95257413, 0.98201376]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.sqrt, np.array([[4, 9]], dtype=dtype), expected=np.array([[2, 3]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.tanh, np.array( [[1, 1, 1, 1], @@ -171,7 +191,7 @@ class UnaryOpsTest(XLATestCase): [0.76159418, 0.96402758, 0.99505478, 0.99932933]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( nn_ops.log_softmax, np.array( [[1, 1, 1, 1], @@ -182,17 +202,17 @@ class UnaryOpsTest(XLATestCase): [-3.4401896, -2.4401896, -1.4401897, -0.44018969]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( nn_ops.relu, np.array([[-1, 1]], dtype=dtype), expected=np.array([[0, 1]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( nn_ops.relu6, np.array([[-0.05, 6.05, 5]], dtype=dtype), expected=np.array([[0, 6, 5]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( nn_ops.softmax, np.array( [[1, 1, 1, 1], @@ -203,49 +223,50 @@ class UnaryOpsTest(XLATestCase): [0.032058604, 0.087144323, 0.23688284, 0.64391428]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( nn_ops.softplus, np.array([[-2, 0, 8]], dtype=dtype), expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype)) def testNumericOps(self): for dtype in self.numeric_types: - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), expected=np.array([[2, 1]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.negative, np.array([[-1, 1]], dtype=dtype), expected=np.array([[1, -1]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.square, np.array([[-2, 3]], dtype=dtype), expected=np.array([[4, 9]], dtype=dtype)) - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.zeros_like, np.array([[4, 3], [2, 1]], dtype=dtype), expected=np.array([[0, 0], [0, 0]], dtype=dtype)) def testLogicalOps(self): - self._testUnary( + self._assertOpOutputMatchesExpected( math_ops.logical_not, np.array([[True, False], [False, True]], dtype=np.bool), expected=np.array([[False, True], [True, False]], dtype=np.bool)) def testBiasAddGrad(self): - self._testUnary( + self._assertOpOutputMatchesExpected( gen_nn_ops.bias_add_grad, np.array([[1., 2.], [3., 4.]], dtype=np.float32), expected=np.array([4., 6.], dtype=np.float32)) - self._testUnary(lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), - np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], - dtype=np.float32), - expected=np.array([10., 26.], dtype=np.float32)) + self._assertOpOutputMatchesExpected( + lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"), + np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], + dtype=np.float32), + expected=np.array([10., 26.], dtype=np.float32)) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] @@ -257,13 +278,13 @@ class UnaryOpsTest(XLATestCase): src = src.reshape(shape) dst = src.astype(dst_type.as_numpy_dtype) - self._testUnary( + self._assertOpOutputMatchesExpected( lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), src, expected=dst) def testInvertPermutation(self): - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.invert_permutation, np.array([1, 2, 0], np.int32), expected=np.array([2, 0, 1], dtype=np.int32)) @@ -271,17 +292,18 @@ class UnaryOpsTest(XLATestCase): def testRank(self): rank_op = lambda x: array_ops.rank_internal(x, optimize=False) for dtype in self.numeric_types: - self._testUnary(rank_op, dtype(7), expected=np.int32(0)) - self._testUnary( + self._assertOpOutputMatchesExpected( + rank_op, dtype(7), expected=np.int32(0)) + self._assertOpOutputMatchesExpected( rank_op, np.array( [[], []], dtype=dtype), expected=np.int32(2)) - self._testUnary( + self._assertOpOutputMatchesExpected( rank_op, np.array( [-1, 1], dtype=dtype), expected=np.int32(1)) - self._testUnary( + self._assertOpOutputMatchesExpected( rank_op, np.array( [[-1, 1]], dtype=dtype), expected=np.int32(2)) - self._testUnary( + self._assertOpOutputMatchesExpected( rank_op, np.array([[-1], [1], [4]], dtype=dtype), expected=np.int32(2)) @@ -289,20 +311,21 @@ class UnaryOpsTest(XLATestCase): def testShape(self): shape_op = lambda x: array_ops.shape_internal(x, optimize=False) for dtype in self.numeric_types: - self._testUnary(shape_op, dtype(7), expected=np.array([], dtype=np.int32)) - self._testUnary( + self._assertOpOutputMatchesExpected( + shape_op, dtype(7), expected=np.array([], dtype=np.int32)) + self._assertOpOutputMatchesExpected( shape_op, np.array([[], []], dtype=dtype), expected=np.array([2, 0], dtype=np.int32)) - self._testUnary( + self._assertOpOutputMatchesExpected( shape_op, np.array([-1, 1], dtype=dtype), expected=np.array([2], dtype=np.int32)) - self._testUnary( + self._assertOpOutputMatchesExpected( shape_op, np.array([[-1, 1]], dtype=dtype), expected=np.array([1, 2], dtype=np.int32)) - self._testUnary( + self._assertOpOutputMatchesExpected( shape_op, np.array([[-1], [1], [4]], dtype=dtype), expected=np.array([3, 1], dtype=np.int32)) @@ -310,20 +333,21 @@ class UnaryOpsTest(XLATestCase): def testSize(self): size_op = lambda x: array_ops.size_internal(x, optimize=False) for dtype in self.numeric_types: - self._testUnary(size_op, dtype(7), expected=np.int32(1)) - self._testUnary( + self._assertOpOutputMatchesExpected( + size_op, dtype(7), expected=np.int32(1)) + self._assertOpOutputMatchesExpected( size_op, np.array([[], []], dtype=dtype), expected=np.int32(0)) - self._testUnary( + self._assertOpOutputMatchesExpected( size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2)) - self._testUnary( + self._assertOpOutputMatchesExpected( size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) - self._testUnary( + self._assertOpOutputMatchesExpected( size_op, np.array([[-1], [1], [4]], dtype=dtype), expected=np.int32(3)) def testUnpack(self): - self._testUnary( + self._assertOpOutputMatchesExpected( array_ops.unstack, np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32), expected=[ @@ -333,13 +357,14 @@ class UnaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) - self._testUnary(lambda x: array_ops.unstack(x, axis=1), - np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32), - expected=[ - np.array([1., 3., 5.], dtype=np.float32), - np.array([2., 4., 6.], dtype=np.float32), - ], - equality_test=self.ListsAreClose) + self._assertOpOutputMatchesExpected( + lambda x: array_ops.unstack(x, axis=1), + np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32), + expected=[ + np.array([1., 3., 5.], dtype=np.float32), + np.array([2., 4., 6.], dtype=np.float32), + ], + equality_test=self.ListsAreClose) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index eced089b32..c3ba1a7a8b 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -33,7 +33,7 @@ namespace { public: \ explicit Name##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ void Compile(XlaOpKernelContext* ctx) { \ - xla::ComputationBuilder& b = *ctx->builder(); \ + xla::ComputationBuilder* b = ctx->builder(); \ xla::ComputationDataHandle x = ctx->Input(0); \ xla::ComputationDataHandle y = COMPUTATION; \ ctx->SetOutput(0, y); \ @@ -42,27 +42,32 @@ namespace { REGISTER_XLA_OP(#Name, Name##Op); // Return x if x>0, otherwise -x. -XLAJIT_MAKE_UNARY(Abs, b.Abs(x)); -XLAJIT_MAKE_UNARY(Ceil, b.Ceil(x)); -XLAJIT_MAKE_UNARY(Exp, b.Exp(x)); -XLAJIT_MAKE_UNARY(Floor, b.Floor(x)); +XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); +XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); +XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); +XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, b.Sign(x)); +XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); // Return 1/x -XLAJIT_MAKE_UNARY(Inv, b.Div(XlaHelpers::One(&b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Reciprocal, b.Div(XlaHelpers::One(&b, input_type(0)), x)); -XLAJIT_MAKE_UNARY(Log, b.Log(x)); -XLAJIT_MAKE_UNARY(LogicalNot, b.LogicalNot(x)); -XLAJIT_MAKE_UNARY(Neg, b.Neg(x)); +XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); +XLAJIT_MAKE_UNARY(Log, b->Log(x)); + +// TODO(b/34703906): use a more accurate implementation of log1p. +XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); + +XLAJIT_MAKE_UNARY(LogicalNot, b->LogicalNot(x)); +XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); XLAJIT_MAKE_UNARY(Rsqrt, - b.Pow(x, XlaHelpers::FloatLiteral(&b, input_type(0), -0.5))); -XLAJIT_MAKE_UNARY(Sigmoid, b.Map({x}, *ctx->GetOrCreateSigmoid(input_type(0)))); + b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); +XLAJIT_MAKE_UNARY(Sigmoid, + b->Map({x}, *ctx->GetOrCreateSigmoid(input_type(0)))); XLAJIT_MAKE_UNARY(Softplus, - b.Log(b.Add(b.Exp(x), XlaHelpers::One(&b, input_type(0))))); + b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0))))); XLAJIT_MAKE_UNARY(Sqrt, - b.Pow(x, XlaHelpers::FloatLiteral(&b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Square, b.Mul(x, x)); -XLAJIT_MAKE_UNARY(Tanh, b.Tanh(x)); + b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); +XLAJIT_MAKE_UNARY(Square, b->Mul(x, x)); +XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x)); #undef XLAJIT_MAKE_UNARY diff --git a/tensorflow/compiler/tf2xla/op_registrations.cc b/tensorflow/compiler/tf2xla/op_registrations.cc index d1a7abb22c..e32070efa3 100644 --- a/tensorflow/compiler/tf2xla/op_registrations.cc +++ b/tensorflow/compiler/tf2xla/op_registrations.cc @@ -117,6 +117,8 @@ REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LinSpace").TypeConstraint("T", kCpuFloatTypes)); REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Log").TypeConstraint("T", kCpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, + Name("Log1p").TypeConstraint("T", kCpuFloatTypes)); REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalAnd")); REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalNot")); REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalOr")); @@ -358,6 +360,8 @@ REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LinSpace").TypeConstraint("T", kGpuFloatTypes)); REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Log").TypeConstraint("T", kGpuFloatTypes)); +REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, + Name("Log1p").TypeConstraint("T", kGpuFloatTypes)); REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalAnd")); REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalNot")); REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalOr")); -- cgit v1.2.3