/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" namespace tensorflow { namespace { class VarIsInitializedOp : public XlaOpKernel { public: explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { XlaResource* variable; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); ctx->SetOutput( 0, xla::ConstantR0(ctx->builder(), variable->initialized())); } }; REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); class VariableShapeOp : public XlaOpKernel { public: explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); } void Compile(XlaOpKernelContext* ctx) override { DataType variable_dtype; TensorShape shape; OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); ctx->SetConstantOutput(0, shape_constant); } private: DataType out_dtype_; }; REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); class ReadVariableOp : public XlaOpKernel { public: explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); } void Compile(XlaOpKernelContext* ctx) override { xla::XlaOp handle; OP_REQUIRES_OK( ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); ctx->SetOutput(0, handle); } private: DataType dtype_; }; REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp); class AssignVariableOp : public XlaOpKernel { public: explicit AssignVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1))); } }; REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp); class AssignAddVariableOp : public XlaOpKernel { public: explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { DataType type = ctx->input_type(1); xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = xla::Add(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( Name("AssignAddVariableOp").TypeConstraint("dtype", kNumericTypes), AssignAddVariableOp); class AssignSubVariableOp : public XlaOpKernel { public: explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { DataType type = ctx->input_type(1); xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = xla::Sub(handle, ctx->Input(1)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( Name("AssignSubVariableOp").TypeConstraint("dtype", kNumericTypes), AssignSubVariableOp); class ResourceGatherOp : public XlaOpKernel { public: explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* builder = ctx->builder(); DataType type = ctx->expected_output_dtype(0); TensorShape resource_shape; xla::XlaOp resource_handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, &resource_handle)); auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); DataType index_type = ctx->input_type(1); xla::XlaOp gather; OP_REQUIRES_OK( ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, /*axis=*/0, /*indices_are_nd=*/false, type, index_type, builder, &gather)); ctx->SetOutput(0, gather); } }; REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); class ResourceScatterOp : public XlaOpKernel { public: explicit ResourceScatterOp( OpKernelConstruction* context, bool indices_are_vectors, std::function combiner) : XlaOpKernel(context), indices_are_vectors_(indices_are_vectors), combiner_(std::move(combiner)) {} void Compile(XlaOpKernelContext* context) override { xla::XlaBuilder* builder = context->builder(); DataType dtype = context->input_type(2); TensorShape var_shape; xla::XlaOp var_value; OP_REQUIRES_OK( context, context->ReadVariableInput(0, dtype, &var_shape, &var_value)); const xla::XlaOp indices = context->Input(1); const xla::XlaOp updates = context->Input(2); auto result = XlaScatter(var_value, updates, indices, indices_are_vectors_, combiner_, builder); OP_REQUIRES_OK(context, result.status()); OP_REQUIRES_OK(context, context->AssignVariable(0, dtype, result.ValueOrDie())); } private: const bool indices_are_vectors_; const std::function combiner_; }; class ResourceScatterAddOp : public ResourceScatterOp { public: explicit ResourceScatterAddOp(OpKernelConstruction* context) : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { return xla::Add(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp); class ResourceScatterSubOp : public ResourceScatterOp { public: explicit ResourceScatterSubOp(OpKernelConstruction* context) : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { return xla::Sub(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp); class ResourceScatterMulOp : public ResourceScatterOp { public: explicit ResourceScatterMulOp(OpKernelConstruction* context) : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { return xla::Mul(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp); class ResourceScatterDivOp : public ResourceScatterOp { public: explicit ResourceScatterDivOp(OpKernelConstruction* context) : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { return xla::Div(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp); class ResourceScatterMinOp : public ResourceScatterOp { public: explicit ResourceScatterMinOp(OpKernelConstruction* context) : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { return xla::Min(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp); class ResourceScatterMaxOp : public ResourceScatterOp { public: explicit ResourceScatterMaxOp(OpKernelConstruction* context) : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { return xla::Max(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp); class ResourceScatterUpdateOp : public ResourceScatterOp { public: explicit ResourceScatterUpdateOp(OpKernelConstruction* context) : ResourceScatterOp(context, /*indices_are_vectors=*/false, /*combiner=*/{}) {} }; REGISTER_XLA_OP(Name("ResourceScatterUpdate"), ResourceScatterUpdateOp); class ResourceScatterNdUpdateOp : public ResourceScatterOp { public: explicit ResourceScatterNdUpdateOp(OpKernelConstruction* context) : ResourceScatterOp(context, /*indices_are_vectors=*/true, /*combiner=*/{}) {} }; REGISTER_XLA_OP(Name("ResourceScatterNdUpdate"), ResourceScatterNdUpdateOp); class ResourceScatterNdAddOp : public ResourceScatterOp { public: explicit ResourceScatterNdAddOp(OpKernelConstruction* context) : ResourceScatterOp(context, /*indices_are_vectors=*/true, /*combiner=*/Combine) {} private: static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, xla::XlaBuilder* builder) { return xla::Add(x, y); } }; REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); } // namespace } // namespace tensorflow