aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/unary_ops_composition.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/unary_ops_composition.cc')
-rw-r--r--tensorflow/core/kernels/unary_ops_composition.cc432
1 files changed, 432 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/unary_ops_composition.cc b/tensorflow/core/kernels/unary_ops_composition.cc
new file mode 100644
index 0000000000..0c2cb1b39f
--- /dev/null
+++ b/tensorflow/core/kernels/unary_ops_composition.cc
@@ -0,0 +1,432 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/relu_op_functor.h"
+
+namespace tensorflow {
+
+template <typename T>
+class UnaryOpsComposition; // forward declare kernel
+
+template <typename T>
+struct UnaryOpsCompositionSupport;
+
+template <typename T>
+struct UnaryOpsCompositionBase {
+ using InputBuffer = typename TTypes<T>::ConstFlat;
+ using OutputBuffer = typename TTypes<T>::Flat;
+
+ using ComputeFn = void (*)(const InputBuffer&, OutputBuffer*);
+
+ struct ComputeFnRegistration {
+ ComputeFn compute_fn;
+ int cost;
+ };
+
+ bool HasComputeFn(const string& name) {
+ return compute_fns.find(name) != compute_fns.end();
+ }
+
+ protected:
+ void RegisterComputeFn(const string& name, ComputeFn compute_fn, int cost) {
+ VLOG(5) << "Register compute fn: name=" << name << " cost=" << cost;
+ compute_fns[name] = {compute_fn, cost};
+ }
+
+ private:
+ friend class UnaryOpsComposition<T>;
+
+ Status ExportComputeFns(const std::vector<string>& op_names,
+ std::vector<ComputeFn>* fns, int* cost) {
+ for (const string& op_name : op_names) {
+ auto it = compute_fns.find(op_name);
+ if (it == compute_fns.end())
+ return errors::InvalidArgument(
+ "Do not have a compute function registered for op: ", op_name);
+
+ const ComputeFnRegistration& reg = it->second;
+ fns->push_back(reg.compute_fn);
+ *cost += reg.cost;
+ }
+
+ return Status::OK();
+ }
+
+ std::unordered_map<string, ComputeFnRegistration> compute_fns;
+};
+
+template <typename T>
+class UnaryOpsComposition : public OpKernel {
+ public:
+ using Kernel = UnaryOpsComposition<T>;
+
+ using Scalar = T;
+ using Packet = typename Eigen::internal::packet_traits<T>::type;
+
+ using Support = UnaryOpsCompositionSupport<T>;
+
+ using InputBuffer = typename Support::InputBuffer;
+ using OutputBuffer = typename Support::OutputBuffer;
+ using ComputeFn = typename Support::ComputeFn;
+
+ explicit UnaryOpsComposition(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("op_names", &op_names_));
+
+ OP_REQUIRES(context, !op_names_.empty(),
+ errors::InvalidArgument(
+ "Unary op composition must have at least one op"));
+
+ OP_REQUIRES_OK(context,
+ support_.ExportComputeFns(op_names_, &fns_, &cost_));
+
+ VLOG(2) << "Composed unary op: [" << str_util::Join(op_names_, ", ")
+ << "]; cost=" << cost_;
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& in = ctx->input(0);
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(
+ ctx, ctx->forward_input_or_allocate_output({0}, 0, in.shape(), &out));
+
+ InputBuffer in_flat = in.flat<T>();
+ OutputBuffer out_flat = out->flat<T>();
+
+ const std::size_t num_fns = fns_.size();
+ auto compute_fn = [this, &in_flat, &out_flat, &num_fns](int64 begin,
+ int64 end) {
+ int64 len = end - begin;
+ const InputBuffer in_slice(in_flat.data() + begin, len);
+ const InputBuffer scratch_slice(out_flat.data() + begin, len);
+ OutputBuffer out_slice(out_flat.data() + begin, len);
+
+ fns_[0](in_slice, &out_slice);
+ for (int i = 1; i < num_fns; ++i) {
+ fns_[i](scratch_slice, &out_slice);
+ }
+ };
+
+ const CPUDevice& device = ctx->eigen_device<CPUDevice>();
+ const int kOverheadCycles = static_cast<int>(num_fns) * 10;
+ Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T) * num_fns,
+ /*bytes_stored=*/sizeof(T) * num_fns,
+ kOverheadCycles + cost_);
+ device.parallelFor(in.NumElements(), cost, AlignBlockSize,
+ std::move(compute_fn));
+ }
+
+ private:
+ static const int kPacketSize = Eigen::internal::unpacket_traits<Packet>::size;
+
+ static inline int64 AlignBlockSize(int64 block_size) {
+ // Align block size to packet size and account for unrolling in run above.
+ if (block_size >= 16 * kPacketSize) {
+ return (block_size + 4 * kPacketSize - 1) & ~(4 * kPacketSize - 1);
+ }
+ // Aligning to 4 * PacketSize would increase block size by more than 25%.
+ return (block_size + kPacketSize - 1) & ~(kPacketSize - 1);
+ }
+
+ Support support_;
+
+ std::vector<string> op_names_;
+ std::vector<ComputeFn> fns_;
+ int cost_ = 0;
+};
+
+// Register compute functions for UnaryOp functors.
+#define REGISTER_COMPUTE_FN_HELPER(name, functor) \
+ static_assert(std::is_same<functor::in_type, functor::out_type>::value, \
+ "Functor must have same input and output types"); \
+ \
+ static inline void Compute##name(const InputBuffer& in, OutputBuffer* out) { \
+ *out = in.unaryExpr(functor::func()); \
+ } \
+ static inline int Cost##name() { \
+ return Eigen::internal::functor_traits<functor::func>::Cost; \
+ }
+
+// Register compute function for the Relu/Relu6/Elu/Selu.
+#define REGISTER_RELU_HELPER() \
+ template <typename T> \
+ using functor_traits = Eigen::internal::functor_traits<T>; \
+ \
+ static inline void ComputeRelu(const InputBuffer& in, OutputBuffer* out) { \
+ auto relu = functor::Relu<Eigen::DefaultDevice, T>(); \
+ relu(Eigen::DefaultDevice(), in, *out); \
+ } \
+ \
+ static inline int CostRelu() { \
+ return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost; \
+ } \
+ \
+ static inline void ComputeRelu6(const InputBuffer& in, OutputBuffer* out) { \
+ auto relu6 = functor::Relu6<Eigen::DefaultDevice, T>(); \
+ relu6(Eigen::DefaultDevice(), in, *out); \
+ } \
+ \
+ static inline int CostRelu6() { \
+ return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost + \
+ functor_traits<Eigen::internal::scalar_min_op<T>>::Cost; \
+ } \
+ static inline void ComputeElu(const InputBuffer& in, OutputBuffer* out) { \
+ auto elu = functor::Elu<Eigen::DefaultDevice, T>(); \
+ elu(Eigen::DefaultDevice(), in, *out); \
+ } \
+ \
+ static inline int CostElu() { \
+ return functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \
+ Eigen::NumTraits<T>::MulCost; \
+ } \
+ static inline void ComputeSelu(const InputBuffer& in, OutputBuffer* out) { \
+ auto selu = functor::Selu<Eigen::DefaultDevice, T>(); \
+ selu(Eigen::DefaultDevice(), in, *out); \
+ } \
+ \
+ static inline int CostSelu() { \
+ return 2 * (functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \
+ Eigen::NumTraits<T>::MulCost); \
+ }
+
+#define REGISTER_COMPUTE_FN(func) \
+ RegisterComputeFn(#func, Compute##func, Cost##func());
+
+template <>
+struct UnaryOpsCompositionSupport<float> : UnaryOpsCompositionBase<float> {
+ using T = float;
+
+ UnaryOpsCompositionSupport() {
+ // UnaryOp functors.
+ REGISTER_COMPUTE_FN(Abs);
+ REGISTER_COMPUTE_FN(Acos);
+ REGISTER_COMPUTE_FN(Acosh);
+ REGISTER_COMPUTE_FN(Asin);
+ REGISTER_COMPUTE_FN(Asinh);
+ REGISTER_COMPUTE_FN(Atan);
+ REGISTER_COMPUTE_FN(Atanh);
+ REGISTER_COMPUTE_FN(Ceil);
+ REGISTER_COMPUTE_FN(Cos);
+ REGISTER_COMPUTE_FN(Cosh);
+ REGISTER_COMPUTE_FN(Expm1);
+ REGISTER_COMPUTE_FN(Exp);
+ REGISTER_COMPUTE_FN(Floor);
+ REGISTER_COMPUTE_FN(Inv);
+ REGISTER_COMPUTE_FN(Log);
+ REGISTER_COMPUTE_FN(Log1p);
+ REGISTER_COMPUTE_FN(Neg);
+ REGISTER_COMPUTE_FN(Reciprocal);
+ REGISTER_COMPUTE_FN(Rint);
+ REGISTER_COMPUTE_FN(Round);
+ REGISTER_COMPUTE_FN(Rsqrt);
+ REGISTER_COMPUTE_FN(Sigmoid);
+ REGISTER_COMPUTE_FN(Sin);
+ REGISTER_COMPUTE_FN(Sinh);
+ REGISTER_COMPUTE_FN(Sqrt);
+ REGISTER_COMPUTE_FN(Square);
+ REGISTER_COMPUTE_FN(Tan);
+ REGISTER_COMPUTE_FN(Tanh);
+
+ // Additional compute functions not defined via UnaryOp functors.
+ REGISTER_COMPUTE_FN(Elu);
+ REGISTER_COMPUTE_FN(Relu);
+ REGISTER_COMPUTE_FN(Relu6);
+ REGISTER_COMPUTE_FN(Selu);
+ }
+
+ REGISTER_RELU_HELPER();
+
+ // clang-format off
+ REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>);
+ REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>);
+ REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>);
+ REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>);
+ REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>);
+ REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>);
+ REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>);
+ REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>);
+ REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>);
+ // clang-format on
+};
+
+template <>
+struct UnaryOpsCompositionSupport<Eigen::half>
+ : UnaryOpsCompositionBase<Eigen::half> {
+ using T = Eigen::half;
+
+ UnaryOpsCompositionSupport() {
+ REGISTER_COMPUTE_FN(Abs);
+ REGISTER_COMPUTE_FN(Ceil);
+ REGISTER_COMPUTE_FN(Cos);
+ REGISTER_COMPUTE_FN(Expm1);
+ REGISTER_COMPUTE_FN(Exp);
+ REGISTER_COMPUTE_FN(Floor);
+ REGISTER_COMPUTE_FN(Inv);
+ REGISTER_COMPUTE_FN(Log);
+ REGISTER_COMPUTE_FN(Log1p);
+ REGISTER_COMPUTE_FN(Neg);
+ REGISTER_COMPUTE_FN(Reciprocal);
+ REGISTER_COMPUTE_FN(Round);
+ REGISTER_COMPUTE_FN(Rsqrt);
+ REGISTER_COMPUTE_FN(Sigmoid);
+ REGISTER_COMPUTE_FN(Sin);
+ REGISTER_COMPUTE_FN(Sqrt);
+ REGISTER_COMPUTE_FN(Square);
+ REGISTER_COMPUTE_FN(Tanh);
+ // Additional compute functions not defined via UnaryOp functors.
+ REGISTER_COMPUTE_FN(Elu);
+ REGISTER_COMPUTE_FN(Relu);
+ REGISTER_COMPUTE_FN(Relu6);
+ REGISTER_COMPUTE_FN(Selu);
+ }
+
+ REGISTER_RELU_HELPER();
+
+ // clang-format off
+ REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>);
+ REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>);
+ REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>);
+ REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>);
+ REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>);
+ REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>);
+ REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>);
+ // clang-format on
+};
+
+template <>
+struct UnaryOpsCompositionSupport<double> : UnaryOpsCompositionBase<double> {
+ using T = double;
+
+ UnaryOpsCompositionSupport() {
+ REGISTER_COMPUTE_FN(Abs);
+ REGISTER_COMPUTE_FN(Acos);
+ REGISTER_COMPUTE_FN(Acosh);
+ REGISTER_COMPUTE_FN(Asin);
+ REGISTER_COMPUTE_FN(Asinh);
+ REGISTER_COMPUTE_FN(Atan);
+ REGISTER_COMPUTE_FN(Atanh);
+ REGISTER_COMPUTE_FN(Ceil);
+ REGISTER_COMPUTE_FN(Cos);
+ REGISTER_COMPUTE_FN(Cosh);
+ REGISTER_COMPUTE_FN(Expm1);
+ REGISTER_COMPUTE_FN(Exp);
+ REGISTER_COMPUTE_FN(Floor);
+ REGISTER_COMPUTE_FN(Inv);
+ REGISTER_COMPUTE_FN(Log);
+ REGISTER_COMPUTE_FN(Log1p);
+ REGISTER_COMPUTE_FN(Neg);
+ REGISTER_COMPUTE_FN(Reciprocal);
+ REGISTER_COMPUTE_FN(Rint);
+ REGISTER_COMPUTE_FN(Round);
+ REGISTER_COMPUTE_FN(Rsqrt);
+ REGISTER_COMPUTE_FN(Sigmoid);
+ REGISTER_COMPUTE_FN(Sin);
+ REGISTER_COMPUTE_FN(Sinh);
+ REGISTER_COMPUTE_FN(Sqrt);
+ REGISTER_COMPUTE_FN(Square);
+ REGISTER_COMPUTE_FN(Tan);
+ REGISTER_COMPUTE_FN(Tanh);
+ // Additional compute functions not defined via UnaryOp functors.
+ REGISTER_COMPUTE_FN(Elu);
+ REGISTER_COMPUTE_FN(Relu);
+ REGISTER_COMPUTE_FN(Relu6);
+ REGISTER_COMPUTE_FN(Selu);
+ }
+
+ REGISTER_RELU_HELPER();
+
+ // clang-format off
+ REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>);
+ REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>);
+ REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>);
+ REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>);
+ REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>);
+ REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>);
+ REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>);
+ REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>);
+ REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>);
+ // clang-format on
+};
+
+// Register the CPU kernels.
+#define REGISTER_CPU(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_UnaryOpsComposition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ UnaryOpsComposition<T>);
+
+REGISTER_CPU(float);
+REGISTER_CPU(Eigen::half);
+REGISTER_CPU(double);
+
+#undef REGISTER_CPU
+
+} // namespace tensorflow