aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_cwise_ops_common.cc
diff options
context:
space:
mode:
authorGravatar Vivek Rane <vivek.v.rane@intel.com>2017-07-09 23:16:36 -0700
committerGravatar Vivek Rane <vivek.v.rane@intel.com>2017-08-28 14:23:14 -0700
commit2546f7930e4488dd7fa6f482ed1fb389d2d32774 (patch)
tree09cd587a3a2690801fbe1e885e039ed8fe5b6d01 /tensorflow/core/kernels/mkl_cwise_ops_common.cc
parent668db64a5d612d5f96b5d87772ce6ff6531fc035 (diff)
Added MKL element-wise ops that utilize eigen ops as their back-end. Also added an input-conversion op that ensures that shapes of both input tensors are compatible (same or broadcastable). Added SquaredDifference to layout pass, and fixed the test for layout pass (it assumed Add/Mul/Sub would not be substituted with Mkl ops) Fixed missing edge deletion for 2 edges Fixing condition for checking broadcast Added more sanity checks Changed check for incoming control edges to code that moves control edges from the elementwise node to the inputconversion node. Fixed bug in CHECK for input types, which did not consider output number of source while checking datatype Fixed unit test bugs and added SquaredDifference elementwise op for CycleGAN Fixed merge issue (code duplication) Ran buildifier and clang-format on changed code Fixed a couple of merge issues and clang-format changes outside the modified code
Diffstat (limited to 'tensorflow/core/kernels/mkl_cwise_ops_common.cc')
-rw-r--r--tensorflow/core/kernels/mkl_cwise_ops_common.cc88
1 files changed, 88 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
new file mode 100644
index 0000000000..7fc633c254
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
@@ -0,0 +1,88 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0(the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+#include <iostream>
+#include <vector>
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Device, typename Functor>
+class MklBinaryOp : public BinaryOp<Device, Functor> {
+ public:
+ explicit MklBinaryOp(OpKernelConstruction* context)
+ : BinaryOp<Device, Functor>(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ auto in0 = context->input(0);
+ auto in1 = context->input(1);
+ VLOG(1) << "Shapes (start mklbinaryop compute): "
+ << in0.shape().DebugString() << " _and_ "
+ << in1.shape().DebugString();
+
+ // Call the TensorFlow BinaryOp Compute method
+ BinaryOp<Device, Functor>::Compute(context);
+
+ auto out = context->mutable_output(0);
+ VLOG(1) << "Shapes (output): " << out->shape().DebugString();
+
+ // Pass input shape through to ouput shape
+ ForwardMklMetaDataInToOut(context, 0, 0);
+
+ out = context->mutable_output(0);
+ VLOG(1) << "Shapes (output): " << out->shape().DebugString();
+ }
+};
+
+//---------- Registration macros for various element-wise ops -----------
+// We will need to redefine "REGISTER" to include the mkl_op_registry flag
+#pragma push_macro("REGISTER")
+#undef REGISTER
+#define REGISTER(OP, D, N, F, T) \
+ REGISTER_KERNEL_BUILDER(Name(N) \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ OP<D##Device, F<T>>);
+
+REGISTER5(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double,
+ int32, int64);
+REGISTER7(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double,
+ int32, int64, complex64, complex128);
+REGISTER5(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double,
+ uint8, int32);
+REGISTER5(MklBinaryOp, CPU, "_MklMaximum", functor::maximum, float, Eigen::half,
+ double, int32, int64);
+REGISTER5(MklBinaryOp, CPU, "_MklSquaredDifference",
+ functor::squared_difference, float, Eigen::half, double, int32,
+ int64);
+
+#undef REGISTER
+#pragma pop_macro("REGISTER")
+//-----------------------------------------------------------------------
+
+} // end namespace tensorflow
+
+#endif // INTEL_MKL