diff options
author | 2018-10-04 05:34:55 -0700 | |
---|---|---|
committer | 2018-10-04 05:38:38 -0700 | |
commit | 28f239fdfa0c94f715fccf0197ab6c3c8df27d28 (patch) | |
tree | 9a5998dc21a3428c292e597afce0a4a15ffeee26 /tensorflow/compiler/tf2xla | |
parent | 9cd6cab4f85f1f35c6532da3fb68839294d44ee4 (diff) |
Implement DataFormatVecPermute for XLA.
Also clear "_kernel" attributes of nodes if they are set to "host".
This is not meaningful when processing the graph for XLA, and it
would prevent finding the registered XLA kernel.
PiperOrigin-RevId: 215722216
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/permute_op.cc | 98 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 11 |
3 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 3e823254d3..9a7130f253 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -62,6 +62,7 @@ tf_kernel_library( "one_hot_op.cc", "pack_op.cc", "pad_op.cc", + "permute_op.cc", "pooling_ops.cc", "qr_op.cc", "quantize_and_dequantize_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc new file mode 100644 index 0000000000..0764e5503d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> +#include <vector> + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +class DataFormatVecPermuteOp : public XlaOpKernel { + public: + explicit DataFormatVecPermuteOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_)); + OP_REQUIRES( + ctx, src_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + TensorFormat data_format; + OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_)); + OP_REQUIRES( + ctx, dst_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + } + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + const TensorShape input_tensor_shape = ctx->InputShape(0); + int input_rank = input_tensor_shape.dims(); + OP_REQUIRES(ctx, input_rank == 1 || input_rank == 2, + errors::InvalidArgument( + "Input must be a vector or matrix, but got shape ", + input_tensor_shape.DebugString())); + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(0) == 4, + errors::InvalidArgument( + "First dimension of input must be of size 4, but got shape ", + input_tensor_shape.DebugString())); + if (input_rank == 2) { + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(1) == 2, + errors::InvalidArgument( + "Second dimension of 2D input must be of size 2, but got shape ", + input_tensor_shape.DebugString())); + } + std::vector<int32> dst_indices(4, 0); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + if (src_format_[i] == dst_format_[j]) { + dst_indices[i] = j; + break; + } + } + } + auto keys = xla::ConstantR1(builder, absl::Span<const int32>(dst_indices)); + if (input_rank == 2) { + keys = xla::BroadcastInDim( + keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0}); + } + auto sorted = xla::Sort(keys, ctx->Input(0), 0); + auto output = xla::GetTupleElement(sorted, 1); + ctx->SetOutput(0, output); + } + + private: + string src_format_; + string dst_format_; + + TF_DISALLOW_COPY_AND_ASSIGN(DataFormatVecPermuteOp); +}; + +// TODO(b/115384656): Support DT_INT64. +REGISTER_XLA_OP(Name("DataFormatVecPermute").TypeConstraint("T", DT_INT32), + DataFormatVecPermuteOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index d5094e8ec5..b2c57e8880 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -194,6 +194,17 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, std::unique_ptr<Graph> graph = GetGraph(fbody); + // Clear the "_kernel" attribute if it is set to "host". This is used to + // indicate that a computation should happen on the host instead of the + // accelerator, but doesn't make sense in XLA. + const char* const kKernelAttr = "_kernel"; + for (Node* n : graph->nodes()) { + string value; + if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") { + n->ClearAttr(kKernelAttr); + } + } + // _Arg and _Retval nodes don't exist in the stored subgraph for the function; // they are added by the function body looked up. Therefore, they don't have // core assignments here. |