aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/bias_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bias_ops.cc119
1 files changed, 119 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
new file mode 100644
index 0000000000..217e82304e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <numeric>
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+class BiasOp : public XlaOpKernel {
+ public:
+ explicit BiasOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ string data_format;
+ if (ctx->GetAttr("data_format", &data_format).ok()) {
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ } else {
+ data_format_ = FORMAT_NHWC;
+ }
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape bias_shape = ctx->InputShape(1);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrixOrHigher(input_shape),
+ errors::InvalidArgument("Input tensor must be at least 2D: ",
+ input_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias_shape),
+ errors::InvalidArgument("Biases must be 1D: ",
+ bias_shape.DebugString()));
+ int feature_dim = (data_format_ == FORMAT_NHWC) ? input_shape.dims() - 1
+ : input_shape.dims() - 3;
+ OP_REQUIRES(
+ ctx, feature_dim >= 0,
+ errors::InvalidArgument("Input tensor does not have enough dimensions "
+ "to contain the feature dimension"));
+ OP_REQUIRES(
+ ctx, bias_shape.dim_size(0) == input_shape.dim_size(feature_dim),
+ errors::InvalidArgument(
+ "Must provide as many biases as the last dimension "
+ "of the input tensor: ",
+ bias_shape.DebugString(), " vs. ", input_shape.DebugString()));
+
+ xla::ComputationDataHandle result =
+ ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim});
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ TensorFormat data_format_;
+};
+
+REGISTER_XLA_OP("BiasAdd", BiasOp);
+REGISTER_XLA_OP("BiasAddV1", BiasOp);
+
+class BiasAddGradOp : public XlaOpKernel {
+ public:
+ explicit BiasAddGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ string data_format;
+ if (ctx->GetAttr("data_format", &data_format).ok()) {
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ } else {
+ data_format_ = FORMAT_NHWC;
+ }
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape out_backprop_shape = ctx->InputShape(0);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrixOrHigher(out_backprop_shape),
+ errors::InvalidArgument("Input tensor must be at least 2D: ",
+ out_backprop_shape.DebugString()));
+
+ int feature_dim = (data_format_ == FORMAT_NHWC)
+ ? out_backprop_shape.dims() - 1
+ : out_backprop_shape.dims() - 3;
+ OP_REQUIRES(
+ ctx, feature_dim >= 0,
+ errors::InvalidArgument("Input tensor does not have enough dimensions "
+ "to contain the feature dimension"));
+
+ std::vector<int64> reduce_dims(out_backprop_shape.dims() - 1);
+ std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0);
+ std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(),
+ feature_dim + 1);
+ xla::ComputationDataHandle result = ctx->builder()->Reduce(
+ ctx->Input(0), XlaHelpers::Zero(ctx->builder(), input_type(0)),
+ *ctx->GetOrCreateAdd(input_type(0)), reduce_dims);
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ TensorFormat data_format_;
+};
+
+REGISTER_XLA_OP("BiasAddGrad", BiasAddGradOp);
+
+} // namespace
+} // namespace tensorflow