diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_fused_batch_norm_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_fused_batch_norm_op.cc | 360 |
1 files changed, 245 insertions, 115 deletions
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index a761562a4b..8340a91d05 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -703,27 +703,31 @@ class MklFusedBatchNormOp : public OpKernel { void Compute(OpKernelContext* context) override { try { auto cpu_engine = engine(engine::cpu, 0); - const size_t src_index = 0; // index of src input tensor - const size_t scale_index = 1; // index of scale tensor - const size_t shift_index = 2; // index of shift tensor - const size_t mean_index = 3; // index of est_mean tensor - const size_t var_index = 4; // index of est_variance tensor - - const Tensor& src_tensor = MklGetInput(context, src_index); - const Tensor& scale_tensor = MklGetInput(context, scale_index); - const Tensor& shift_tensor = MklGetInput(context, shift_index); - const Tensor& est_mean_tensor = MklGetInput(context, mean_index); - const Tensor& est_variance_tensor = MklGetInput(context, var_index); - + const size_t kSrcIndex = 0; // index of src input tensor + const size_t kScaleIndex = 1; // index of scale tensor + const size_t kShiftIndex = 2; // index of shift tensor + const size_t kMeanIndex = 3; // index of est_mean tensor + const size_t kVarianceIndex = 4; // index of est_variance tensor + + const Tensor& src_tensor = MklGetInput(context, kSrcIndex); + const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); + const Tensor& shift_tensor = MklGetInput(context, kShiftIndex); + const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex); + const Tensor& est_variance_tensor = MklGetInput(context, + kVarianceIndex); + + TensorShape tf_shape_src; MklDnnShape dnn_shape_src; - GetMklShape(context, src_index, &dnn_shape_src); + GetMklShape(context, kSrcIndex, &dnn_shape_src); if (dnn_shape_src.IsMklTensor()) { + tf_shape_src = dnn_shape_src.GetTfShape(); OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, errors::InvalidArgument( "input must be 4-dimensional", src_tensor.shape().DebugString())); } else { + tf_shape_src = src_tensor.shape(); OP_REQUIRES(context, src_tensor.dims() == 4, errors::InvalidArgument( "input must be 4-dimensional", @@ -756,39 +760,35 @@ class MklFusedBatchNormOp : public OpKernel { est_variance_tensor.shape().DebugString())); } + // special case: input with 0 element and 0 batch size + Tensor* dst_tensor = nullptr; + if (tf_shape_src.num_elements() == 0) { + HandleEmptyInput(context, + tf_shape_src, + scale_tensor.shape(), + &dst_tensor); + return; + } + if (dnn_shape_src.IsMklTensor()) depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); else ExtractParams(context); // Indices of output tensors - const size_t dst_index = 0; - const size_t batch_mean_index = 1; - const size_t batch_variance_index = 2; - const size_t saved_mean_index = 3; - const size_t saved_variance_index = 4; + const size_t kDstIndex = 0; - // allocate batch mean output tensor + // allocate 4 output TF tensors Tensor* batch_mean_tensor = nullptr; - MklDnnShape mkl_shape_batch_mean; - mkl_shape_batch_mean.SetMklTensor(false); - AllocateOutputSetMklShape(context, - batch_mean_index, - &batch_mean_tensor, - scale_tensor.shape(), - mkl_shape_batch_mean); - CHECK_NOTNULL(batch_mean_tensor); - - // Batch variance Tensor* batch_variance_tensor = nullptr; - MklDnnShape mkl_shape_batch_variance; - mkl_shape_batch_variance.SetMklTensor(false); - AllocateOutputSetMklShape(context, - batch_variance_index, - &batch_variance_tensor, - scale_tensor.shape(), - mkl_shape_batch_variance); - CHECK_NOTNULL(batch_variance_tensor); + Tensor* saved_mean_tensor = nullptr; + Tensor* saved_variance_tensor = nullptr; + AllocateTFOutputs(context, + scale_tensor.shape(), + &batch_mean_tensor, + &batch_variance_tensor, + &saved_mean_tensor, + &saved_variance_tensor); if (is_training_) SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor); @@ -844,26 +844,6 @@ class MklFusedBatchNormOp : public OpKernel { weights_data[k + depth_] = shift_tf[k]; } - // Mean and variance (without Bessel's correction) saved for backward - // computation to serve as pre-computed mean and variance. - Tensor* saved_mean_tensor = nullptr; - MklDnnShape mkl_shape_saved_mean; - mkl_shape_saved_mean.SetMklTensor(false); - AllocateOutputSetMklShape(context, saved_mean_index, - &saved_mean_tensor, - scale_tensor.shape(), - mkl_shape_saved_mean); - CHECK_NOTNULL(saved_mean_tensor); - - Tensor* saved_variance_tensor = nullptr; - MklDnnShape mkl_shape_saved_variance; - mkl_shape_saved_variance.SetMklTensor(false); - AllocateOutputSetMklShape(context, saved_variance_index, - &saved_variance_tensor, - scale_tensor.shape(), - mkl_shape_saved_variance); - CHECK_NOTNULL(saved_variance_tensor); - // set mean primitive auto mean_desc = memory::desc({1, depth_}, MklDnnType<T>(), @@ -902,7 +882,6 @@ class MklFusedBatchNormOp : public OpKernel { // allocate dst tensor MklDnnShape dnn_shape_dst; TensorShape tf_shape_dst; - Tensor* dst_tensor = nullptr; if (dnn_shape_src.IsMklTensor()) { dnn_shape_dst.SetMklTensor(true); auto dst_pd = bnrm_fwd_pd.dst_primitive_desc(); @@ -915,7 +894,7 @@ class MklFusedBatchNormOp : public OpKernel { dnn_shape_dst.SetMklTensor(false); tf_shape_dst = src_tensor.shape(); } - AllocateOutputSetMklShape(context, dst_index, &dst_tensor, + AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, dnn_shape_dst); // Output of batchnorm has same shape as input. @@ -958,10 +937,8 @@ class MklFusedBatchNormOp : public OpKernel { size_t adjust_size = orig_size - 1; adjust_factor = (static_cast<float>(orig_size)) / adjust_size; } - T* batch_variance_data_tf = reinterpret_cast<T*>( - batch_variance_tensor->flat<T>().data()); for (int k=0; k < depth_; k++) - batch_variance_data_tf[k] = + batch_variance_tensor->flat<T>().data()[k] = (reinterpret_cast<T*>(variance_m.get_data_handle()))[k] * adjust_factor; } catch (mkldnn::error &e) { @@ -994,8 +971,100 @@ class MklFusedBatchNormOp : public OpKernel { variance_values_ = reinterpret_cast<T*>( const_cast<T*>(variance.flat<T>().data())); } -}; + void HandleEmptyInput(OpKernelContext* context, + TensorShape tf_shape_src, + TensorShape tf_shape_scale, + Tensor** dst_tensor) { + CHECK_NOTNULL(dst_tensor); + + const size_t kDstIndex = 0; + MklDnnShape dnn_shape_dst; + dnn_shape_dst.SetMklTensor(false); + AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, + tf_shape_src, dnn_shape_dst); + CHECK_NOTNULL(*dst_tensor); + memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0, + (*dst_tensor)->tensor_data().size()); + + Tensor* batch_mean_tensor = nullptr; + Tensor* batch_variance_tensor = nullptr; + Tensor* saved_mean_tensor = nullptr; + Tensor* saved_variance_tensor = nullptr; + AllocateTFOutputs(context, tf_shape_scale, + &batch_mean_tensor, + &batch_variance_tensor, + &saved_mean_tensor, + &saved_variance_tensor); + } + + void AllocateTFOutputs(OpKernelContext* context, + TensorShape tf_shape_scale, + Tensor** batch_mean_tensor, + Tensor** batch_variance_tensor, + Tensor** saved_mean_tensor, + Tensor** saved_variance_tensor) { + CHECK_NOTNULL(batch_mean_tensor); + CHECK_NOTNULL(batch_variance_tensor); + CHECK_NOTNULL(saved_mean_tensor); + CHECK_NOTNULL(saved_variance_tensor); + + const size_t kBatchMeanIndex = 1; + const size_t kBatchVarianceIndex = 2; + const size_t kSavedMeanIndex = 3; + const size_t kSavedVarianceIndex = 4; + + // allocate batch mean output tensor + MklDnnShape mkl_shape_batch_mean; + mkl_shape_batch_mean.SetMklTensor(false); + AllocateOutputSetMklShape(context, + kBatchMeanIndex, + batch_mean_tensor, + tf_shape_scale, + mkl_shape_batch_mean); + CHECK_NOTNULL(*batch_mean_tensor); + // set NAN mean value in case of empty input tensor + for (int k=0; k < tf_shape_scale.num_elements(); k++) + (*batch_mean_tensor)->flat<T>().data()[k] = NAN; + + // allocate batch variance output tensor + MklDnnShape mkl_shape_batch_variance; + mkl_shape_batch_variance.SetMklTensor(false); + AllocateOutputSetMklShape(context, + kBatchVarianceIndex, + batch_variance_tensor, + tf_shape_scale, + mkl_shape_batch_variance); + CHECK_NOTNULL(*batch_variance_tensor); + // set NAN variance value in case of empty input tensor + for (int k=0; k < tf_shape_scale.num_elements(); k++) + (*batch_variance_tensor)->flat<T>().data()[k] = NAN; + + // Mean and variance (without Bessel's correction) saved for backward + // computation to serve as pre-computed mean and variance. + MklDnnShape mkl_shape_saved_mean; + mkl_shape_saved_mean.SetMklTensor(false); + AllocateOutputSetMklShape(context, kSavedMeanIndex, + saved_mean_tensor, + tf_shape_scale, + mkl_shape_saved_mean); + CHECK_NOTNULL(*saved_mean_tensor); + // set NAN mean value in case of empty input tensor + for (int k=0; k < tf_shape_scale.num_elements(); k++) + (*saved_mean_tensor)->flat<T>().data()[k] = NAN; + + MklDnnShape mkl_shape_saved_variance; + mkl_shape_saved_variance.SetMklTensor(false); + AllocateOutputSetMklShape(context, kSavedVarianceIndex, + saved_variance_tensor, + tf_shape_scale, + mkl_shape_saved_variance); + CHECK_NOTNULL(*saved_variance_tensor); + // set NAN variance value in case of empty input tensor + for (int k=0; k < tf_shape_scale.num_elements(); k++) + (*saved_variance_tensor)->flat<T>().data()[k] = NAN; + } +}; template <typename Device, typename T> class MklFusedBatchNormGradOp : public OpKernel { @@ -1009,34 +1078,37 @@ class MklFusedBatchNormGradOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); } void Compute(OpKernelContext* context) override { try { auto cpu_engine = engine(engine::cpu, 0); - - const size_t diff_dst_index = 0; // index of diff_dst tensor - const size_t src_index = 1; // index of src input tensor - const size_t scale_index = 2; // index of scale tensor - const size_t mean_index = 3; // index of saved_mean tensor - const size_t variance_index = 4; // index of saved_variance tensor - const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index); - const Tensor& src_tensor = MklGetInput(context, src_index); - const Tensor& scale_tensor = MklGetInput(context, scale_index); - const Tensor& saved_mean_tensor = MklGetInput(context, mean_index); + const size_t kDiffDstIndex = 0; // index of diff_dst tensor + const size_t kSrcIndex = 1; // index of src input tensor + const size_t kScaleIndex = 2; // index of scale tensor + const size_t kMeanIndex = 3; // index of saved_mean tensor + const size_t kVarianceIndex = 4; // index of saved_variance tensor + const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); + const Tensor& src_tensor = MklGetInput(context, kSrcIndex); + const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); + const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex); const Tensor& saved_variance_tensor = MklGetInput(context, - variance_index); + kVarianceIndex); MklDnnShape dnn_shape_src, dnn_shape_diff_dst; - GetMklShape(context, src_index, &dnn_shape_src); - GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst); + GetMklShape(context, kSrcIndex, &dnn_shape_src); + GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst); + TensorShape tf_shape_src, tf_shape_diff_dst; if (dnn_shape_diff_dst.IsMklTensor()) { + tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); OP_REQUIRES(context, dnn_shape_diff_dst.GetDimension() == 4, errors::InvalidArgument( "input must be 4-dimensional", diff_dst_tensor.shape().DebugString())); } else { + tf_shape_diff_dst = diff_dst_tensor.shape(); OP_REQUIRES(context, diff_dst_tensor.dims() == 4, errors::InvalidArgument( "input must be 4-dimensional", @@ -1044,11 +1116,13 @@ class MklFusedBatchNormGradOp : public OpKernel { } if (dnn_shape_src.IsMklTensor()) { + tf_shape_src = dnn_shape_src.GetTfShape(); OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, errors::InvalidArgument( "input must be 4-dimensional", src_tensor.shape().DebugString())); } else { + tf_shape_src = src_tensor.shape(); OP_REQUIRES(context, src_tensor.dims() == 4, errors::InvalidArgument( "input must be 4-dimensional", @@ -1069,6 +1143,15 @@ class MklFusedBatchNormGradOp : public OpKernel { "saved variance must be 1-dimensional", saved_variance_tensor.shape().DebugString())); + Tensor* diff_src_tensor = nullptr; + if (tf_shape_src.num_elements() == 0 || + tf_shape_diff_dst.num_elements() == 0) { + HandleEmptyInput(context, tf_shape_src, + scale_tensor.shape(), + &diff_src_tensor); + return; + } + if (dnn_shape_src.IsMklTensor()) depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); else @@ -1165,25 +1248,21 @@ class MklFusedBatchNormGradOp : public OpKernel { auto diff_weights_m = memory(diff_weights_pd); auto bnrm_fwd_desc = batch_normalization_forward::desc( - prop_kind::forward_training, - src.GetUsrMemDesc(), - epsilon_, - use_scale_shift); + prop_kind::forward_training, + src.GetUsrMemDesc(), + epsilon_, + is_training_ ? use_scale_shift : + (use_scale_shift | use_global_stats)); auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( bnrm_fwd_desc, cpu_engine); // Indices of output tensors - const size_t diff_src_index = 0; // index of diff_src tensor - const size_t diff_scale_index = 1; // index of diff_scale tensor - const size_t diff_shift_index = 2; // index of diff_shift tensor - const size_t p1_index = 3; // index of 1st placeholder tensor - const size_t p2_index = 4; // index of 2nd placeholder tensor + const size_t kDiffSrcIndex = 0; // index of diff_src tensor // allocate diff_src tensor MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; - Tensor* diff_src_tensor = nullptr; if (dnn_shape_src.IsMklTensor()) { dnn_shape_diff_src.SetMklTensor(true); auto diff_src_pd = bnrm_fwd_pd.dst_primitive_desc(); @@ -1201,7 +1280,7 @@ class MklFusedBatchNormGradOp : public OpKernel { dnn_shape_diff_src.SetMklTensor(false); tf_shape_diff_src = src_tensor.shape(); } - AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, + AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, tf_shape_diff_src, dnn_shape_diff_src); diff_src.SetUsrMem(src_md, diff_src_tensor); @@ -1212,7 +1291,15 @@ class MklFusedBatchNormGradOp : public OpKernel { diff_src.GetUsrMemDesc(), src.GetUsrMemDesc(), epsilon_, - use_scale_shift); + /* for inference, specify use_global_stats + 1. on fwd prop, use mean and variance + provided as inputs + 2. on bwd prop, mean and variance are + considered as constants. Thus, + reduce the amout of MKL computations + */ + is_training_ ? use_scale_shift : + (use_scale_shift | use_global_stats)); auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc( bnrm_bwd_desc, cpu_engine, @@ -1232,41 +1319,22 @@ class MklFusedBatchNormGradOp : public OpKernel { net.push_back(bnrm_bwd_op); stream(stream::kind::eager).submit(net).wait(); - // separate out scale and shift grad and copy to individual tensors - const TensorShape& tf_shape_scale_shift = scale_tensor.shape(); + // allocate 4 output TF tensors Tensor* diff_scale_tensor = nullptr; - MklDnnShape mkl_shape_diff_scale; - mkl_shape_diff_scale.SetMklTensor(false); - AllocateOutputSetMklShape(context, diff_scale_index, &diff_scale_tensor, - tf_shape_scale_shift, mkl_shape_diff_scale); - Tensor* diff_shift_tensor = nullptr; - MklDnnShape mkl_shape_diff_shift; - mkl_shape_diff_shift.SetMklTensor(false); - AllocateOutputSetMklShape(context, diff_shift_index, &diff_shift_tensor, - tf_shape_scale_shift, mkl_shape_diff_shift); + AllocateTFOutputs(context, scale_tensor.shape(), + &diff_scale_tensor, + &diff_shift_tensor); // copy data: diff_scale and diff_shift T* diff_weights_data_dnn = reinterpret_cast<T*> (diff_weights_m.get_data_handle()); - float* diff_scale_data_tf = const_cast<float*>( - static_cast<const float*>(diff_scale_tensor->flat<T>().data())); - float* diff_shift_data_tf = const_cast<float*>( - static_cast<const float*>(diff_shift_tensor->flat<T>().data())); for (int i = 0; i < depth_; i++) { - diff_scale_data_tf[i] = diff_weights_data_dnn[i]; - diff_shift_data_tf[i] = diff_weights_data_dnn[i + depth_]; + diff_scale_tensor->flat<T>().data()[i] = + diff_weights_data_dnn[i]; + diff_shift_tensor->flat<T>().data()[i] = + diff_weights_data_dnn[i + depth_]; } - - // Placeholders for estimated_mean and estimated_variance, which are - // used for inference and thus not needed here for gradient computation. - Tensor* p1_tensor = nullptr, *p2_tensor = nullptr; - MklDnnShape mkl_shape_p; - mkl_shape_p.SetMklTensor(false); - AllocateOutputSetMklShape(context, p1_index, &p1_tensor, - TensorShape({}), mkl_shape_p); - AllocateOutputSetMklShape(context, p2_index, &p2_tensor, - TensorShape({}), mkl_shape_p); } catch (mkldnn::error &e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + @@ -1282,12 +1350,74 @@ class MklFusedBatchNormGradOp : public OpKernel { T epsilon_; TensorFormat tensor_format_; int depth_; // batch normalization is done for per channel. + bool is_training_; void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); } + void HandleEmptyInput(OpKernelContext* context, + TensorShape tf_shape_src, + TensorShape tf_shape_scale_shift, + Tensor** diff_src_tensor) { + const size_t kDiffSrcIndex = 0; + + MklDnnShape dnn_shape_diff_src; + dnn_shape_diff_src.SetMklTensor(false); + AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, + tf_shape_src, dnn_shape_diff_src); + for (size_t i=0; i < (*diff_src_tensor)->shape().num_elements(); i++) + (*diff_src_tensor)->flat<T>().data()[i] = 0; + + Tensor* diff_scale_tensor = nullptr; + Tensor* diff_shift_tensor = nullptr; + AllocateTFOutputs(context, + tf_shape_scale_shift, + &diff_scale_tensor, + &diff_shift_tensor); + } + + void AllocateTFOutputs(OpKernelContext* context, + TensorShape tf_shape_scale_shift, + Tensor** diff_scale_tensor, + Tensor** diff_shift_tensor) { + CHECK_NOTNULL(diff_scale_tensor); + CHECK_NOTNULL(diff_shift_tensor); + + const size_t kDiffScaleIndex = 1; + const size_t kDiffShiftIndex = 2; + const size_t kP1Index = 3; + const size_t kP2Index = 4; + + // separate out scale and shift grad and copy to individual tensors + MklDnnShape mkl_shape_diff_scale; + mkl_shape_diff_scale.SetMklTensor(false); + AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, + tf_shape_scale_shift, mkl_shape_diff_scale); + CHECK_NOTNULL(*diff_scale_tensor); + for (size_t i=0; i < (*diff_scale_tensor)->shape().num_elements(); i++) + (*diff_scale_tensor)->flat<T>().data()[i] = 0; + + MklDnnShape mkl_shape_diff_shift; + mkl_shape_diff_shift.SetMklTensor(false); + AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, + tf_shape_scale_shift, mkl_shape_diff_shift); + CHECK_NOTNULL(*diff_shift_tensor); + for (size_t i=0; i < (*diff_shift_tensor)->shape().num_elements(); i++) + (*diff_shift_tensor)->flat<T>().data()[i] = 0; + + // Placeholders for estimated_mean and estimated_variance, which are + // used for inference and thus not needed here for gradient computation. + Tensor* p1_tensor = nullptr, *p2_tensor = nullptr; + MklDnnShape mkl_shape_p; + mkl_shape_p.SetMklTensor(false); + AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, + TensorShape({}), mkl_shape_p); + AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, + TensorShape({}), mkl_shape_p); + } + memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } |