diff options
author | 2018-08-07 10:33:44 -0700 | |
---|---|---|
committer | 2018-08-07 10:33:57 -0700 | |
commit | 90bf05c0d147a7e0c6e48720e17e51233b2bcd3c (patch) | |
tree | 08a6f23e47af61b1c0a72cf962d07ad3646bfdc1 | |
parent | b7950bce77bf74dcf2c11fb2f4bb45f6e673f82d (diff) | |
parent | 8e2f587b95cda1f67deaa4ae315c7540444919f5 (diff) |
Merge pull request #19402 from Intel-tensorflow:primreuse_batch_norm
PiperOrigin-RevId: 207737829
-rw-r--r-- | tensorflow/core/kernels/mkl_fused_batch_norm_op.cc | 908 |
1 files changed, 646 insertions, 262 deletions
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 3fe660cf96..0149e78db5 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -262,6 +262,7 @@ class MklFusedBatchNormOp : public OpKernel { } void MklCreateInputLayout(OpKernelContext* context) { + const Tensor& input = MklGetInput(context, 0); bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); if (input_in_mkl_format) { mkl_lt_input = @@ -544,6 +545,7 @@ class MklFusedBatchNormGradOp : public OpKernel { } void MklCreateInputLayout(OpKernelContext* context) { + const Tensor& input = MklGetInput(context, 0); bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); if (input_in_mkl_format) { mkl_lt_input = @@ -684,6 +686,466 @@ class MklFusedBatchNormGradOp : public OpKernel { #ifndef INTEL_MKL_ML +struct MklBatchNormFwdParams { + memory::dims src_dims; + int depth; + float eps; + bool training; + + MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, + bool training) + : src_dims(src_dims), depth(depth), eps(eps), training(training) {} +}; + +template <typename T> +class MklFusedBatchNormFwdPrimitive : public MklPrimitive { + public: + explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams) + : cpu_engine_(engine::cpu, 0) { + context_.fwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); + if (context_.bn_fwd == nullptr) Setup(fwdParams); + } + + ~MklFusedBatchNormFwdPrimitive() {} + + // BatchNormalization forward execute + // src_data: input data buffer of src + // weights_data: input data buffer of weights + // dst_data: output data buffer of dst + // mean_data: output data buffer of means + // variance_data: output data buffer of variances + void Execute(const T* src_data, const T* weights_data, T* dst_data, + T* mean_data, T* variance_data) { + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(src_data))); + context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); + + if (context_.flags & use_scale_shift) + context_.weights_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(weights_data))); + + if ((context_.pkind == prop_kind::forward_training) || + (context_.flags & use_global_stats)) { + context_.mean_mem->set_data_handle(static_cast<void*>(mean_data)); + context_.variance_mem->set_data_handle(static_cast<void*>(variance_data)); + } + + // execution + context_.fwd_stream->submit(context_.fwd_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + + if (context_.flags & use_scale_shift) + context_.weights_mem->set_data_handle(DummyData); + + if ((context_.pkind == prop_kind::forward_training) || + (context_.flags & use_global_stats)) { + context_.mean_mem->set_data_handle(DummyData); + context_.variance_mem->set_data_handle(DummyData); + } + } + + memory::primitive_desc GetDstPd() const { + return (*context_.dst_mem).get_primitive_desc(); + } + + mkldnn_memory_format_t GetSrcFmt() const { + return (*context_.src_mem).get_primitive_desc().desc().data.format; + } + + mkldnn_memory_format_t GetDstFmt() const { + return (*context_.dst_mem).get_primitive_desc().desc().data.format; + } + + private: + // Primitive reuse context for BatchNorm fwd op + struct BatchNormFwdContext { + // flags indict if it is training or inference mode + int64 flags; + + // algorithm + mkldnn::prop_kind pkind; + + // Mkldnn Memory + std::shared_ptr<mkldnn::memory> src_mem; + std::shared_ptr<mkldnn::memory> weights_mem; + std::shared_ptr<mkldnn::memory> dst_mem; + std::shared_ptr<mkldnn::memory> mean_mem; + std::shared_ptr<mkldnn::memory> variance_mem; + + // BatchNorm forward primitive + std::shared_ptr<mkldnn::primitive> bn_fwd; + std::shared_ptr<mkldnn::stream> fwd_stream; + std::vector<mkldnn::primitive> fwd_primitives; + + BatchNormFwdContext() + : flags(0), + pkind(mkldnn::forward_training), + src_mem(nullptr), + weights_mem(nullptr), + dst_mem(nullptr), + mean_mem(nullptr), + variance_mem(nullptr), + bn_fwd(nullptr), + fwd_stream(nullptr) {} + }; + + void Setup(const MklBatchNormFwdParams& fwdParams) { + context_.flags = fwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats); + context_.pkind = fwdParams.training ? prop_kind::forward_training + : prop_kind::forward_scoring; + + // memory desc + auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>(), + get_desired_format(fwdParams.src_dims[1])); + + // fwd desc & primitive desc + auto fwd_desc = batch_normalization_forward::desc( + context_.pkind, src_md, fwdParams.eps, context_.flags); + auto fwd_pd = + batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_); + + // memory primitive + context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + context_.dst_mem.reset(new memory(fwd_pd.dst_primitive_desc(), DummyData)); + + if (context_.flags & use_scale_shift) { + auto weights_desc = memory::desc({2, fwdParams.depth}, MklDnnType<T>(), + memory::format::nc); + context_.weights_mem.reset( + new memory({weights_desc, cpu_engine_}, DummyData)); + } + + if (fwdParams.training || (context_.flags & use_global_stats)) { + auto mean_desc = memory::desc({1, fwdParams.depth}, MklDnnType<T>(), + memory::format::nc); + context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + + auto variance_desc = + memory::desc({1, fwdParams.depth}, MklDnnType<T>(), memory::nc); + context_.variance_mem.reset( + new memory({variance_desc, cpu_engine_}, DummyData)); + } + + // BatchNorm forward primitive + if (!fwdParams.training && !(context_.flags & use_global_stats)) { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.weights_mem, + *context_.dst_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.dst_mem)); + } + } else if (context_.flags & use_global_stats) { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + (const primitive::at)*context_.variance_mem, *context_.weights_mem, + *context_.dst_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + (const primitive::at)*context_.variance_mem, *context_.dst_mem)); + } + } else { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.weights_mem, *context_.dst_mem, + *context_.mean_mem, *context_.variance_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.dst_mem, *context_.mean_mem, + *context_.variance_mem)); + } + } + + context_.fwd_primitives.push_back(*context_.bn_fwd); + } + + mkldnn::memory::desc get_desc_data(const mkldnn::memory& m) const { + return m.get_primitive_desc().desc().data; + } + + struct BatchNormFwdContext context_; + engine cpu_engine_; +}; + +template <typename T> +class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> { + public: + static MklFusedBatchNormFwdPrimitive<T>* Get( + const MklBatchNormFwdParams& fwdParams) { + auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T>*>( + MklFusedBatchNormFwdPrimitiveFactory<T>::GetInstance().GetBatchNormFwd( + fwdParams)); + + if (bn_fwd == nullptr) { + bn_fwd = new MklFusedBatchNormFwdPrimitive<T>(fwdParams); + MklFusedBatchNormFwdPrimitiveFactory<T>::GetInstance().SetBatchNormFwd( + fwdParams, bn_fwd); + } + return bn_fwd; + } + + static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() { + static MklFusedBatchNormFwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklFusedBatchNormFwdPrimitiveFactory() {} + ~MklFusedBatchNormFwdPrimitiveFactory() {} + + static std::string CreateKey(const MklBatchNormFwdParams& fwdParams) { + std::string prefix = "bn_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey<int>(fwdParams.depth); + key_creator.AddAsKey<float>(fwdParams.eps); + key_creator.AddAsKey<bool>(fwdParams.training); + return key_creator.GetKey(); + } + + MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) { + std::string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams, + MklPrimitive* op) { + std::string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +}; + +struct MklBatchNormBwdParams { + memory::dims src_dims; + memory::dims diff_dst_dims; + int depth; + float eps; + bool training; + + MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, + int depth, float eps, bool training) + : src_dims(src_dims), + diff_dst_dims(diff_dst_dims), + depth(depth), + eps(eps), + training(training) {} +}; + +template <typename T> +class MklFusedBatchNormBwdPrimitive : public MklPrimitive { + public: + explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) + : cpu_engine_(engine::cpu, 0) { + context_.bwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); + if (context_.bn_bwd == nullptr) Setup(bwdParams); + } + + ~MklFusedBatchNormBwdPrimitive() {} + + // BatchNormalization backward execute + // src_data: input data buffer of src + // mean_data: input data buffer of mean + // variance_data: input data buffer of variance + // diff_dst_data: input data buffer of diff_dst + // weights_data: input data buffer of weights + // diff_src_data: output data buffer of diff_src + // diff_weights_data: output data buffer of diff_weights + void Execute(const T* src_data, const T* mean_data, const T* variance_data, + const T* diff_dst_data, const T* weights_data, T* diff_src_data, + T* diff_weights_data) { + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(src_data))); + context_.mean_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(mean_data))); + context_.variance_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(variance_data))); + context_.diff_dst_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(diff_dst_data))); + + if (context_.flags & use_scale_shift) { + context_.weights_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(weights_data))); + context_.diff_weights_mem->set_data_handle( + static_cast<void*>(diff_weights_data)); + } + + context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); + + // execution + context_.bwd_stream->submit(context_.bwd_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.mean_mem->set_data_handle(DummyData); + context_.variance_mem->set_data_handle(DummyData); + context_.diff_dst_mem->set_data_handle(DummyData); + if (context_.flags & use_scale_shift) { + context_.weights_mem->set_data_handle(DummyData); + context_.diff_weights_mem->set_data_handle(DummyData); + } + context_.diff_src_mem->set_data_handle(DummyData); + } + + mkldnn_memory_format_t GetSrcFmt() { + return (*context_.src_mem).get_primitive_desc().desc().data.format; + } + + mkldnn_memory_format_t GetDiffDstFmt() { + return (*context_.diff_dst_mem).get_primitive_desc().desc().data.format; + } + + memory::primitive_desc GetDiffSrcPd() { + return (*context_.diff_src_mem).get_primitive_desc(); + } + + private: + struct BatchNormBwdContext { + // Flags to indicate whether it is training or inference + int64 flags; + + // MKLDNN memory + std::shared_ptr<mkldnn::memory> src_mem; + std::shared_ptr<mkldnn::memory> mean_mem; + std::shared_ptr<mkldnn::memory> variance_mem; + std::shared_ptr<mkldnn::memory> diff_dst_mem; + std::shared_ptr<mkldnn::memory> weights_mem; + std::shared_ptr<mkldnn::memory> diff_weights_mem; + std::shared_ptr<mkldnn::memory> diff_src_mem; + + // Batch Norm primitive + std::shared_ptr<mkldnn::primitive> bn_bwd; + std::vector<mkldnn::primitive> bwd_primitives; + std::shared_ptr<mkldnn::stream> bwd_stream; + + BatchNormBwdContext() + : src_mem(nullptr), + mean_mem(nullptr), + variance_mem(nullptr), + diff_dst_mem(nullptr), + weights_mem(nullptr), + diff_weights_mem(nullptr), + diff_src_mem(nullptr), + bwd_stream(nullptr) {} + }; + + void Setup(const MklBatchNormBwdParams& bwdParams) { + context_.flags = bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats); + + // memory desc + auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(), + get_desired_format(bwdParams.src_dims[1])); + auto diff_dst_md = + memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>(), + get_desired_format(bwdParams.diff_dst_dims[1])); + auto variance_desc = + memory::desc({1, bwdParams.depth}, MklDnnType<T>(), memory::nc); + auto mean_desc = + memory::desc({1, bwdParams.depth}, MklDnnType<T>(), memory::format::nc); + auto weights_desc = + memory::desc({2, bwdParams.depth}, MklDnnType<T>(), memory::format::nc); + auto diff_weights_desc = weights_desc; + + // fwd desc & primitive desc + auto fwd_desc = batch_normalization_forward::desc( + prop_kind::forward_training, src_md, bwdParams.eps, + bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats)); + auto fwd_pd = + batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_); + + // BatchNorm backward primtive + // + // For inference, specify use_global_stats + // 1. on fwd propagation, use mean and variance provided as inputs. + // 2. on bwd propagation, mean and variance are considered as constants. + // Thus, reduce the amount of MKL computation. + auto bwd_desc = batch_normalization_backward::desc( + prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, + bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats)); + auto bn_bwd_pd = batch_normalization_backward::primitive_desc( + bwd_desc, cpu_engine_, fwd_pd); + + // memory primitive + context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + context_.diff_dst_mem.reset( + new memory({diff_dst_md, cpu_engine_}, DummyData)); + context_.variance_mem.reset( + new memory({variance_desc, cpu_engine_}, DummyData)); + context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + context_.weights_mem.reset( + new memory({weights_desc, cpu_engine_}, DummyData)); + context_.diff_weights_mem.reset( + new memory({diff_weights_desc, cpu_engine_}, DummyData)); + context_.diff_src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + + context_.bn_bwd.reset(new batch_normalization_backward( + bn_bwd_pd, *context_.src_mem, *context_.mean_mem, + *context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem, + *context_.diff_src_mem, *context_.diff_weights_mem)); + context_.bwd_primitives.push_back(*context_.bn_bwd); + } + + struct BatchNormBwdContext context_; + engine cpu_engine_; +}; + +template <typename T> +class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> { + public: + static MklFusedBatchNormBwdPrimitive<T>* Get( + const MklBatchNormBwdParams& bwdParams) { + auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T>*>( + MklFusedBatchNormBwdPrimitiveFactory<T>::GetInstance().GetBatchNormBwd( + bwdParams)); + if (bn_bwd == nullptr) { + bn_bwd = new MklFusedBatchNormBwdPrimitive<T>(bwdParams); + MklFusedBatchNormBwdPrimitiveFactory<T>::GetInstance().SetBatchNormBwd( + bwdParams, bn_bwd); + } + return bn_bwd; + } + + static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() { + static MklFusedBatchNormBwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklFusedBatchNormBwdPrimitiveFactory() {} + ~MklFusedBatchNormBwdPrimitiveFactory() {} + + static std::string CreateKey(const MklBatchNormBwdParams& bwdParams) { + std::string prefix = "bn_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(bwdParams.diff_dst_dims); + key_creator.AddAsKey<int>(bwdParams.depth); + key_creator.AddAsKey<float>(bwdParams.eps); + key_creator.AddAsKey<bool>(bwdParams.training); + return key_creator.GetKey(); + } + + MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) { + std::string key = CreateKey(bwdParams); + return this->GetOp(key); + } + + void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams, + MklPrimitive* op) { + std::string key = CreateKey(bwdParams); + this->SetOp(key, op); + } +}; + template <typename Device, typename T> class MklFusedBatchNormOp : public OpKernel { public: @@ -701,7 +1163,6 @@ class MklFusedBatchNormOp : public OpKernel { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const size_t kSrcIndex = 0; // index of src input tensor const size_t kScaleIndex = 1; // index of scale tensor const size_t kShiftIndex = 2; // index of shift tensor @@ -786,7 +1247,7 @@ class MklFusedBatchNormOp : public OpKernel { SetMeanVariance(est_mean_tensor, est_variance_tensor); MklDnnData<T> src(&cpu_engine); - MklDnnData<T> dst(&cpu_engine); + MklDnnData<T> weights(&cpu_engine); memory::format format_m; if (dnn_shape_src.IsMklTensor()) { @@ -800,123 +1261,102 @@ class MklFusedBatchNormOp : public OpKernel { } // set src primitive - memory::dims src_dims; - if (dnn_shape_src.IsMklTensor()) { - src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), - tensor_format_); - } else { - src_dims = - TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); - } + memory::dims src_dims = + dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); auto src_md = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetMklLayout() : memory::desc(src_dims, MklDnnType<T>(), format_m); - src.SetUsrMem(src_md, &src_tensor); - // set weights primitive // MKL-DNN packs scale & shift as "weights": // <scale>...<scale><shift>...<shift> - auto weights_desc = memory::desc({2, static_cast<int>(depth_)}, - MklDnnType<T>(), memory::format::nc); - auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); - auto weights_m = memory(weights_pd); - T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle()); - T* scale_tf = - reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data())); - T* shift_tf = - reinterpret_cast<T*>(const_cast<T*>(shift_tensor.flat<T>().data())); + weights.AllocateBuffer(2 * depth_ * sizeof(T)); + T* weights_data = reinterpret_cast<T*>(weights.GetAllocatedBuffer()); + const T* scale_tf = scale_tensor.flat<T>().data(); + const T* shift_tf = shift_tensor.flat<T>().data(); - for (int k = 0; k < depth_; k++) { - weights_data[k] = scale_tf[k]; - weights_data[k + depth_] = shift_tf[k]; - } - - // set mean primitive - auto mean_desc = memory::desc({1, static_cast<int>(depth_)}, - MklDnnType<T>(), memory::format::nc); - auto mean_pd = memory::primitive_desc(mean_desc, cpu_engine); + std::memcpy(weights_data, scale_tf, depth_ * sizeof(T)); + std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(T)); char* saved_mean_data_tf = reinterpret_cast<char*>(saved_mean_tensor->flat<T>().data()); std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_), depth_ * sizeof(T)); - auto mean_m = - memory(mean_pd, reinterpret_cast<void*>(saved_mean_data_tf)); - // set variance primitive - auto variance_desc = memory::desc({1, static_cast<int>(depth_)}, - MklDnnType<T>(), memory::format::nc); - auto variance_pd = memory::primitive_desc(variance_desc, cpu_engine); char* saved_variance_data_tf = reinterpret_cast<char*>(saved_variance_tensor->flat<T>().data()); std::memcpy(saved_variance_data_tf, reinterpret_cast<char*>(variance_values_), depth_ * sizeof(T)); - auto variance_m = memory(variance_pd, saved_variance_data_tf); - - prop_kind pk = (is_training_) ? prop_kind::forward_training - : prop_kind::forward_scoring; - auto bnrm_fwd_desc = batch_normalization_forward::desc( - pk, src.GetUsrMemDesc(), epsilon_, - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, cpu_engine); - - // allocate dst tensor + + // get batchnorm op from the pool + MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_); + MklFusedBatchNormFwdPrimitive<T>* bn_fwd = + MklFusedBatchNormFwdPrimitiveFactory<T>::Get(fwdParams); + + // check if reorder is needed for src, weights, mean, variance + const T* src_data = src_tensor.flat<T>().data(); + if (src_md.data.format != bn_fwd->GetSrcFmt()) { + src.SetUsrMem(src_md, &src_tensor); + auto src_target = memory::primitive_desc( + {{src_dims}, + MklDnnType<T>(), + static_cast<memory::format>(bn_fwd->GetSrcFmt())}, + cpu_engine); + src.CheckReorderToOpMem(src_target); + src_data = const_cast<T*>( + reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); + } + + // allocate output (dst) tensor; always set it as MKL-DNN layout MklDnnShape dnn_shape_dst; TensorShape tf_shape_dst; - if (dnn_shape_src.IsMklTensor()) { - dnn_shape_dst.SetMklTensor(true); - auto dst_pd = bnrm_fwd_pd.dst_primitive_desc(); - dnn_shape_dst.SetMklLayout(&dst_pd); - dnn_shape_dst.SetElemType(MklDnnType<T>()); - dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, - format_m); - tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); - } else { - dnn_shape_dst.SetMklTensor(false); - tf_shape_dst = src_tensor.shape(); - } + dnn_shape_dst.SetMklTensor(true); + auto dst_pd = bn_fwd->GetDstPd(); + dnn_shape_dst.SetMklLayout(&dst_pd); + dnn_shape_dst.SetElemType(MklDnnType<T>()); + auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension() + : src_tensor.shape().dims(); + dnn_shape_dst.SetTfLayout(ndims, src_dims, format_m); + tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, dnn_shape_dst); - // Output of batchnorm has same shape as input. - dst.SetUsrMem(src_md, dst_tensor); + T* weights_op_data = weights_data; + T* mean_op_data = saved_mean_tensor->flat<T>().data(); + T* variance_op_data = saved_variance_tensor->flat<T>().data(); + T* dst_data = dst_tensor->flat<T>().data(); - primitive bnrm_fwd_op; - if (is_training_) { - bnrm_fwd_op = - batch_normalization_forward(bnrm_fwd_pd, src.GetOpMem(), weights_m, - dst.GetOpMem(), mean_m, variance_m); - } else { - bnrm_fwd_op = batch_normalization_forward( - bnrm_fwd_pd, src.GetOpMem(), mean_m, variance_m, - (const primitive::at)weights_m, dst.GetOpMem()); - } - std::vector<primitive> net; - net.push_back(bnrm_fwd_op); - stream(stream::kind::eager).submit(net).wait(); + // execution + bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data, + variance_op_data); // copy batch_mean data - T* batch_mean_data_tf = - reinterpret_cast<T*>(batch_mean_tensor->flat<T>().data()); + T* batch_mean_data_tf = batch_mean_tensor->flat<T>().data(); std::memcpy(reinterpret_cast<char*>(batch_mean_data_tf), - reinterpret_cast<char*>(mean_m.get_data_handle()), + reinterpret_cast<char*>(saved_mean_data_tf), depth_ * sizeof(T)); + // TODO(yli135): OpMem is same as usr mem since + // since its format is hard-coded as nc when primitive is created. // copy batch_variance data with Bessel's correction - // if training mode is on float adjust_factor = 1.0; if (is_training_) { size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; size_t adjust_size = orig_size - 1; adjust_factor = (static_cast<float>(orig_size)) / adjust_size; } - for (int k = 0; k < depth_; k++) - batch_variance_tensor->flat<T>().data()[k] = - (reinterpret_cast<T*>(variance_m.get_data_handle()))[k] * - adjust_factor; + + auto variance_data = reinterpret_cast<T*>(saved_variance_data_tf); + auto batch_variance_data = batch_variance_tensor->flat<T>().data(); + if (is_training_) { + for (int k = 0; k < depth_; k++) { + batch_variance_data[k] = variance_data[k] * adjust_factor; + } + } else { + std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(T)); + } } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -933,7 +1373,8 @@ class MklFusedBatchNormOp : public OpKernel { bool is_training_; T* mean_values_; T* variance_values_; - int depth_; // batch normalization is done for per channel. + size_t depth_; // batch normalization is done for per channel. + engine cpu_engine = engine(engine::cpu, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -990,8 +1431,9 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_batch_mean); CHECK_NOTNULL(*batch_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*batch_mean_tensor)->flat<T>().data()[k] = NAN; + int num_elements = tf_shape_scale.num_elements(); + auto batch_mean_data = (*batch_mean_tensor)->flat<T>().data(); + std::fill_n(batch_mean_data, num_elements, NAN); // allocate batch variance output tensor MklDnnShape mkl_shape_batch_variance; @@ -1001,8 +1443,8 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_batch_variance); CHECK_NOTNULL(*batch_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*batch_variance_tensor)->flat<T>().data()[k] = NAN; + auto batch_variance_data = (*batch_variance_tensor)->flat<T>().data(); + std::fill_n(batch_variance_data, num_elements, NAN); // Mean and variance (without Bessel's correction) saved for backward // computation to serve as pre-computed mean and variance. @@ -1012,8 +1454,8 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_saved_mean); CHECK_NOTNULL(*saved_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*saved_mean_tensor)->flat<T>().data()[k] = NAN; + auto saved_mean_data = (*saved_mean_tensor)->flat<T>().data(); + std::fill_n(saved_mean_data, num_elements, NAN); MklDnnShape mkl_shape_saved_variance; mkl_shape_saved_variance.SetMklTensor(false); @@ -1022,8 +1464,8 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_saved_variance); CHECK_NOTNULL(*saved_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*saved_variance_tensor)->flat<T>().data()[k] = NAN; + auto saved_variance_data = (*saved_variance_tensor)->flat<T>().data(); + std::fill_n(saved_variance_data, num_elements, NAN); } }; @@ -1044,12 +1486,12 @@ class MklFusedBatchNormGradOp : public OpKernel { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const size_t kDiffDstIndex = 0; // index of diff_dst tensor const size_t kSrcIndex = 1; // index of src input tensor const size_t kScaleIndex = 2; // index of scale tensor const size_t kMeanIndex = 3; // index of saved_mean tensor const size_t kVarianceIndex = 4; // index of saved_variance tensor + const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); const Tensor& src_tensor = MklGetInput(context, kSrcIndex); const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); @@ -1060,8 +1502,8 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnShape dnn_shape_src, dnn_shape_diff_dst; GetMklShape(context, kSrcIndex, &dnn_shape_src); GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst); - TensorShape tf_shape_src, tf_shape_diff_dst; + TensorShape tf_shape_src, tf_shape_diff_dst; if (dnn_shape_diff_dst.IsMklTensor()) { tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); OP_REQUIRES( @@ -1102,6 +1544,7 @@ class MklFusedBatchNormGradOp : public OpKernel { saved_variance_tensor.shape().DebugString())); Tensor* diff_src_tensor = nullptr; + // special case: input with 0 element and 0 batch size if (tf_shape_src.num_elements() == 0 || tf_shape_diff_dst.num_elements() == 0) { HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), @@ -1117,189 +1560,127 @@ class MklFusedBatchNormGradOp : public OpKernel { ExtractParams(context); } - MklDnnData<T> src(&cpu_engine); - MklDnnData<T> mean(&cpu_engine); - MklDnnData<T> variance(&cpu_engine); - MklDnnData<T> diff_dst(&cpu_engine); - MklDnnData<T> diff_src(&cpu_engine); - - memory::dims src_dims, diff_dst_dims; - if (dnn_shape_src.IsMklTensor()) - src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), - tensor_format_); - else - src_dims = - TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); - - if (dnn_shape_diff_dst.IsMklTensor()) - diff_dst_dims = TFShapeToMklDnnDimsInNCHW( - dnn_shape_diff_dst.GetTfShape(), tensor_format_); - else - diff_dst_dims = - TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); - - // set src and diff_dst primitives according to input layout - memory::desc src_md({}, memory::data_undef, memory::format_undef); - memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); + memory::format format_m; if (dnn_shape_src.IsMklTensor()) { - src_md = dnn_shape_src.GetMklLayout(); - } else { - src_md = memory::desc(src_dims, MklDnnType<T>(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); - } - if (dnn_shape_diff_dst.IsMklTensor()) { - diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); + if (dnn_shape_src.IsTensorInNCHWFormat()) + format_m = memory::format::nchw; + else + format_m = memory::format::nhwc; } else { - diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); + format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); } - src.SetUsrMem(src_md, &src_tensor); - diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - - // weights -- DNN packs scales/shifts as weights in order of - // scale, ..., scale, shift, ..., shift - auto weights_desc = - memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc); - auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); - auto weights_m = memory(weights_pd); - T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle()); - T* scale_tf = - reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data())); + + MklDnnData<T> src(&cpu_engine); + MklDnnData<T> diff_dst(&cpu_engine); + MklDnnData<T> weights(&cpu_engine); + MklDnnData<T> diff_weights(&cpu_engine); + + memory::dims src_dims = + dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); + memory::dims diff_dst_dims = + dnn_shape_diff_dst.IsMklTensor() + ? dnn_shape_diff_dst.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), + tensor_format_); + + // set src and diff_dst primitive descriptors + memory::desc src_md = + dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetMklLayout() + : memory::desc(src_dims, MklDnnType<T>(), format_m); + memory::desc diff_dst_md = + dnn_shape_diff_dst.IsMklTensor() + ? dnn_shape_diff_dst.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType<T>(), format_m); + + // weights -- MKL DNN packs scales/ shifts as weights in order + // of scale, ..., scale, shift, ...., shift + weights.AllocateBuffer(2 * depth_ * sizeof(T)); + T* weights_data_tf = reinterpret_cast<T*>(weights.GetAllocatedBuffer()); + const T* scale_tf = scale_tensor.flat<T>().data(); for (int k = 0; k < depth_; k++) { - weights_data[k] = scale_tf[k]; - weights_data[k + depth_] = 0; + weights_data_tf[k] = scale_tf[k]; + weights_data_tf[k + depth_] = 0; } - // set mean primitive - memory::dims mv_dims = GetMeanVarianceDims(); - mean.SetUsrMem(mv_dims, memory::format::nc, - const_cast<void*>(static_cast<const void*>( - saved_mean_tensor.flat<T>().data()))); - mean.SetOpMemDesc(mv_dims, memory::format::nc); - - // set variance primitive - variance.SetUsrMem(mv_dims, memory::format::nc, - const_cast<void*>(static_cast<const void*>( - saved_variance_tensor.flat<T>().data()))); - variance.SetOpMemDesc(mv_dims, memory::format::nc); - - // set diff_weight primitive - auto diff_weights_desc = - memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc); - auto diff_weights_pd = - memory::primitive_desc(diff_weights_desc, cpu_engine); - auto diff_weights_m = memory(diff_weights_pd); - - auto bnrm_fwd_desc = batch_normalization_forward::desc( - prop_kind::forward_training, src.GetUsrMemDesc(), epsilon_, - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, cpu_engine); + diff_weights.AllocateBuffer(2 * depth_ * sizeof(T)); + + MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, + is_training_); + MklFusedBatchNormBwdPrimitive<T>* bn_bwd = + MklFusedBatchNormBwdPrimitiveFactory<T>::Get(bwdParams); + + // check if src/diff_dst need to be reordered + const T* src_data = src_tensor.flat<T>().data(); + if (src_md.data.format != bn_bwd->GetSrcFmt()) { + src.SetUsrMem(src_md, &src_tensor); + auto src_target = memory::primitive_desc( + {{src_dims}, + MklDnnType<T>(), + static_cast<memory::format>(bn_bwd->GetSrcFmt())}, + cpu_engine); + src.CheckReorderToOpMem(src_target); + src_data = const_cast<T*>( + reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); + } + + const T* diff_dst_data = diff_dst_tensor.flat<T>().data(); + if (diff_dst_md.data.format != bn_bwd->GetDiffDstFmt()) { + diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + auto diff_dst_target = memory::primitive_desc( + {{diff_dst_dims}, + MklDnnType<T>(), + static_cast<memory::format>(bn_bwd->GetDiffDstFmt())}, + cpu_engine); + diff_dst.CheckReorderToOpMem(diff_dst_target); + diff_dst_data = const_cast<T*>( + reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle())); + } // Indices of output tensors const size_t kDiffSrcIndex = 0; // index of diff_src tensor - // allocate diff_src tensor + // allocate output tensor: diff_src, always set as MKL-DNN layout MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; - - // MKL-DNN's BN primitive not provide API to fetch internal format - // set common_md as OpMem - // src and diff_dst will reorder to common_md - // diff_src will set as common_md - memory::desc common_md({}, memory::data_undef, memory::format_undef); - if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { - if (dnn_shape_src.IsMklTensor()) { - common_md = dnn_shape_src.GetMklLayout(); - } else { - common_md = dnn_shape_diff_dst.GetMklLayout(); - } - } else { - common_md = memory::desc(src_dims, MklDnnType<T>(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); - } - // if any of src and diff_dst as mkl layout, - // then we set diff_src as mkl layout - if (dnn_shape_src.IsMklTensor() || - dnn_shape_diff_dst.IsMklTensor()) { - dnn_shape_diff_src.SetMklTensor(true); - // set diff_src's mkl layout as common_md - auto diff_src_pd = memory::primitive_desc(common_md, cpu_engine); - dnn_shape_diff_src.SetMklLayout(&diff_src_pd); - dnn_shape_diff_src.SetElemType(MklDnnType<T>()); - if (dnn_shape_src.IsMklTensor()) { - dnn_shape_diff_src.SetTfLayout( - dnn_shape_src.GetDimension(), - src_dims, - dnn_shape_src.GetTfDataFormat()); - dnn_shape_diff_src.SetTfDimOrder( - dnn_shape_src.GetDimension(), - tensor_format_); - } else { - dnn_shape_diff_src.SetTfLayout( - dnn_shape_diff_dst.GetDimension(), - src_dims, - dnn_shape_diff_dst.GetTfDataFormat()); - dnn_shape_diff_src.SetTfDimOrder( - dnn_shape_diff_dst.GetDimension(), - tensor_format_); - } - tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); - } else { - dnn_shape_diff_src.SetMklTensor(false); - // both src and diff_dst are TensorFlow layout, - // so it is OK to get TensorFlow shape. - tf_shape_diff_src = src_tensor.shape(); - } + dnn_shape_diff_src.SetMklTensor(true); + auto diff_src_pd = bn_bwd->GetDiffSrcPd(); + dnn_shape_diff_src.SetMklLayout(&diff_src_pd); + dnn_shape_diff_src.SetElemType(MklDnnType<T>()); + dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, format_m); + dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_); + tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, tf_shape_diff_src, dnn_shape_diff_src); - // set diff_src - diff_src.SetUsrMem(common_md, diff_src_tensor); - - prop_kind pk = prop_kind::backward; - auto bnrm_bwd_desc = batch_normalization_backward::desc( - pk, common_md, common_md, epsilon_, - /* for inference, specify use_global_stats - 1. on fwd prop, use mean and variance - provided as inputs - 2. on bwd prop, mean and variance are - considered as constants. Thus, - reduce the amout of MKL computations - */ - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc( - bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd); - - std::vector<primitive> net; - src.CheckReorderToOpMem(memory::primitive_desc(common_md, - cpu_engine), &net); - diff_dst.CheckReorderToOpMem(memory::primitive_desc(common_md, - cpu_engine), &net); - - auto bnrm_bwd_op = batch_normalization_backward( - bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(), - diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m); - - net.push_back(bnrm_bwd_op); - stream(stream::kind::eager).submit(net).wait(); - - // allocate 4 output TF tensors + T* mean_data = + static_cast<T*>(const_cast<T*>(saved_mean_tensor.flat<T>().data())); + T* variance_data = static_cast<T*>( + const_cast<T*>(saved_variance_tensor.flat<T>().data())); + T* weights_data = weights_data_tf; + T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data()); + T* diff_weights_data = static_cast<T*>(diff_weights.GetAllocatedBuffer()); + // Execute + bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, + weights_data, diff_src_data, diff_weights_data); + + // allocate output TF tensors: diff_scale and diff_shift Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor, &diff_shift_tensor); // copy data: diff_scale and diff_shift - T* diff_weights_data_dnn = - reinterpret_cast<T*>(diff_weights_m.get_data_handle()); - for (int i = 0; i < depth_; i++) { - diff_scale_tensor->flat<T>().data()[i] = diff_weights_data_dnn[i]; - diff_shift_tensor->flat<T>().data()[i] = - diff_weights_data_dnn[i + depth_]; - } + auto diff_scale_data = diff_scale_tensor->flat<T>().data(); + auto diff_shift_data = diff_shift_tensor->flat<T>().data(); + std::memcpy(reinterpret_cast<char*>(diff_scale_data), + reinterpret_cast<char*>(diff_weights_data), + depth_ * sizeof(T)); + std::memcpy(reinterpret_cast<char*>(diff_shift_data), + reinterpret_cast<char*>(diff_weights_data + depth_), + depth_ * sizeof(T)); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -1315,6 +1696,7 @@ class MklFusedBatchNormGradOp : public OpKernel { TensorFormat tensor_format_; int depth_; // batch normalization is done for per channel. bool is_training_; + engine cpu_engine = engine(engine::cpu, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -1330,8 +1712,8 @@ class MklFusedBatchNormGradOp : public OpKernel { dnn_shape_diff_src.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, tf_shape_src, dnn_shape_diff_src); - for (size_t i = 0; i < (*diff_src_tensor)->shape().num_elements(); i++) - (*diff_src_tensor)->flat<T>().data()[i] = 0; + auto diff_src_data = (*diff_src_tensor)->flat<T>().data(); + std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), 0); Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; @@ -1357,16 +1739,18 @@ class MklFusedBatchNormGradOp : public OpKernel { AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, tf_shape_scale_shift, mkl_shape_diff_scale); CHECK_NOTNULL(*diff_scale_tensor); - for (size_t i = 0; i < (*diff_scale_tensor)->shape().num_elements(); i++) - (*diff_scale_tensor)->flat<T>().data()[i] = 0; + auto diff_scale_data = (*diff_scale_tensor)->flat<T>().data(); + std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(), + 0); MklDnnShape mkl_shape_diff_shift; mkl_shape_diff_shift.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, tf_shape_scale_shift, mkl_shape_diff_shift); CHECK_NOTNULL(*diff_shift_tensor); - for (size_t i = 0; i < (*diff_shift_tensor)->shape().num_elements(); i++) - (*diff_shift_tensor)->flat<T>().data()[i] = 0; + auto diff_shift_data = (*diff_shift_tensor)->flat<T>().data(); + std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(), + 0); // Placeholders for estimated_mean and estimated_variance, which are // used for inference and thus not needed here for gradient computation. |