aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_softmax_op.cc
diff options
context:
space:
mode:
authorGravatar mbhuiyan <mohammad.ashraf.bhuiyan@intel.com>2018-08-30 11:20:50 -0700
committerGravatar mbhuiyan <mohammad.ashraf.bhuiyan@intel.com>2018-08-30 11:20:50 -0700
commit6d2ea449ac4661e23b0ff61516d27a1b728b54b6 (patch)
tree2a707b254185041606b5dd644255962fcfd5ee02 /tensorflow/core/kernels/mkl_softmax_op.cc
parentf54671240a5429a7146cc1a01efc73db12d0aab5 (diff)
fixed clang formatting
Diffstat (limited to 'tensorflow/core/kernels/mkl_softmax_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc33
1 files changed, 15 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index fa9855103a..1a01ee37be 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -17,13 +17,13 @@ limitations under the License.
#ifdef INTEL_MKL
#ifndef INTEL_MKL_ML_ONLY
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/tensor_format.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/util/mkl_util.h"
@@ -50,40 +50,39 @@ class MklSoftmaxOp : public OpKernel {
// src_tensor now points to the 0-th input of global data struct "context"
size_t src_idx = 0;
const Tensor& src_tensor = MklGetInput(context, src_idx);
-
- // Get number of dimensions of src
const int input_dims = src_tensor.dims();
// Add: get MklShape
MklDnnShape src_mkl_shape;
GetMklShape(context, src_idx, &src_mkl_shape);
- // Set layout type based on input_dims
+ // src_dims is the dimenstion of src_tensor
+ // dim of the dst will also be same as src_dims
+ auto src_tf_shape = src_mkl_shape.IsMklTensor()
+ ? src_mkl_shape.GetTfShape()
+ : src_tensor.shape();
+ auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
+ auto output_dims = src_dims;
memory::format layout_type;
switch (input_dims) {
case 1:
layout_type = memory::format::x;
+ break;
case 2:
layout_type = memory::format::nc;
+ break;
case 3:
layout_type = memory::format::tnc;
+ break;
case 4:
layout_type = memory::format::nchw;
+ break;
case 5:
layout_type = memory::format::ncdhw;
+ break;
default:
- ctx->SetStatus(
- errors::Unimplemented(" MKL softmax does not support 1 > input_dims > 5 ");
+ OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5:"));
}
-
- // src_dims is the dimenstion of src_tensor
- // dim of the dst will also be same as src_dims
- auto src_tf_shape = src_mkl_shape.IsMklTensor()
- ? src_mkl_shape.GetTfShape()
- : src_tensor.shape();
- auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
- auto output_dims = src_dims;
-
// Create softmax memory for src, dst: both are defined in mkl_util.h,
// they are wrapper
MklDnnData<T> src(&cpu_engine);
@@ -108,9 +107,7 @@ class MklSoftmaxOp : public OpKernel {
src.SetOpMemDesc(src_dims, layout_type);
// creating a memory descriptor
-
- // axis to which softmax will be applied
- // Always softmax is applied to the last axis
+ // passing outermost dim as default axis
int axis = input_dims - 1;
auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring,
src.GetOpMemDesc(), axis);