diff options
author | 2018-08-30 11:20:50 -0700 | |
---|---|---|
committer | 2018-08-30 11:20:50 -0700 | |
commit | 6d2ea449ac4661e23b0ff61516d27a1b728b54b6 (patch) | |
tree | 2a707b254185041606b5dd644255962fcfd5ee02 /tensorflow/core/kernels/mkl_softmax_op.cc | |
parent | f54671240a5429a7146cc1a01efc73db12d0aab5 (diff) |
fixed clang formatting
Diffstat (limited to 'tensorflow/core/kernels/mkl_softmax_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_softmax_op.cc | 33 |
1 files changed, 15 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc index fa9855103a..1a01ee37be 100644 --- a/tensorflow/core/kernels/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -17,13 +17,13 @@ limitations under the License. #ifdef INTEL_MKL #ifndef INTEL_MKL_ML_ONLY +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/tensor_format.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/util/mkl_util.h" @@ -50,40 +50,39 @@ class MklSoftmaxOp : public OpKernel { // src_tensor now points to the 0-th input of global data struct "context" size_t src_idx = 0; const Tensor& src_tensor = MklGetInput(context, src_idx); - - // Get number of dimensions of src const int input_dims = src_tensor.dims(); // Add: get MklShape MklDnnShape src_mkl_shape; GetMklShape(context, src_idx, &src_mkl_shape); - // Set layout type based on input_dims + // src_dims is the dimenstion of src_tensor + // dim of the dst will also be same as src_dims + auto src_tf_shape = src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetTfShape() + : src_tensor.shape(); + auto src_dims = TFShapeToMklDnnDims(src_tf_shape); + auto output_dims = src_dims; memory::format layout_type; switch (input_dims) { case 1: layout_type = memory::format::x; + break; case 2: layout_type = memory::format::nc; + break; case 3: layout_type = memory::format::tnc; + break; case 4: layout_type = memory::format::nchw; + break; case 5: layout_type = memory::format::ncdhw; + break; default: - ctx->SetStatus( - errors::Unimplemented(" MKL softmax does not support 1 > input_dims > 5 "); + OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5:")); } - - // src_dims is the dimenstion of src_tensor - // dim of the dst will also be same as src_dims - auto src_tf_shape = src_mkl_shape.IsMklTensor() - ? src_mkl_shape.GetTfShape() - : src_tensor.shape(); - auto src_dims = TFShapeToMklDnnDims(src_tf_shape); - auto output_dims = src_dims; - // Create softmax memory for src, dst: both are defined in mkl_util.h, // they are wrapper MklDnnData<T> src(&cpu_engine); @@ -108,9 +107,7 @@ class MklSoftmaxOp : public OpKernel { src.SetOpMemDesc(src_dims, layout_type); // creating a memory descriptor - - // axis to which softmax will be applied - // Always softmax is applied to the last axis + // passing outermost dim as default axis int axis = input_dims - 1; auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring, src.GetOpMemDesc(), axis); |