aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-10-04 05:34:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 05:38:38 -0700
commit28f239fdfa0c94f715fccf0197ab6c3c8df27d28 (patch)
tree9a5998dc21a3428c292e597afce0a4a15ffeee26 /tensorflow/compiler
parent9cd6cab4f85f1f35c6532da3fb68839294d44ee4 (diff)
Implement DataFormatVecPermute for XLA.
Also clear "_kernel" attributes of nodes if they are set to "host". This is not meaningful when processing the graph for XLA, and it would prevent finding the registered XLA kernel. PiperOrigin-RevId: 215722216
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tests/BUILD13
-rw-r--r--tensorflow/compiler/tests/permute_test.py80
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/permute_op.cc98
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc11
5 files changed, 203 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 822fedf121..ee36729fd1 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1029,6 +1029,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "permute_test",
+ size = "small",
+ srcs = ["permute_test.py"],
+ deps = [
+ "//tensorflow/compiler/tests:xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:nn_ops",
+ ],
+)
+
+tf_xla_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
diff --git a/tensorflow/compiler/tests/permute_test.py b/tensorflow/compiler/tests/permute_test.py
new file mode 100644
index 0000000000..dbb9274df4
--- /dev/null
+++ b/tensorflow/compiler/tests/permute_test.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the DataFormatVecPermute operator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+
+class XlaPermuteOpTest(xla_test.XLATestCase):
+
+ def _runPermuteAndCompare(self, x, src_format, dst_format, expected):
+ with self.cached_session() as session:
+ with self.test_scope():
+ placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape)
+ param = {placeholder: x}
+ output = nn_ops.data_format_vec_permute(
+ placeholder, src_format=src_format, dst_format=dst_format)
+ result = session.run(output, param)
+ self.assertAllEqual(result, expected)
+
+ def testNHWCToNCHW(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9])
+
+ def testNCHWToNHWC(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4])
+
+ def testNHWCToHWNC(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "HWNC", [4, 9, 7, 3])
+
+ def testHWNCToNHWC(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "HWNC", "NHWC", [9, 7, 4, 3])
+
+ def testNHWCToNCHW2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "NCHW",
+ [[7, 4], [5, 1], [9, 3], [4, 5]])
+
+ def testNHWCToHWNC2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "HWNC",
+ [[9, 3], [4, 5], [7, 4], [5, 1]])
+
+ def testHWNCToNHWC2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "HWNC", "NHWC",
+ [[4, 5], [7, 4], [9, 3], [5, 1]])
+
+ def testNCHWToNHWC2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NCHW", "NHWC",
+ [[7, 4], [4, 5], [5, 1], [9, 3]])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 3e823254d3..9a7130f253 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -62,6 +62,7 @@ tf_kernel_library(
"one_hot_op.cc",
"pack_op.cc",
"pad_op.cc",
+ "permute_op.cc",
"pooling_ops.cc",
"qr_op.cc",
"quantize_and_dequantize_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc
new file mode 100644
index 0000000000..0764e5503d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+class DataFormatVecPermuteOp : public XlaOpKernel {
+ public:
+ explicit DataFormatVecPermuteOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_));
+ OP_REQUIRES(
+ ctx, src_format_.size() == 4,
+ errors::InvalidArgument("Data format should have 4 characters"));
+ TensorFormat data_format;
+ OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_));
+ OP_REQUIRES(
+ ctx, dst_format_.size() == 4,
+ errors::InvalidArgument("Data format should have 4 characters"));
+ OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format),
+ errors::InvalidArgument("Invalid data format"));
+ }
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto builder = ctx->builder();
+ const TensorShape input_tensor_shape = ctx->InputShape(0);
+ int input_rank = input_tensor_shape.dims();
+ OP_REQUIRES(ctx, input_rank == 1 || input_rank == 2,
+ errors::InvalidArgument(
+ "Input must be a vector or matrix, but got shape ",
+ input_tensor_shape.DebugString()));
+ OP_REQUIRES(
+ ctx, input_tensor_shape.dim_size(0) == 4,
+ errors::InvalidArgument(
+ "First dimension of input must be of size 4, but got shape ",
+ input_tensor_shape.DebugString()));
+ if (input_rank == 2) {
+ OP_REQUIRES(
+ ctx, input_tensor_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "Second dimension of 2D input must be of size 2, but got shape ",
+ input_tensor_shape.DebugString()));
+ }
+ std::vector<int32> dst_indices(4, 0);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ if (src_format_[i] == dst_format_[j]) {
+ dst_indices[i] = j;
+ break;
+ }
+ }
+ }
+ auto keys = xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
+ if (input_rank == 2) {
+ keys = xla::BroadcastInDim(
+ keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0});
+ }
+ auto sorted = xla::Sort(keys, ctx->Input(0), 0);
+ auto output = xla::GetTupleElement(sorted, 1);
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ string src_format_;
+ string dst_format_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DataFormatVecPermuteOp);
+};
+
+// TODO(b/115384656): Support DT_INT64.
+REGISTER_XLA_OP(Name("DataFormatVecPermute").TypeConstraint("T", DT_INT32),
+ DataFormatVecPermuteOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index d5094e8ec5..b2c57e8880 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -194,6 +194,17 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
std::unique_ptr<Graph> graph = GetGraph(fbody);
+ // Clear the "_kernel" attribute if it is set to "host". This is used to
+ // indicate that a computation should happen on the host instead of the
+ // accelerator, but doesn't make sense in XLA.
+ const char* const kKernelAttr = "_kernel";
+ for (Node* n : graph->nodes()) {
+ string value;
+ if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") {
+ n->ClearAttr(kKernelAttr);
+ }
+ }
+
// _Arg and _Retval nodes don't exist in the stored subgraph for the function;
// they are added by the function body looked up. Therefore, they don't have
// core assignments here.