aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Chris Leary <leary@google.com>2017-07-26 18:39:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-26 18:43:32 -0700
commit32e198f2d5787ca81aba89bf073e4eb380769253 (patch)
tree72051598edeaaedde73ae4658cabaa4cadc1b3a8
parent9b30dc3a824fd277fcd622a458b25f26c0db7b72 (diff)
[TF:XLA] Add tf.cross support.
See #11788 PiperOrigin-RevId: 163287731
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cross_op.cc87
3 files changed, 106 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 9eaede7f40..83cfd2ea75 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -765,6 +765,24 @@ class BinaryOpsTest(XLATestCase):
np.array([1, 0], dtype=np.int32),
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
+ def testCross(self):
+ for dtype in self.float_types:
+ self._testBinary(
+ gen_math_ops.cross,
+ np.zeros((4, 3), dtype=dtype),
+ np.zeros((4, 3), dtype=dtype),
+ expected=np.zeros((4, 3), dtype=dtype))
+ self._testBinary(
+ gen_math_ops.cross,
+ np.array([1, 2, 3], dtype=dtype),
+ np.array([4, 5, 6], dtype=dtype),
+ expected=np.array([-3, 6, -3], dtype=dtype))
+ self._testBinary(
+ gen_math_ops.cross,
+ np.array([[1, 2, 3], [10, 11, 12]], dtype=dtype),
+ np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype),
+ expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype))
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 546e9be864..b114b7e6f8 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -24,6 +24,7 @@ tf_kernel_library(
"concat_op.cc",
"const_op.cc",
"conv_ops.cc",
+ "cross_op.cc",
"cwise_ops.cc",
"depthwise_conv_ops.cc",
"diag_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
new file mode 100644
index 0000000000..3df8c00f1b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+class CrossOp : public XlaOpKernel {
+ public:
+ explicit CrossOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape in0_shape = ctx->InputShape(0);
+ TensorShape in1_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, in0_shape == in1_shape,
+ errors::InvalidArgument("Both inputs must be of same shape: ",
+ in0_shape.DebugString(), " vs. ",
+ in1_shape.DebugString()));
+ OP_REQUIRES(ctx, in0_shape.dims() >= 1,
+ errors::InvalidArgument("Input must be at least 1D",
+ in0_shape.DebugString()));
+
+ auto inner_dim = in0_shape.dim_size(in0_shape.dims() - 1);
+ OP_REQUIRES(ctx, inner_dim == 3,
+ errors::FailedPrecondition(
+ "Cross-products are only defined for 3-element vectors."));
+
+ // in0 is a [...,X,Y,Z,3]
+ // in1 is the same shape as in0
+ // So slice 0 is: in0[...,:,:,:,0:1]
+ // So slice 1 is: in0[...,:,:,:,1:2]
+ // So slice 2 is: in0[...,:,:,:,2:3]
+
+ std::vector<int64> starts(in0_shape.dims(), 0);
+ std::vector<int64> limits;
+ for (auto dim_size : in0_shape.dim_sizes()) {
+ limits.push_back(dim_size);
+ }
+ std::vector<int64> strides(in0_shape.dims(), 1);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ auto in0 = ctx->Input(0);
+ auto in1 = ctx->Input(1);
+ starts.back() = 0;
+ limits.back() = 1;
+ auto u1 = b->Slice(in0, starts, limits, strides);
+ auto v1 = b->Slice(in1, starts, limits, strides);
+ starts.back() = 1;
+ limits.back() = 2;
+ auto u2 = b->Slice(in0, starts, limits, strides);
+ auto v2 = b->Slice(in1, starts, limits, strides);
+ starts.back() = 2;
+ limits.back() = 3;
+ auto u3 = b->Slice(in0, starts, limits, strides);
+ auto v3 = b->Slice(in1, starts, limits, strides);
+
+ auto s1 = b->Sub(b->Mul(u2, v3), b->Mul(u3, v2));
+ auto s2 = b->Sub(b->Mul(u3, v1), b->Mul(u1, v3));
+ auto s3 = b->Sub(b->Mul(u1, v2), b->Mul(u2, v1));
+ auto output = b->ConcatInDim({s1, s2, s3}, in0_shape.dims() - 1);
+
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(CrossOp);
+};
+
+REGISTER_XLA_OP(Name("Cross"), CrossOp);
+
+} // namespace
+} // namespace tensorflow