aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_fused_batch_norm_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_fused_batch_norm_op.cc360
1 files changed, 245 insertions, 115 deletions
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index a761562a4b..8340a91d05 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -703,27 +703,31 @@ class MklFusedBatchNormOp : public OpKernel {
void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
- const size_t src_index = 0; // index of src input tensor
- const size_t scale_index = 1; // index of scale tensor
- const size_t shift_index = 2; // index of shift tensor
- const size_t mean_index = 3; // index of est_mean tensor
- const size_t var_index = 4; // index of est_variance tensor
-
- const Tensor& src_tensor = MklGetInput(context, src_index);
- const Tensor& scale_tensor = MklGetInput(context, scale_index);
- const Tensor& shift_tensor = MklGetInput(context, shift_index);
- const Tensor& est_mean_tensor = MklGetInput(context, mean_index);
- const Tensor& est_variance_tensor = MklGetInput(context, var_index);
-
+ const size_t kSrcIndex = 0; // index of src input tensor
+ const size_t kScaleIndex = 1; // index of scale tensor
+ const size_t kShiftIndex = 2; // index of shift tensor
+ const size_t kMeanIndex = 3; // index of est_mean tensor
+ const size_t kVarianceIndex = 4; // index of est_variance tensor
+
+ const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
+ const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
+ const Tensor& shift_tensor = MklGetInput(context, kShiftIndex);
+ const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex);
+ const Tensor& est_variance_tensor = MklGetInput(context,
+ kVarianceIndex);
+
+ TensorShape tf_shape_src;
MklDnnShape dnn_shape_src;
- GetMklShape(context, src_index, &dnn_shape_src);
+ GetMklShape(context, kSrcIndex, &dnn_shape_src);
if (dnn_shape_src.IsMklTensor()) {
+ tf_shape_src = dnn_shape_src.GetTfShape();
OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
errors::InvalidArgument(
"input must be 4-dimensional",
src_tensor.shape().DebugString()));
} else {
+ tf_shape_src = src_tensor.shape();
OP_REQUIRES(context, src_tensor.dims() == 4,
errors::InvalidArgument(
"input must be 4-dimensional",
@@ -756,39 +760,35 @@ class MklFusedBatchNormOp : public OpKernel {
est_variance_tensor.shape().DebugString()));
}
+ // special case: input with 0 element and 0 batch size
+ Tensor* dst_tensor = nullptr;
+ if (tf_shape_src.num_elements() == 0) {
+ HandleEmptyInput(context,
+ tf_shape_src,
+ scale_tensor.shape(),
+ &dst_tensor);
+ return;
+ }
+
if (dnn_shape_src.IsMklTensor())
depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
else
ExtractParams(context);
// Indices of output tensors
- const size_t dst_index = 0;
- const size_t batch_mean_index = 1;
- const size_t batch_variance_index = 2;
- const size_t saved_mean_index = 3;
- const size_t saved_variance_index = 4;
+ const size_t kDstIndex = 0;
- // allocate batch mean output tensor
+ // allocate 4 output TF tensors
Tensor* batch_mean_tensor = nullptr;
- MklDnnShape mkl_shape_batch_mean;
- mkl_shape_batch_mean.SetMklTensor(false);
- AllocateOutputSetMklShape(context,
- batch_mean_index,
- &batch_mean_tensor,
- scale_tensor.shape(),
- mkl_shape_batch_mean);
- CHECK_NOTNULL(batch_mean_tensor);
-
- // Batch variance
Tensor* batch_variance_tensor = nullptr;
- MklDnnShape mkl_shape_batch_variance;
- mkl_shape_batch_variance.SetMklTensor(false);
- AllocateOutputSetMklShape(context,
- batch_variance_index,
- &batch_variance_tensor,
- scale_tensor.shape(),
- mkl_shape_batch_variance);
- CHECK_NOTNULL(batch_variance_tensor);
+ Tensor* saved_mean_tensor = nullptr;
+ Tensor* saved_variance_tensor = nullptr;
+ AllocateTFOutputs(context,
+ scale_tensor.shape(),
+ &batch_mean_tensor,
+ &batch_variance_tensor,
+ &saved_mean_tensor,
+ &saved_variance_tensor);
if (is_training_)
SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
@@ -844,26 +844,6 @@ class MklFusedBatchNormOp : public OpKernel {
weights_data[k + depth_] = shift_tf[k];
}
- // Mean and variance (without Bessel's correction) saved for backward
- // computation to serve as pre-computed mean and variance.
- Tensor* saved_mean_tensor = nullptr;
- MklDnnShape mkl_shape_saved_mean;
- mkl_shape_saved_mean.SetMklTensor(false);
- AllocateOutputSetMklShape(context, saved_mean_index,
- &saved_mean_tensor,
- scale_tensor.shape(),
- mkl_shape_saved_mean);
- CHECK_NOTNULL(saved_mean_tensor);
-
- Tensor* saved_variance_tensor = nullptr;
- MklDnnShape mkl_shape_saved_variance;
- mkl_shape_saved_variance.SetMklTensor(false);
- AllocateOutputSetMklShape(context, saved_variance_index,
- &saved_variance_tensor,
- scale_tensor.shape(),
- mkl_shape_saved_variance);
- CHECK_NOTNULL(saved_variance_tensor);
-
// set mean primitive
auto mean_desc = memory::desc({1, depth_},
MklDnnType<T>(),
@@ -902,7 +882,6 @@ class MklFusedBatchNormOp : public OpKernel {
// allocate dst tensor
MklDnnShape dnn_shape_dst;
TensorShape tf_shape_dst;
- Tensor* dst_tensor = nullptr;
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_dst.SetMklTensor(true);
auto dst_pd = bnrm_fwd_pd.dst_primitive_desc();
@@ -915,7 +894,7 @@ class MklFusedBatchNormOp : public OpKernel {
dnn_shape_dst.SetMklTensor(false);
tf_shape_dst = src_tensor.shape();
}
- AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
+ AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor,
tf_shape_dst, dnn_shape_dst);
// Output of batchnorm has same shape as input.
@@ -958,10 +937,8 @@ class MklFusedBatchNormOp : public OpKernel {
size_t adjust_size = orig_size - 1;
adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
}
- T* batch_variance_data_tf = reinterpret_cast<T*>(
- batch_variance_tensor->flat<T>().data());
for (int k=0; k < depth_; k++)
- batch_variance_data_tf[k] =
+ batch_variance_tensor->flat<T>().data()[k] =
(reinterpret_cast<T*>(variance_m.get_data_handle()))[k]
* adjust_factor;
} catch (mkldnn::error &e) {
@@ -994,8 +971,100 @@ class MklFusedBatchNormOp : public OpKernel {
variance_values_ = reinterpret_cast<T*>(
const_cast<T*>(variance.flat<T>().data()));
}
-};
+ void HandleEmptyInput(OpKernelContext* context,
+ TensorShape tf_shape_src,
+ TensorShape tf_shape_scale,
+ Tensor** dst_tensor) {
+ CHECK_NOTNULL(dst_tensor);
+
+ const size_t kDstIndex = 0;
+ MklDnnShape dnn_shape_dst;
+ dnn_shape_dst.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, kDstIndex, dst_tensor,
+ tf_shape_src, dnn_shape_dst);
+ CHECK_NOTNULL(*dst_tensor);
+ memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0,
+ (*dst_tensor)->tensor_data().size());
+
+ Tensor* batch_mean_tensor = nullptr;
+ Tensor* batch_variance_tensor = nullptr;
+ Tensor* saved_mean_tensor = nullptr;
+ Tensor* saved_variance_tensor = nullptr;
+ AllocateTFOutputs(context, tf_shape_scale,
+ &batch_mean_tensor,
+ &batch_variance_tensor,
+ &saved_mean_tensor,
+ &saved_variance_tensor);
+ }
+
+ void AllocateTFOutputs(OpKernelContext* context,
+ TensorShape tf_shape_scale,
+ Tensor** batch_mean_tensor,
+ Tensor** batch_variance_tensor,
+ Tensor** saved_mean_tensor,
+ Tensor** saved_variance_tensor) {
+ CHECK_NOTNULL(batch_mean_tensor);
+ CHECK_NOTNULL(batch_variance_tensor);
+ CHECK_NOTNULL(saved_mean_tensor);
+ CHECK_NOTNULL(saved_variance_tensor);
+
+ const size_t kBatchMeanIndex = 1;
+ const size_t kBatchVarianceIndex = 2;
+ const size_t kSavedMeanIndex = 3;
+ const size_t kSavedVarianceIndex = 4;
+
+ // allocate batch mean output tensor
+ MklDnnShape mkl_shape_batch_mean;
+ mkl_shape_batch_mean.SetMklTensor(false);
+ AllocateOutputSetMklShape(context,
+ kBatchMeanIndex,
+ batch_mean_tensor,
+ tf_shape_scale,
+ mkl_shape_batch_mean);
+ CHECK_NOTNULL(*batch_mean_tensor);
+ // set NAN mean value in case of empty input tensor
+ for (int k=0; k < tf_shape_scale.num_elements(); k++)
+ (*batch_mean_tensor)->flat<T>().data()[k] = NAN;
+
+ // allocate batch variance output tensor
+ MklDnnShape mkl_shape_batch_variance;
+ mkl_shape_batch_variance.SetMklTensor(false);
+ AllocateOutputSetMklShape(context,
+ kBatchVarianceIndex,
+ batch_variance_tensor,
+ tf_shape_scale,
+ mkl_shape_batch_variance);
+ CHECK_NOTNULL(*batch_variance_tensor);
+ // set NAN variance value in case of empty input tensor
+ for (int k=0; k < tf_shape_scale.num_elements(); k++)
+ (*batch_variance_tensor)->flat<T>().data()[k] = NAN;
+
+ // Mean and variance (without Bessel's correction) saved for backward
+ // computation to serve as pre-computed mean and variance.
+ MklDnnShape mkl_shape_saved_mean;
+ mkl_shape_saved_mean.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, kSavedMeanIndex,
+ saved_mean_tensor,
+ tf_shape_scale,
+ mkl_shape_saved_mean);
+ CHECK_NOTNULL(*saved_mean_tensor);
+ // set NAN mean value in case of empty input tensor
+ for (int k=0; k < tf_shape_scale.num_elements(); k++)
+ (*saved_mean_tensor)->flat<T>().data()[k] = NAN;
+
+ MklDnnShape mkl_shape_saved_variance;
+ mkl_shape_saved_variance.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, kSavedVarianceIndex,
+ saved_variance_tensor,
+ tf_shape_scale,
+ mkl_shape_saved_variance);
+ CHECK_NOTNULL(*saved_variance_tensor);
+ // set NAN variance value in case of empty input tensor
+ for (int k=0; k < tf_shape_scale.num_elements(); k++)
+ (*saved_variance_tensor)->flat<T>().data()[k] = NAN;
+ }
+};
template <typename Device, typename T>
class MklFusedBatchNormGradOp : public OpKernel {
@@ -1009,34 +1078,37 @@ class MklFusedBatchNormGradOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
}
void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
-
- const size_t diff_dst_index = 0; // index of diff_dst tensor
- const size_t src_index = 1; // index of src input tensor
- const size_t scale_index = 2; // index of scale tensor
- const size_t mean_index = 3; // index of saved_mean tensor
- const size_t variance_index = 4; // index of saved_variance tensor
- const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
- const Tensor& src_tensor = MklGetInput(context, src_index);
- const Tensor& scale_tensor = MklGetInput(context, scale_index);
- const Tensor& saved_mean_tensor = MklGetInput(context, mean_index);
+ const size_t kDiffDstIndex = 0; // index of diff_dst tensor
+ const size_t kSrcIndex = 1; // index of src input tensor
+ const size_t kScaleIndex = 2; // index of scale tensor
+ const size_t kMeanIndex = 3; // index of saved_mean tensor
+ const size_t kVarianceIndex = 4; // index of saved_variance tensor
+ const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex);
+ const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
+ const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
+ const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex);
const Tensor& saved_variance_tensor = MklGetInput(context,
- variance_index);
+ kVarianceIndex);
MklDnnShape dnn_shape_src, dnn_shape_diff_dst;
- GetMklShape(context, src_index, &dnn_shape_src);
- GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
+ GetMklShape(context, kSrcIndex, &dnn_shape_src);
+ GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst);
+ TensorShape tf_shape_src, tf_shape_diff_dst;
if (dnn_shape_diff_dst.IsMklTensor()) {
+ tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape();
OP_REQUIRES(context, dnn_shape_diff_dst.GetDimension() == 4,
errors::InvalidArgument(
"input must be 4-dimensional",
diff_dst_tensor.shape().DebugString()));
} else {
+ tf_shape_diff_dst = diff_dst_tensor.shape();
OP_REQUIRES(context, diff_dst_tensor.dims() == 4,
errors::InvalidArgument(
"input must be 4-dimensional",
@@ -1044,11 +1116,13 @@ class MklFusedBatchNormGradOp : public OpKernel {
}
if (dnn_shape_src.IsMklTensor()) {
+ tf_shape_src = dnn_shape_src.GetTfShape();
OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
errors::InvalidArgument(
"input must be 4-dimensional",
src_tensor.shape().DebugString()));
} else {
+ tf_shape_src = src_tensor.shape();
OP_REQUIRES(context, src_tensor.dims() == 4,
errors::InvalidArgument(
"input must be 4-dimensional",
@@ -1069,6 +1143,15 @@ class MklFusedBatchNormGradOp : public OpKernel {
"saved variance must be 1-dimensional",
saved_variance_tensor.shape().DebugString()));
+ Tensor* diff_src_tensor = nullptr;
+ if (tf_shape_src.num_elements() == 0 ||
+ tf_shape_diff_dst.num_elements() == 0) {
+ HandleEmptyInput(context, tf_shape_src,
+ scale_tensor.shape(),
+ &diff_src_tensor);
+ return;
+ }
+
if (dnn_shape_src.IsMklTensor())
depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
else
@@ -1165,25 +1248,21 @@ class MklFusedBatchNormGradOp : public OpKernel {
auto diff_weights_m = memory(diff_weights_pd);
auto bnrm_fwd_desc = batch_normalization_forward::desc(
- prop_kind::forward_training,
- src.GetUsrMemDesc(),
- epsilon_,
- use_scale_shift);
+ prop_kind::forward_training,
+ src.GetUsrMemDesc(),
+ epsilon_,
+ is_training_ ? use_scale_shift :
+ (use_scale_shift | use_global_stats));
auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc(
bnrm_fwd_desc,
cpu_engine);
// Indices of output tensors
- const size_t diff_src_index = 0; // index of diff_src tensor
- const size_t diff_scale_index = 1; // index of diff_scale tensor
- const size_t diff_shift_index = 2; // index of diff_shift tensor
- const size_t p1_index = 3; // index of 1st placeholder tensor
- const size_t p2_index = 4; // index of 2nd placeholder tensor
+ const size_t kDiffSrcIndex = 0; // index of diff_src tensor
// allocate diff_src tensor
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
- Tensor* diff_src_tensor = nullptr;
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_diff_src.SetMklTensor(true);
auto diff_src_pd = bnrm_fwd_pd.dst_primitive_desc();
@@ -1201,7 +1280,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
dnn_shape_diff_src.SetMklTensor(false);
tf_shape_diff_src = src_tensor.shape();
}
- AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
+ AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor,
tf_shape_diff_src, dnn_shape_diff_src);
diff_src.SetUsrMem(src_md, diff_src_tensor);
@@ -1212,7 +1291,15 @@ class MklFusedBatchNormGradOp : public OpKernel {
diff_src.GetUsrMemDesc(),
src.GetUsrMemDesc(),
epsilon_,
- use_scale_shift);
+ /* for inference, specify use_global_stats
+ 1. on fwd prop, use mean and variance
+ provided as inputs
+ 2. on bwd prop, mean and variance are
+ considered as constants. Thus,
+ reduce the amout of MKL computations
+ */
+ is_training_ ? use_scale_shift :
+ (use_scale_shift | use_global_stats));
auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc(
bnrm_bwd_desc,
cpu_engine,
@@ -1232,41 +1319,22 @@ class MklFusedBatchNormGradOp : public OpKernel {
net.push_back(bnrm_bwd_op);
stream(stream::kind::eager).submit(net).wait();
- // separate out scale and shift grad and copy to individual tensors
- const TensorShape& tf_shape_scale_shift = scale_tensor.shape();
+ // allocate 4 output TF tensors
Tensor* diff_scale_tensor = nullptr;
- MklDnnShape mkl_shape_diff_scale;
- mkl_shape_diff_scale.SetMklTensor(false);
- AllocateOutputSetMklShape(context, diff_scale_index, &diff_scale_tensor,
- tf_shape_scale_shift, mkl_shape_diff_scale);
-
Tensor* diff_shift_tensor = nullptr;
- MklDnnShape mkl_shape_diff_shift;
- mkl_shape_diff_shift.SetMklTensor(false);
- AllocateOutputSetMklShape(context, diff_shift_index, &diff_shift_tensor,
- tf_shape_scale_shift, mkl_shape_diff_shift);
+ AllocateTFOutputs(context, scale_tensor.shape(),
+ &diff_scale_tensor,
+ &diff_shift_tensor);
// copy data: diff_scale and diff_shift
T* diff_weights_data_dnn = reinterpret_cast<T*>
(diff_weights_m.get_data_handle());
- float* diff_scale_data_tf = const_cast<float*>(
- static_cast<const float*>(diff_scale_tensor->flat<T>().data()));
- float* diff_shift_data_tf = const_cast<float*>(
- static_cast<const float*>(diff_shift_tensor->flat<T>().data()));
for (int i = 0; i < depth_; i++) {
- diff_scale_data_tf[i] = diff_weights_data_dnn[i];
- diff_shift_data_tf[i] = diff_weights_data_dnn[i + depth_];
+ diff_scale_tensor->flat<T>().data()[i] =
+ diff_weights_data_dnn[i];
+ diff_shift_tensor->flat<T>().data()[i] =
+ diff_weights_data_dnn[i + depth_];
}
-
- // Placeholders for estimated_mean and estimated_variance, which are
- // used for inference and thus not needed here for gradient computation.
- Tensor* p1_tensor = nullptr, *p2_tensor = nullptr;
- MklDnnShape mkl_shape_p;
- mkl_shape_p.SetMklTensor(false);
- AllocateOutputSetMklShape(context, p1_index, &p1_tensor,
- TensorShape({}), mkl_shape_p);
- AllocateOutputSetMklShape(context, p2_index, &p2_tensor,
- TensorShape({}), mkl_shape_p);
} catch (mkldnn::error &e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) +
@@ -1282,12 +1350,74 @@ class MklFusedBatchNormGradOp : public OpKernel {
T epsilon_;
TensorFormat tensor_format_;
int depth_; // batch normalization is done for per channel.
+ bool is_training_;
void ExtractParams(OpKernelContext* context) {
const Tensor& input = MklGetInput(context, 0);
depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
}
+ void HandleEmptyInput(OpKernelContext* context,
+ TensorShape tf_shape_src,
+ TensorShape tf_shape_scale_shift,
+ Tensor** diff_src_tensor) {
+ const size_t kDiffSrcIndex = 0;
+
+ MklDnnShape dnn_shape_diff_src;
+ dnn_shape_diff_src.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor,
+ tf_shape_src, dnn_shape_diff_src);
+ for (size_t i=0; i < (*diff_src_tensor)->shape().num_elements(); i++)
+ (*diff_src_tensor)->flat<T>().data()[i] = 0;
+
+ Tensor* diff_scale_tensor = nullptr;
+ Tensor* diff_shift_tensor = nullptr;
+ AllocateTFOutputs(context,
+ tf_shape_scale_shift,
+ &diff_scale_tensor,
+ &diff_shift_tensor);
+ }
+
+ void AllocateTFOutputs(OpKernelContext* context,
+ TensorShape tf_shape_scale_shift,
+ Tensor** diff_scale_tensor,
+ Tensor** diff_shift_tensor) {
+ CHECK_NOTNULL(diff_scale_tensor);
+ CHECK_NOTNULL(diff_shift_tensor);
+
+ const size_t kDiffScaleIndex = 1;
+ const size_t kDiffShiftIndex = 2;
+ const size_t kP1Index = 3;
+ const size_t kP2Index = 4;
+
+ // separate out scale and shift grad and copy to individual tensors
+ MklDnnShape mkl_shape_diff_scale;
+ mkl_shape_diff_scale.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor,
+ tf_shape_scale_shift, mkl_shape_diff_scale);
+ CHECK_NOTNULL(*diff_scale_tensor);
+ for (size_t i=0; i < (*diff_scale_tensor)->shape().num_elements(); i++)
+ (*diff_scale_tensor)->flat<T>().data()[i] = 0;
+
+ MklDnnShape mkl_shape_diff_shift;
+ mkl_shape_diff_shift.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor,
+ tf_shape_scale_shift, mkl_shape_diff_shift);
+ CHECK_NOTNULL(*diff_shift_tensor);
+ for (size_t i=0; i < (*diff_shift_tensor)->shape().num_elements(); i++)
+ (*diff_shift_tensor)->flat<T>().data()[i] = 0;
+
+ // Placeholders for estimated_mean and estimated_variance, which are
+ // used for inference and thus not needed here for gradient computation.
+ Tensor* p1_tensor = nullptr, *p2_tensor = nullptr;
+ MklDnnShape mkl_shape_p;
+ mkl_shape_p.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, kP1Index, &p1_tensor,
+ TensorShape({}), mkl_shape_p);
+ AllocateOutputSetMklShape(context, kP2Index, &p2_tensor,
+ TensorShape({}), mkl_shape_p);
+ }
+
memory::dims GetMeanVarianceDims() {
return memory::dims({1, depth_});
}