diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_softmax_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_softmax_op.cc | 163 |
1 files changed, 163 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc new file mode 100644 index 0000000000..896d562933 --- /dev/null +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -0,0 +1,163 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. +#ifdef INTEL_MKL +#ifdef INTEL_MKL_DNN + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "mkldnn.h" +#include "mkldnn_types.h" +#include "tensorflow/core/platform/default/logging.h" +#include "tensorflow/core/util/mkl_util.h" + +#include "mkldnn.hpp" +using mkldnn::stream; +using mkldnn::prop_kind; +using mkldnn::softmax_forward; + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + + + +template <typename Device, typename T> +class MklSoftmaxOp : public OpKernel { + public: + ~MklSoftmaxOp() {} + + explicit MklSoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + try { + auto cpu_engine = engine(engine::cpu, 0); + + // src_tensor now points to the 0-th input of global data struct "context" + size_t src_idx = 0; + const Tensor& src_tensor = MklGetInput(context, src_idx); + + // Add: get MklShape + MklDnnShape src_mkl_shape; + GetMklShape(context, src_idx, &src_mkl_shape); + + + // src_dims is the dimenstion of src_tensor + // dim of the dst will also be same as src_dims + auto src_tf_shape = src_mkl_shape.IsMklTensor() ? + src_mkl_shape.GetTfShape() : src_tensor.shape(); + auto src_dims = TFShapeToMklDnnDims(src_tf_shape); + auto output_dims = src_dims; + + // Create softmax memory for src, dst: both are defined in mkl_util.h, + // they are wrapper + MklDnnData<T> src(&cpu_engine); + MklDnnData<T> dst(&cpu_engine); + + // If input is in MKL layout, then simply grab input layout; otherwise, + // construct input Tf layout. For TF layout, although input shape + // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's + // layout + auto src_md = src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetMklLayout() + : memory::desc(src_dims, MklDnnType<T>(), + memory::format::nc); + + // src: setting memory descriptor and op memory descriptor + // Basically following two functions maps the TF "src_tensor" to mkl + // tensor object "src" + // following functions are in mkl_util.h + // data format is "nc" for src and dst; since the src and dst buffer is + // always in 2D shape + src.SetUsrMem(src_md, &src_tensor); + src.SetOpMemDesc(src_dims, memory::format::nc); + + // creating a memory descriptor + int axis = 1; // axis to which softmax will be applied + auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring, + src.GetOpMemDesc(), axis); + auto softmax_fwd_pd = softmax_forward::primitive_desc(softmax_fwd_desc, + cpu_engine); + + // add: output + Tensor* output_tensor = nullptr; + MklDnnShape output_mkl_shape; + TensorShape output_tf_shape; // shape of output TF tensor. + // Softmax MklDnn output layout is same as input layout. + auto dst_pd = src.GetUsrMemPrimDesc(); + + // if input is MKL shape, ouput is also MKL shape. + // if input is TF shape, output is also TF shape + if (src_mkl_shape.IsMklTensor()) { + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetElemType(MklDnnType<T>()); + output_mkl_shape.SetTfLayout(output_dims.size(), output_dims, + memory::format::nc); + output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T))); + } else { // then output is also TF shape + output_mkl_shape.SetMklTensor(false); + output_tf_shape = MklDnnDimsToTFShape(output_dims); + } + // Allocate output shape (MKL or TF based on the above) + AllocateOutputSetMklShape(context, 0, &output_tensor, output_tf_shape, + output_mkl_shape); + + // Output_dims and input_dims are same + dst.SetUsrMem(src_md, output_tensor); + + // finally creating the "softmax op" using the primitive descriptor, src + // and dst + auto softmax_fwd = + softmax_forward(softmax_fwd_pd, src.GetOpMem(), dst.GetOpMem()); + + // execute net (pushing to the stream) + // following 3 are common for all mkl dnn ops + std::vector<primitive> net; + net.push_back(softmax_fwd); + stream(stream::kind::eager).submit(net).wait(); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); + } + } +}; + +/* Register DNN kernels for supported operations and supported types - right now + * it is only Softmax and f32 */ +#define REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES(type) \ + REGISTER_KERNEL_BUILDER(Name("_MklSoftmax") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklSoftmaxOp<CPUDevice, type>); +TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES); + + +} // namespace tensorflow + +#endif // INTEL_MKL_DNN +#endif // INTEL_MKL |