diff options
Diffstat (limited to 'tensorflow/core/kernels')
73 files changed, 5039 insertions, 2492 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 2cb54bd973..d142e36772 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -22,6 +22,7 @@ package_group( "//learning/brain/research/sparse_matrix/...", "//learning/faster_training/...", "//tensorflow/...", + "//third_party/car/...", ], ) @@ -124,6 +125,7 @@ tf_kernel_library( ":bounds_check", ":dense_update_functor", ":ops_util", + ":training_op_helpers", ":variable_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -781,7 +783,7 @@ tf_kernel_library( tf_kernel_library( name = "quantize_and_dequantize_op", prefix = "quantize_and_dequantize_op", - deps = ARRAY_DEPS, + deps = ARRAY_DEPS + [":cwise_op"], ) tf_kernel_library( @@ -2346,6 +2348,22 @@ tf_cuda_cc_test( ) tf_cuda_cc_test( + name = "crop_and_resize_op_benchmark_test", + srcs = ["crop_and_resize_op_benchmark_test.cc"], + deps = [ + ":image", + ":ops_testutil", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cuda_cc_test( name = "resize_benchmark_test", srcs = ["resize_op_benchmark_test.cc"], deps = [ @@ -2836,6 +2854,8 @@ tf_kernel_library( srcs = [] + if_mkl([ "mkl_batch_matmul_op.cc", ]), + # <prefix>*impl.h are excluded by default from the CPU build, add explicitly. + hdrs = ["batch_matmul_op_impl.h"], # Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true, # to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521 copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]), @@ -3772,7 +3792,7 @@ tf_kernel_library( "spacetodepth_op.h", "spacetodepth_op_gpu.cu.cc", ], - visibility = ["//visibility:private"], + visibility = [":friends"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -4869,6 +4889,7 @@ filegroup( "fill_functor.cc", "fill_functor.h", "function_ops.cc", + "function_ops.h", "gather_functor.h", "gather_nd_op.cc", "gather_nd_op.h", @@ -5350,10 +5371,6 @@ cc_library( srcs = if_android(["decode_image_op.cc"]), copts = tf_copts(), linkopts = ["-ldl"], - tags = [ - "manual", - "notap", - ], visibility = ["//visibility:public"], deps = [ "//tensorflow/core:android_gif_internal", diff --git a/tensorflow/core/kernels/as_string_op.cc b/tensorflow/core/kernels/as_string_op.cc index a7757d1361..e6d6c40f76 100644 --- a/tensorflow/core/kernels/as_string_op.cc +++ b/tensorflow/core/kernels/as_string_op.cc @@ -47,6 +47,7 @@ class AsStringOp : public OpKernel { case DT_FLOAT: case DT_DOUBLE: case DT_COMPLEX64: + case DT_COMPLEX128: break; default: OP_REQUIRES(ctx, !(scientific || shortest), @@ -83,6 +84,7 @@ class AsStringOp : public OpKernel { case DT_FLOAT: case DT_DOUBLE: case DT_COMPLEX64: + case DT_COMPLEX128: if (shortest) { strings::Appendf(&format_, "g"); } else if (scientific) { @@ -100,7 +102,7 @@ class AsStringOp : public OpKernel { DataTypeString(dtype))); } - if (dtype == DT_COMPLEX64) { + if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) { format_ = strings::Printf("(%s,%s)", format_.c_str(), format_.c_str()); } } @@ -144,6 +146,13 @@ class AsStringOp : public OpKernel { format_.c_str(), input_flat(i).real(), input_flat(i).imag()); } } break; + case (DT_COMPLEX128): { + const auto& input_flat = input_tensor->flat<complex128>(); + for (int i = 0; i < input_flat.size(); ++i) { + output_flat(i) = strings::Printf( + format_.c_str(), input_flat(i).real(), input_flat(i).imag()); + } + } break; default: bool can_encode_type = false; OP_REQUIRES(context, can_encode_type, diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc index fe259c1634..aa7a2752e8 100644 --- a/tensorflow/core/kernels/batch_matmul_op_real.cc +++ b/tensorflow/core/kernels/batch_matmul_op_real.cc @@ -31,8 +31,7 @@ TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU); #if GOOGLE_CUDA TF_CALL_float(REGISTER_BATCH_MATMUL_GPU); TF_CALL_double(REGISTER_BATCH_MATMUL_GPU); -// TODO(csigg): Implement Stream::ThenBlasGemv for Eigen::half and uncomment. -// TF_CALL_half(REGISTER_BATCH_MATMUL_GPU); +TF_CALL_half(REGISTER_BATCH_MATMUL_GPU); #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h index b77c14d012..656b6ced6d 100644 --- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h @@ -147,13 +147,21 @@ class AdaptiveSharedBatchScheduler // Tracks processing latency and adjusts in_flight_batches_limit to minimize. void CallbackWrapper(const internal::ASBSBatch<TaskType>* batch, - BatchProcessor callback); + BatchProcessor callback, bool is_express); // Schedules batch if in_flight_batches_limit_ is not met. void MaybeScheduleNextBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Schedules the earliest closed batch in batches_ + // if batch_thread_pool_ has an idle thead. + // Batches scheduled this way are called express batches. + // Express batches are not limited by in_flight_batches_limit_, and + // their latencies will not affect in_flight_batches_limit_. + void MaybeScheduleClosedBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Notifies scheduler of non-empty batch which is eligible for processing. - void AddBatch(const internal::ASBSBatch<TaskType>* batch); + void AddBatch(const internal::ASBSBatch<TaskType>* batch, + bool also_schedule_closed_batch); // Removes queue from scheduler. void RemoveQueue(const internal::ASBSQueue<TaskType>* queue); @@ -180,8 +188,10 @@ class AdaptiveSharedBatchScheduler // results in an actual cap of 3 80% of the time, and 4 20% of the time. double in_flight_batches_limit_ GUARDED_BY(mu_); - // Number of batches currently being processed. + // Number of regular batches currently being processed. int64 in_flight_batches_ GUARDED_BY(mu_) = 0; + // Number of express batches currently being processed. + int64 in_flight_express_batches_ GUARDED_BY(mu_) = 0; // RNG engine and distribution. std::default_random_engine rand_engine_; @@ -363,10 +373,14 @@ Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue( template <typename TaskType> void AdaptiveSharedBatchScheduler<TaskType>::AddBatch( - const internal::ASBSBatch<TaskType>* batch) { + const internal::ASBSBatch<TaskType>* batch, + bool also_schedule_closed_batch) { mutex_lock l(mu_); batches_.push_back(batch); MaybeScheduleNextBatch(); + if (also_schedule_closed_batch) { + MaybeScheduleClosedBatch(); + } } template <typename TaskType> @@ -407,19 +421,45 @@ void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatch() { batch->queue()->ReleaseBatch(batch); batch_thread_pool_->Schedule( std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this, - batch, queues_and_callbacks_[batch->queue()])); + batch, queues_and_callbacks_[batch->queue()], false)); in_flight_batches_++; } template <typename TaskType> +void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleClosedBatch() { + if (in_flight_batches_ + in_flight_express_batches_ >= + options_.num_batch_threads) { + return; + } + for (auto it = batches_.begin(); it != batches_.end(); it++) { + if ((*it)->IsClosed()) { + const internal::ASBSBatch<TaskType>* batch = *it; + batches_.erase(it); + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule( + std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, + this, batch, queues_and_callbacks_[batch->queue()], true)); + in_flight_express_batches_++; + return; + } + } +} + +template <typename TaskType> void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper( const internal::ASBSBatch<TaskType>* batch, - AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback) { + AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback, + bool is_express) { int64 start_time = batch->creation_time_micros(); callback(std::unique_ptr<Batch<TaskType>>( const_cast<internal::ASBSBatch<TaskType>*>(batch))); int64 end_time = GetEnv()->NowMicros(); mutex_lock l(mu_); + if (is_express) { + in_flight_express_batches_--; + MaybeScheduleClosedBatch(); + return; + } in_flight_batches_--; batch_count_++; batch_latency_sum_ += end_time - start_time; @@ -496,6 +536,7 @@ Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) { " is larger than maximum batch size ", options_.max_batch_size); } + bool is_old_batch_closed = false; { mutex_lock l(mu_); // Current batch is full, create another if allowed. @@ -505,6 +546,7 @@ Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) { return errors::Unavailable("The batch scheduling queue is full"); } current_batch_->Close(); + is_old_batch_closed = true; current_batch_ = nullptr; } if (!current_batch_) { @@ -516,7 +558,8 @@ Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) { num_enqueued_tasks_++; } // AddBatch must be called outside of lock, since it may call ReleaseBatch. - if (new_batch != nullptr) scheduler_->AddBatch(new_batch); + if (new_batch != nullptr) + scheduler_->AddBatch(new_batch, is_old_batch_closed); return Status::OK(); } diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index b4c97df38b..0478c93280 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -59,6 +59,8 @@ CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_)); + // Quantized data types use the same underlying format as their non quantized // version so we use the non quantized implementation for casting. if (external_dst_dtype_ == DT_QUINT8) { @@ -100,7 +102,7 @@ void CastOpBase::Compute(OpKernelContext* ctx) { Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out)); out->set_dtype(dst_dtype_); - work_(ctx, in, out); + work_(ctx, in, out, use_truncation_); out->set_dtype(external_dst_dtype_); } } diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h index aae1e7ff19..527ab528c9 100644 --- a/tensorflow/core/kernels/cast_op.h +++ b/tensorflow/core/kernels/cast_op.h @@ -24,8 +24,71 @@ limitations under the License. #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/types.h" +// Note that the GPU cast functor templates need to be instantiated unlike the +// CPU ones, and hence their specializations are different than that for CPUs. +#ifdef SPECIALIZE_FOR_GPUS +#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \ + template <typename Device> \ + struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \ + void operator()(const Device& d, \ + typename TTypes<OUT_TYPE>::Flat out_tensor, \ + typename TTypes<IN_OUT>::ConstFlat in_tensor, \ + bool truncate = false) { \ + if (truncate) { \ + out_tensor.device(d) = \ + in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \ + .template cast<OUT_TYPE>(); \ + } else { \ + out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ + } \ + } \ + }; \ + template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>; +#else +#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \ + template <> \ + struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \ + void operator()(const DEVICE& d, \ + typename TTypes<OUT_TYPE>::Flat out_tensor, \ + typename TTypes<IN_OUT>::ConstFlat in_tensor, \ + bool truncate = false) { \ + if (truncate) { \ + out_tensor.device(d) = \ + in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \ + .template cast<OUT_TYPE>(); \ + } else { \ + out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ + } \ + } \ + }; +#endif + +#define CAST_FUNCTORS(devname) \ + SPECIALIZE_CAST(devname, float, double) \ + SPECIALIZE_CAST(devname, float, std::complex<double>) \ + SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \ + SPECIALIZE_CAST(devname, std::complex<float>, double) \ + SPECIALIZE_CAST(devname, Eigen::half, double) \ + SPECIALIZE_CAST(devname, Eigen::half, float) \ + SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \ + SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \ + SPECIALIZE_CAST(devname, bfloat16, float) \ + template <typename OUT_TYPE, typename IN_OUT> \ + struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \ + void operator()(const devname& d, \ + typename TTypes<OUT_TYPE>::Flat out_tensor, \ + typename TTypes<IN_OUT>::ConstFlat in_tensor, \ + bool truncate = false) { \ + out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \ + } \ + }; + namespace tensorflow { +typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*, + bool trunc)> + CastFunctorType; + // Common base class of Cast kernels class CastOpBase : public OpKernel { public: @@ -38,8 +101,8 @@ class CastOpBase : public OpKernel { DataType dst_dtype_; DataType external_src_dtype_; DataType external_dst_dtype_; - std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr; - + bool use_truncation_; + CastFunctorType work_ = nullptr; Status Unimplemented(); TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase); @@ -56,6 +119,23 @@ class CpuCastOp : public CastOpBase { namespace functor { +template <typename I> +constexpr int MantissaWidth() { + return std::numeric_limits<I>::digits; +} + +template <> +constexpr int MantissaWidth<Eigen::half>() { + // Remember, there's 1 hidden bit + return 10 + 1; +} + +template <> +constexpr int MantissaWidth<bfloat16>() { + // Remember, there's 1 hidden bit + return 7 + 1; +} + template <typename Device, typename Tout, typename Tin> void Cast(const Device& d, typename TTypes<Tout>::Flat o, typename TTypes<Tin>::ConstFlat i) { @@ -65,7 +145,85 @@ void Cast(const Device& d, typename TTypes<Tout>::Flat o, template <typename Device, typename Tout, typename Tin> struct CastFunctor { void operator()(const Device& d, typename TTypes<Tout>::Flat o, - typename TTypes<Tin>::ConstFlat i); + typename TTypes<Tin>::ConstFlat i, bool truncate = false); +}; + +// Only enable LSBZeroSetterHelper for 64 and 32 bit input data types. +// Specialize for others if needed in future. +template <typename I> +typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!std::isnan(t)) { + uint64_t* p = reinterpret_cast<uint64_t*>(&t); + *p &= (0xFFFFFFFFFFFFFFFF << n); + } +} + +template <typename I> +typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!std::isnan(t)) { + uint32_t* p = reinterpret_cast<uint32_t*>(&t); + *p &= (0xFFFFFFFF << n); + } +} + +// Set n least significant bits to 0 +template <typename I, typename O> +struct LSBZeroSetter { + EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter) + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const { + constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I t = a; + LSBZeroSetterHelper(t, bits); + return t; + } +}; + +template <typename I, typename O> +struct LSBZeroSetter<std::complex<I>, std::complex<O>> { + EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter) + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()( + const std::complex<I>& a) const { + constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I re = std::real(a); + I img = std::imag(a); + LSBZeroSetterHelper(re, bits); + LSBZeroSetterHelper(img, bits); + std::complex<I> toReturn(re, img); + return toReturn; + } +}; + +template <typename I, typename O> +struct LSBZeroSetter<std::complex<I>, O> { + EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter) + // Sets the 16 LSBits of the float to 0 + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()( + const std::complex<I>& a) const { + constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I re = std::real(a); + I img = std::imag(a); + LSBZeroSetterHelper(re, bits); + LSBZeroSetterHelper(img, bits); + std::complex<I> toReturn(re, img); + return toReturn; + } }; } // end namespace functor diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc index 607e7f5efd..036996fca2 100644 --- a/tensorflow/core/kernels/cast_op_gpu.cu.cc +++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc @@ -18,22 +18,19 @@ limitations under the License. #define EIGEN_USE_GPU #include "tensorflow/core/framework/bfloat16.h" +#define SPECIALIZE_FOR_GPUS #include "tensorflow/core/kernels/cast_op.h" +#undef SPECIALIZE_FOR_GPUS namespace tensorflow { namespace functor { typedef Eigen::GpuDevice GPUDevice; -template <typename O, typename I> -struct CastFunctor<GPUDevice, O, I> { - void operator()(const GPUDevice& d, typename TTypes<O>::Flat o, - typename TTypes<I>::ConstFlat i) { - Cast<GPUDevice, O, I>(d, o, i); - } -}; +CAST_FUNCTORS(GPUDevice); #define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I> + #define DEFINE_ALL_FROM(in_type) \ DEFINE(in_type, bool); \ DEFINE(in_type, uint8); \ @@ -59,14 +56,43 @@ DEFINE_ALL_FROM(int8); DEFINE_ALL_FROM(int16); DEFINE_ALL_FROM(int32); DEFINE_ALL_FROM(int64); -DEFINE_ALL_FROM(Eigen::half); -DEFINE_ALL_FROM(float); DEFINE_ALL_FROM(double); -DEFINE_ALL_FROM(std::complex<float>); DEFINE_ALL_FROM(std::complex<double>); -DEFINE(bfloat16, float); DEFINE(float, bfloat16); +#define DEFINE_ALL_TO_FLOAT(out_type) \ + DEFINE(out_type, bool); \ + DEFINE(out_type, uint8); \ + DEFINE(out_type, uint16); \ + DEFINE(out_type, uint32); \ + DEFINE(out_type, uint64); \ + DEFINE(out_type, int8); \ + DEFINE(out_type, int16); \ + DEFINE(out_type, int32); \ + DEFINE(out_type, int64); \ + DEFINE(out_type, Eigen::half); \ + DEFINE(out_type, float); \ + DEFINE(out_type, std::complex<float>) + +#define DEFINE_ALL_TO_HALF(out_type) \ + DEFINE(out_type, bool); \ + DEFINE(out_type, uint8); \ + DEFINE(out_type, uint16); \ + DEFINE(out_type, uint32); \ + DEFINE(out_type, uint64); \ + DEFINE(out_type, int8); \ + DEFINE(out_type, int16); \ + DEFINE(out_type, int32); \ + DEFINE(out_type, int64); \ + DEFINE(out_type, Eigen::half) + +DEFINE_ALL_TO_HALF(Eigen::half); +DEFINE_ALL_TO_HALF(bfloat16); +DEFINE_ALL_TO_FLOAT(float); +DEFINE_ALL_TO_FLOAT(std::complex<float>); + +#undef DEFINE_ALL_TO_FLOAT +#undef DEFINE_ALL_TO_HALF #undef DEFINE_ALL_FROM #undef DEFINE diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h index fe821b25df..b899bac681 100644 --- a/tensorflow/core/kernels/cast_op_impl.h +++ b/tensorflow/core/kernels/cast_op_impl.h @@ -25,22 +25,10 @@ namespace tensorflow { namespace functor { -template <typename O, typename I> -struct CastFunctor<Eigen::ThreadPoolDevice, O, I> { - void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<O>::Flat o, - typename TTypes<I>::ConstFlat i) { - o.device(d) = i.template cast<O>(); - } -}; +CAST_FUNCTORS(Eigen::ThreadPoolDevice); #ifdef TENSORFLOW_USE_SYCL -template <typename O, typename I> -struct CastFunctor<Eigen::SyclDevice, O, I> { - void operator()(const Eigen::SyclDevice& d, typename TTypes<O>::Flat o, - typename TTypes<I>::ConstFlat i) { - o.device(d) = i.template cast<O>(); - } -}; +CAST_FUNCTORS(Eigen::SyclDevice); #endif // TENSORFLOW_USE_SYCL } // namespace functor @@ -68,139 +56,103 @@ struct CastFunctor<Eigen::SyclDevice, O, I> { CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ FN(arg0, arg1, bfloat16); -#define CAST_CASE(DEVICE, IN, OUT) \ - if (DataTypeToEnum<OUT>::value == dst_dtype) { \ - return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \ - functor::CastFunctor<DEVICE, OUT, IN> func; \ - func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \ - }; \ +#define CAST_CASE(DEVICE, IN, OUT) \ + if (DataTypeToEnum<OUT>::value == dst_dtype) { \ + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \ + bool truncate) { \ + functor::CastFunctor<DEVICE, OUT, IN> func; \ + func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \ + truncate); \ + }; \ } // The functions below are implemented in the cast_op_impl_*.cc files. -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromBool(DataType dst_dtype); +CastFunctorType GetCpuCastFromBool(DataType dst_dtype); + +CastFunctorType GetCpuCastFromUint8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint8(DataType dst_dtype); +CastFunctorType GetCpuCastFromUint16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint16(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint32(DataType dst_dtype); +CastFunctorType GetCpuCastFromUint32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint64(DataType dst_dtype); +CastFunctorType GetCpuCastFromUint64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt8(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt16(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt32(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt64(DataType dst_dtype); +CastFunctorType GetCpuCastFromInt64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromHalf(DataType dst_dtype); +CastFunctorType GetCpuCastFromHalf(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromFloat(DataType dst_dtype); +CastFunctorType GetCpuCastFromFloat(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromDouble(DataType dst_dtype); +CastFunctorType GetCpuCastFromDouble(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromComplex64(DataType dst_dtype); +CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromComplex128(DataType dst_dtype); +CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromBfloat(DataType dst_dtype); +CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype); #if GOOGLE_CUDA // Same, for GPU. -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromBool(DataType dst_dtype); +CastFunctorType GetGpuCastFromBool(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint8(DataType dst_dtype); +CastFunctorType GetGpuCastFromUint8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint16(DataType dst_dtype); +CastFunctorType GetGpuCastFromUint16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint32(DataType dst_dtype); +CastFunctorType GetGpuCastFromInt8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint64(DataType dst_dtype); +CastFunctorType GetGpuCastFromUint32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt8(DataType dst_dtype); +CastFunctorType GetGpuCastFromUint64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt16(DataType dst_dtype); +CastFunctorType GetGpuCastFromInt16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt32(DataType dst_dtype); +CastFunctorType GetGpuCastFromInt32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt64(DataType dst_dtype); +CastFunctorType GetGpuCastFromInt64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromHalf(DataType dst_dtype); +CastFunctorType GetGpuCastFromHalf(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromFloat(DataType dst_dtype); +CastFunctorType GetGpuCastFromFloat(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromDouble(DataType dst_dtype); +CastFunctorType GetGpuCastFromDouble(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromComplex64(DataType dst_dtype); +CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromComplex128(DataType dst_dtype); +CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromBfloat(DataType dst_dtype); +CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype); #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromBool(DataType dst_dtype); +CastFunctorType GetSyclCastFromBool(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint8(DataType dst_dtype); +CastFunctorType GetSyclCastFromUint8(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint16(DataType dst_dtype); +CastFunctorType GetSyclCastFromUint16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint32(DataType dst_dtype); +CastFunctorType GetSyclCastFromUint32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint64(DataType dst_dtype); +CastFunctorType GetSyclCastFromUint64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt16(DataType dst_dtype); +CastFunctorType GetSyclCastFromInt16(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt32(DataType dst_dtype); +CastFunctorType GetSyclCastFromInt32(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt64(DataType dst_dtype); +CastFunctorType GetSyclCastFromInt64(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromFloat(DataType dst_dtype); +CastFunctorType GetSyclCastFromFloat(DataType dst_dtype); -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromDouble(DataType dst_dtype); +CastFunctorType GetSyclCastFromDouble(DataType dst_dtype); #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_impl_bfloat.cc b/tensorflow/core/kernels/cast_op_impl_bfloat.cc index bfa7ba0d47..96aae15608 100644 --- a/tensorflow/core/kernels/cast_op_impl_bfloat.cc +++ b/tensorflow/core/kernels/cast_op_impl_bfloat.cc @@ -22,20 +22,19 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromBfloat(DataType dst_dtype) { +CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, bfloat16); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromBfloat(DataType dst_dtype) { +CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype) { if (dst_dtype == DT_FLOAT) { - return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, + bool truncate) { functor::CastFunctor<GPUDevice, float, bfloat16> func; func(ctx->eigen_device<GPUDevice>(), out->flat<float>(), - inp.flat<bfloat16>()); + inp.flat<bfloat16>(), truncate); }; } return nullptr; diff --git a/tensorflow/core/kernels/cast_op_impl_bool.cc b/tensorflow/core/kernels/cast_op_impl_bool.cc index c5c7394b43..792d4781f2 100644 --- a/tensorflow/core/kernels/cast_op_impl_bool.cc +++ b/tensorflow/core/kernels/cast_op_impl_bool.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromBool(DataType dst_dtype) { +CastFunctorType GetCpuCastFromBool(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, bool); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromBool(DataType dst_dtype) { +CastFunctorType GetGpuCastFromBool(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, bool); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromBool(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromBool(DataType dst_dtype) { +CastFunctorType GetSyclCastFromBool(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, bool); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_complex128.cc b/tensorflow/core/kernels/cast_op_impl_complex128.cc index 52899d58cd..9a184e5954 100644 --- a/tensorflow/core/kernels/cast_op_impl_complex128.cc +++ b/tensorflow/core/kernels/cast_op_impl_complex128.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromComplex128(DataType dst_dtype) { +CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<double>); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromComplex128(DataType dst_dtype) { +CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<double>); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_complex64.cc b/tensorflow/core/kernels/cast_op_impl_complex64.cc index 617bda53d5..77bc620b46 100644 --- a/tensorflow/core/kernels/cast_op_impl_complex64.cc +++ b/tensorflow/core/kernels/cast_op_impl_complex64.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromComplex64(DataType dst_dtype) { +CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<float>); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromComplex64(DataType dst_dtype) { +CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<float>); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_double.cc b/tensorflow/core/kernels/cast_op_impl_double.cc index 7dc485ddad..ff9056897f 100644 --- a/tensorflow/core/kernels/cast_op_impl_double.cc +++ b/tensorflow/core/kernels/cast_op_impl_double.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromDouble(DataType dst_dtype) { +CastFunctorType GetCpuCastFromDouble(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, double); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromDouble(DataType dst_dtype) { +CastFunctorType GetGpuCastFromDouble(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, double); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromDouble(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromDouble(DataType dst_dtype) { +CastFunctorType GetSyclCastFromDouble(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, double); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_float.cc b/tensorflow/core/kernels/cast_op_impl_float.cc index 1c933914fd..f1e8f0e37b 100644 --- a/tensorflow/core/kernels/cast_op_impl_float.cc +++ b/tensorflow/core/kernels/cast_op_impl_float.cc @@ -22,15 +22,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromFloat(DataType dst_dtype) { +CastFunctorType GetCpuCastFromFloat(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, float); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromFloat(DataType dst_dtype) { +CastFunctorType GetGpuCastFromFloat(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, GPUDevice, float); return nullptr; } @@ -38,8 +36,7 @@ GetGpuCastFromFloat(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromFloat(DataType dst_dtype) { +CastFunctorType GetSyclCastFromFloat(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, float); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_half.cc b/tensorflow/core/kernels/cast_op_impl_half.cc index ef4b94e326..5da3a01352 100644 --- a/tensorflow/core/kernels/cast_op_impl_half.cc +++ b/tensorflow/core/kernels/cast_op_impl_half.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromHalf(DataType dst_dtype) { +CastFunctorType GetCpuCastFromHalf(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, Eigen::half); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromHalf(DataType dst_dtype) { +CastFunctorType GetGpuCastFromHalf(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, Eigen::half); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_int16.cc b/tensorflow/core/kernels/cast_op_impl_int16.cc index 59360f7445..440ee88fb5 100644 --- a/tensorflow/core/kernels/cast_op_impl_int16.cc +++ b/tensorflow/core/kernels/cast_op_impl_int16.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt16(DataType dst_dtype) { +CastFunctorType GetCpuCastFromInt16(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, int16); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt16(DataType dst_dtype) { +CastFunctorType GetGpuCastFromInt16(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int16); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromInt16(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt16(DataType dst_dtype) { +CastFunctorType GetSyclCastFromInt16(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int16); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc index a867392fde..4b3e7efddc 100644 --- a/tensorflow/core/kernels/cast_op_impl_int32.cc +++ b/tensorflow/core/kernels/cast_op_impl_int32.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt32(DataType dst_dtype) { +CastFunctorType GetCpuCastFromInt32(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, int32); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt32(DataType dst_dtype) { +CastFunctorType GetGpuCastFromInt32(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int32); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromInt32(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt32(DataType dst_dtype) { +CastFunctorType GetSyclCastFromInt32(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int32); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_int64.cc b/tensorflow/core/kernels/cast_op_impl_int64.cc index 467a8f6c89..0f711aa560 100644 --- a/tensorflow/core/kernels/cast_op_impl_int64.cc +++ b/tensorflow/core/kernels/cast_op_impl_int64.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt64(DataType dst_dtype) { +CastFunctorType GetCpuCastFromInt64(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, int64); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt64(DataType dst_dtype) { +CastFunctorType GetGpuCastFromInt64(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int64); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromInt64(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt64(DataType dst_dtype) { +CastFunctorType GetSyclCastFromInt64(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int64); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_int8.cc b/tensorflow/core/kernels/cast_op_impl_int8.cc index 21002a4321..eac185d5a0 100644 --- a/tensorflow/core/kernels/cast_op_impl_int8.cc +++ b/tensorflow/core/kernels/cast_op_impl_int8.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromInt8(DataType dst_dtype) { +CastFunctorType GetCpuCastFromInt8(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, int8); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromInt8(DataType dst_dtype) { +CastFunctorType GetGpuCastFromInt8(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int8); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromInt8(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromInt8(DataType dst_dtype) { +CastFunctorType GetSyclCastFromInt8(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int8); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_uint16.cc b/tensorflow/core/kernels/cast_op_impl_uint16.cc index cd829bae2a..3aebbdc1f3 100644 --- a/tensorflow/core/kernels/cast_op_impl_uint16.cc +++ b/tensorflow/core/kernels/cast_op_impl_uint16.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint16(DataType dst_dtype) { +CastFunctorType GetCpuCastFromUint16(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, uint16); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint16(DataType dst_dtype) { +CastFunctorType GetGpuCastFromUint16(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint16); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromUint16(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint16(DataType dst_dtype) { +CastFunctorType GetSyclCastFromUint16(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint16); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_uint32.cc b/tensorflow/core/kernels/cast_op_impl_uint32.cc index d1a854d98b..86f5961bcc 100644 --- a/tensorflow/core/kernels/cast_op_impl_uint32.cc +++ b/tensorflow/core/kernels/cast_op_impl_uint32.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint32(DataType dst_dtype) { +CastFunctorType GetCpuCastFromUint32(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, uint32); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint32(DataType dst_dtype) { +CastFunctorType GetGpuCastFromUint32(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint32); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromUint32(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint32(DataType dst_dtype) { +CastFunctorType GetSyclCastFromUint32(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint32); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_uint64.cc b/tensorflow/core/kernels/cast_op_impl_uint64.cc index 604e0424fc..6478c266ee 100644 --- a/tensorflow/core/kernels/cast_op_impl_uint64.cc +++ b/tensorflow/core/kernels/cast_op_impl_uint64.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint64(DataType dst_dtype) { +CastFunctorType GetCpuCastFromUint64(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, uint64); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint64(DataType dst_dtype) { +CastFunctorType GetGpuCastFromUint64(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint64); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromUint64(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint64(DataType dst_dtype) { +CastFunctorType GetSyclCastFromUint64(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint64); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_impl_uint8.cc b/tensorflow/core/kernels/cast_op_impl_uint8.cc index 2d1a6f3a4e..b22547a23e 100644 --- a/tensorflow/core/kernels/cast_op_impl_uint8.cc +++ b/tensorflow/core/kernels/cast_op_impl_uint8.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint8(DataType dst_dtype) { +CastFunctorType GetCpuCastFromUint8(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, uint8); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint8(DataType dst_dtype) { +CastFunctorType GetGpuCastFromUint8(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint8); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromUint8(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint8(DataType dst_dtype) { +CastFunctorType GetSyclCastFromUint8(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint8); return nullptr; } diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index 9bbf7afb16..cb305de5e3 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -40,17 +40,27 @@ static Graph* Cast(int num) { class CastOpTest : public OpsTestBase { protected: - void MakeOp(DataType src, DataType dst) { - TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast") - .Input(FakeInput(src)) - .Attr("SrcT", src) - .Attr("DstT", dst) - .Finalize(node_def())); + void MakeOp(DataType src, DataType dst, bool trunc = false) { + if (trunc) { + TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast") + .Input(FakeInput(src)) + .Attr("SrcT", src) + .Attr("DstT", dst) + .Attr("Truncate", true) + .Finalize(node_def())); + } else { + TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast") + .Input(FakeInput(src)) + .Attr("SrcT", src) + .Attr("DstT", dst) + .Finalize(node_def())); + } + TF_EXPECT_OK(InitOp()); } template <typename INPUT, typename OUTPUT> - void CheckCast() { + void CheckCast(bool trunc = false) { DataType in_type = DataTypeToEnum<INPUT>::v(); DataType out_type = DataTypeToEnum<OUTPUT>::v(); MakeOp(in_type, out_type); @@ -64,8 +74,9 @@ class CastOpTest : public OpsTestBase { } }; -#define TEST_CAST(in, out) \ - TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } +#define TEST_CAST(in, out) \ + TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } \ + TEST_F(CastOpTest, TestCast2##_##in##_##out) { CheckCast<in, out>(true); } #define TEST_ALL_CASTS_FROM(in) \ TEST_CAST(in, uint8); \ diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc index c281153795..1236f27051 100644 --- a/tensorflow/core/kernels/conv_ops_test.cc +++ b/tensorflow/core/kernels/conv_ops_test.cc @@ -229,7 +229,7 @@ class FusedResizePadConvOpTest : public OpsTestBase { std::vector<Tensor> fused_tensors; TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors)); - test::ExpectTensorNear<T>(unfused_tensors[0], fused_tensors[0], 1e-5); + test::ExpectClose(unfused_tensors[0], fused_tensors[0]); } template <typename T> @@ -282,7 +282,7 @@ class FusedResizePadConvOpTest : public OpsTestBase { std::vector<Tensor> fused_tensors; TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors)); - test::ExpectTensorNear<T>(unfused_tensors[0], fused_tensors[0], 1e-5); + test::ExpectClose(unfused_tensors[0], fused_tensors[0]); } }; diff --git a/tensorflow/core/kernels/crop_and_resize_op_benchmark_test.cc b/tensorflow/core/kernels/crop_and_resize_op_benchmark_test.cc new file mode 100644 index 0000000000..d7ca64bea0 --- /dev/null +++ b/tensorflow/core/kernels/crop_and_resize_op_benchmark_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +static Graph* BM_CropAndResize(int batches, int width, int height, int depth, + int crop_height, int crop_width) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor in(DT_FLOAT, TensorShape({batches, height, width, depth})); + in.flat<float>().setRandom(); + Tensor boxes(DT_FLOAT, TensorShape({batches, 4})); + auto boxes_tensor = boxes.matrix<float>(); + Tensor box_ind(DT_INT32, TensorShape({batches})); + auto box_ind_flat = box_ind.flat<int32>(); + for (int i = 0; i < batches; ++i) { + boxes_tensor(i, 0) = 0.2; + boxes_tensor(i, 1) = 0.2; + boxes_tensor(i, 2) = 0.8; + boxes_tensor(i, 3) = 0.7; + box_ind_flat(i) = i; + } + Tensor crop_size(DT_INT32, TensorShape({2})); + auto crop_size_flat = crop_size.flat<int32>(); + crop_size_flat(0) = crop_height; + crop_size_flat(1) = crop_width; + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CropAndResize") + .Input(test::graph::Constant(g, in)) + .Input(test::graph::Constant(g, boxes)) + .Input(test::graph::Constant(g, box_ind)) + .Input(test::graph::Constant(g, crop_size)) + .Finalize(g, &ret)); + return g; +} + +#define BM_CropAndResizeDev(DEVICE, B, W, H, D, CH, CW) \ + static void BM_CropAndResize_##DEVICE##_##B##_##W##_##H##_##D##_##CH##_##CW( \ + int iters) { \ + testing::ItemsProcessed(iters* B* W* H* D); \ + test::Benchmark(#DEVICE, BM_CropAndResize(B, W, H, D, CH, CW)).Run(iters); \ + } \ + BENCHMARK(BM_CropAndResize_##DEVICE##_##B##_##W##_##H##_##D##_##CH##_##CW); + +// Benchmark results using CPU:Intel Haswell with HyperThreading (6 cores) +// Benchmark Time(ns) CPU(ns) Iterations +// BM_CropAndResize_cpu_1_640_640_3_512_512 7078765 7173520 100 163.361M items/s +// BM_CropAndResize_cpu_1_640_640_1_512_512 3801232 3914692 185 99.784M items/s +// BM_CropAndResize_cpu_1_80_80_512_7_7 182470 241767 2941 1.372G items/s + +BM_CropAndResizeDev(cpu, 1, 640, 640, 3, 512, 512); +BM_CropAndResizeDev(cpu, 1, 640, 640, 1, 512, 512); +BM_CropAndResizeDev(cpu, 1, 80, 80, 512, 7, 7); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_tan.cc b/tensorflow/core/kernels/cwise_op_tan.cc index c1a25767d3..90762fb1b0 100644 --- a/tensorflow/core/kernels/cwise_op_tan.cc +++ b/tensorflow/core/kernels/cwise_op_tan.cc @@ -16,7 +16,8 @@ limitations under the License. #include "tensorflow/core/kernels/cwise_ops_common.h" namespace tensorflow { -REGISTER2(UnaryOp, CPU, "Tan", functor::tan, float, double); +REGISTER4(UnaryOp, CPU, "Tan", functor::tan, float, double, complex64, + complex128); #if GOOGLE_CUDA REGISTER2(UnaryOp, GPU, "Tan", functor::tan, float, double); diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index e04fa20414..607a694dba 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -177,6 +177,19 @@ tf_kernel_library( ) tf_kernel_library( + name = "filter_by_component_dataset_op", + srcs = ["filter_by_component_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( name = "map_dataset_op", srcs = ["map_dataset_op.cc"], deps = [ @@ -204,12 +217,28 @@ tf_kernel_library( ], ) +cc_library( + name = "parallel_map_iterator", + srcs = ["parallel_map_iterator.cc"], + hdrs = ["parallel_map_iterator.h"], + deps = [ + ":dataset", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + tf_kernel_library( name = "parallel_map_dataset_op", srcs = ["parallel_map_dataset_op.cc"], deps = [ ":captured_function", ":dataset", + ":parallel_map_iterator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -222,6 +251,7 @@ tf_kernel_library( tf_kernel_library( name = "generator_dataset_op", srcs = ["generator_dataset_op.cc"], + hdrs = ["generator_dataset_op.h"], deps = [ ":captured_function", "//tensorflow/core:core_cpu_internal", @@ -314,6 +344,7 @@ tf_cc_test( tf_kernel_library( name = "prefetch_dataset_op", srcs = ["prefetch_dataset_op.cc"], + hdrs = ["prefetch_dataset_op.h"], deps = [ ":dataset", ":prefetch_autotuner", @@ -535,9 +566,11 @@ tf_kernel_library( tf_kernel_library( name = "iterator_ops", srcs = ["iterator_ops.cc"], + hdrs = ["iterator_ops.h"], deps = [ ":dataset", ":dataset_utils", + ":optional_ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -550,6 +583,20 @@ tf_kernel_library( ) tf_kernel_library( + name = "optional_ops", + srcs = ["optional_ops.cc"], + hdrs = ["optional_ops.h"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_kernel_library( name = "cache_dataset_ops", srcs = ["cache_dataset_ops.cc"], deps = [ @@ -605,6 +652,7 @@ tf_kernel_library( ":dataset", ":dataset_ops", ":dense_to_sparse_batch_dataset_op", + ":filter_by_component_dataset_op", ":filter_dataset_op", ":flat_map_dataset_op", ":generator_dataset_op", @@ -614,7 +662,9 @@ tf_kernel_library( ":iterator_ops", ":map_and_batch_dataset_op", ":map_dataset_op", + ":map_defun_op", ":optimize_dataset_op", + ":optional_ops", ":padded_batch_dataset_op", ":parallel_interleave_dataset_op", ":parallel_map_dataset_op", @@ -655,3 +705,15 @@ tf_kernel_library( "//tensorflow/core/kernels:ops_util", ], ) + +tf_kernel_library( + name = "map_defun_op", + srcs = ["map_defun_op.cc"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index ed4932bf32..86b0840aea 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -39,7 +39,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { ParseScalarArgument<string>(ctx, "filename", &filename)); if (filename.empty()) { - *output = new MemoryDataset(input); + *output = new MemoryDataset(ctx, input); } else { *output = new FileDataset(ctx, input, filename, ctx->env()); } @@ -68,8 +68,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>(new FileCacheIterator( - {this, strings::StrCat(prefix, "::FileCacheIterator")})); + return std::unique_ptr<IteratorBase>( + new FileIterator({this, strings::StrCat(prefix, "::FileIterator")})); } const DataTypeVector& output_dtypes() const override { @@ -105,9 +105,9 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { tensor_index); } - class FileCacheIterator : public DatasetIterator<FileDataset> { + class FileIterator : public DatasetIterator<FileDataset> { public: - explicit FileCacheIterator(const Params& params) + explicit FileIterator(const Params& params) : DatasetIterator<FileDataset>(params) { if (params.dataset->env_ ->FileExists(MetaFilename(params.dataset->filename_)) @@ -526,7 +526,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { enum Mode { read, write }; Mode mode_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_); - }; // FileCacheIterator + }; // FileIterator const DatasetBase* const input_; const string filename_; @@ -538,9 +538,10 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { const string tensor_format_string_; }; // FileDataset - class MemoryDataset : public DatasetBase { + class MemoryDataset : public GraphDatasetBase { public: - explicit MemoryDataset(const DatasetBase* input) : input_(input) { + explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input) + : GraphDatasetBase(ctx), input_(input), cache_(new MemoryCache()) { input->Ref(); } @@ -548,18 +549,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - mutex_lock l(mu_); - if (cache_) { - return std::unique_ptr<IteratorBase>(new MemoryReaderIterator( - {this, strings::StrCat(prefix, "::MemoryReader")}, cache_.get())); - } - if (!writer_iterator_created_) { - writer_iterator_created_ = true; - return std::unique_ptr<IteratorBase>(new MemoryWriterIterator( - {this, strings::StrCat(prefix, "::MemoryWriter")})); - } - return std::unique_ptr<IteratorBase>(new DuplicateWriterIterator( - {this, strings::StrCat(prefix, "::DuplicateWriter")})); + return std::unique_ptr<IteratorBase>(new MemoryIterator( + {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_)); } const DataTypeVector& output_dtypes() const override { @@ -574,114 +565,321 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { return "CacheDatasetOp::MemoryDataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node)); + Node* filename_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(string(""), &filename_node)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_node, filename_node}, output)); + return Status::OK(); + } + private: - // MemoryWriterIterator passes through and appends items from the input - // dataset to its vector. + // A thread-safe data structure for caching dataset elements. // - // This iterator is used when dataset->cache_ is null. After buffering - // the tensors in memory, upon exhausing the underlying iterator, they are - // updated into the parent dataset's cache_ pointer. - class MemoryWriterIterator : public DatasetIterator<MemoryDataset> { + // The expected use is that a single `MemoryWriterIterator` populates the + // cache with dataset elements. Once all elements are cached, the cache can + // be used by one or more `MemoryReaderIterator`s. + class MemoryCache { public: - explicit MemoryWriterIterator(const Params& params) - : DatasetIterator<MemoryDataset>(params), - cache_(new std::vector<std::vector<Tensor>>) {} + MemoryCache() = default; - ~MemoryWriterIterator() override { + // Marks the cache as completed. + void Complete() { mutex_lock l(mu_); - if (cache_) { - LOG(ERROR) - << "The calling iterator did not fully read the dataset we were " - "attempting to cache. In order to avoid unexpected truncation " - "of the sequence, the current [partially cached] sequence " - "will be dropped. This can occur if you have a sequence " - "similar to `dataset.cache().take(k).repeat()`. Instead, swap " - "the order (i.e. `dataset.take(k).cache().repeat()`)"; - mutex_lock l2(dataset()->mu_); - dataset()->writer_iterator_created_ = false; - } + completed_ = true; } - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + // Returns whether the cache is claimed. + bool IsClaimed() { + tf_shared_lock l(mu_); + return claimed_; } - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { + // Returns whether the cache is completed. + bool IsCompleted() { + tf_shared_lock l(mu_); + return completed_; + } + + // Attempts to claim the cache, returning whether the cache was claimed. + bool MaybeClaim() { mutex_lock l(mu_); - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - if (*end_of_sequence) { - // Guard on cache_ to not crash if GetNext is called a second time - // after *end_of_sequence == true - if (cache_) { - mutex_lock l(dataset()->mu_); - DCHECK(dataset()->writer_iterator_created_); - DCHECK(!dataset()->cache_); - cache_.swap(dataset()->cache_); - } - return Status::OK(); + if (!claimed_) { + claimed_ = true; + return true; } - cache_->emplace_back(*out_tensors); - return Status::OK(); + return false; + } + + // Resets the cache. + void Reset() { + mutex_lock l(mu_); + claimed_ = false; + completed_ = false; + cache_.clear(); + } + + // Returns the element at the given index. + const std::vector<Tensor>& at(int64 index) { + tf_shared_lock l(mu_); + DCHECK(index < cache_.size()); + return cache_[index]; + } + + // Adds the element to the cache. + void emplace_back(std::vector<Tensor> element) { + mutex_lock l(mu_); + cache_.emplace_back(std::move(element)); + } + + // Returns the size of the cache. + size_t size() { + tf_shared_lock l(mu_); + return cache_.size(); } private: mutex mu_; - std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); - std::unique_ptr<std::vector<std::vector<Tensor>>> cache_ GUARDED_BY(mu_); - }; // MemoryWriterIterator - - class MemoryReaderIterator : public DatasetIterator<MemoryDataset> { + // Determines whether a writer has claimed the cache. + bool claimed_ GUARDED_BY(mu_) = false; + // Determines whether all elements of the dataset have been cached. + bool completed_ GUARDED_BY(mu_) = false; + std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_); + }; + + class MemoryIterator : public DatasetIterator<MemoryDataset> { public: - explicit MemoryReaderIterator( - const Params& params, const std::vector<std::vector<Tensor>>* cache) - : DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) { - CHECK(cache); + explicit MemoryIterator(const Params& params, + const std::shared_ptr<MemoryCache>& cache) + : DatasetIterator<MemoryDataset>(params), cache_(cache) { + mode_ = cache->MaybeClaim() ? Mode::write : Mode::read; + InitializeIterator(); + } + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + if (mode_ == Mode::read && !cache_->IsCompleted()) { + return errors::Internal( + "Cache should only be read after it has been completed."); + } + return iterator_->Initialize(ctx); } Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - if (index_ < cache_->size()) { - const std::vector<Tensor>& cache_tensors = (*cache_)[index_]; - out_tensors->insert(out_tensors->begin(), cache_tensors.begin(), - cache_tensors.end()); - index_++; - *end_of_sequence = false; - return Status::OK(); - } else { - *end_of_sequence = true; - return Status::OK(); + return iterator_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_)); + if (cache_->IsClaimed()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cache_claimed"), "")); + size_t cache_size = cache_->size(); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cache_size"), cache_size)); + for (size_t i = 0; i < cache_size; i++) { + auto& element = cache_->at(i); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("cache[", i, "].size")), + element.size())); + for (size_t j = 0; j < element.size(); ++j) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("cache[", i, "][", j, "]")), + element[j])); + } + } + if (cache_->IsCompleted()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cache_completed"), "")); + } } + return SaveParent(writer, iterator_); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + iterator_.reset(); + cache_->Reset(); + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp)); + mode_ = static_cast<Mode>(temp); + } + if (reader->Contains(full_name("cache_claimed"))) { + CHECK(cache_->MaybeClaim()); + size_t cache_size; + { + int64 temp; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("cache_size"), &temp)); + cache_size = static_cast<size_t>(temp); + } + for (size_t i = 0; i < cache_size; ++i) { + std::vector<Tensor> element; + size_t element_size; + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("cache[", i, "].size")), &temp)); + element_size = static_cast<size_t>(temp); + } + element.reserve(element_size); + for (size_t j = 0; j < element_size; ++j) { + element.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("cache[", i, "][", j, "]")), + &element.back())); + } + cache_->emplace_back(std::move(element)); + } + if (reader->Contains(full_name("cache_completed"))) { + cache_->Complete(); + } + } + InitializeIterator(); + TF_RETURN_IF_ERROR(iterator_->Initialize(ctx)); + return RestoreParent(ctx, reader, iterator_); } private: - mutex mu_; - const std::vector<std::vector<Tensor>>* const cache_; - size_t index_ GUARDED_BY(mu_); - }; // MemoryReaderIterator + class MemoryWriterIterator : public DatasetIterator<MemoryDataset> { + public: + explicit MemoryWriterIterator(const Params& params, + const std::shared_ptr<MemoryCache>& cache) + : DatasetIterator<MemoryDataset>(params), cache_(cache) { + CHECK(cache_); + } - class DuplicateWriterIterator : public DatasetIterator<MemoryDataset> { - public: - explicit DuplicateWriterIterator(const Params& params) - : DatasetIterator<MemoryDataset>(params) {} + ~MemoryWriterIterator() override { + mutex_lock l(mu_); + if (cache_->size() > 0 && !cache_->IsCompleted()) { + LOG(WARNING) + << "The calling iterator did not fully read the dataset being " + "cached. In order to avoid unexpected truncation of the " + "dataset, the partially cached contents of the dataset" + "will be discarded. This can happen if you have an input " + "pipeline similar to `dataset.cache().take(k).repeat()`. " + "You should use `dataset.take(k).cache().repeat()` instead."; + cache_->Reset(); + } + } - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - return errors::AlreadyExists( - "There appears to be a concurrent caching iterator running."); + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (*end_of_sequence) { + cache_->Complete(); + return Status::OK(); + } + cache_->emplace_back(*out_tensors); + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + return SaveParent(writer, input_impl_); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + return RestoreParent(ctx, reader, input_impl_); + } + + private: + mutex mu_; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + std::shared_ptr<MemoryCache> cache_; + }; // MemoryWriterIterator + + class MemoryReaderIterator : public DatasetIterator<MemoryDataset> { + public: + explicit MemoryReaderIterator(const Params& params, + const std::shared_ptr<MemoryCache>& cache) + : DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) { + CHECK(cache); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("index"), index_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("index"), &temp)); + index_ = static_cast<size_t>(temp); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (index_ < cache_->size()) { + const std::vector<Tensor>& cache_tensors = cache_->at(index_); + out_tensors->insert(out_tensors->begin(), cache_tensors.begin(), + cache_tensors.end()); + index_++; + *end_of_sequence = false; + return Status::OK(); + } else { + *end_of_sequence = true; + return Status::OK(); + } + } + + private: + mutex mu_; + const std::shared_ptr<MemoryCache> cache_; + size_t index_ GUARDED_BY(mu_); + }; // MemoryReaderIterator + + void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + switch (mode_) { + case Mode::read: + iterator_.reset( + new MemoryReaderIterator({dataset(), prefix()}, cache_)); + break; + case Mode::write: + iterator_.reset( + new MemoryWriterIterator({dataset(), prefix()}, cache_)); + } } - }; // DuplicateWriterIterator + + mutex mu_; + std::shared_ptr<MemoryCache> cache_; + enum Mode { read, write }; + Mode mode_ GUARDED_BY(mu_); + std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_); + }; // MemoryIterator const DatasetBase* const input_; - mutable mutex mu_; - mutable std::unique_ptr<std::vector<std::vector<Tensor>>> cache_ - GUARDED_BY(mu_); - mutable bool writer_iterator_created_ GUARDED_BY(mu_) = false; + const std::shared_ptr<MemoryCache> cache_; }; // MemoryDataset }; // CacheDatasetOp diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc new file mode 100644 index 0000000000..8b29456354 --- /dev/null +++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc @@ -0,0 +1,169 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. +// TODO(prazek): Filter already has a logic of filtering by the given tensor, +// but it must return both components. We could introduce kernel like +// DropComponentDatasetOp and use FilterDataset for filtering. +class FilterByLastComponentDatasetOp : public UnaryDatasetOpKernel { + public: + explicit FilterByLastComponentDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + *output = new Dataset(ctx, input, output_types_, output_shapes_); + } + + private: + const int graph_def_version_; + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const DataTypeVector& output_types, + std::vector<PartialTensorShape> output_shapes) + : GraphDatasetBase(ctx), + input_(input), + output_types_(output_types), + output_shapes_(std::move(output_shapes)) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<Iterator>(new Iterator( + {this, strings::StrCat(prefix, "::FilterByLastComponent")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "FilterByLastComponentDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs. + {}, {}, output)); + return Status::OK(); + } + + private: + const DatasetBase* const input_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + // NOTE(mrry): This method is thread-safe as long as `input_impl_` is + // thread-safe. However, if multiple threads enter this method, outputs + // may be observed in a non-deterministic order. + bool matched; + do { + { + tf_shared_lock l(mu_); + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + } + if (*end_of_sequence) { + mutex_lock l(mu_); + input_impl_.reset(); + return Status::OK(); + } + + matched = out_tensors->back().scalar<bool>()(); + out_tensors->pop_back(); + if (!matched) { + // Clear the output tensor list since it didn't match. + out_tensors->clear(); + } + } while (!matched); + *end_of_sequence = false; + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + }; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("FilterByLastComponentDataset").Device(DEVICE_CPU), + FilterByLastComponentDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 0981e42ba1..c4dd849b8b 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -15,192 +15,174 @@ limitations under the License. #include <iterator> #include <vector> -#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/kernels/data/generator_dataset_op.h" + #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { -namespace { - // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class GeneratorDatasetOp : public DatasetOpKernel { +class GeneratorDatasetOp::Dataset : public GraphDatasetBase { public: - explicit GeneratorDatasetOp(OpKernelConstruction* ctx) - : DatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func, + std::unique_ptr<CapturedFunction> next_func, + std::unique_ptr<CapturedFunction> finalize_func, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : GraphDatasetBase(ctx), + init_func_(std::move(init_func)), + next_func_(std::move(next_func)), + finalize_func_(std::move(finalize_func)), + output_types_(output_types), + output_shapes_(output_shapes) {} + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Generator")})); } - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - OpInputList init_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", - &init_func_other_args_input)); - std::vector<Tensor> init_func_other_args; - init_func_other_args.reserve(init_func_other_args_input.size()); - for (const Tensor& t : init_func_other_args_input) { - init_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> init_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - init_func_, std::move(init_func_other_args), &init_func)); - - OpInputList next_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", - &next_func_other_args_input)); - std::vector<Tensor> next_func_other_args; - next_func_other_args.reserve(next_func_other_args_input.size()); - for (const Tensor& t : next_func_other_args_input) { - next_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> next_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - next_func_, std::move(next_func_other_args), &next_func)); - - OpInputList finalize_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", - &finalize_func_other_args_input)); - std::vector<Tensor> finalize_func_other_args; - finalize_func_other_args.reserve(finalize_func_other_args_input.size()); - for (const Tensor& t : finalize_func_other_args_input) { - finalize_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> finalize_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - finalize_func_, std::move(finalize_func_other_args), - &finalize_func)); - - *output = - new Dataset(ctx, std::move(init_func), std::move(next_func), - std::move(finalize_func), output_types_, output_shapes_); + const DataTypeVector& output_dtypes() const override { return output_types_; } + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; } + string DebugString() const override { return "GeneratorDatasetOp::Dataset"; } + private: - class Dataset : public GraphDatasetBase { + class Iterator : public DatasetIterator<Dataset> { public: - Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func, - std::unique_ptr<CapturedFunction> next_func, - std::unique_ptr<CapturedFunction> finalize_func, - const DataTypeVector& output_types, - const std::vector<PartialTensorShape>& output_shapes) - : GraphDatasetBase(ctx), - init_func_(std::move(init_func)), - next_func_(std::move(next_func)), - finalize_func_(std::move(finalize_func)), - output_types_(output_types), - output_shapes_(output_shapes) {} - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Generator")})); - } - - const DataTypeVector& output_dtypes() const override { - return output_types_; - } - const std::vector<PartialTensorShape>& output_shapes() const override { - return output_shapes_; - } - - string DebugString() const override { - return "GeneratorDatasetOp::Dataset"; - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} - - ~Iterator() override { - if (!finalized_) { - std::vector<Tensor> ignored; - Status s = - dataset()->finalize_func_->RunInstantiated(state_, &ignored); - if (!s.ok()) { - LOG(WARNING) - << "Error occurred when finalizing GeneratorDataset iterator: " - << s; - } + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + ~Iterator() override { + if (!finalized_) { + std::vector<Tensor> ignored; + Status s = dataset()->finalize_func_->RunInstantiated(state_, &ignored); + if (!s.ok()) { + LOG(WARNING) + << "Error occurred when finalizing GeneratorDataset iterator: " + << s; } } + } - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - - if (!initialized_) { - TF_RETURN_IF_ERROR( - dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); - // Explicitly instantiate the finalize function here so that - // we can invoke it in the destructor. - TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); - initialized_ = true; - } + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + + if (!initialized_) { + TF_RETURN_IF_ERROR( + dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); + // Explicitly instantiate the finalize function here so that + // we can invoke it in the destructor. + TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); + initialized_ = true; + } - if (finalized_) { - *end_of_sequence = true; - return Status::OK(); - } + if (finalized_) { + *end_of_sequence = true; + return Status::OK(); + } - Status s = dataset()->next_func_->RunWithBorrowedArgs(ctx, state_, - out_tensors); - if (s.ok()) { - *end_of_sequence = false; - } else if (errors::IsOutOfRange(s)) { - // `next_func` may deliberately raise `errors::OutOfRange` - // to indicate that we should terminate the iteration. - s = Status::OK(); - *end_of_sequence = true; - - // NOTE(mrry): We ignore any tensors returned by the - // finalize function. - std::vector<Tensor> ignored; - TF_RETURN_IF_ERROR( - dataset()->finalize_func_->RunInstantiated(state_, &ignored)); - finalized_ = true; - } - return s; + Status s = + dataset()->next_func_->RunWithBorrowedArgs(ctx, state_, out_tensors); + if (s.ok()) { + *end_of_sequence = false; + } else if (errors::IsOutOfRange(s)) { + // `next_func` may deliberately raise `errors::OutOfRange` + // to indicate that we should terminate the iteration. + s = Status::OK(); + *end_of_sequence = true; + + // NOTE(mrry): We ignore any tensors returned by the + // finalize function. + std::vector<Tensor> ignored; + TF_RETURN_IF_ERROR( + dataset()->finalize_func_->RunInstantiated(state_, &ignored)); + finalized_ = true; } + return s; + } - private: - mutex mu_; - bool initialized_ GUARDED_BY(mu_) = false; - bool finalized_ GUARDED_BY(mu_) = false; - std::vector<Tensor> state_ GUARDED_BY(mu_); - }; - - const std::unique_ptr<CapturedFunction> init_func_; - const std::unique_ptr<CapturedFunction> next_func_; - const std::unique_ptr<CapturedFunction> finalize_func_; - const DataTypeVector output_types_; - const std::vector<PartialTensorShape> output_shapes_; + private: + mutex mu_; + bool initialized_ GUARDED_BY(mu_) = false; + bool finalized_ GUARDED_BY(mu_) = false; + std::vector<Tensor> state_ GUARDED_BY(mu_); }; - DataTypeVector output_types_; - std::vector<PartialTensorShape> output_shapes_; - NameAttrList init_func_; - NameAttrList next_func_; - NameAttrList finalize_func_; + const std::unique_ptr<CapturedFunction> init_func_; + const std::unique_ptr<CapturedFunction> next_func_; + const std::unique_ptr<CapturedFunction> finalize_func_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; }; +GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); +} + +void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx, + DatasetBase** output) { + OpInputList init_func_other_args_input; + OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", + &init_func_other_args_input)); + std::vector<Tensor> init_func_other_args; + init_func_other_args.reserve(init_func_other_args_input.size()); + for (const Tensor& t : init_func_other_args_input) { + init_func_other_args.push_back(t); + } + std::unique_ptr<CapturedFunction> init_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args), + &init_func)); + + OpInputList next_func_other_args_input; + OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", + &next_func_other_args_input)); + std::vector<Tensor> next_func_other_args; + next_func_other_args.reserve(next_func_other_args_input.size()); + for (const Tensor& t : next_func_other_args_input) { + next_func_other_args.push_back(t); + } + std::unique_ptr<CapturedFunction> next_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args), + &next_func)); + + OpInputList finalize_func_other_args_input; + OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", + &finalize_func_other_args_input)); + std::vector<Tensor> finalize_func_other_args; + finalize_func_other_args.reserve(finalize_func_other_args_input.size()); + for (const Tensor& t : finalize_func_other_args_input) { + finalize_func_other_args.push_back(t); + } + std::unique_ptr<CapturedFunction> finalize_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create( + finalize_func_, std::move(finalize_func_other_args), + &finalize_func)); + + *output = + new Dataset(ctx, std::move(init_func), std::move(next_func), + std::move(finalize_func), output_types_, output_shapes_); +} + REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU), GeneratorDatasetOp); REGISTER_KERNEL_BUILDER( Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"), GeneratorDatasetOp); -} // namespace - } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h new file mode 100644 index 0000000000..3f84fa9c2e --- /dev/null +++ b/tensorflow/core/kernels/data/generator_dataset_op.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/kernels/data/captured_function.h" + +namespace tensorflow { + +class GeneratorDatasetOp : public DatasetOpKernel { + public: + explicit GeneratorDatasetOp(OpKernelConstruction* ctx); + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + NameAttrList init_func_; + NameAttrList next_func_; + NameAttrList finalize_func_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index da489db7c8..e2df14337c 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" + #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/common_runtime/threadpool_device.h" @@ -23,8 +24,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -80,6 +81,8 @@ Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, return Status::OK(); } +} // namespace + class IteratorResource : public ResourceBase { public: IteratorResource(const DataTypeVector& output_dtypes, @@ -437,300 +440,179 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, // Note that IteratorHandleOp holds a reference to the resource it creates. If // cleaning up resources with DestroyResourceOp is important, consider creating // resource containers with AnonymousIteratorHandleOp instead. -class IteratorHandleOp : public OpKernel { - public: - explicit IteratorHandleOp(OpKernelConstruction* ctx) - : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); - } +IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); +} - // The resource is deleted from the resource manager only when it is private - // to kernel. Ideally the resource should be deleted when it is no longer held - // by anyone, but it would break backward compatibility. - ~IteratorHandleOp() override { - if (resource_ != nullptr) { - resource_->Unref(); - if (cinfo_.resource_is_private_to_kernel()) { - if (!cinfo_.resource_manager() - ->template Delete<IteratorResource>(cinfo_.container(), - cinfo_.name()) - .ok()) { - // Do nothing; the resource can have been deleted by session resets. - } +// The resource is deleted from the resource manager only when it is private +// to kernel. Ideally the resource should be deleted when it is no longer held +// by anyone, but it would break backward compatibility. +IteratorHandleOp::~IteratorHandleOp() { + if (resource_ != nullptr) { + resource_->Unref(); + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete<IteratorResource>(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. } } } +} - void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { - { - mutex_lock l(mu_); - if (resource_ == nullptr) { - FunctionLibraryRuntime* lib; - std::unique_ptr<DeviceMgr> device_mgr(nullptr); - std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); - // If the iterator is shared then we construct a new FLR, and pass that - // in. NOTE(mrry,rohanj): In this case it is not possible to call remote - // functions from the iterator. We may add this functionality if there - // is sufficient demand, but it will require a significant refactoring. - if (!name_.empty()) { - lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr); - } else { - OP_REQUIRES_OK(context, context->function_library()->Clone( - &flib_def, &pflr, &lib)); - } - - ResourceMgr* mgr = context->resource_manager(); - OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); - - IteratorResource* resource; - OP_REQUIRES_OK( - context, - mgr->LookupOrCreate<IteratorResource>( - cinfo_.container(), cinfo_.name(), &resource, - [lib, &device_mgr, &flib_def, &pflr, - this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource( - output_dtypes_, output_shapes_, graph_def_version_, - std::move(device_mgr), std::move(flib_def), - std::move(pflr), lib); - return Status::OK(); - })); +void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (resource_ == nullptr) { + FunctionLibraryRuntime* lib; + std::unique_ptr<DeviceMgr> device_mgr(nullptr); + std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); + // If the iterator is shared then we construct a new FLR, and pass that + // in. NOTE(mrry,rohanj): In this case it is not possible to call remote + // functions from the iterator. We may add this functionality if there + // is sufficient demand, but it will require a significant refactoring. + if (!name_.empty()) { + lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr); + } else { + OP_REQUIRES_OK(context, context->function_library()->Clone( + &flib_def, &pflr, &lib)); + } - Status s = VerifyResource(resource); - if (TF_PREDICT_FALSE(!s.ok())) { - resource->Unref(); - context->SetStatus(s); - return; - } + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); - resource_ = resource; + IteratorResource* resource; + OP_REQUIRES_OK( + context, + mgr->LookupOrCreate<IteratorResource>( + cinfo_.container(), cinfo_.name(), &resource, + [lib, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new IteratorResource( + output_dtypes_, output_shapes_, graph_def_version_, + std::move(device_mgr), std::move(flib_def), + std::move(pflr), lib); + return Status::OK(); + })); + + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; } - } - OP_REQUIRES_OK(context, MakeResourceHandleToOutput( - context, 0, cinfo_.container(), cinfo_.name(), - MakeTypeIndex<IteratorResource>())); - } - - private: - // During the first Compute(), resource is either created or looked up using - // shared_name. In the latter case, the resource found should be verified if - // it is compatible with this op's configuration. The verification may fail in - // cases such as two graphs asking queues of the same shared name to have - // inconsistent capacities. - Status VerifyResource(IteratorResource* resource) { - TF_RETURN_IF_ERROR( - VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); - TF_RETURN_IF_ERROR( - VerifyShapesCompatible(output_shapes_, resource->output_shapes())); - return Status::OK(); - } - template <typename To, typename From> // use like this: down_cast<T*>(foo); - static inline To down_cast(From* f) { // so we only accept pointers - static_assert( - (std::is_base_of<From, typename std::remove_pointer<To>::type>::value), - "target type not derived from source type"); - - // We skip the assert and hence the dynamic_cast if RTTI is disabled. -#if !defined(__GNUC__) || defined(__GXX_RTTI) - // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. - assert(f == nullptr || dynamic_cast<To>(f) != nullptr); -#endif // !defined(__GNUC__) || defined(__GXX_RTTI) - return static_cast<To>(f); + resource_ = resource; + } } + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<IteratorResource>())); +} - FunctionLibraryRuntime* CreatePrivateFLR( - OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, - std::unique_ptr<FunctionLibraryDefinition>* flib_def, - std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) { - // Wrap the existing device in order to see any captured resources - // in its resource manager. The existing device will outlive the - // IteratorResource, because we are storing the IteratorResource - // in that device's resource manager. - Device* wrapped_device = RenamedDevice::NewRenamedDevice( - ctx->device()->name(), down_cast<Device*>(ctx->device()), - false /* owns_underlying */, false /* isolate_session_state */); - device_mgr->reset(new DeviceMgr({wrapped_device})); - flib_def->reset(new FunctionLibraryDefinition( - *ctx->function_library()->GetFunctionLibraryDefinition())); - pflr->reset(new ProcessFunctionLibraryRuntime( - device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(), - {} /* TODO(mrry): OptimizerOptions? */, - nullptr /* TODO(mrry): ClusterFLR */)); - - return (*pflr)->GetFLR(ctx->device()->name()); - } +Status IteratorHandleOp::VerifyResource(IteratorResource* resource) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, resource->output_shapes())); + return Status::OK(); +} - mutex mu_; - ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. - IteratorResource* resource_ GUARDED_BY(mu_) = nullptr; - DataTypeVector output_dtypes_; - std::vector<PartialTensorShape> output_shapes_; - const int graph_def_version_; - string name_; -}; +FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR( + OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, + std::unique_ptr<FunctionLibraryDefinition>* flib_def, + std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) { + // Wrap the existing device in order to see any captured resources + // in its resource manager. The existing device will outlive the + // IteratorResource, because we are storing the IteratorResource + // in that device's resource manager. + Device* wrapped_device = RenamedDevice::NewRenamedDevice( + ctx->device()->name(), down_cast<Device*>(ctx->device()), + false /* owns_underlying */, false /* isolate_session_state */); + device_mgr->reset(new DeviceMgr({wrapped_device})); + flib_def->reset(new FunctionLibraryDefinition( + *ctx->function_library()->GetFunctionLibraryDefinition())); + pflr->reset(new ProcessFunctionLibraryRuntime( + device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(), + {} /* TODO(mrry): OptimizerOptions? */, + nullptr /* TODO(mrry): ClusterFLR */)); + + return (*pflr)->GetFLR(ctx->device()->name()); +} // Like IteratorHandleOp, but creates handles which are never shared, and does // not hold a reference to these handles. The latter is important for eager // execution, since OpKernel instances generally live as long as the program // running them. -class AnonymousIteratorHandleOp : public OpKernel { - public: - explicit AnonymousIteratorHandleOp(OpKernelConstruction* context) - : OpKernel(context), graph_def_version_(context->graph_def_version()) { - OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_)); - OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_)); - } +AnonymousIteratorHandleOp::AnonymousIteratorHandleOp( + OpKernelConstruction* context) + : OpKernel(context), graph_def_version_(context->graph_def_version()) { + OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_)); + OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_)); +} - void Compute(OpKernelContext* context) override { - FunctionLibraryRuntime* lib; - std::unique_ptr<DeviceMgr> device_mgr(nullptr); - std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); - OP_REQUIRES_OK(context, - context->function_library()->Clone(&flib_def, &pflr, &lib)); +void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) { + FunctionLibraryRuntime* lib; + std::unique_ptr<DeviceMgr> device_mgr(nullptr); + std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); + OP_REQUIRES_OK(context, + context->function_library()->Clone(&flib_def, &pflr, &lib)); - ResourceMgr* mgr = context->resource_manager(); + ResourceMgr* mgr = context->resource_manager(); - const string container_name = "AnonymousIterator"; - string unique_name; - { - mutex_lock l(static_resource_lookup_mutex_); - while (true) { // Find an unused name - IteratorResource* existing_resource = nullptr; - unique_name = strings::StrCat("AnonymousIterator", current_id_++); - Status status = mgr->Lookup<IteratorResource>( - container_name, unique_name, &existing_resource); - if (status.code() == error::NOT_FOUND) { - break; - } - OP_REQUIRES_OK(context, status); - existing_resource->Unref(); + const string container_name = "AnonymousIterator"; + string unique_name; + { + mutex_lock l(static_resource_lookup_mutex_); + while (true) { // Find an unused name + IteratorResource* existing_resource = nullptr; + unique_name = strings::StrCat("AnonymousIterator", current_id_++); + Status status = mgr->Lookup<IteratorResource>(container_name, unique_name, + &existing_resource); + if (status.code() == error::NOT_FOUND) { + break; } - IteratorResource* new_resource = new IteratorResource( - output_dtypes_, output_shapes_, graph_def_version_, - std::move(device_mgr), std::move(flib_def), std::move(pflr), lib); - // Create the resource with our chosen name under the resource lookup - // mutex to avoid another kernel racily creating a resource with this - // name. - OP_REQUIRES_OK(context, mgr->Create<IteratorResource>( - container_name, unique_name, new_resource)); + OP_REQUIRES_OK(context, status); + existing_resource->Unref(); } - OP_REQUIRES_OK(context, MakeResourceHandleToOutput( - context, 0, container_name, unique_name, - MakeTypeIndex<IteratorResource>())); + IteratorResource* new_resource = new IteratorResource( + output_dtypes_, output_shapes_, graph_def_version_, + std::move(device_mgr), std::move(flib_def), std::move(pflr), lib); + // Create the resource with our chosen name under the resource lookup + // mutex to avoid another kernel racily creating a resource with this + // name. + OP_REQUIRES_OK(context, mgr->Create<IteratorResource>( + container_name, unique_name, new_resource)); } - - private: - // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp - // instances. - static mutex static_resource_lookup_mutex_; - // current_id_ is just a hint for creating unique names. If it turns out - // there's a collision (e.g. because another AnonymousIteratorHandleOp - // instance is generating handles) we'll just skip that id. - static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_); - DataTypeVector output_dtypes_; - std::vector<PartialTensorShape> output_shapes_; - const int graph_def_version_; -}; + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, container_name, unique_name, + MakeTypeIndex<IteratorResource>())); +} // Static initializers for AnonymousIteratorHandleOp id counting. mutex AnonymousIteratorHandleOp::static_resource_lookup_mutex_{ LINKER_INITIALIZED}; int64 AnonymousIteratorHandleOp::current_id_(0); -class MakeIteratorOp : public OpKernel { - public: - explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - DatasetBase* dataset; - OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); - IteratorResource* iterator_resource; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); - core::ScopedUnref unref(iterator_resource); - - IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); - std::unique_ptr<IteratorBase> iterator; - OP_REQUIRES_OK(ctx, - dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); - OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); - } -}; - -// A simple background worker that executes closures asynchronously and without -// blocking. -// -// A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel` -// to avoid blocking an executor thread that may be required by the blocking -// work. -// -// NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this -// purpose because its current implementation (in Eigen) uses a finite-length -// queue and will block the caller when full. This can lead to deadlock under -// heavy load. Since the number of concurrent work items in each user of a -// `BackgroundWorker` is at most one per op invocation, the dynamic allocation -// overhead is tolerable. -class BackgroundWorker { - public: - BackgroundWorker(Env* env, const string& name) { - thread_.reset(env->StartThread({} /* thread_options */, name, - [this]() { WorkerLoop(); })); - } - - ~BackgroundWorker() { - { - mutex_lock l(mu_); - cancelled_ = true; - } - cond_var_.notify_one(); - // Block until the background thread has terminated. - // - // NOTE(mrry): We explicitly free and join the thread here because - // `WorkerLoop()` uses other members of this object, and so we must join - // the thread before destroying them. - thread_.reset(); - } - - void Schedule(std::function<void()> work_item) { - { - mutex_lock l(mu_); - work_queue_.push_back(std::move(work_item)); - } - cond_var_.notify_one(); - } - - private: - void WorkerLoop() { - while (true) { - std::function<void()> work_item = nullptr; - { - mutex_lock l(mu_); - while (!cancelled_ && work_queue_.empty()) { - cond_var_.wait(l); - } - if (cancelled_) { - return; - } - DCHECK(!work_queue_.empty()); - work_item = std::move(work_queue_.front()); - work_queue_.pop_front(); - } - DCHECK(work_item != nullptr); - work_item(); - } - } - - std::unique_ptr<Thread> thread_; - mutex mu_; - condition_variable cond_var_; - bool cancelled_ GUARDED_BY(mu_) = false; - std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_); -}; +void MakeIteratorOp::Compute(OpKernelContext* ctx) { + DatasetBase* dataset; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); + IteratorResource* iterator_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); + core::ScopedUnref unref(iterator_resource); + + IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx); + std::unique_ptr<IteratorBase> iterator; + OP_REQUIRES_OK(ctx, dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); + OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); +} class ToSingleElementOp : public AsyncOpKernel { public: @@ -995,13 +877,92 @@ class OneShotIteratorOp : public AsyncOpKernel { const int graph_def_version_; }; -class IteratorGetNextOp : public AsyncOpKernel { +void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { + IteratorResource* iterator; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + background_worker_.Schedule(std::bind( + [ctx, iterator](DoneCallback done) { + std::vector<Tensor> components; + bool end_of_sequence = false; + + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.function_library = iterator->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + IteratorContext iter_ctx(std::move(params)); + + Status s = iterator->GetNext(&iter_ctx, &components, &end_of_sequence); + // NOTE(mrry): We must unref the iterator before calling `done()`, to + // avoid destruction races. + iterator->Unref(); + + if (!s.ok()) { + ctx->SetStatus(s); + } else if (end_of_sequence) { + ctx->SetStatus(errors::OutOfRange("End of sequence")); + } else { + for (int i = 0; i < components.size(); ++i) { + // TODO(mrry): Check that the shapes match the shape attrs. + ctx->set_output(i, components[i]); + } + } + done(); + }, + std::move(done))); +} + +class IteratorGetNextSyncOp : public OpKernel { + public: + explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + IteratorResource* iterator; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); + core::ScopedUnref unref_iterator(iterator); + + std::vector<Tensor> components; + bool end_of_sequence = false; + + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.function_library = iterator->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + IteratorContext iter_ctx(std::move(params)); + + OP_REQUIRES_OK(ctx, + iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); + OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence")); + + for (int i = 0; i < components.size(); ++i) { + // TODO(mrry): Check that the shapes match the shape attrs. + ctx->set_output(i, components[i]); + } + } +}; + +class IteratorGetNextAsOptionalOp : public AsyncOpKernel { public: - explicit IteratorGetNextOp(OpKernelConstruction* ctx) + explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx), - background_worker_(ctx->env(), - strings::StrCat("iterator_get_next_thread_", - SanitizeThreadSuffix(name()))) {} + background_worker_( + ctx->env(), strings::StrCat("iterator_get_next_as_optional_thread_", + SanitizeThreadSuffix(name()))) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { IteratorResource* iterator; @@ -1011,7 +972,7 @@ class IteratorGetNextOp : public AsyncOpKernel { // inter-op thread pool thread, so we issue the call from the // owned thread pool. background_worker_.Schedule(std::bind( - [ctx, iterator](DoneCallback done) { + [this, ctx, iterator](DoneCallback done) { std::vector<Tensor> components; bool end_of_sequence = false; @@ -1034,12 +995,32 @@ class IteratorGetNextOp : public AsyncOpKernel { if (!s.ok()) { ctx->SetStatus(s); } else if (end_of_sequence) { - ctx->SetStatus(errors::OutOfRange("End of sequence")); + OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done); } else { for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. - ctx->set_output(i, components[i]); + OP_REQUIRES_ASYNC( + ctx, components[i].dtype() == output_types_[i], + errors::InvalidArgument( + "The given optional does not match the expected type for " + "component ", + i, ". Expected: ", DataTypeString(output_types_[i]), + ". Actual: ", DataTypeString(components[i].dtype()), "."), + done); + OP_REQUIRES_ASYNC( + ctx, + output_shapes_[i].IsCompatibleWith(components[i].shape()), + errors::InvalidArgument( + "The given optional does not match the expected shape " + "for component ", + i, ". Expected: ", output_shapes_[i].DebugString(), + ". Actual: ", components[i].shape().DebugString(), "."), + done); } + + OP_REQUIRES_OK_ASYNC( + ctx, + WriteOptionalWithValueToOutput(ctx, 0, std::move(components)), + done); } done(); }, @@ -1048,126 +1029,80 @@ class IteratorGetNextOp : public AsyncOpKernel { private: BackgroundWorker background_worker_; + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; }; -class IteratorGetNextSyncOp : public OpKernel { - public: - explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - IteratorResource* iterator; - OP_REQUIRES_OK(ctx, - LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); - core::ScopedUnref unref_iterator(iterator); - - std::vector<Tensor> components; - bool end_of_sequence = false; - - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); - params.function_library = iterator->function_library(); - DeviceBase* device = ctx->function_library()->device(); - params.allocator_getter = [device](AllocatorAttributes attrs) { - return device->GetAllocator(attrs); - }; - IteratorContext iter_ctx(std::move(params)); - - OP_REQUIRES_OK(ctx, - iterator->GetNext(&iter_ctx, &components, &end_of_sequence)); - OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence")); - - for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. - ctx->set_output(i, components[i]); - } - } -}; - -class IteratorToStringHandleOp : public OpKernel { - public: - explicit IteratorToStringHandleOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - const Tensor& resource_handle_t = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), - errors::InvalidArgument("resource_handle must be a scalar")); - - // Validate that the handle corresponds to a real resource, and - // that it is an IteratorResource. - IteratorResource* iterator_resource; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); - iterator_resource->Unref(); +void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) { + const Tensor& resource_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), + errors::InvalidArgument("resource_handle must be a scalar")); + + // Validate that the handle corresponds to a real resource, and + // that it is an IteratorResource. + IteratorResource* iterator_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); + iterator_resource->Unref(); + + Tensor* string_handle_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &string_handle_t)); + string_handle_t->scalar<string>()() = + resource_handle_t.scalar<ResourceHandle>()().SerializeAsString(); +} - Tensor* string_handle_t; - OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, TensorShape({}), &string_handle_t)); - string_handle_t->scalar<string>()() = - resource_handle_t.scalar<ResourceHandle>()().SerializeAsString(); - } -}; +IteratorFromStringHandleOp::IteratorFromStringHandleOp( + OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES( + ctx, + output_dtypes_.empty() || output_shapes_.empty() || + output_dtypes_.size() == output_shapes_.size(), + errors::InvalidArgument("If both 'output_types' and 'output_shapes' " + "are set, they must have the same length.")); +} -class IteratorFromStringHandleOp : public OpKernel { - public: - explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - OP_REQUIRES( - ctx, - output_dtypes_.empty() || output_shapes_.empty() || - output_dtypes_.size() == output_shapes_.size(), - errors::InvalidArgument("If both 'output_types' and 'output_shapes' " - "are set, they must have the same length.")); +void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) { + const Tensor& string_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()), + errors::InvalidArgument("string_handle must be a scalar")); + + ResourceHandle resource_handle; + OP_REQUIRES( + ctx, resource_handle.ParseFromString(string_handle_t.scalar<string>()()), + errors::InvalidArgument( + "Could not parse string_handle as a valid ResourceHandle")); + + OP_REQUIRES( + ctx, resource_handle.device() == ctx->device()->attributes().name(), + errors::InvalidArgument("Attempted create an iterator on device \"", + ctx->device()->attributes().name(), + "\" from handle defined on device \"", + resource_handle.device(), "\"")); + + // Validate that the handle corresponds to a real resource, and + // that it is an IteratorResource. + IteratorResource* iterator_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource)); + core::ScopedUnref unref_iterator(iterator_resource); + if (!output_dtypes_.empty()) { + OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_, + iterator_resource->output_dtypes())); } - - void Compute(OpKernelContext* ctx) override { - const Tensor& string_handle_t = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()), - errors::InvalidArgument("string_handle must be a scalar")); - - ResourceHandle resource_handle; - OP_REQUIRES( - ctx, - resource_handle.ParseFromString(string_handle_t.scalar<string>()()), - errors::InvalidArgument( - "Could not parse string_handle as a valid ResourceHandle")); - - OP_REQUIRES( - ctx, resource_handle.device() == ctx->device()->attributes().name(), - errors::InvalidArgument("Attempted create an iterator on device \"", - ctx->device()->attributes().name(), - "\" from handle defined on device \"", - resource_handle.device(), "\"")); - - // Validate that the handle corresponds to a real resource, and - // that it is an IteratorResource. - IteratorResource* iterator_resource; + if (!output_shapes_.empty()) { OP_REQUIRES_OK(ctx, - LookupResource(ctx, resource_handle, &iterator_resource)); - core::ScopedUnref unref_iterator(iterator_resource); - if (!output_dtypes_.empty()) { - OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_, - iterator_resource->output_dtypes())); - } - if (!output_shapes_.empty()) { - OP_REQUIRES_OK( - ctx, VerifyShapesCompatible(output_shapes_, - iterator_resource->output_shapes())); - } - - Tensor* resource_handle_t; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t)); - resource_handle_t->scalar<ResourceHandle>()() = resource_handle; + VerifyShapesCompatible(output_shapes_, + iterator_resource->output_shapes())); } - private: - DataTypeVector output_dtypes_; - std::vector<PartialTensorShape> output_shapes_; -}; + Tensor* resource_handle_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &resource_handle_t)); + resource_handle_t->scalar<ResourceHandle>()() = resource_handle; +} class SerializeIteratorOp : public OpKernel { public: @@ -1240,6 +1175,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU), IteratorGetNextSyncOp); REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_GPU), IteratorGetNextSyncOp); +REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE_CPU), + IteratorGetNextAsOptionalOp); +REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE_GPU), + IteratorGetNextAsOptionalOp); REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU), IteratorToStringHandleOp); REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") @@ -1259,6 +1198,4 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), DeserializeIteratorOp); -} // namespace - } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h new file mode 100644 index 0000000000..e426febcce --- /dev/null +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/ops_util.h" + +namespace tensorflow { + +class IteratorResource; + +class IteratorHandleOp : public OpKernel { + public: + explicit IteratorHandleOp(OpKernelConstruction* ctx); + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~IteratorHandleOp() override; + + void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_); + + private: + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + Status VerifyResource(IteratorResource* resource); + + template <typename To, typename From> // use like this: down_cast<T*>(foo); + static inline To down_cast(From* f) { // so we only accept pointers + static_assert( + (std::is_base_of<From, typename std::remove_pointer<To>::type>::value), + "target type not derived from source type"); + + // We skip the assert and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. + assert(f == nullptr || dynamic_cast<To>(f) != nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + return static_cast<To>(f); + } + + FunctionLibraryRuntime* CreatePrivateFLR( + OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, + std::unique_ptr<FunctionLibraryDefinition>* flib_def, + std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr); + + mutex mu_; + ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. + IteratorResource* resource_ GUARDED_BY(mu_) = nullptr; + DataTypeVector output_dtypes_; + std::vector<PartialTensorShape> output_shapes_; + const int graph_def_version_; + string name_; +}; + +// Like IteratorHandleOp, but creates handles which are never shared, and does +// not hold a reference to these handles. The latter is important for eager +// execution, since OpKernel instances generally live as long as the program +// running them. +class AnonymousIteratorHandleOp : public OpKernel { + public: + explicit AnonymousIteratorHandleOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* context) override; + + private: + // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp + // instances. + static mutex static_resource_lookup_mutex_; + // current_id_ is just a hint for creating unique names. If it turns out + // there's a collision (e.g. because another AnonymousIteratorHandleOp + // instance is generating handles) we'll just skip that id. + static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_); + DataTypeVector output_dtypes_; + std::vector<PartialTensorShape> output_shapes_; + const int graph_def_version_; +}; + +class MakeIteratorOp : public OpKernel { + public: + explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +class IteratorGetNextOp : public AsyncOpKernel { + public: + explicit IteratorGetNextOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + background_worker_(ctx->env(), + strings::StrCat("iterator_get_next_thread_", + SanitizeThreadSuffix(name()))) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + private: + BackgroundWorker background_worker_; +}; + +class IteratorToStringHandleOp : public OpKernel { + public: + explicit IteratorToStringHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +class IteratorFromStringHandleOp : public OpKernel { + public: + explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + DataTypeVector output_dtypes_; + std::vector<PartialTensorShape> output_shapes_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc new file mode 100644 index 0000000000..d66716ef66 --- /dev/null +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -0,0 +1,192 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/util/batch_util.h" +#include "tensorflow/core/util/reffed_status_callback.h" + +namespace tensorflow { +namespace { + +void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, + bool always_collect_stats) { + opts->step_id = ctx->step_id(); + opts->rendezvous = ctx->rendezvous(); + opts->cancellation_manager = ctx->cancellation_manager(); + if (always_collect_stats) { + opts->stats_collector = ctx->stats_collector(); + } + opts->runner = ctx->runner(); +} + +class MapDefunOp : public AsyncOpKernel { + public: + explicit MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + auto func_lib = ctx->function_library(); + OP_REQUIRES(ctx, func_lib != nullptr, + errors::Internal("No function library.")); + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func)); + OP_REQUIRES_OK(ctx, + func_lib->Instantiate(func->name(), AttrSlice(&func->attr()), + &func_handle_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + + OP_REQUIRES(ctx, ctx->num_inputs() >= 0, + errors::InvalidArgument("Must have at least one input.")); + OP_REQUIRES(ctx, ctx->num_outputs() >= 0, + errors::InvalidArgument("Must have at least one output.")); + OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(), + errors::InvalidArgument( + "Length of output_shapes and output_types must match.")); + } + + ~MapDefunOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + int64 batch_size = ctx->input(0).dim_size(0); + // Inputs + auto* args = new std::vector<Tensor>; + auto* arg_shapes = new std::vector<TensorShape>; + arg_shapes->reserve(ctx->num_inputs()); + args->reserve(ctx->num_inputs()); + + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + args->push_back(ctx->input(i)); + arg_shapes->push_back(ctx->input(i).shape()); + arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension + OP_REQUIRES_ASYNC( + ctx, batch_size == ctx->input(i).dim_size(0), + errors::InvalidArgument("All inputs must have the same dimension 0."), + done); + } + + // Outputs + auto* output = new OpOutputList; + OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done); + + for (size_t i = 0; i < output_types().size(); ++i) { + Tensor* out = nullptr; + TensorShape output_shape = output_shapes_.at(i); + output_shape.InsertDim(0, batch_size); + OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done); + } + + SetRunOptions(ctx, &opts_, false); + + // Run loop + StatusCallback callback = std::bind( + [](OpKernelContext* ctx, std::vector<Tensor>* args, + std::vector<TensorShape>* arg_shapes, OpOutputList* output, + DoneCallback& done, const Status& status) { + delete args; + delete arg_shapes; + delete output; + ctx->SetStatus(status); + done(); + }, + ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1); + + auto* refcounted = new ReffedStatusCallback(std::move(callback)); + + for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) { + // Start from i = 1 because refcounted is initialized with refcount = 1 + refcounted->Ref(); + } + for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) { + auto* call_frame = + new MapFunctionCallFrame(*args, *arg_shapes, output, this, i); + ctx->function_library()->Run( + opts_, func_handle_, call_frame, + [call_frame, refcounted](const Status& func_status) { + delete call_frame; + refcounted->UpdateStatus(func_status); + refcounted->Unref(); + }); + } + } + + private: + FunctionLibraryRuntime::Handle func_handle_; + FunctionLibraryRuntime::Options opts_; + std::vector<TensorShape> output_shapes_; + + class MapFunctionCallFrame : public CallFrameInterface { + public: + MapFunctionCallFrame(const std::vector<Tensor>& args, + const std::vector<TensorShape>& arg_shapes, + OpOutputList* output, OpKernel* kernel, size_t iter) + : args_(args), + arg_shapes_(arg_shapes), + output_(output), + kernel_(kernel), + iter_(iter) {} + + ~MapFunctionCallFrame() override {} + + size_t num_args() const override { return args_.size(); } + size_t num_retvals() const override { + return static_cast<size_t>(kernel_->num_outputs()); + } + + Status GetArg(int index, Tensor* val) const override { + if (index < 0 || index >= args_.size()) { + return errors::InvalidArgument( + "Mismatch in number of function inputs."); + } + bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1), + arg_shapes_.at(index)); + if (!result) { + return errors::Internal("GetArg failed."); + } else if (!val->IsAligned()) { + // Ensure alignment + *val = tensor::DeepCopy(*val); + } + + return Status::OK(); + } + + Status SetRetval(int index, const Tensor& val) override { + if (index < 0 || index >= kernel_->num_outputs()) { + return errors::InvalidArgument( + "Mismatch in number of function outputs."); + } + + if (val.dtype() != kernel_->output_type(index)) { + return errors::InvalidArgument( + "Mismatch in function return type and expected output type for " + "output: ", + index); + } + return batch_util::CopyElementToSlice(val, (*output_)[index], iter_); + } + + private: + const std::vector<Tensor>& args_; + const std::vector<TensorShape>& arg_shapes_; + OpOutputList* output_; + const OpKernel* kernel_; + const size_t iter_; + }; +}; // namespace + +REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc new file mode 100644 index 0000000000..cfac45dbc7 --- /dev/null +++ b/tensorflow/core/kernels/data/optional_ops.cc @@ -0,0 +1,270 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/optional_ops.h" + +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" + +namespace tensorflow { +namespace { +const char kOptionalVariantTypeName[] = "tensorflow::data::Optional"; + +// An `OptionalVariant` can represent either an "actual value" (a tuple of +// tensors) or "none", and may be stored in a DT_VARIANT tensor. +class OptionalVariant { + public: + // Create an `OptionalVariant` with no actual value. + OptionalVariant() : values_(nullptr) {} + + // Create an `OptionalVariant` with the actual value given by the tuple of + // tensors in `values`. + explicit OptionalVariant(std::vector<Tensor> values) + : values_(new std::vector<Tensor>(std::move(values))) {} + + OptionalVariant(const OptionalVariant& other) : values_(other.values_) {} + + // Returns true if `this` represents an actual value. + bool has_value() const { return values_ != nullptr; } + + // REQUIRES: `this->has_value()` must be true. + const std::vector<Tensor>& get_values() const { + CHECK(values_) << "Tried to get values from an empty OptionalVariant"; + return *values_; + } + + // Implementations of the necessary methods for using `OptionalVariant` + // objects in DT_VARIANT tensors. + string TypeName() const { return kOptionalVariantTypeName; } + void Encode(VariantTensorData* data) const { + data->set_metadata(values_ != nullptr); + if (values_ != nullptr) { + for (const auto& t : *values_) { + *(data->add_tensors()) = t; + } + } + } + + bool Decode(const VariantTensorData& data) { + if (data.type_name() != TypeName()) { + return false; + } + bool has_value = false; + if (!data.get_metadata(&has_value)) { + return false; + } + if (has_value) { + values_.reset(new std::vector<Tensor>(data.tensors())); + } else { + values_.reset(); + } + return true; + } + + string DebugString() const { + if (values_) { + return strings::StrCat("OptionalVariant<", "values: (", + str_util::Join(*values_, ", ", + [](string* s, const Tensor& elem) { + *s = elem.DebugString(); + }), + ")>"); + } else { + return strings::StrCat("OptionalVariant<None>"); + } + } + + private: + std::shared_ptr<const std::vector<Tensor>> values_; +}; + +class OptionalNoneOp : public OpKernel { + public: + explicit OptionalNoneOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + OP_REQUIRES_OK(ctx, WriteOptionalNoneToOutput(ctx, 0)); + } +}; + +class OptionalFromValueOp : public OpKernel { + public: + explicit OptionalFromValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + OpInputList components_input; + OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input)); + std::vector<Tensor> components; + components.reserve(components_input.size()); + for (const Tensor& component_t : components_input) { + components.push_back(component_t); + } + OP_REQUIRES_OK( + ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components))); + } +}; + +class OptionalHasValueOp : public OpKernel { + public: + explicit OptionalHasValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor* optional_input; + OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()), + errors::InvalidArgument( + "Input to OptionalHasValue must be a scalar tensor " + "containing an OptionalVariant object.")); + const OptionalVariant* optional = + optional_input->scalar<Variant>()().get<OptionalVariant>(); + OP_REQUIRES( + ctx, optional != nullptr, + errors::InvalidArgument( + "Input to OptionalHasValue must be an OptionalVariant object.")); + Tensor* result; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &result)); + result->scalar<bool>()() = optional->has_value(); + } +}; + +class OptionalGetValueOp : public OpKernel { + public: + explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* optional_input; + OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input)); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()), + errors::InvalidArgument( + "Input to OptionalHasValue must be a scalar tensor " + "containing an OptionalVariant object.")); + const OptionalVariant* optional = + optional_input->scalar<Variant>()().get<OptionalVariant>(); + OP_REQUIRES( + ctx, optional != nullptr, + errors::InvalidArgument( + "Input to OptionalHasValue must be an OptionalVariant object.")); + OP_REQUIRES( + ctx, optional->has_value(), + errors::InvalidArgument("The given optional does not have a value.")); + const auto& components = optional->get_values(); + for (int i = 0; i < components.size(); ++i) { + OP_REQUIRES( + ctx, components[i].dtype() == output_types_[i], + errors::InvalidArgument( + "The given optional does not match the expected type for " + "component ", + i, ". Expected: ", DataTypeString(output_types_[i]), + ". Actual: ", DataTypeString(components[i].dtype()), ".")); + OP_REQUIRES(ctx, + output_shapes_[i].IsCompatibleWith(components[i].shape()), + errors::InvalidArgument( + "The given optional does not match the expected shape " + "for component ", + i, ". Expected: ", output_shapes_[i].DebugString(), + ". Actual: ", components[i].shape().DebugString(), ".")); + ctx->set_output(i, components[i]); + } + } + + private: + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_CPU), + OptionalNoneOp); +REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_GPU), + OptionalNoneOp); +REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE_CPU), + OptionalFromValueOp); +REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE_GPU), + OptionalFromValueOp); + +REGISTER_KERNEL_BUILDER(Name("OptionalHasValue").Device(DEVICE_CPU), + OptionalHasValueOp); +REGISTER_KERNEL_BUILDER( + Name("OptionalHasValue").Device(DEVICE_GPU).HostMemory("has_value"), + OptionalHasValueOp); +REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_CPU), + OptionalGetValueOp); +REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_GPU), + OptionalGetValueOp); + +static Status OptionalDeviceCopy( + const OptionalVariant& from, OptionalVariant* to, + const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { + if (from.has_value()) { + const std::vector<Tensor>& from_values = from.get_values(); + std::vector<Tensor> to_values; + to_values.reserve(from_values.size()); + for (const Tensor& t : from_values) { + if (DMAHelper::CanUseDMA(&t)) { + Tensor tmp(t.dtype()); + TF_RETURN_IF_ERROR(copy(t, &tmp)); + to_values.push_back(std::move(tmp)); + } else { + to_values.push_back(t); + } + } + *to = OptionalVariant(std::move(to_values)); + } else { + *to = from; + } + return Status::OK(); +} + +#define REGISTER_OPTIONAL_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ + OptionalVariant, DIRECTION, kOptionalVariantTypeName, \ + OptionalDeviceCopy) + +REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); +REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); +REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE); + +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(OptionalVariant, + kOptionalVariantTypeName); + +} // namespace + +Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, + std::vector<Tensor> value) { + OptionalVariant v(std::move(value)); + Tensor* variant_t; + AllocatorAttributes cpu_alloc; + cpu_alloc.set_on_host(true); + TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}), + &variant_t, cpu_alloc)); + variant_t->scalar<Variant>()() = v; + return Status::OK(); +} + +Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) { + OptionalVariant v; + Tensor* variant_t; + AllocatorAttributes cpu_alloc; + cpu_alloc.set_on_host(true); + TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}), + &variant_t, cpu_alloc)); + variant_t->scalar<Variant>()() = v; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h new file mode 100644 index 0000000000..6f25567678 --- /dev/null +++ b/tensorflow/core/kernels/data/optional_ops.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ + +#include <vector> + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant_tensor_data.h" + +namespace tensorflow { + +// Stores a DT_VARIANT value representing an Optional with the given value +// in the `output_index`^th output of the given kernel execution context. +Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, + std::vector<Tensor> value); + +// Stores a DT_VARIANT value representing an Optional with no value +// in the `output_index`^th output of the given kernel execution context. +Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 15f3dc3b1d..b736b33c2e 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/parallel_map_iterator.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/random/random.h" @@ -87,8 +88,16 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::ParallelMap")})); + auto map_func = [this](IteratorContext* ctx, + std::vector<Tensor> input_element, + std::vector<Tensor>* result, StatusCallback done) { + captured_func_->RunAsync(ctx, std::move(input_element), result, + std::move(done)); + }; + + return NewParallelMapIterator( + {this, strings::StrCat(prefix, "::ParallelMap")}, input_, + std::move(map_func), num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { @@ -148,279 +157,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { } private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} - - ~Iterator() override { - // TODO(mrry): Replace this cancellation logic with a - // CancellationManager. The syntax would be more heavyweight, - // but it would be possible to thread a cancellation manager - // through the IteratorContext to upstream, - // potentially-blocking iterators, when we add these. - mutex_lock l(mu_); - // Cancel the runner thread. - cancelled_ = true; - cond_var_.notify_all(); - // Wait for all in-flight calls to complete. - while (num_calls_ > 0) { - cond_var_.wait(l); - } - } - - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - std::shared_ptr<InvocationResult> result; - { - mutex_lock l(mu_); - EnsureRunnerThreadStarted(ctx); - while (invocation_results_.empty()) { - cond_var_.wait(l); - } - std::swap(result, invocation_results_.front()); - invocation_results_.pop_front(); - } - cond_var_.notify_all(); - result->notification.WaitForNotification(); - return ProcessResult(result, out_tensors, end_of_sequence); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - // Wait for all in-flight calls to complete. - while (num_calls_ > 0) { - cond_var_.wait(l); - } - CHECK_EQ(num_calls_, 0); - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name("invocation_results.size"), invocation_results_.size())); - for (size_t i = 0; i < invocation_results_.size(); i++) { - std::shared_ptr<InvocationResult> result = invocation_results_[i]; - TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("invocation_results[", i, "].size")), - result->return_values.size())); - for (size_t j = 0; j < result->return_values.size(); j++) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name( - strings::StrCat("invocation_results[", i, "][", j, "]")), - result->return_values[j])); - } - if (result->end_of_input) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("invocation_results[", i, - "].end_of_input")), - "")); - } - } - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); - int64 invocation_results_size; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name("invocation_results.size"), &invocation_results_size)); - for (size_t i = 0; i < invocation_results_size; i++) { - std::shared_ptr<InvocationResult> result(new InvocationResult()); - invocation_results_.push_back(result); - TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); - size_t num_return_values; - { - int64 size; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat("invocation_results[", i, "].size")), - &size)); - num_return_values = static_cast<size_t>(size); - if (num_return_values != size) { - return errors::InvalidArgument(strings::StrCat( - full_name( - strings::StrCat("invocation_results[", i, "].size")), - ": ", size, " is not a valid value of type size_t.")); - } - } - result->return_values.reserve(num_return_values); - for (size_t j = 0; j < num_return_values; j++) { - result->return_values.emplace_back(); - TF_RETURN_IF_ERROR( - reader->ReadTensor(full_name(strings::StrCat( - "invocation_results[", i, "][", j, "]")), - &result->return_values.back())); - } - result->end_of_input = reader->Contains(full_name( - strings::StrCat("invocation_results[", i, "].end_of_input"))); - result->notification.Notify(); - } - return Status::OK(); - } - - private: - struct InvocationResult { - Notification notification; - Status status; - std::vector<Tensor> return_values; - bool end_of_input; - }; - - void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!runner_thread_) { - std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); - runner_thread_.reset(ctx->env()->StartThread( - {}, "runner_thread", - std::bind(&Iterator::RunnerThread, this, ctx_copy))); - } - } - - void CallCompleted(const std::shared_ptr<InvocationResult>& result) - LOCKS_EXCLUDED(mu_) { - { - mutex_lock l(mu_); - num_calls_--; - } - result->notification.Notify(); - cond_var_.notify_all(); - } - - void CallFunction(const std::shared_ptr<IteratorContext>& ctx, - const std::shared_ptr<InvocationResult>& result) - LOCKS_EXCLUDED(mu_) { - // Get the next input element. - std::vector<Tensor> input_element; - result->status = input_impl_->GetNext(ctx.get(), &input_element, - &result->end_of_input); - if (result->end_of_input || !result->status.ok()) { - CallCompleted(result); - return; - } - - // Call `func_(input_element)`, store the result in - // `result->return_values`, and notify `result->notification` to unblock - // a consumer. - auto done = [this, result](Status status) { - result->status.Update(status); - CallCompleted(result); - }; - dataset()->captured_func_->RunAsync(ctx.get(), std::move(input_element), - &result->return_values, done); - } - - int64 MaxInvocationResults() { return dataset()->num_parallel_calls_; } - - Status ProcessResult(const std::shared_ptr<InvocationResult>& result, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) { - if (!result->end_of_input && result->status.ok()) { - *out_tensors = std::move(result->return_values); - *end_of_sequence = false; - return Status::OK(); - } - if (errors::IsOutOfRange(result->status)) { - // `f` may deliberately raise `errors::OutOfRange` to indicate that we - // should terminate the iteration early. - *end_of_sequence = true; - return Status::OK(); - } - *end_of_sequence = result->end_of_input; - return result->status; - } - - void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { - std::vector<std::shared_ptr<InvocationResult>> new_calls; - new_calls.reserve(dataset()->num_parallel_calls_); - while (true) { - { - mutex_lock l(mu_); - while (!cancelled_ && - (num_calls_ >= dataset()->num_parallel_calls_ || - invocation_results_.size() >= MaxInvocationResults())) { - cond_var_.wait(l); - } - if (cancelled_) { - return; - } - while (num_calls_ < dataset()->num_parallel_calls_ && - invocation_results_.size() < MaxInvocationResults()) { - invocation_results_.emplace_back(new InvocationResult()); - new_calls.push_back(invocation_results_.back()); - num_calls_++; - } - } - cond_var_.notify_all(); - for (const auto& call : new_calls) { - CallFunction(ctx, call); - } - new_calls.clear(); - } - } - - Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, - const Status& status) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - CodeKey(index), static_cast<int64>(status.code()))); - if (!status.ok()) { - TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), - status.error_message())); - } - return Status::OK(); - } - - Status ReadStatusLocked(IteratorStateReader* reader, size_t index, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - int64 code_int; - TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); - error::Code code = static_cast<error::Code>(code_int); - - if (code != error::Code::OK) { - string error_message; - TF_RETURN_IF_ERROR( - reader->ReadScalar(ErrorMessageKey(index), &error_message)); - *status = Status(code, error_message); - } else { - *status = Status::OK(); - } - return Status::OK(); - } - - string CodeKey(size_t index) { - return full_name( - strings::StrCat("invocation_results[", index, "].code")); - } - - string ErrorMessageKey(size_t index) { - return full_name( - strings::StrCat("invocation_results[", index, "].error_message")); - } - - // Used for coordination between the main thread and the runner thread. - mutex mu_; - // Used for coordination between the main thread and the runner thread. In - // particular, the runner thread should only schedule new calls when the - // number of in-flight calls is less than the user specified level of - // parallelism and there are slots available in the `invocation_results_` - // buffer. - condition_variable cond_var_; - // Counts the number of outstanding calls. - int64 num_calls_ GUARDED_BY(mu_) = 0; - std::unique_ptr<IteratorBase> input_impl_; - // Buffer for storing the invocation results. - std::deque<std::shared_ptr<InvocationResult>> invocation_results_ - GUARDED_BY(mu_); - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); - bool cancelled_ GUARDED_BY(mu_) = false; - }; - const DatasetBase* const input_; const NameAttrList func_; const int32 num_parallel_calls_; diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc new file mode 100644 index 0000000000..10549df25e --- /dev/null +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -0,0 +1,318 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/parallel_map_iterator.h" + +#include <deque> +#include <functional> +#include <utility> +#include <vector> + +namespace tensorflow { +namespace { + +class ParallelMapIterator : public DatasetBaseIterator { + public: + explicit ParallelMapIterator( + const typename DatasetBaseIterator::BaseParams& params, + const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, + int32 num_parallel_calls) + : DatasetBaseIterator(params), + input_dataset_(input_dataset), + map_func_(std::move(map_func)), + num_parallel_calls_(num_parallel_calls) {} + + ~ParallelMapIterator() override { + // TODO(mrry): Replace this cancellation logic with a + // CancellationManager. The syntax would be more heavyweight, + // but it would be possible to thread a cancellation manager + // through the IteratorContext to upstream, + // potentially-blocking iterators, when we add these. + mutex_lock l(mu_); + // Cancel the runner thread. + cancelled_ = true; + cond_var_.notify_all(); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + } + + Status Initialize(IteratorContext* ctx) override { + return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + std::shared_ptr<InvocationResult> result; + { + mutex_lock l(mu_); + EnsureRunnerThreadStarted(ctx); + while (invocation_results_.empty()) { + cond_var_.wait(l); + } + std::swap(result, invocation_results_.front()); + invocation_results_.pop_front(); + } + cond_var_.notify_all(); + result->notification.WaitForNotification(); + return ProcessResult(result, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + CHECK_EQ(num_calls_, 0); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("invocation_results.size"), + invocation_results_.size())); + for (size_t i = 0; i < invocation_results_.size(); i++) { + std::shared_ptr<InvocationResult> result = invocation_results_[i]; + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + result->return_values.size())); + for (size_t j = 0; j < result->return_values.size(); j++) { + TF_RETURN_IF_ERROR( + writer->WriteTensor(full_name(strings::StrCat( + "invocation_results[", i, "][", j, "]")), + result->return_values[j])); + } + if (result->end_of_input) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name( + strings::StrCat("invocation_results[", i, "].end_of_input")), + "")); + } + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + int64 invocation_results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name("invocation_results.size"), &invocation_results_size)); + for (size_t i = 0; i < invocation_results_size; i++) { + std::shared_ptr<InvocationResult> result(new InvocationResult()); + invocation_results_.push_back(result); + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); + size_t num_return_values; + { + int64 size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name(strings::StrCat( + "invocation_results[", i, "].size")), + &size)); + num_return_values = static_cast<size_t>(size); + if (num_return_values != size) { + return errors::InvalidArgument(strings::StrCat( + full_name( + strings::StrCat("invocation_results[", i, "].size")), + ": ", size, " is not a valid value of type size_t.")); + } + } + result->return_values.reserve(num_return_values); + for (size_t j = 0; j < num_return_values; j++) { + result->return_values.emplace_back(); + TF_RETURN_IF_ERROR( + reader->ReadTensor(full_name(strings::StrCat( + "invocation_results[", i, "][", j, "]")), + &result->return_values.back())); + } + result->end_of_input = reader->Contains(full_name( + strings::StrCat("invocation_results[", i, "].end_of_input"))); + result->notification.Notify(); + } + return Status::OK(); + } + + private: + struct InvocationResult { + Notification notification; + Status status; + std::vector<Tensor> return_values; + bool end_of_input; + }; + + void EnsureRunnerThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!runner_thread_) { + std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); + } + } + + void CallCompleted(const std::shared_ptr<InvocationResult>& result) + LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + num_calls_--; + } + result->notification.Notify(); + cond_var_.notify_all(); + } + + void CallFunction(const std::shared_ptr<IteratorContext>& ctx, + const std::shared_ptr<InvocationResult>& result) + LOCKS_EXCLUDED(mu_) { + // Get the next input element. + std::vector<Tensor> input_element; + result->status = + input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input); + if (result->end_of_input || !result->status.ok()) { + CallCompleted(result); + return; + } + + // Call `func_(input_element)`, store the result in + // `result->return_values`, and notify `result->notification` to unblock + // a consumer. + auto done = [this, result](Status status) { + result->status.Update(status); + CallCompleted(result); + }; + + map_func_(ctx.get(), std::move(input_element), &result->return_values, + std::move(done)); + } + + int64 MaxInvocationResults() { return num_parallel_calls_; } + + Status ProcessResult(const std::shared_ptr<InvocationResult>& result, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) { + if (!result->end_of_input && result->status.ok()) { + *out_tensors = std::move(result->return_values); + *end_of_sequence = false; + return Status::OK(); + } + if (errors::IsOutOfRange(result->status)) { + // `f` may deliberately raise `errors::OutOfRange` to indicate that we + // should terminate the iteration early. + *end_of_sequence = true; + return Status::OK(); + } + *end_of_sequence = result->end_of_input; + return result->status; + } + + void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + std::vector<std::shared_ptr<InvocationResult>> new_calls; + new_calls.reserve(num_parallel_calls_); + while (true) { + { + mutex_lock l(mu_); + while (!cancelled_ && + (num_calls_ >= num_parallel_calls_ || + invocation_results_.size() >= MaxInvocationResults())) { + cond_var_.wait(l); + } + if (cancelled_) { + return; + } + while (num_calls_ < num_parallel_calls_ && + invocation_results_.size() < MaxInvocationResults()) { + invocation_results_.emplace_back(new InvocationResult()); + new_calls.push_back(invocation_results_.back()); + num_calls_++; + } + } + cond_var_.notify_all(); + for (const auto& call : new_calls) { + CallFunction(ctx, call); + } + new_calls.clear(); + } + } + + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(ErrorMessageKey(index), status.error_message())); + } + return Status::OK(); + } + + Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast<error::Code>(code_int); + + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); + } + + string CodeKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].code")); + } + + string ErrorMessageKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].error_message")); + } + + const DatasetBase* const input_dataset_; // Not owned. + const ParallelMapIteratorFunction map_func_; + const int32 num_parallel_calls_; + // Used for coordination between the main thread and the runner thread. + mutex mu_; + // Used for coordination between the main thread and the runner thread. In + // particular, the runner thread should only schedule new calls when the + // number of in-flight calls is less than the user specified level of + // parallelism and there are slots available in the `invocation_results_` + // buffer. + condition_variable cond_var_; + // Counts the number of outstanding calls. + int64 num_calls_ GUARDED_BY(mu_) = 0; + std::unique_ptr<IteratorBase> input_impl_; + // Buffer for storing the invocation results. + std::deque<std::shared_ptr<InvocationResult>> invocation_results_ + GUARDED_BY(mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + bool cancelled_ GUARDED_BY(mu_) = false; +}; + +} // namespace + +std::unique_ptr<IteratorBase> NewParallelMapIterator( + const DatasetBaseIterator::BaseParams& params, + const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, + int32 num_parallel_calls) { + return std::unique_ptr<IteratorBase>(new ParallelMapIterator( + params, input_dataset, std::move(map_func), num_parallel_calls)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h new file mode 100644 index 0000000000..2ce36c3869 --- /dev/null +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_ + +#include <memory> + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { + +// A function that transforms elements of one dataset into another +// asynchronously. The arguments are: +// 1. An `IteratorContext*` for the context in which the function should +// execute. +// 2. A `std::vector<Tensor>` containing the input element. +// 3. A `std::vector<Tensor>*` to which the function will write the result. +// 4. A `StatusCallback` that should be invoked when the function is complete. +using ParallelMapIteratorFunction = + std::function<void(IteratorContext*, std::vector<Tensor>, + std::vector<Tensor>*, StatusCallback)>; + +// Returns a new iterator that applies `map_func` to the elements of +// `input_dataset` using the given degree of parallelism. +std::unique_ptr<IteratorBase> NewParallelMapIterator( + const DatasetBaseIterator::BaseParams& params, + const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, + int32 num_parallel_calls); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_ diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index cc16108dce..9000842840 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -14,347 +14,335 @@ limitations under the License. ==============================================================================*/ #include <deque> +#include "tensorflow/core/kernels/data/prefetch_dataset_op.h" + #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/dataset.h" -#include "tensorflow/core/kernels/data/prefetch_autotuner.h" #include "tensorflow/core/lib/core/error_codes.pb.h" namespace tensorflow { -namespace { - // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. -class PrefetchDatasetOp : public UnaryDatasetOpKernel { +class PrefetchDatasetOp::Dataset : public GraphDatasetBase { public: - explicit PrefetchDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} - - protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) override { - int64 buffer_size; - OP_REQUIRES_OK( - ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); - OP_REQUIRES(ctx, - buffer_size >= 0 || buffer_size == PrefetchAutotuner::kAutoTune, - errors::InvalidArgument("buffer_size must be >= 0")); - - *output = new Dataset(ctx, input, buffer_size); + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size) + : GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) { + input_->Ref(); } - private: - class Dataset : public GraphDatasetBase { - public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size) - : GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) { - input_->Ref(); - } + ~Dataset() override { input_->Unref(); } - ~Dataset() override { input_->Unref(); } + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Prefetch")})); + } - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Prefetch")})); - } + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } - const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); - } - const std::vector<PartialTensorShape>& output_shapes() const override { - return input_->output_shapes(); - } + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } - string DebugString() const override { return "PrefetchDatasetOp::Dataset"; } + string DebugString() const override { return "PrefetchDatasetOp::Dataset"; } - protected: - Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); - Node* buffer_size = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, buffer_size}, output)); - return Status::OK(); - } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + Node* buffer_size = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node, buffer_size}, output)); + return Status::OK(); + } - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params), - auto_tuner_(params.dataset->buffer_size_) {} - - ~Iterator() override { - // Signal the prefetch thread to terminate it. We will then - // join that thread when we delete `this->prefetch_thread_`. - // - // TODO(mrry): Replace this cancellation logic with a - // CancellationManager. The syntax would be more heavyweight, - // but it would be possible to thread a cancellation manager - // through the IteratorContext to upstream, - // potentially-blocking iterators, when we add these. - { - mutex_lock l(mu_); - cancelled_ = true; - cond_var_.notify_all(); - } - } + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + auto_tuner_(params.dataset->buffer_size_) {} - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + ~Iterator() override { + // Signal the prefetch thread to terminate it. We will then + // join that thread when we delete `this->prefetch_thread_`. + // + // TODO(mrry): Replace this cancellation logic with a + // CancellationManager. The syntax would be more heavyweight, + // but it would be possible to thread a cancellation manager + // through the IteratorContext to upstream, + // potentially-blocking iterators, when we add these. + { + mutex_lock l(mu_); + cancelled_ = true; + cond_var_.notify_all(); } + } - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); - // Wait until the next element in the buffer has been - // produced, or we are shutting down. - while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ && - auto_tuner_.buffer_limit() != 0) { - auto_tuner_.RecordEmpty(); - cond_var_.wait(l); - } + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } - if (cancelled_) { - return errors::Cancelled( - "PrefetchDatasetOp::Dataset::Iterator::GetNext"); - } + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); + // Wait until the next element in the buffer has been + // produced, or we are shutting down. + while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ && + auto_tuner_.buffer_limit() != 0) { + auto_tuner_.RecordEmpty(); + cond_var_.wait(l); + } - if (!buffer_.empty()) { - return Consume(out_tensors, end_of_sequence); - } + if (cancelled_) { + return errors::Cancelled( + "PrefetchDatasetOp::Dataset::Iterator::GetNext"); + } - if (prefetch_thread_finished_) { - *end_of_sequence = true; - return Status::OK(); - } + if (!buffer_.empty()) { + return Consume(out_tensors, end_of_sequence); + } - DCHECK_EQ(auto_tuner_.buffer_limit(), 0); + if (prefetch_thread_finished_) { + *end_of_sequence = true; + return Status::OK(); } - mutex_lock parent_l(parent_mu_); - mutex_lock l(mu_); - return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + DCHECK_EQ(auto_tuner_.buffer_limit(), 0); } - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - // Acquire both locks to ensure that the prefetch thread and - // all GetNext threads are blocked. - mutex_lock parent_l(parent_mu_); - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("buffer_size"), buffer_.size())); - for (size_t i = 0; i < buffer_.size(); i++) { - auto& buffer_element = buffer_[i]; - TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status)); - if (buffer_element.status.ok()) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("buffer[", i, "].size")), - buffer_element.value.size())); - for (size_t j = 0; j < buffer_element.value.size(); j++) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat("buffer[", i, "][", j, "]")), - buffer_element.value[j])); - } + mutex_lock parent_l(parent_mu_); + mutex_lock l(mu_); + return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + // Acquire both locks to ensure that the prefetch thread and + // all GetNext threads are blocked. + mutex_lock parent_l(parent_mu_); + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("buffer_size"), buffer_.size())); + for (size_t i = 0; i < buffer_.size(); i++) { + auto& buffer_element = buffer_[i]; + TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status)); + if (buffer_element.status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("buffer[", i, "].size")), + buffer_element.value.size())); + for (size_t j = 0; j < buffer_element.value.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("buffer[", i, "][", j, "]")), + buffer_element.value[j])); } } - return Status::OK(); } + return Status::OK(); + } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock parent_l(parent_mu_); - mutex_lock l(mu_); - buffer_.clear(); - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); - size_t buffer_size; - { - int64 temp; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("buffer_size"), &temp)); - buffer_size = static_cast<size_t>(temp); - } - for (size_t i = 0; i < buffer_size; i++) { - buffer_.emplace_back(); - auto& buffer_element = buffer_.back(); - TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status)); - if (buffer_element.status.ok()) { - size_t value_size; - { - int64 temp; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name(strings::StrCat("buffer[", i, "].size")), &temp)); - value_size = static_cast<size_t>(temp); - } - buffer_element.value.reserve(value_size); - for (size_t j = 0; j < value_size; j++) { - buffer_element.value.emplace_back(); - TF_RETURN_IF_ERROR(reader->ReadTensor( - full_name(strings::StrCat("buffer[", i, "][", j, "]")), - &buffer_element.value.back())); - } + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock parent_l(parent_mu_); + mutex_lock l(mu_); + buffer_.clear(); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + size_t buffer_size; + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("buffer_size"), &temp)); + buffer_size = static_cast<size_t>(temp); + } + for (size_t i = 0; i < buffer_size; i++) { + buffer_.emplace_back(); + auto& buffer_element = buffer_.back(); + TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status)); + if (buffer_element.status.ok()) { + size_t value_size; + { + int64 temp; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("buffer[", i, "].size")), &temp)); + value_size = static_cast<size_t>(temp); + } + buffer_element.value.reserve(value_size); + for (size_t j = 0; j < value_size; j++) { + buffer_element.value.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("buffer[", i, "][", j, "]")), + &buffer_element.value.back())); } } - return Status::OK(); } + return Status::OK(); + } - private: - // A buffer element comprises a status and (if that status is - // OK) a vector of tensors, representing an element of the input dataset. - struct BufferElement { - // The producer sets `status` if getting the input element fails. - Status status; - // The buffered data element. - std::vector<Tensor> value; - }; - - Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // A new element is available. Forward the status from computing it, and - // (if we successfully got an element) the output values. - Status s = buffer_.front().status; - if (s.ok()) { - *out_tensors = std::move(buffer_.front().value); - } - buffer_.pop_front(); - *end_of_sequence = false; - - // Wake the prefetch thread, in case it has been waiting for space - // in the buffer. Also wake up threads from other calls to GetNext. - // - // TODO(mrry): Consider using different condition variables for - // GetNext and Prefetch. - cond_var_.notify_all(); - return s; - } + private: + // A buffer element comprises a status and (if that status is + // OK) a vector of tensors, representing an element of the input dataset. + struct BufferElement { + // The producer sets `status` if getting the input element fails. + Status status; + // The buffered data element. + std::vector<Tensor> value; + }; - Status EnsurePrefetchThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!prefetch_thread_) { - prefetch_thread_.reset( - ctx->env()->StartThread({}, "prefetch_thread", - std::bind(&Iterator::PrefetchThread, this, - new IteratorContext(*ctx)))); - } - return Status::OK(); + Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // A new element is available. Forward the status from computing it, and + // (if we successfully got an element) the output values. + Status s = buffer_.front().status; + if (s.ok()) { + *out_tensors = std::move(buffer_.front().value); } + buffer_.pop_front(); + *end_of_sequence = false; - // Prefetches elements of the input, storing results in an internal - // buffer. + // Wake the prefetch thread, in case it has been waiting for space + // in the buffer. Also wake up threads from other calls to GetNext. // - // It owns the iterator context passed to it. - void PrefetchThread(IteratorContext* ctx) { - std::unique_ptr<IteratorContext> cleanup(ctx); - while (true) { - std::vector<Tensor> value; + // TODO(mrry): Consider using different condition variables for + // GetNext and Prefetch. + cond_var_.notify_all(); + return s; + } - // 1. Wait for a slot in the buffer. - { - mutex_lock l(mu_); - while (!cancelled_ && - buffer_.size() >= auto_tuner_.buffer_limit()) { - cond_var_.wait(l); - } - - if (cancelled_) { - return; - } - } + Status EnsurePrefetchThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!prefetch_thread_) { + prefetch_thread_.reset( + ctx->env()->StartThread({}, "prefetch_thread", + std::bind(&Iterator::PrefetchThread, this, + new IteratorContext(*ctx)))); + } + return Status::OK(); + } - // 2. Read the next element. - // Acquire the parent lock since we will be reading an element - // from the input iterator. Note that we do not wish to release - // this lock till we have added the fetched element to the - // `buffer_` else there will be local state that may be missed - // by SaveInternal. - mutex_lock parent_l(parent_mu_); - bool end_of_sequence; - BufferElement buffer_element; - buffer_element.status = input_impl_->GetNext( - ctx, &buffer_element.value, &end_of_sequence); - if (buffer_element.status.ok() && end_of_sequence) { - mutex_lock l(mu_); - prefetch_thread_finished_ = true; - cond_var_.notify_all(); - return; + // Prefetches elements of the input, storing results in an internal + // buffer. + // + // It owns the iterator context passed to it. + void PrefetchThread(IteratorContext* ctx) { + std::unique_ptr<IteratorContext> cleanup(ctx); + while (true) { + std::vector<Tensor> value; + + // 1. Wait for a slot in the buffer. + { + mutex_lock l(mu_); + while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) { + cond_var_.wait(l); } - // 3. Signal that the element has been produced. - { - mutex_lock l(mu_); - buffer_.push_back(std::move(buffer_element)); - cond_var_.notify_all(); + if (cancelled_) { + return; } } - } - Status WriteStatus(IteratorStateWriter* writer, size_t index, - const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - TF_RETURN_IF_ERROR(writer->WriteScalar( - CodeKey(index), static_cast<int64>(status.code()))); - if (!status.ok()) { - TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), - status.error_message())); + // 2. Read the next element. + // Acquire the parent lock since we will be reading an element + // from the input iterator. Note that we do not wish to release + // this lock till we have added the fetched element to the + // `buffer_` else there will be local state that may be missed + // by SaveInternal. + mutex_lock parent_l(parent_mu_); + bool end_of_sequence; + BufferElement buffer_element; + buffer_element.status = + input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence); + if (buffer_element.status.ok() && end_of_sequence) { + mutex_lock l(mu_); + prefetch_thread_finished_ = true; + cond_var_.notify_all(); + return; } - return Status::OK(); - } - Status ReadStatus(IteratorStateReader* reader, size_t index, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - int64 code_int; - TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); - error::Code code = static_cast<error::Code>(code_int); - - if (code != error::Code::OK) { - string error_message; - TF_RETURN_IF_ERROR( - reader->ReadScalar(ErrorMessageKey(index), &error_message)); - *status = Status(code, error_message); - } else { - *status = Status::OK(); + // 3. Signal that the element has been produced. + { + mutex_lock l(mu_); + buffer_.push_back(std::move(buffer_element)); + cond_var_.notify_all(); } - return Status::OK(); } + } - string CodeKey(size_t index) { - return full_name(strings::StrCat("status[", index, "].code")); + Status WriteStatus(IteratorStateWriter* writer, size_t index, + const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + CodeKey(index), static_cast<int64>(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), + status.error_message())); } + return Status::OK(); + } - string ErrorMessageKey(size_t index) { - return full_name(strings::StrCat("status[", index, "].error_message")); + Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast<error::Code>(code_int); + + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); } + return Status::OK(); + } - // This mutex is used to ensure exclusivity between multiple threads - // reading/writing this iterator's local state. - mutex mu_; - // This mutex is used to ensure exclusivity between multiple threads - // accessing the parent iterator. We keep this separate from `mu_` to - // allow prefetching to run in parallel with GetNext calls. - mutex parent_mu_ ACQUIRED_BEFORE(mu_); - std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_); - condition_variable cond_var_; - PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); - std::deque<BufferElement> buffer_ GUARDED_BY(mu_); - std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_); - bool cancelled_ GUARDED_BY(mu_) = false; - bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; - }; + string CodeKey(size_t index) { + return full_name(strings::StrCat("status[", index, "].code")); + } + + string ErrorMessageKey(size_t index) { + return full_name(strings::StrCat("status[", index, "].error_message")); + } - const DatasetBase* const input_; - const int64 buffer_size_; + // This mutex is used to ensure exclusivity between multiple threads + // reading/writing this iterator's local state. + mutex mu_; + // This mutex is used to ensure exclusivity between multiple threads + // accessing the parent iterator. We keep this separate from `mu_` to + // allow prefetching to run in parallel with GetNext calls. + mutex parent_mu_ ACQUIRED_BEFORE(mu_); + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_); + condition_variable cond_var_; + PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); + std::deque<BufferElement> buffer_ GUARDED_BY(mu_); + std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_); + bool cancelled_ GUARDED_BY(mu_) = false; + bool prefetch_thread_finished_ GUARDED_BY(mu_) = false; }; + const DatasetBase* const input_; + const int64 buffer_size_; }; +void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) { + int64 buffer_size; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); + OP_REQUIRES(ctx, + buffer_size >= 0 || buffer_size == PrefetchAutotuner::kAutoTune, + errors::InvalidArgument("buffer_size must be >= 0")); + + *output = new Dataset(ctx, input, buffer_size); +} + REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU), PrefetchDatasetOp); REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") @@ -363,6 +351,4 @@ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") .HostMemory("input_dataset") .HostMemory("handle"), PrefetchDatasetOp); -} // namespace - } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h new file mode 100644 index 0000000000..c40c4b00da --- /dev/null +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_ + +#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/prefetch_autotuner.h" + +namespace tensorflow { + +class PrefetchDatasetOp : public UnaryDatasetOpKernel { + public: + explicit PrefetchDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index 754c32b6ca..58ec3d4495 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -390,7 +390,7 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel { for (const auto& feature_list : example.feature_lists().feature_list()) { - stats_aggregator->IncrementCounter("feature_lists_count", "reainer", + stats_aggregator->IncrementCounter("feature_lists_count", "trainer", 1); for (const auto& feature : feature_list.second.feature()) { feature_values_list_size_sum += AddStatsFeatureValues(feature); diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index d5c33c0188..bfdabc3a9f 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -16,13 +16,13 @@ limitations under the License. #include <deque> #include <vector> +#include "tensorflow/core/kernels/function_ops.h" + #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/memory_types.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/gradients.h" @@ -33,64 +33,40 @@ limitations under the License. namespace tensorflow { -static const char* const kArgOp = FunctionLibraryDefinition::kArgOp; -static const char* const kRetOp = FunctionLibraryDefinition::kRetOp; static const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp; -class ArgOp : public OpKernel { - public: - explicit ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); - } - - void Compute(OpKernelContext* ctx) override { - auto frame = ctx->call_frame(); - OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame")); - Tensor val; - OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); - OP_REQUIRES(ctx, val.dtype() == dtype_, - errors::InvalidArgument( - "Type mismatch: actual ", DataTypeString(val.dtype()), - " vs. expect ", DataTypeString(dtype_))); - ctx->set_output(0, val); - } - - bool IsExpensive() override { return false; } - - private: - int index_; - DataType dtype_; - - TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); -}; - -class RetvalOp : public OpKernel { - public: - explicit RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); - } - - void Compute(OpKernelContext* ctx) override { - const Tensor& val = ctx->input(0); - OP_REQUIRES(ctx, val.dtype() == dtype_, - errors::InvalidArgument( - "Type mismatch: actual ", DataTypeString(val.dtype()), - " vs. expect ", DataTypeString(dtype_))); - auto frame = ctx->call_frame(); - OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame")); - OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val)); - } - - bool IsExpensive() override { return false; } - - private: - int index_; - DataType dtype_; - - TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); -}; +ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); +} + +void ArgOp::Compute(OpKernelContext* ctx) { + auto frame = ctx->call_frame(); + OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame")); + Tensor val; + OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val)); + OP_REQUIRES(ctx, val.dtype() == dtype_, + errors::InvalidArgument("Type mismatch: actual ", + DataTypeString(val.dtype()), + " vs. expect ", DataTypeString(dtype_))); + ctx->set_output(0, val); +} + +RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_)); +} + +void RetvalOp::Compute(OpKernelContext* ctx) { + const Tensor& val = ctx->input(0); + OP_REQUIRES(ctx, val.dtype() == dtype_, + errors::InvalidArgument("Type mismatch: actual ", + DataTypeString(val.dtype()), + " vs. expect ", DataTypeString(dtype_))); + auto frame = ctx->call_frame(); + OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame")); + OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val)); +} REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp); REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp); @@ -304,123 +280,105 @@ REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL), #endif // TENSORFLOW_USE_SYCL -class RemoteCallOp : public AsyncOpKernel { - public: - explicit RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { - OP_REQUIRES_OK(ctx, - ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_)); - } - - ~RemoteCallOp() override {} - - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - FunctionLibraryRuntime* lib = ctx->function_library(); - OP_REQUIRES_ASYNC(ctx, lib != nullptr, - errors::Internal("No function library is provided."), - done); - - const string& source_device = lib->device()->name(); - const Tensor* target; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); - string target_device; - OP_REQUIRES_OK_ASYNC( - ctx, - DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(), - source_device, &target_device), - done); - - AttrValueMap attr_values = func_.attr(); - FunctionLibraryRuntime::InstantiateOptions instantiate_opts; - instantiate_opts.target = target_device; - - FunctionTarget function_target = {target_device, lib}; - - FunctionLibraryRuntime::Handle handle; - { - mutex_lock l(mu_); - auto cached_entry = handle_cache_.find(function_target); - if (cached_entry != handle_cache_.end()) { - handle = cached_entry->second; - } else { - VLOG(1) << "Instantiating " << func_.name() << " on " << target_device; - tracing::ScopedActivity activity(strings::StrCat( - "RemoteCall: Instantiate: ", func_.name(), " on ", target_device)); - OP_REQUIRES_OK_ASYNC( - ctx, - lib->Instantiate(func_.name(), AttrSlice(&attr_values), - instantiate_opts, &handle), - done); - auto insert_result = handle_cache_.insert({function_target, handle}); - CHECK(insert_result.second) << "Insert unsuccessful."; - VLOG(1) << "Instantiated " << func_.name() << " on " << target_device - << ", resulting in handle: " << handle << " flr: " << lib; - } +RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { + OP_REQUIRES_OK(ctx, + ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_)); +} + +void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { + FunctionLibraryRuntime* lib = ctx->function_library(); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library is provided."), done); + + const string& source_device = lib->device()->name(); + const Tensor* target; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); + string target_device; + OP_REQUIRES_OK_ASYNC( + ctx, + DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(), + source_device, &target_device), + done); + + AttrValueMap attr_values = func_.attr(); + FunctionLibraryRuntime::InstantiateOptions instantiate_opts; + instantiate_opts.target = target_device; + + FunctionTarget function_target = {target_device, lib}; + + FunctionLibraryRuntime::Handle handle; + { + mutex_lock l(mu_); + auto cached_entry = handle_cache_.find(function_target); + if (cached_entry != handle_cache_.end()) { + handle = cached_entry->second; + } else { + VLOG(1) << "Instantiating " << func_.name() << " on " << target_device; + tracing::ScopedActivity activity(strings::StrCat( + "RemoteCall: Instantiate: ", func_.name(), " on ", target_device)); + OP_REQUIRES_OK_ASYNC( + ctx, + lib->Instantiate(func_.name(), AttrSlice(&attr_values), + instantiate_opts, &handle), + done); + auto insert_result = handle_cache_.insert({function_target, handle}); + CHECK(insert_result.second) << "Insert unsuccessful."; + VLOG(1) << "Instantiated " << func_.name() << " on " << target_device + << ", resulting in handle: " << handle << " flr: " << lib; } + } - OpInputList arguments; - OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); + OpInputList arguments; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done); - FunctionLibraryRuntime::Options opts; - opts.step_id = ctx->step_id(); - opts.runner = ctx->runner(); - opts.source_device = source_device; - if (opts.source_device != target_device) { - opts.remote_execution = true; - } - opts.create_rendezvous = true; - std::vector<Tensor> args; - args.reserve(arguments.size()); - for (const Tensor& argument : arguments) { - args.push_back(argument); - } - for (const auto& dtype : input_dtypes_) { - AllocatorAttributes arg_alloc_attrs; - if (DataTypeAlwaysOnHost(dtype)) { - arg_alloc_attrs.set_on_host(true); - } - opts.args_alloc_attrs.push_back(arg_alloc_attrs); + FunctionLibraryRuntime::Options opts; + opts.step_id = ctx->step_id(); + opts.runner = ctx->runner(); + opts.source_device = source_device; + if (opts.source_device != target_device) { + opts.remote_execution = true; + } + opts.create_rendezvous = true; + std::vector<Tensor> args; + args.reserve(arguments.size()); + for (const Tensor& argument : arguments) { + args.push_back(argument); + } + for (const auto& dtype : input_dtypes_) { + AllocatorAttributes arg_alloc_attrs; + if (DataTypeAlwaysOnHost(dtype)) { + arg_alloc_attrs.set_on_host(true); } - for (const auto& dtype : output_dtypes_) { - AllocatorAttributes ret_alloc_attrs; - if (DataTypeAlwaysOnHost(dtype)) { - ret_alloc_attrs.set_on_host(true); - } - opts.rets_alloc_attrs.push_back(ret_alloc_attrs); + opts.args_alloc_attrs.push_back(arg_alloc_attrs); + } + for (const auto& dtype : output_dtypes_) { + AllocatorAttributes ret_alloc_attrs; + if (DataTypeAlwaysOnHost(dtype)) { + ret_alloc_attrs.set_on_host(true); } - auto* rets = new std::vector<Tensor>; - auto* activity = new tracing::ScopedActivity(strings::StrCat( - "RemoteCall: Run: ", func_.name(), " on ", target_device)); - VLOG(1) << "Running " << func_.name() << " on " << target_device - << " with handle: " << handle; - lib->Run(opts, handle, args, rets, - [rets, activity, done, ctx](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } else { - for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); - } - } - delete rets; - delete activity; - done(); - }); + opts.rets_alloc_attrs.push_back(ret_alloc_attrs); } - - private: - NameAttrList func_; - DataTypeVector input_dtypes_; - DataTypeVector output_dtypes_; - - mutex mu_; - typedef std::pair<string, FunctionLibraryRuntime*> FunctionTarget; - std::map<FunctionTarget, FunctionLibraryRuntime::Handle> handle_cache_ - GUARDED_BY(mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp); -}; + auto* rets = new std::vector<Tensor>; + auto* activity = new tracing::ScopedActivity(strings::StrCat( + "RemoteCall: Run: ", func_.name(), " on ", target_device)); + VLOG(1) << "Running " << func_.name() << " on " << target_device + << " with handle: " << handle; + lib->Run(opts, handle, args, rets, + [rets, activity, done, ctx](const Status& status) { + if (!status.ok()) { + ctx->SetStatus(status); + } else { + for (size_t i = 0; i < rets->size(); ++i) { + ctx->set_output(i, (*rets)[i]); + } + } + delete rets; + delete activity; + done(); + }); +} REGISTER_KERNEL_BUILDER( Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp); diff --git a/tensorflow/core/kernels/function_ops.h b/tensorflow/core/kernels/function_ops.h new file mode 100644 index 0000000000..9e88cc6d8c --- /dev/null +++ b/tensorflow/core/kernels/function_ops.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +static const char* const kArgOp = FunctionLibraryDefinition::kArgOp; +static const char* const kRetOp = FunctionLibraryDefinition::kRetOp; + +class ArgOp : public OpKernel { + public: + explicit ArgOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + bool IsExpensive() override { return false; } + + private: + int index_; + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(ArgOp); +}; + +class RetvalOp : public OpKernel { + public: + explicit RetvalOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + bool IsExpensive() override { return false; } + + private: + int index_; + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); +}; + +class RemoteCallOp : public AsyncOpKernel { + public: + explicit RemoteCallOp(OpKernelConstruction* ctx); + + ~RemoteCallOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + private: + NameAttrList func_; + DataTypeVector input_dtypes_; + DataTypeVector output_dtypes_; + + mutex mu_; + typedef std::pair<string, FunctionLibraryRuntime*> FunctionTarget; + std::map<FunctionTarget, FunctionLibraryRuntime::Handle> handle_cache_ + GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_ diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index cb285bf732..1529d2e336 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -127,31 +127,47 @@ class IfOp : public AsyncOpKernel { explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { auto lib = ctx->function_library(); OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library")); - const NameAttrList* func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &func)); - OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &then_handle_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &func)); - OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_)); } ~IfOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + auto lib = ctx->function_library(); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library"), done); + + // TODO(b/37549631): Because this op has `SetIsStateful()` in its op + // registration, this kernel may be shared by multiple subgraphs, which have + // different associated `FunctionLibraryRuntime` objects and hence different + // `FHandle` namespaces. So we must call Instantiate() to make sure we get + // the correct function handles with respect to `lib`. Note the underlying + // `lib->Instantiate()` caches the created function handles, so calling + // `Instantiate()` repeatedly on the same `lib` and function is cheap. + FHandle then_handle; + FHandle else_handle; + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done); + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done); + bool cond; OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond)); - (new State(this, ctx, cond, done))->Start(); + (new State(this, ctx, cond, then_handle, else_handle, done))->Start(); } private: - FHandle then_handle_; - FHandle else_handle_; + NameAttrList then_func_; + NameAttrList else_func_; class State { public: - State(IfOp* kernel, OpKernelContext* ctx, bool cond, DoneCallback done) + State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle, + FHandle else_handle, DoneCallback done) : kernel_(kernel), ctx_(ctx), cond_(cond), + then_handle_(then_handle), + else_handle_(else_handle), done_(std::move(done)), lib_(CHECK_NOTNULL(ctx_->function_library())) { SetRunOptions(ctx_, &opts_, true /* always_collect_stats */); @@ -163,7 +179,7 @@ class IfOp : public AsyncOpKernel { ~State() {} void Start() { - FHandle handle = cond_ ? kernel_->then_handle_ : kernel_->else_handle_; + FHandle handle = cond_ ? then_handle_ : else_handle_; rets_.clear(); lib_->Run( // Evaluate one of the branch. @@ -184,6 +200,8 @@ class IfOp : public AsyncOpKernel { IfOp* const kernel_; OpKernelContext* const ctx_; const bool cond_; + FHandle then_handle_; + FHandle else_handle_; DoneCallback done_; FunctionLibraryRuntime* const lib_; FunctionLibraryRuntime::Options opts_; @@ -200,6 +218,10 @@ REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"), REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp); REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp); +REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp); +REGISTER_KERNEL_BUILDER( + Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp); + class WhileOp : public AsyncOpKernel { public: explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { @@ -214,30 +236,17 @@ class WhileOp : public AsyncOpKernel { OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library"), done); - // TODO(b/37549631): Because this op has `SetIsStateful()` in its - // op registration, this kernel may be shared by multiple - // subgraphs, which have different associated - // `FunctionLibraryRuntime` objects and hence different `FHandle` - // namespaces. We currently work around this by caching the map - // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two - // functions this op uses. + // TODO(b/37549631): Because this op has `SetIsStateful()` in its op + // registration, this kernel may be shared by multiple subgraphs, which have + // different associated `FunctionLibraryRuntime` objects and hence different + // `FHandle` namespaces. So we must call Instantiate() to make sure we get + // the correct function handles with respect to `lib`. Note the underlying + // `lib->Instantiate()` caches the created function handles, so calling + // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle cond_handle; FHandle body_handle; - { - mutex_lock l(mu_); - const auto iter = handles_.find(lib); - if (iter == handles_.end()) { - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), - done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), - done); - handles_[lib] = {cond_handle, body_handle}; - } else { - cond_handle = iter->second.first; - body_handle = iter->second.second; - } - } - + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done); + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done); (new State(this, ctx, cond_handle, body_handle, done))->Start(); } @@ -245,10 +254,6 @@ class WhileOp : public AsyncOpKernel { NameAttrList cond_func_; NameAttrList body_func_; - mutex mu_; - std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>> - handles_ GUARDED_BY(mu_); - class State { public: State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, @@ -378,6 +383,9 @@ REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp); REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp); REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp); + Status GetScalar(OpKernelContext* ctx, int index, int32* value, const char* label) { Tensor t = ctx->input(index); diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index f99dd643f7..d89f1592bd 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -45,6 +45,24 @@ struct FusedBatchNorm; template <typename Device, typename T, typename U> struct FusedBatchNormGrad; +template <bool IsSame, typename Y, typename X, typename T> +struct CastIfNecessary { + static inline void process( + Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth, + const CPUDevice& d) { + y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>(); + } +}; + +template <typename Y, typename X, typename T> +struct CastIfNecessary<true, Y, X, T> { + static inline void process( + Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth, + const CPUDevice& d) { + y.reshape(rest_by_depth).device(d) = x_shifted; + } +}; + template <typename T, typename U> struct FusedBatchNorm<CPUDevice, T, U> { void operator()(OpKernelContext* context, const Tensor& x_input, @@ -125,7 +143,11 @@ struct FusedBatchNorm<CPUDevice, T, U> { auto x_shifted = x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec); - y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>(); + // Explicitly checks the types of T and U and only casts x_shifted when + // T != U. (Not doing so caused a 35-50% performance slowdown for + // some compiler flags.) + CastIfNecessary<std::is_same<T, U>::value, decltype(y), decltype(x_shifted), + T>::process(y, x_shifted, rest_by_depth, d); } }; diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index 80376c61aa..5d4737549b 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -578,25 +578,41 @@ struct MatMulFunctor<SYCLDevice, T> { .Label("cublas"), \ MatMulOp<GPUDevice, T, true /* cublas */>) -#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#if defined(INTEL_MKL) -// MKL does not support half and int32 types for matrix-multiplication, so -// register the kernel to use default Eigen based implementations for these -// types. Registration for NO-LABEL version is in mkl_matmul_op.cc -TF_CALL_float(REGISTER_CPU_EIGEN); -TF_CALL_double(REGISTER_CPU_EIGEN); +// MKL does not support half, bfloat16 and int32 types for +// matrix-multiplication, so register the kernel to use default Eigen based +// implementations for these types. REGISTER_CPU defines two versions - Eigen +// label and NO-LABEL TF_CALL_half(REGISTER_CPU); TF_CALL_bfloat16(REGISTER_CPU); - TF_CALL_int32(REGISTER_CPU); + +// Float is supported in both MKL DNN as well as in MKL ML +// Registration for NO-LABEL version is in mkl_matmul_op.cc for types supported +// by MKL. However we define Eigen label version here just to pass a few unit +// tests +TF_CALL_float(REGISTER_CPU_EIGEN); + +// MKL DNN does not support complex64/complex128/double, if user specifies +// to use only opensource MKL DNN then use default implementation for these +// types otherwise use GEMM from MKL ML binary + +#if defined(DO_NOT_USE_ML) +TF_CALL_complex64(REGISTER_CPU); +TF_CALL_complex128(REGISTER_CPU); +TF_CALL_double(REGISTER_CPU); +#else // DO_NOT_USE_ML TF_CALL_complex64(REGISTER_CPU_EIGEN); TF_CALL_complex128(REGISTER_CPU_EIGEN); -#else +TF_CALL_double(REGISTER_CPU_EIGEN); +#endif + +#else // INTEL MKL TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); TF_CALL_half(REGISTER_CPU); TF_CALL_bfloat16(REGISTER_CPU); - TF_CALL_int32(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU); diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index d545d34fdf..d3566c2e37 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -442,7 +442,6 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const Tensor& input_tensor = MklGetInput(context, this->kInputTensorIndexInput); MklDnnShape dnn_shape_input; @@ -450,14 +449,14 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { this->SanityCheckInput(context, input_tensor, dnn_shape_input); if (!context->status().ok()) return; - MklDnnData<T> dnn_data_input(&cpu_engine); - MklDnnData<T> dnn_data_output(&cpu_engine); + MklDnnData<T> dnn_data_input(&cpu_engine_); // initialize variables for the pooling op MklPoolParameters pool_params; // Get the input tensor and initialize the pooling parameters - this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, - &dnn_data_input); + TensorShape input_tensor_shape = input_tensor.shape(); + this->InitMklPoolParameters(context, &pool_params, dnn_shape_input, + input_tensor_shape); OP_REQUIRES_OK(context, context->status()); // Declare output tensor @@ -467,65 +466,62 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { // If input is an empty tensor, allocate an empty output tensor and return if (input_tensor.NumElements() == 0) { - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(false); - TensorShape output_tf_shape; - if (pool_params.data_format == TensorFormat::FORMAT_NCHW) { - output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); - } else { - memory::dims output_dims_NHWC_order; - output_dims_NHWC_order = {pool_params.tensor_in_batch, - static_cast<int>(pool_params.out_height), - static_cast<int>(pool_params.out_width), - pool_params.out_depth}; - output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); - } const int kOutputIndex = 0; - AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor, - output_tf_shape, output_mkl_shape); - CHECK_NOTNULL(output_tensor); + this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params, + output_dims_mkl_order, &output_tensor); return; } - // If input is in Mkl layout, then just get the memory format from it - // directly, instead of using input data_format to AvgPool. - if (dnn_shape_input.IsMklTensor()) { - dnn_data_output.SetUsrMem( - output_dims_mkl_order, - static_cast<memory::format>( - dnn_data_input.GetUsrMemDesc().data.format)); - - } else { - dnn_data_output.SetUsrMem(output_dims_mkl_order, - this->data_format_mkldnn_); - } - - // describe the memory layout - dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); - - // 3. create a pooling primitive descriptor - auto pool_desc = pooling_forward::desc( - prop_kind::forward, algorithm::pooling_avg_exclude_padding, - dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast<int>(pool_params.pad_top), - static_cast<int>(pool_params.pad_left)}), - memory::dims({static_cast<int>(pool_params.pad_bottom), - static_cast<int>(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_prim_desc = - pooling_forward::primitive_desc(pool_desc, cpu_engine); - - this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order, + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); + + // Get the input memory descriptor + memory::desc input_md = + dnn_shape_input.IsMklTensor() + ? dnn_shape_input.GetMklLayout() + : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape, + this->data_format_tf_), + MklDnnType<T>(), this->data_format_mkldnn_); + + // Get src/filter/stride/padding information + memory::dims src_dims = + dnn_shape_input.IsMklTensor() + ? dnn_shape_input.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(), + this->data_format_tf_); + + // Get an average pooling primitive from the op pool + MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr; + MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims, + strides, padding_left, padding_right, + algorithm::pooling_avg_exclude_padding); + pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams); + + // allocate output tensor + this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()), + output_dims_mkl_order, this->data_format_mkldnn_, &output_tensor); CHECK_NOTNULL(output_tensor); OP_REQUIRES_OK(context, context->status()); - dnn_data_output.SetUsrMemDataHandle(output_tensor); - this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input, - &dnn_data_output); + // check whether we need to reorder src + const T* src_data = input_tensor.flat<T>().data(); + if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) { + dnn_data_input.SetUsrMem(input_md, &input_tensor); + auto src_target_primitive_desc = memory::primitive_desc( + {{src_dims}, MklDnnType<T>(), pooling_fwd->GetSrcMemoryFormat()}, + cpu_engine_); + dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc); + src_data = const_cast<T*>( + reinterpret_cast<T*>(dnn_data_input.GetOpMem().get_data_handle())); + } + + T* dst_data = output_tensor->flat<T>().data(); + + // execute pooling + pooling_fwd->Execute(src_data, dst_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -535,9 +531,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { errors::Aborted("Operation received an exception:", error_msg)); } } // Compute -}; // MklAvgPoolingOp -//----------------------------------------------------------------------------- + private: + engine cpu_engine_ = engine(engine::cpu, 0); +}; // MklAvgPoolingOp template <class Device, class T> class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { @@ -547,91 +544,78 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); - MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape; - const Tensor& tensor_in_shape = + const Tensor& orig_input_tensor = MklGetInput(context, kInputTensorIndexInputShape); - const Tensor& input_gradient_tensor = + const Tensor& grad_tensor = MklGetInput(context, kInputTensorIndexInputGradient); - GetMklShape(context, kInputTensorIndexInputShape, - &original_input_mkl_shape); - GetMklShape(context, kInputTensorIndexInputGradient, - &input_gradient_mkl_shape); - SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor, - original_input_mkl_shape, input_gradient_mkl_shape); + MklDnnShape orig_input_mkl_shape, grad_mkl_shape; + GetMklShape(context, kInputTensorIndexInputShape, &orig_input_mkl_shape); + GetMklShape(context, kInputTensorIndexInputGradient, &grad_mkl_shape); if (!context->status().ok()) return; // Used to allocate output_diff_src/diff_src - // and create pool_fwd mdm desc - // 0. Input("orig_input_shape: int32") //NOT a T Tensor! - // 1. Input("grad: T") - - MklDnnData<T> input_gradient_diff_dst(&cpu_engine); - MklDnnData<T> output_diff_src(&cpu_engine); - Tensor* output_tensor_diff_src = nullptr; - TensorShape original_input_shape; + MklDnnData<T> grad_dnn_data(&cpu_engine_); MklPoolParameters pool_params; - memory::dims output_dims_mkl_order, original_input_dims_nchw; - // Configure the original input memory descriptor - memory::desc original_input_md = ConfigureOriginalInput( - context, tensor_in_shape, original_input_mkl_shape, - &original_input_dims_nchw, &pool_params, &original_input_shape); - - // configure the original output memory descriptor - // by definition, the shape of the original output is the same - // as the shape of the gradient diff_dst - memory::desc original_output_md = this->ConfigureOriginalOutput( - pool_params, input_gradient_mkl_shape, output_dims_mkl_order); - - memory::desc target_diff_dst_md = this->ConfigureInputGradient( - input_gradient_mkl_shape, input_gradient_tensor, - &input_gradient_diff_dst, original_output_md); - // The shape of the output diff src needs to be the same shape as the - // original input. But we will set its format to be same as the format of - // input gradient. We won't use format of original input since it will - // always be in Tensorflow layout (given that AvgPoolGrad gets shape of - // the input rather than actual input). - output_diff_src.SetUsrMem( - original_input_dims_nchw, - static_cast<memory::format>(target_diff_dst_md.data.format)); - - // Create the forward pooling primitive descriptor so we can reference it - // in the backward pooling primitive descriptor - auto pool_fwd_desc = pooling_forward::desc( - prop_kind::forward, algorithm::pooling_avg_exclude_padding, - original_input_md, original_output_md, - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast<int>(pool_params.pad_top), - static_cast<int>(pool_params.pad_left)}), - memory::dims({static_cast<int>(pool_params.pad_bottom), - static_cast<int>(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_fwd_prim_desc = - pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); - - auto pool_bkwd_desc = pooling_backward::desc( - algorithm::pooling_avg_exclude_padding, - output_diff_src.GetUsrMemDesc(), target_diff_dst_md, - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast<int>(pool_params.pad_top), - static_cast<int>(pool_params.pad_left)}), - memory::dims({static_cast<int>(pool_params.pad_bottom), - static_cast<int>(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( - pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); - this->AllocateOutputTensor( - context, pool_bkwd_prim_desc, original_input_dims_nchw, - this->data_format_mkldnn_, &output_tensor_diff_src); - - output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src); - - this->PrepareAndExecuteNet( - pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src, - memory::primitive_desc(target_diff_dst_md, cpu_engine)); + auto shape_vec = orig_input_tensor.vec<int32>(); + TensorShape orig_input_shape; + for (int i = 0; i < orig_input_tensor.NumElements(); i++) { + orig_input_shape.AddDim(shape_vec(i)); + } + this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape, + orig_input_shape); + + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); + + memory::dims orig_input_dims_mkl_order = + orig_input_mkl_shape.IsMklTensor() + ? orig_input_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(orig_input_shape, + this->data_format_tf_); + + memory::dims diff_dst_dims = + grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(), + this->data_format_tf_); + memory::dims output_dims_mkl_order; + this->GetOutputDims(pool_params, &output_dims_mkl_order); + + MklPoolingParams bwdParams(orig_input_dims_mkl_order, + output_dims_mkl_order, filter_dims, strides, + padding_left, padding_right, + algorithm::pooling_avg_exclude_padding); + MklPoolingBwdPrimitive<T>* pooling_bwd = + MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); + + Tensor* output_tensor = nullptr; + this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), + orig_input_dims_mkl_order, + this->data_format_mkldnn_, &output_tensor); + // get diff_dst memory::desc + memory::desc diff_dst_md = + grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType<T>(), + this->data_format_mkldnn_); + // Check whether we need to reorder diff_dst + const T* diff_dst_data = grad_tensor.flat<T>().data(); + if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) { + auto target_diff_dst = memory::primitive_desc( + {{diff_dst_dims}, MklDnnType<T>(), pooling_bwd->GetDiffDstFormat()}, + cpu_engine_); + grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor); + grad_dnn_data.CheckReorderToOpMem(target_diff_dst); + diff_dst_data = const_cast<T*>( + reinterpret_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle())); + } + + T* diff_src_data = output_tensor->flat<T>().data(); + + // execute pooling op + pooling_bwd->Execute(diff_dst_data, diff_src_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -639,33 +623,14 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", error_msg)); } - } // Compute + } private: // 0. Input("orig_input_shape: int32") // 1. Input("grad: T") const int kInputTensorIndexInputShape = 0; const int kInputTensorIndexInputGradient = 1; - - memory::desc ConfigureOriginalInput( - OpKernelContext* context, const Tensor& tensor_original_input_shape, - const MklDnnShape& original_input_mkl_shape, - memory::dims* original_input_dims_mkl_order, - MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { - CHECK_NOTNULL(original_input_dims_mkl_order); - CHECK_NOTNULL(pool_params); - CHECK_NOTNULL(input_tensor_shape); - // For AvgPoolGrad, we only get the size of the original input because - // The original data is irrelvant. - auto shape_vec = tensor_original_input_shape.vec<int32>(); - for (int64 i = 0; i < tensor_original_input_shape.NumElements(); ++i) { - input_tensor_shape->AddDim(shape_vec(i)); - } - - return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput( - context, tensor_original_input_shape, original_input_mkl_shape, - original_input_dims_mkl_order, pool_params, *input_tensor_shape); - } + engine cpu_engine_ = engine(engine::cpu, 0); void SanityCheckInputs(OpKernelContext* context, const Tensor& tensor_in_shape, diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index 6f490cdc23..d8efb1be3e 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -308,11 +308,9 @@ class MklConcatOp : public OpKernel { } if (invoke_eigen) { - string msg = std::string("Invoking Eigen version of Concat. Reason:") + - (!is_concat_dim_channel - ? std::string("Concat dimension is not channel") - : std::string("Not all tensors are in Mkl layout")); - VLOG(1) << "_MklConcatOp: " << msg; + VLOG(1) << "_MklConcatOp: Invoking Eigen version of Concat. Reason:" + << (!is_concat_dim_channel ? "Concat dimension is not channel" + : "Not all tensors are in Mkl layout"); CallEigenVersion(context, input_tensors, input_shapes); return; } diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index a370037d97..b73a119a88 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -328,9 +328,8 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> { return instance_; } - static std::string CreateKey( - const MklConvBwdFilterParams& convBwdFilterDims) { - std::string prefix = "conv2d_bwd_filter"; + static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) { + string prefix = "conv2d_bwd_filter"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(convBwdFilterDims.src_dims); @@ -346,13 +345,13 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> { MklPrimitive* GetConv2dBwdFilter( const MklConvBwdFilterParams& convBwdFilterDims) { - std::string key = CreateKey(convBwdFilterDims); + string key = CreateKey(convBwdFilterDims); return this->GetOp(key); } void SetConv2dBwdFilter( const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) { - std::string key = CreateKey(convBwdFilterDims); + string key = CreateKey(convBwdFilterDims); this->SetOp(key, op); } }; diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index b0f7faaa1a..39498f1a80 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -265,9 +265,8 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { return instance_; } - static std::string CreateKey( - const MklConvBwdInputParams& convBwdInputDims) { - std::string prefix = "conv2d_bwd_input"; + static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) { + string prefix = "conv2d_bwd_input"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(convBwdInputDims.diff_src_dims); @@ -282,13 +281,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { MklPrimitive* GetConv2dBwdInput( const MklConvBwdInputParams& convBwdInputDims) { - std::string key = CreateKey(convBwdInputDims); + string key = CreateKey(convBwdInputDims); return this->GetOp(key); } void SetConv2dBwdInput( const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) { - std::string key = CreateKey(convBwdInputDims); + string key = CreateKey(convBwdInputDims); this->SetOp(key, op); } }; diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index b568973220..62396eeb8b 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -18,7 +18,6 @@ limitations under the License. #include <string.h> #include <map> -#include <string> #include <vector> #include <memory> @@ -35,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/padding.h" @@ -298,8 +298,8 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> { return instance_; } - static std::string CreateKey(const MklConvFwdParams& convFwdDims) { - std::string prefix = "conv2d_fwd_"; + static string CreateKey(const MklConvFwdParams& convFwdDims) { + string prefix = "conv2d_fwd_"; FactoryKeyCreator key_creator; key_creator.AddAsKey(prefix); key_creator.AddAsKey(convFwdDims.src_dims); @@ -314,12 +314,12 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> { } MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) { - std::string key = CreateKey(convFwdDims); + string key = CreateKey(convFwdDims); return this->GetOp(key); } void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) { - std::string key = CreateKey(convFwdDims); + string key = CreateKey(convFwdDims); this->SetOp(key, op); } }; @@ -930,10 +930,9 @@ class MklConv2DOp : public OpKernel { conv2d_fwd->Execute(src_data, filter_data, dst_data); } } catch (mkldnn::error &e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + std::string(e.message) + - ", in file " + std::string(__FILE__) + ":" + - std::to_string(__LINE__); + string error_msg = tensorflow::strings::StrCat( + "Status: ", e.status, ", message: ", string(e.message), ", in file ", + __FILE__, ":", __LINE__); OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", error_msg)); } diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index 5e1a5001dc..3f154ff33b 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ #include <limits> -#include <string> #include <vector> #include <memory> diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 3fe660cf96..0149e78db5 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -262,6 +262,7 @@ class MklFusedBatchNormOp : public OpKernel { } void MklCreateInputLayout(OpKernelContext* context) { + const Tensor& input = MklGetInput(context, 0); bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); if (input_in_mkl_format) { mkl_lt_input = @@ -544,6 +545,7 @@ class MklFusedBatchNormGradOp : public OpKernel { } void MklCreateInputLayout(OpKernelContext* context) { + const Tensor& input = MklGetInput(context, 0); bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); if (input_in_mkl_format) { mkl_lt_input = @@ -684,6 +686,466 @@ class MklFusedBatchNormGradOp : public OpKernel { #ifndef INTEL_MKL_ML +struct MklBatchNormFwdParams { + memory::dims src_dims; + int depth; + float eps; + bool training; + + MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, + bool training) + : src_dims(src_dims), depth(depth), eps(eps), training(training) {} +}; + +template <typename T> +class MklFusedBatchNormFwdPrimitive : public MklPrimitive { + public: + explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams) + : cpu_engine_(engine::cpu, 0) { + context_.fwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); + if (context_.bn_fwd == nullptr) Setup(fwdParams); + } + + ~MklFusedBatchNormFwdPrimitive() {} + + // BatchNormalization forward execute + // src_data: input data buffer of src + // weights_data: input data buffer of weights + // dst_data: output data buffer of dst + // mean_data: output data buffer of means + // variance_data: output data buffer of variances + void Execute(const T* src_data, const T* weights_data, T* dst_data, + T* mean_data, T* variance_data) { + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(src_data))); + context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); + + if (context_.flags & use_scale_shift) + context_.weights_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(weights_data))); + + if ((context_.pkind == prop_kind::forward_training) || + (context_.flags & use_global_stats)) { + context_.mean_mem->set_data_handle(static_cast<void*>(mean_data)); + context_.variance_mem->set_data_handle(static_cast<void*>(variance_data)); + } + + // execution + context_.fwd_stream->submit(context_.fwd_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + + if (context_.flags & use_scale_shift) + context_.weights_mem->set_data_handle(DummyData); + + if ((context_.pkind == prop_kind::forward_training) || + (context_.flags & use_global_stats)) { + context_.mean_mem->set_data_handle(DummyData); + context_.variance_mem->set_data_handle(DummyData); + } + } + + memory::primitive_desc GetDstPd() const { + return (*context_.dst_mem).get_primitive_desc(); + } + + mkldnn_memory_format_t GetSrcFmt() const { + return (*context_.src_mem).get_primitive_desc().desc().data.format; + } + + mkldnn_memory_format_t GetDstFmt() const { + return (*context_.dst_mem).get_primitive_desc().desc().data.format; + } + + private: + // Primitive reuse context for BatchNorm fwd op + struct BatchNormFwdContext { + // flags indict if it is training or inference mode + int64 flags; + + // algorithm + mkldnn::prop_kind pkind; + + // Mkldnn Memory + std::shared_ptr<mkldnn::memory> src_mem; + std::shared_ptr<mkldnn::memory> weights_mem; + std::shared_ptr<mkldnn::memory> dst_mem; + std::shared_ptr<mkldnn::memory> mean_mem; + std::shared_ptr<mkldnn::memory> variance_mem; + + // BatchNorm forward primitive + std::shared_ptr<mkldnn::primitive> bn_fwd; + std::shared_ptr<mkldnn::stream> fwd_stream; + std::vector<mkldnn::primitive> fwd_primitives; + + BatchNormFwdContext() + : flags(0), + pkind(mkldnn::forward_training), + src_mem(nullptr), + weights_mem(nullptr), + dst_mem(nullptr), + mean_mem(nullptr), + variance_mem(nullptr), + bn_fwd(nullptr), + fwd_stream(nullptr) {} + }; + + void Setup(const MklBatchNormFwdParams& fwdParams) { + context_.flags = fwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats); + context_.pkind = fwdParams.training ? prop_kind::forward_training + : prop_kind::forward_scoring; + + // memory desc + auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>(), + get_desired_format(fwdParams.src_dims[1])); + + // fwd desc & primitive desc + auto fwd_desc = batch_normalization_forward::desc( + context_.pkind, src_md, fwdParams.eps, context_.flags); + auto fwd_pd = + batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_); + + // memory primitive + context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + context_.dst_mem.reset(new memory(fwd_pd.dst_primitive_desc(), DummyData)); + + if (context_.flags & use_scale_shift) { + auto weights_desc = memory::desc({2, fwdParams.depth}, MklDnnType<T>(), + memory::format::nc); + context_.weights_mem.reset( + new memory({weights_desc, cpu_engine_}, DummyData)); + } + + if (fwdParams.training || (context_.flags & use_global_stats)) { + auto mean_desc = memory::desc({1, fwdParams.depth}, MklDnnType<T>(), + memory::format::nc); + context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + + auto variance_desc = + memory::desc({1, fwdParams.depth}, MklDnnType<T>(), memory::nc); + context_.variance_mem.reset( + new memory({variance_desc, cpu_engine_}, DummyData)); + } + + // BatchNorm forward primitive + if (!fwdParams.training && !(context_.flags & use_global_stats)) { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.weights_mem, + *context_.dst_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.dst_mem)); + } + } else if (context_.flags & use_global_stats) { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + (const primitive::at)*context_.variance_mem, *context_.weights_mem, + *context_.dst_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + (const primitive::at)*context_.variance_mem, *context_.dst_mem)); + } + } else { + if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.weights_mem, *context_.dst_mem, + *context_.mean_mem, *context_.variance_mem)); + } else { + context_.bn_fwd.reset(new batch_normalization_forward( + fwd_pd, *context_.src_mem, *context_.dst_mem, *context_.mean_mem, + *context_.variance_mem)); + } + } + + context_.fwd_primitives.push_back(*context_.bn_fwd); + } + + mkldnn::memory::desc get_desc_data(const mkldnn::memory& m) const { + return m.get_primitive_desc().desc().data; + } + + struct BatchNormFwdContext context_; + engine cpu_engine_; +}; + +template <typename T> +class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> { + public: + static MklFusedBatchNormFwdPrimitive<T>* Get( + const MklBatchNormFwdParams& fwdParams) { + auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T>*>( + MklFusedBatchNormFwdPrimitiveFactory<T>::GetInstance().GetBatchNormFwd( + fwdParams)); + + if (bn_fwd == nullptr) { + bn_fwd = new MklFusedBatchNormFwdPrimitive<T>(fwdParams); + MklFusedBatchNormFwdPrimitiveFactory<T>::GetInstance().SetBatchNormFwd( + fwdParams, bn_fwd); + } + return bn_fwd; + } + + static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() { + static MklFusedBatchNormFwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklFusedBatchNormFwdPrimitiveFactory() {} + ~MklFusedBatchNormFwdPrimitiveFactory() {} + + static std::string CreateKey(const MklBatchNormFwdParams& fwdParams) { + std::string prefix = "bn_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey<int>(fwdParams.depth); + key_creator.AddAsKey<float>(fwdParams.eps); + key_creator.AddAsKey<bool>(fwdParams.training); + return key_creator.GetKey(); + } + + MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) { + std::string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams, + MklPrimitive* op) { + std::string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +}; + +struct MklBatchNormBwdParams { + memory::dims src_dims; + memory::dims diff_dst_dims; + int depth; + float eps; + bool training; + + MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, + int depth, float eps, bool training) + : src_dims(src_dims), + diff_dst_dims(diff_dst_dims), + depth(depth), + eps(eps), + training(training) {} +}; + +template <typename T> +class MklFusedBatchNormBwdPrimitive : public MklPrimitive { + public: + explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) + : cpu_engine_(engine::cpu, 0) { + context_.bwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); + if (context_.bn_bwd == nullptr) Setup(bwdParams); + } + + ~MklFusedBatchNormBwdPrimitive() {} + + // BatchNormalization backward execute + // src_data: input data buffer of src + // mean_data: input data buffer of mean + // variance_data: input data buffer of variance + // diff_dst_data: input data buffer of diff_dst + // weights_data: input data buffer of weights + // diff_src_data: output data buffer of diff_src + // diff_weights_data: output data buffer of diff_weights + void Execute(const T* src_data, const T* mean_data, const T* variance_data, + const T* diff_dst_data, const T* weights_data, T* diff_src_data, + T* diff_weights_data) { + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(src_data))); + context_.mean_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(mean_data))); + context_.variance_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(variance_data))); + context_.diff_dst_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(diff_dst_data))); + + if (context_.flags & use_scale_shift) { + context_.weights_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(weights_data))); + context_.diff_weights_mem->set_data_handle( + static_cast<void*>(diff_weights_data)); + } + + context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); + + // execution + context_.bwd_stream->submit(context_.bwd_primitives); + + context_.src_mem->set_data_handle(DummyData); + context_.mean_mem->set_data_handle(DummyData); + context_.variance_mem->set_data_handle(DummyData); + context_.diff_dst_mem->set_data_handle(DummyData); + if (context_.flags & use_scale_shift) { + context_.weights_mem->set_data_handle(DummyData); + context_.diff_weights_mem->set_data_handle(DummyData); + } + context_.diff_src_mem->set_data_handle(DummyData); + } + + mkldnn_memory_format_t GetSrcFmt() { + return (*context_.src_mem).get_primitive_desc().desc().data.format; + } + + mkldnn_memory_format_t GetDiffDstFmt() { + return (*context_.diff_dst_mem).get_primitive_desc().desc().data.format; + } + + memory::primitive_desc GetDiffSrcPd() { + return (*context_.diff_src_mem).get_primitive_desc(); + } + + private: + struct BatchNormBwdContext { + // Flags to indicate whether it is training or inference + int64 flags; + + // MKLDNN memory + std::shared_ptr<mkldnn::memory> src_mem; + std::shared_ptr<mkldnn::memory> mean_mem; + std::shared_ptr<mkldnn::memory> variance_mem; + std::shared_ptr<mkldnn::memory> diff_dst_mem; + std::shared_ptr<mkldnn::memory> weights_mem; + std::shared_ptr<mkldnn::memory> diff_weights_mem; + std::shared_ptr<mkldnn::memory> diff_src_mem; + + // Batch Norm primitive + std::shared_ptr<mkldnn::primitive> bn_bwd; + std::vector<mkldnn::primitive> bwd_primitives; + std::shared_ptr<mkldnn::stream> bwd_stream; + + BatchNormBwdContext() + : src_mem(nullptr), + mean_mem(nullptr), + variance_mem(nullptr), + diff_dst_mem(nullptr), + weights_mem(nullptr), + diff_weights_mem(nullptr), + diff_src_mem(nullptr), + bwd_stream(nullptr) {} + }; + + void Setup(const MklBatchNormBwdParams& bwdParams) { + context_.flags = bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats); + + // memory desc + auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(), + get_desired_format(bwdParams.src_dims[1])); + auto diff_dst_md = + memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>(), + get_desired_format(bwdParams.diff_dst_dims[1])); + auto variance_desc = + memory::desc({1, bwdParams.depth}, MklDnnType<T>(), memory::nc); + auto mean_desc = + memory::desc({1, bwdParams.depth}, MklDnnType<T>(), memory::format::nc); + auto weights_desc = + memory::desc({2, bwdParams.depth}, MklDnnType<T>(), memory::format::nc); + auto diff_weights_desc = weights_desc; + + // fwd desc & primitive desc + auto fwd_desc = batch_normalization_forward::desc( + prop_kind::forward_training, src_md, bwdParams.eps, + bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats)); + auto fwd_pd = + batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_); + + // BatchNorm backward primtive + // + // For inference, specify use_global_stats + // 1. on fwd propagation, use mean and variance provided as inputs. + // 2. on bwd propagation, mean and variance are considered as constants. + // Thus, reduce the amount of MKL computation. + auto bwd_desc = batch_normalization_backward::desc( + prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, + bwdParams.training ? use_scale_shift + : (use_scale_shift | use_global_stats)); + auto bn_bwd_pd = batch_normalization_backward::primitive_desc( + bwd_desc, cpu_engine_, fwd_pd); + + // memory primitive + context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + context_.diff_dst_mem.reset( + new memory({diff_dst_md, cpu_engine_}, DummyData)); + context_.variance_mem.reset( + new memory({variance_desc, cpu_engine_}, DummyData)); + context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + context_.weights_mem.reset( + new memory({weights_desc, cpu_engine_}, DummyData)); + context_.diff_weights_mem.reset( + new memory({diff_weights_desc, cpu_engine_}, DummyData)); + context_.diff_src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + + context_.bn_bwd.reset(new batch_normalization_backward( + bn_bwd_pd, *context_.src_mem, *context_.mean_mem, + *context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem, + *context_.diff_src_mem, *context_.diff_weights_mem)); + context_.bwd_primitives.push_back(*context_.bn_bwd); + } + + struct BatchNormBwdContext context_; + engine cpu_engine_; +}; + +template <typename T> +class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> { + public: + static MklFusedBatchNormBwdPrimitive<T>* Get( + const MklBatchNormBwdParams& bwdParams) { + auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T>*>( + MklFusedBatchNormBwdPrimitiveFactory<T>::GetInstance().GetBatchNormBwd( + bwdParams)); + if (bn_bwd == nullptr) { + bn_bwd = new MklFusedBatchNormBwdPrimitive<T>(bwdParams); + MklFusedBatchNormBwdPrimitiveFactory<T>::GetInstance().SetBatchNormBwd( + bwdParams, bn_bwd); + } + return bn_bwd; + } + + static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() { + static MklFusedBatchNormBwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklFusedBatchNormBwdPrimitiveFactory() {} + ~MklFusedBatchNormBwdPrimitiveFactory() {} + + static std::string CreateKey(const MklBatchNormBwdParams& bwdParams) { + std::string prefix = "bn_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(bwdParams.diff_dst_dims); + key_creator.AddAsKey<int>(bwdParams.depth); + key_creator.AddAsKey<float>(bwdParams.eps); + key_creator.AddAsKey<bool>(bwdParams.training); + return key_creator.GetKey(); + } + + MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) { + std::string key = CreateKey(bwdParams); + return this->GetOp(key); + } + + void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams, + MklPrimitive* op) { + std::string key = CreateKey(bwdParams); + this->SetOp(key, op); + } +}; + template <typename Device, typename T> class MklFusedBatchNormOp : public OpKernel { public: @@ -701,7 +1163,6 @@ class MklFusedBatchNormOp : public OpKernel { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const size_t kSrcIndex = 0; // index of src input tensor const size_t kScaleIndex = 1; // index of scale tensor const size_t kShiftIndex = 2; // index of shift tensor @@ -786,7 +1247,7 @@ class MklFusedBatchNormOp : public OpKernel { SetMeanVariance(est_mean_tensor, est_variance_tensor); MklDnnData<T> src(&cpu_engine); - MklDnnData<T> dst(&cpu_engine); + MklDnnData<T> weights(&cpu_engine); memory::format format_m; if (dnn_shape_src.IsMklTensor()) { @@ -800,123 +1261,102 @@ class MklFusedBatchNormOp : public OpKernel { } // set src primitive - memory::dims src_dims; - if (dnn_shape_src.IsMklTensor()) { - src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), - tensor_format_); - } else { - src_dims = - TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); - } + memory::dims src_dims = + dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); auto src_md = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetMklLayout() : memory::desc(src_dims, MklDnnType<T>(), format_m); - src.SetUsrMem(src_md, &src_tensor); - // set weights primitive // MKL-DNN packs scale & shift as "weights": // <scale>...<scale><shift>...<shift> - auto weights_desc = memory::desc({2, static_cast<int>(depth_)}, - MklDnnType<T>(), memory::format::nc); - auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); - auto weights_m = memory(weights_pd); - T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle()); - T* scale_tf = - reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data())); - T* shift_tf = - reinterpret_cast<T*>(const_cast<T*>(shift_tensor.flat<T>().data())); + weights.AllocateBuffer(2 * depth_ * sizeof(T)); + T* weights_data = reinterpret_cast<T*>(weights.GetAllocatedBuffer()); + const T* scale_tf = scale_tensor.flat<T>().data(); + const T* shift_tf = shift_tensor.flat<T>().data(); - for (int k = 0; k < depth_; k++) { - weights_data[k] = scale_tf[k]; - weights_data[k + depth_] = shift_tf[k]; - } - - // set mean primitive - auto mean_desc = memory::desc({1, static_cast<int>(depth_)}, - MklDnnType<T>(), memory::format::nc); - auto mean_pd = memory::primitive_desc(mean_desc, cpu_engine); + std::memcpy(weights_data, scale_tf, depth_ * sizeof(T)); + std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(T)); char* saved_mean_data_tf = reinterpret_cast<char*>(saved_mean_tensor->flat<T>().data()); std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_), depth_ * sizeof(T)); - auto mean_m = - memory(mean_pd, reinterpret_cast<void*>(saved_mean_data_tf)); - // set variance primitive - auto variance_desc = memory::desc({1, static_cast<int>(depth_)}, - MklDnnType<T>(), memory::format::nc); - auto variance_pd = memory::primitive_desc(variance_desc, cpu_engine); char* saved_variance_data_tf = reinterpret_cast<char*>(saved_variance_tensor->flat<T>().data()); std::memcpy(saved_variance_data_tf, reinterpret_cast<char*>(variance_values_), depth_ * sizeof(T)); - auto variance_m = memory(variance_pd, saved_variance_data_tf); - - prop_kind pk = (is_training_) ? prop_kind::forward_training - : prop_kind::forward_scoring; - auto bnrm_fwd_desc = batch_normalization_forward::desc( - pk, src.GetUsrMemDesc(), epsilon_, - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, cpu_engine); - - // allocate dst tensor + + // get batchnorm op from the pool + MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_); + MklFusedBatchNormFwdPrimitive<T>* bn_fwd = + MklFusedBatchNormFwdPrimitiveFactory<T>::Get(fwdParams); + + // check if reorder is needed for src, weights, mean, variance + const T* src_data = src_tensor.flat<T>().data(); + if (src_md.data.format != bn_fwd->GetSrcFmt()) { + src.SetUsrMem(src_md, &src_tensor); + auto src_target = memory::primitive_desc( + {{src_dims}, + MklDnnType<T>(), + static_cast<memory::format>(bn_fwd->GetSrcFmt())}, + cpu_engine); + src.CheckReorderToOpMem(src_target); + src_data = const_cast<T*>( + reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); + } + + // allocate output (dst) tensor; always set it as MKL-DNN layout MklDnnShape dnn_shape_dst; TensorShape tf_shape_dst; - if (dnn_shape_src.IsMklTensor()) { - dnn_shape_dst.SetMklTensor(true); - auto dst_pd = bnrm_fwd_pd.dst_primitive_desc(); - dnn_shape_dst.SetMklLayout(&dst_pd); - dnn_shape_dst.SetElemType(MklDnnType<T>()); - dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, - format_m); - tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); - } else { - dnn_shape_dst.SetMklTensor(false); - tf_shape_dst = src_tensor.shape(); - } + dnn_shape_dst.SetMklTensor(true); + auto dst_pd = bn_fwd->GetDstPd(); + dnn_shape_dst.SetMklLayout(&dst_pd); + dnn_shape_dst.SetElemType(MklDnnType<T>()); + auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension() + : src_tensor.shape().dims(); + dnn_shape_dst.SetTfLayout(ndims, src_dims, format_m); + tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, dnn_shape_dst); - // Output of batchnorm has same shape as input. - dst.SetUsrMem(src_md, dst_tensor); + T* weights_op_data = weights_data; + T* mean_op_data = saved_mean_tensor->flat<T>().data(); + T* variance_op_data = saved_variance_tensor->flat<T>().data(); + T* dst_data = dst_tensor->flat<T>().data(); - primitive bnrm_fwd_op; - if (is_training_) { - bnrm_fwd_op = - batch_normalization_forward(bnrm_fwd_pd, src.GetOpMem(), weights_m, - dst.GetOpMem(), mean_m, variance_m); - } else { - bnrm_fwd_op = batch_normalization_forward( - bnrm_fwd_pd, src.GetOpMem(), mean_m, variance_m, - (const primitive::at)weights_m, dst.GetOpMem()); - } - std::vector<primitive> net; - net.push_back(bnrm_fwd_op); - stream(stream::kind::eager).submit(net).wait(); + // execution + bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data, + variance_op_data); // copy batch_mean data - T* batch_mean_data_tf = - reinterpret_cast<T*>(batch_mean_tensor->flat<T>().data()); + T* batch_mean_data_tf = batch_mean_tensor->flat<T>().data(); std::memcpy(reinterpret_cast<char*>(batch_mean_data_tf), - reinterpret_cast<char*>(mean_m.get_data_handle()), + reinterpret_cast<char*>(saved_mean_data_tf), depth_ * sizeof(T)); + // TODO(yli135): OpMem is same as usr mem since + // since its format is hard-coded as nc when primitive is created. // copy batch_variance data with Bessel's correction - // if training mode is on float adjust_factor = 1.0; if (is_training_) { size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; size_t adjust_size = orig_size - 1; adjust_factor = (static_cast<float>(orig_size)) / adjust_size; } - for (int k = 0; k < depth_; k++) - batch_variance_tensor->flat<T>().data()[k] = - (reinterpret_cast<T*>(variance_m.get_data_handle()))[k] * - adjust_factor; + + auto variance_data = reinterpret_cast<T*>(saved_variance_data_tf); + auto batch_variance_data = batch_variance_tensor->flat<T>().data(); + if (is_training_) { + for (int k = 0; k < depth_; k++) { + batch_variance_data[k] = variance_data[k] * adjust_factor; + } + } else { + std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(T)); + } } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -933,7 +1373,8 @@ class MklFusedBatchNormOp : public OpKernel { bool is_training_; T* mean_values_; T* variance_values_; - int depth_; // batch normalization is done for per channel. + size_t depth_; // batch normalization is done for per channel. + engine cpu_engine = engine(engine::cpu, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -990,8 +1431,9 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_batch_mean); CHECK_NOTNULL(*batch_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*batch_mean_tensor)->flat<T>().data()[k] = NAN; + int num_elements = tf_shape_scale.num_elements(); + auto batch_mean_data = (*batch_mean_tensor)->flat<T>().data(); + std::fill_n(batch_mean_data, num_elements, NAN); // allocate batch variance output tensor MklDnnShape mkl_shape_batch_variance; @@ -1001,8 +1443,8 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_batch_variance); CHECK_NOTNULL(*batch_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*batch_variance_tensor)->flat<T>().data()[k] = NAN; + auto batch_variance_data = (*batch_variance_tensor)->flat<T>().data(); + std::fill_n(batch_variance_data, num_elements, NAN); // Mean and variance (without Bessel's correction) saved for backward // computation to serve as pre-computed mean and variance. @@ -1012,8 +1454,8 @@ class MklFusedBatchNormOp : public OpKernel { tf_shape_scale, mkl_shape_saved_mean); CHECK_NOTNULL(*saved_mean_tensor); // set NAN mean value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*saved_mean_tensor)->flat<T>().data()[k] = NAN; + auto saved_mean_data = (*saved_mean_tensor)->flat<T>().data(); + std::fill_n(saved_mean_data, num_elements, NAN); MklDnnShape mkl_shape_saved_variance; mkl_shape_saved_variance.SetMklTensor(false); @@ -1022,8 +1464,8 @@ class MklFusedBatchNormOp : public OpKernel { mkl_shape_saved_variance); CHECK_NOTNULL(*saved_variance_tensor); // set NAN variance value in case of empty input tensor - for (int k = 0; k < tf_shape_scale.num_elements(); k++) - (*saved_variance_tensor)->flat<T>().data()[k] = NAN; + auto saved_variance_data = (*saved_variance_tensor)->flat<T>().data(); + std::fill_n(saved_variance_data, num_elements, NAN); } }; @@ -1044,12 +1486,12 @@ class MklFusedBatchNormGradOp : public OpKernel { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const size_t kDiffDstIndex = 0; // index of diff_dst tensor const size_t kSrcIndex = 1; // index of src input tensor const size_t kScaleIndex = 2; // index of scale tensor const size_t kMeanIndex = 3; // index of saved_mean tensor const size_t kVarianceIndex = 4; // index of saved_variance tensor + const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); const Tensor& src_tensor = MklGetInput(context, kSrcIndex); const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); @@ -1060,8 +1502,8 @@ class MklFusedBatchNormGradOp : public OpKernel { MklDnnShape dnn_shape_src, dnn_shape_diff_dst; GetMklShape(context, kSrcIndex, &dnn_shape_src); GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst); - TensorShape tf_shape_src, tf_shape_diff_dst; + TensorShape tf_shape_src, tf_shape_diff_dst; if (dnn_shape_diff_dst.IsMklTensor()) { tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); OP_REQUIRES( @@ -1102,6 +1544,7 @@ class MklFusedBatchNormGradOp : public OpKernel { saved_variance_tensor.shape().DebugString())); Tensor* diff_src_tensor = nullptr; + // special case: input with 0 element and 0 batch size if (tf_shape_src.num_elements() == 0 || tf_shape_diff_dst.num_elements() == 0) { HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), @@ -1117,189 +1560,127 @@ class MklFusedBatchNormGradOp : public OpKernel { ExtractParams(context); } - MklDnnData<T> src(&cpu_engine); - MklDnnData<T> mean(&cpu_engine); - MklDnnData<T> variance(&cpu_engine); - MklDnnData<T> diff_dst(&cpu_engine); - MklDnnData<T> diff_src(&cpu_engine); - - memory::dims src_dims, diff_dst_dims; - if (dnn_shape_src.IsMklTensor()) - src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), - tensor_format_); - else - src_dims = - TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); - - if (dnn_shape_diff_dst.IsMklTensor()) - diff_dst_dims = TFShapeToMklDnnDimsInNCHW( - dnn_shape_diff_dst.GetTfShape(), tensor_format_); - else - diff_dst_dims = - TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); - - // set src and diff_dst primitives according to input layout - memory::desc src_md({}, memory::data_undef, memory::format_undef); - memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); + memory::format format_m; if (dnn_shape_src.IsMklTensor()) { - src_md = dnn_shape_src.GetMklLayout(); - } else { - src_md = memory::desc(src_dims, MklDnnType<T>(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); - } - if (dnn_shape_diff_dst.IsMklTensor()) { - diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); + if (dnn_shape_src.IsTensorInNCHWFormat()) + format_m = memory::format::nchw; + else + format_m = memory::format::nhwc; } else { - diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); + format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); } - src.SetUsrMem(src_md, &src_tensor); - diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - - // weights -- DNN packs scales/shifts as weights in order of - // scale, ..., scale, shift, ..., shift - auto weights_desc = - memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc); - auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); - auto weights_m = memory(weights_pd); - T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle()); - T* scale_tf = - reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data())); + + MklDnnData<T> src(&cpu_engine); + MklDnnData<T> diff_dst(&cpu_engine); + MklDnnData<T> weights(&cpu_engine); + MklDnnData<T> diff_weights(&cpu_engine); + + memory::dims src_dims = + dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); + memory::dims diff_dst_dims = + dnn_shape_diff_dst.IsMklTensor() + ? dnn_shape_diff_dst.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), + tensor_format_); + + // set src and diff_dst primitive descriptors + memory::desc src_md = + dnn_shape_src.IsMklTensor() + ? dnn_shape_src.GetMklLayout() + : memory::desc(src_dims, MklDnnType<T>(), format_m); + memory::desc diff_dst_md = + dnn_shape_diff_dst.IsMklTensor() + ? dnn_shape_diff_dst.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType<T>(), format_m); + + // weights -- MKL DNN packs scales/ shifts as weights in order + // of scale, ..., scale, shift, ...., shift + weights.AllocateBuffer(2 * depth_ * sizeof(T)); + T* weights_data_tf = reinterpret_cast<T*>(weights.GetAllocatedBuffer()); + const T* scale_tf = scale_tensor.flat<T>().data(); for (int k = 0; k < depth_; k++) { - weights_data[k] = scale_tf[k]; - weights_data[k + depth_] = 0; + weights_data_tf[k] = scale_tf[k]; + weights_data_tf[k + depth_] = 0; } - // set mean primitive - memory::dims mv_dims = GetMeanVarianceDims(); - mean.SetUsrMem(mv_dims, memory::format::nc, - const_cast<void*>(static_cast<const void*>( - saved_mean_tensor.flat<T>().data()))); - mean.SetOpMemDesc(mv_dims, memory::format::nc); - - // set variance primitive - variance.SetUsrMem(mv_dims, memory::format::nc, - const_cast<void*>(static_cast<const void*>( - saved_variance_tensor.flat<T>().data()))); - variance.SetOpMemDesc(mv_dims, memory::format::nc); - - // set diff_weight primitive - auto diff_weights_desc = - memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc); - auto diff_weights_pd = - memory::primitive_desc(diff_weights_desc, cpu_engine); - auto diff_weights_m = memory(diff_weights_pd); - - auto bnrm_fwd_desc = batch_normalization_forward::desc( - prop_kind::forward_training, src.GetUsrMemDesc(), epsilon_, - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( - bnrm_fwd_desc, cpu_engine); + diff_weights.AllocateBuffer(2 * depth_ * sizeof(T)); + + MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, + is_training_); + MklFusedBatchNormBwdPrimitive<T>* bn_bwd = + MklFusedBatchNormBwdPrimitiveFactory<T>::Get(bwdParams); + + // check if src/diff_dst need to be reordered + const T* src_data = src_tensor.flat<T>().data(); + if (src_md.data.format != bn_bwd->GetSrcFmt()) { + src.SetUsrMem(src_md, &src_tensor); + auto src_target = memory::primitive_desc( + {{src_dims}, + MklDnnType<T>(), + static_cast<memory::format>(bn_bwd->GetSrcFmt())}, + cpu_engine); + src.CheckReorderToOpMem(src_target); + src_data = const_cast<T*>( + reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); + } + + const T* diff_dst_data = diff_dst_tensor.flat<T>().data(); + if (diff_dst_md.data.format != bn_bwd->GetDiffDstFmt()) { + diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); + auto diff_dst_target = memory::primitive_desc( + {{diff_dst_dims}, + MklDnnType<T>(), + static_cast<memory::format>(bn_bwd->GetDiffDstFmt())}, + cpu_engine); + diff_dst.CheckReorderToOpMem(diff_dst_target); + diff_dst_data = const_cast<T*>( + reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle())); + } // Indices of output tensors const size_t kDiffSrcIndex = 0; // index of diff_src tensor - // allocate diff_src tensor + // allocate output tensor: diff_src, always set as MKL-DNN layout MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; - - // MKL-DNN's BN primitive not provide API to fetch internal format - // set common_md as OpMem - // src and diff_dst will reorder to common_md - // diff_src will set as common_md - memory::desc common_md({}, memory::data_undef, memory::format_undef); - if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { - if (dnn_shape_src.IsMklTensor()) { - common_md = dnn_shape_src.GetMklLayout(); - } else { - common_md = dnn_shape_diff_dst.GetMklLayout(); - } - } else { - common_md = memory::desc(src_dims, MklDnnType<T>(), - TFDataFormatToMklDnnDataFormat(tensor_format_)); - } - // if any of src and diff_dst as mkl layout, - // then we set diff_src as mkl layout - if (dnn_shape_src.IsMklTensor() || - dnn_shape_diff_dst.IsMklTensor()) { - dnn_shape_diff_src.SetMklTensor(true); - // set diff_src's mkl layout as common_md - auto diff_src_pd = memory::primitive_desc(common_md, cpu_engine); - dnn_shape_diff_src.SetMklLayout(&diff_src_pd); - dnn_shape_diff_src.SetElemType(MklDnnType<T>()); - if (dnn_shape_src.IsMklTensor()) { - dnn_shape_diff_src.SetTfLayout( - dnn_shape_src.GetDimension(), - src_dims, - dnn_shape_src.GetTfDataFormat()); - dnn_shape_diff_src.SetTfDimOrder( - dnn_shape_src.GetDimension(), - tensor_format_); - } else { - dnn_shape_diff_src.SetTfLayout( - dnn_shape_diff_dst.GetDimension(), - src_dims, - dnn_shape_diff_dst.GetTfDataFormat()); - dnn_shape_diff_src.SetTfDimOrder( - dnn_shape_diff_dst.GetDimension(), - tensor_format_); - } - tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); - } else { - dnn_shape_diff_src.SetMklTensor(false); - // both src and diff_dst are TensorFlow layout, - // so it is OK to get TensorFlow shape. - tf_shape_diff_src = src_tensor.shape(); - } + dnn_shape_diff_src.SetMklTensor(true); + auto diff_src_pd = bn_bwd->GetDiffSrcPd(); + dnn_shape_diff_src.SetMklLayout(&diff_src_pd); + dnn_shape_diff_src.SetElemType(MklDnnType<T>()); + dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, format_m); + dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_); + tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, tf_shape_diff_src, dnn_shape_diff_src); - // set diff_src - diff_src.SetUsrMem(common_md, diff_src_tensor); - - prop_kind pk = prop_kind::backward; - auto bnrm_bwd_desc = batch_normalization_backward::desc( - pk, common_md, common_md, epsilon_, - /* for inference, specify use_global_stats - 1. on fwd prop, use mean and variance - provided as inputs - 2. on bwd prop, mean and variance are - considered as constants. Thus, - reduce the amout of MKL computations - */ - is_training_ ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc( - bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd); - - std::vector<primitive> net; - src.CheckReorderToOpMem(memory::primitive_desc(common_md, - cpu_engine), &net); - diff_dst.CheckReorderToOpMem(memory::primitive_desc(common_md, - cpu_engine), &net); - - auto bnrm_bwd_op = batch_normalization_backward( - bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(), - diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m); - - net.push_back(bnrm_bwd_op); - stream(stream::kind::eager).submit(net).wait(); - - // allocate 4 output TF tensors + T* mean_data = + static_cast<T*>(const_cast<T*>(saved_mean_tensor.flat<T>().data())); + T* variance_data = static_cast<T*>( + const_cast<T*>(saved_variance_tensor.flat<T>().data())); + T* weights_data = weights_data_tf; + T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data()); + T* diff_weights_data = static_cast<T*>(diff_weights.GetAllocatedBuffer()); + // Execute + bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, + weights_data, diff_src_data, diff_weights_data); + + // allocate output TF tensors: diff_scale and diff_shift Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor, &diff_shift_tensor); // copy data: diff_scale and diff_shift - T* diff_weights_data_dnn = - reinterpret_cast<T*>(diff_weights_m.get_data_handle()); - for (int i = 0; i < depth_; i++) { - diff_scale_tensor->flat<T>().data()[i] = diff_weights_data_dnn[i]; - diff_shift_tensor->flat<T>().data()[i] = - diff_weights_data_dnn[i + depth_]; - } + auto diff_scale_data = diff_scale_tensor->flat<T>().data(); + auto diff_shift_data = diff_shift_tensor->flat<T>().data(); + std::memcpy(reinterpret_cast<char*>(diff_scale_data), + reinterpret_cast<char*>(diff_weights_data), + depth_ * sizeof(T)); + std::memcpy(reinterpret_cast<char*>(diff_shift_data), + reinterpret_cast<char*>(diff_weights_data + depth_), + depth_ * sizeof(T)); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -1315,6 +1696,7 @@ class MklFusedBatchNormGradOp : public OpKernel { TensorFormat tensor_format_; int depth_; // batch normalization is done for per channel. bool is_training_; + engine cpu_engine = engine(engine::cpu, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -1330,8 +1712,8 @@ class MklFusedBatchNormGradOp : public OpKernel { dnn_shape_diff_src.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, tf_shape_src, dnn_shape_diff_src); - for (size_t i = 0; i < (*diff_src_tensor)->shape().num_elements(); i++) - (*diff_src_tensor)->flat<T>().data()[i] = 0; + auto diff_src_data = (*diff_src_tensor)->flat<T>().data(); + std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), 0); Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; @@ -1357,16 +1739,18 @@ class MklFusedBatchNormGradOp : public OpKernel { AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, tf_shape_scale_shift, mkl_shape_diff_scale); CHECK_NOTNULL(*diff_scale_tensor); - for (size_t i = 0; i < (*diff_scale_tensor)->shape().num_elements(); i++) - (*diff_scale_tensor)->flat<T>().data()[i] = 0; + auto diff_scale_data = (*diff_scale_tensor)->flat<T>().data(); + std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(), + 0); MklDnnShape mkl_shape_diff_shift; mkl_shape_diff_shift.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, tf_shape_scale_shift, mkl_shape_diff_shift); CHECK_NOTNULL(*diff_shift_tensor); - for (size_t i = 0; i < (*diff_shift_tensor)->shape().num_elements(); i++) - (*diff_shift_tensor)->flat<T>().data()[i] = 0; + auto diff_shift_data = (*diff_shift_tensor)->flat<T>().data(); + std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(), + 0); // Placeholders for estimated_mean and estimated_variance, which are // used for inference and thus not needed here for gradient computation. diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 62c0404891..fd261433a0 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -23,14 +23,20 @@ limitations under the License. // and when it is undefined at build time, this file becomes an empty // compilation unit -#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) +#if defined(INTEL_MKL) -#include "mkl_cblas.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/fill_functor.h" +// This header file is part of MKL ML, need equivalent file in MKL DNN +#ifndef DO_NOT_USE_ML +#include "mkl_cblas.h" +#else +#include "mkldnn.h" +#endif + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -100,7 +106,6 @@ class MklMatMulOp : public OpKernel { private: bool transpose_a_; bool transpose_b_; - // -------------------------------------------------------------------------- // // @brief Matrix-Matrix Multiplication with FP32 tensors, a, b, c using CBLAS @@ -150,11 +155,26 @@ class MklMatMulOp : public OpKernel { // 1.0 and 0.0 respectively. const float alpha = 1.0f; const float beta = 0.0f; +#if defined(DO_NOT_USE_ML) + const char* const ftrans[] = {"N", "T", "C"}; + int index_transa = transa ? 1 : 0; + int index_transb = transb ? 1 : 0; + VLOG(2) << "MKL DNN SGEMM called"; + // MKL DNN only supports the Fortran api and requires column major while + // Tensorflow uses row major so we reverse the order A and B + mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha, + b, &ldb, a, &lda, &beta, c, &ldc); +#else + // MKL ML binary uses CBLAS API cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +#endif } + // MKLDNN only supports SGEMM +#ifndef DO_NOT_USE_ML + // Matrix-Matrix Multiplication with FP64 tensors. For detailed info about // parameters, look at FP32 function description. void MklBlasGemm(bool transa, bool transb, const int m, const int n, @@ -197,6 +217,7 @@ class MklMatMulOp : public OpKernel { reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta, reinterpret_cast<MKL_Complex16*>(c), ldc); } +#endif }; #define REGISTER_CPU(T) \ @@ -207,9 +228,12 @@ class MklMatMulOp : public OpKernel { // TODO(inteltf) Consider template specialization when adding/removing // additional types TF_CALL_float(REGISTER_CPU); + +#ifndef DO_NOT_USE_ML TF_CALL_double(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU); +#endif } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index ea537524b1..0a2151566e 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -119,6 +119,7 @@ class MklMaxPoolingOp : public OpKernel { mkl_out_shape); Tensor* workspace_tensor; + void* workspace_buf = nullptr; TensorShape workspace_shape; mkl_workspace_shape.SetMklTensor(false); @@ -510,7 +511,6 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { void Compute(OpKernelContext* context) override { try { - auto cpu_engine = engine(engine::cpu, 0); const Tensor& input_tensor = MklGetInput(context, this->kInputTensorIndexInput); MklDnnShape dnn_shape_input; @@ -525,8 +525,9 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { // initialize variables for the pooling op MklPoolParameters pool_params; // Get the input tensor and initialize the pooling parameters - this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, - &dnn_data_input); + TensorShape input_tensor_shape = input_tensor.shape(); + this->InitMklPoolParameters(context, &pool_params, dnn_shape_input, + input_tensor_shape); OP_REQUIRES_OK(context, context->status()); // Declare output tensor @@ -534,44 +535,70 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { memory::dims output_dims_mkl_order; this->GetOutputDims(pool_params, &output_dims_mkl_order); - // If input is in Mkl layout, then just get the memory format from it - // directly, instead of using input data_format to MaxPool. - if (dnn_shape_input.IsMklTensor()) { - dnn_data_output.SetUsrMem( - output_dims_mkl_order, - static_cast<memory::format>( - dnn_data_input.GetUsrMemDesc().data.format)); - } else { - dnn_data_output.SetUsrMem(output_dims_mkl_order, - this->data_format_mkldnn_); + // If input is an empty tensor, allocate an empty output tensor and return + if (input_tensor.NumElements() == 0) { + const int kOutputIndex = 0; + this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params, + output_dims_mkl_order, &output_tensor); + return; } - // describe the memory layout; let mkl-dnn choose the best for the op - dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); - - auto pool_desc = pooling_forward::desc( - prop_kind::forward, algorithm::pooling_max, - dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast<int>(pool_params.pad_top), - static_cast<int>(pool_params.pad_left)}), - memory::dims({static_cast<int>(pool_params.pad_bottom), - static_cast<int>(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_fwd_desc = - pooling_forward::primitive_desc(pool_desc, cpu_engine); - - this->AllocateOutputTensor(context, pool_fwd_desc, output_dims_mkl_order, + // Get the input memory descriptor + memory::desc input_md = + dnn_shape_input.IsMklTensor() + ? dnn_shape_input.GetMklLayout() + : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape, + this->data_format_tf_), + MklDnnType<T>(), this->data_format_mkldnn_); + + // Get src/filter/stride/padding information + memory::dims src_dims = + dnn_shape_input.IsMklTensor() + ? dnn_shape_input.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(), + this->data_format_tf_); + + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); + + // Get a pooling op from the cached pool + MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr; + MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims, + strides, padding_left, padding_right, + algorithm::pooling_max); + pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams); + + // allocate output tensor + this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()), + output_dims_mkl_order, this->data_format_mkldnn_, &output_tensor); OP_REQUIRES_OK(context, context->status()); - dnn_data_output.SetUsrMemDataHandle(output_tensor); + dnn_data_output.SetUsrMem(output_dims_mkl_order, + pooling_fwd->GetDstMemoryFormat(), + output_tensor); - AllocateWorkspaceTensor(context, pool_fwd_desc, &dnn_data_wksp); + AllocateWorkspaceTensor(context, *(pooling_fwd->GetPoolingFwdPd()), + &dnn_data_wksp); OP_REQUIRES_OK(context, context->status()); - this->PrepareAndExecuteNet(pool_fwd_desc, &dnn_data_input, - &dnn_data_output, &dnn_data_wksp); + // check wehther we need to reorder src + const T* src_data = input_tensor.flat<T>().data(); + if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) { + dnn_data_input.SetUsrMem(input_md, &input_tensor); + auto src_target_primitive_desc = memory::primitive_desc( + {{src_dims}, MklDnnType<T>(), pooling_fwd->GetSrcMemoryFormat()}, + cpu_engine); + dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc); + src_data = const_cast<T*>( + reinterpret_cast<T*>(dnn_data_input.GetOpMem().get_data_handle())); + } + + T* dst_data = output_tensor->flat<T>().data(); + void* ws_data = dnn_data_wksp.GetOpMem().get_data_handle(); + + // execute pooling op + pooling_fwd->Execute(src_data, dst_data, ws_data); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + @@ -579,10 +606,11 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", error_msg)); } - } // Compute + } private: const int kOutputTensorIndexWorkspace = 1; + engine cpu_engine = engine(engine::cpu, 0); void AllocateWorkspaceTensor( OpKernelContext* context, @@ -616,98 +644,105 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { public: explicit MklMaxPoolingGradOp(OpKernelConstruction* context) : MklPoolingBackwardOpBase<T>(context) {} - void Compute(OpKernelContext* context) override { try { auto cpu_engine = engine(engine::cpu, 0); const Tensor& orig_input_tensor = MklGetInput(context, kInputTensorIndexOrigInput); - const Tensor& orig_output_tensor = - MklGetInput(context, kInputTensorIndexOrigOutput); const Tensor& grad_tensor = MklGetInput(context, kInputTensorIndexGradient); const Tensor& workspace_tensor = MklGetInput(context, kInputTensorIndexWorkspace); - MklDnnShape orig_input_mkl_shape, orig_output_mkl_shape, grad_mkl_shape, - workspace_mkl_shape; + MklDnnShape orig_input_mkl_shape, grad_mkl_shape; GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape); - GetMklShape(context, kInputTensorIndexOrigOutput, &orig_output_mkl_shape); GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape); - GetMklShape(context, kInputTensorIndexWorkspace, &workspace_mkl_shape); - - SanityCheckInputs(context, orig_input_tensor, orig_output_tensor, - grad_tensor, workspace_tensor, orig_input_mkl_shape, - orig_output_mkl_shape, grad_mkl_shape, - workspace_mkl_shape); if (!context->status().ok()) return; MklDnnData<T> grad_dnn_data(&cpu_engine); MklDnnData<uint8> workspace_dnn_data(&cpu_engine); - MklDnnData<T> output_dnn_data(&cpu_engine); - Tensor* output_tensor = nullptr; + MklPoolParameters pool_params; - TensorShape orig_input_shape; - memory::dims output_dims_mkl_order, orig_input_dims_mkl_order; - memory::desc original_input_md = ConfigureOriginalInput( - context, orig_input_tensor, orig_input_mkl_shape, - &orig_input_dims_mkl_order, &pool_params, &orig_input_shape); - - memory::desc original_output_md = this->ConfigureOriginalOutput( - pool_params, orig_output_mkl_shape, output_dims_mkl_order); - - memory::desc target_diff_dst_md = this->ConfigureInputGradient( - grad_mkl_shape, grad_tensor, &grad_dnn_data, original_output_md); - - output_dnn_data.SetUsrMem(original_input_md); - - // Create the forward pooling primitive descriptor so we can - // pass it as a hint to the backward pooling primitive descriptor - auto pool_fwd_desc = pooling_forward::desc( - prop_kind::forward, algorithm::pooling_max, original_input_md, - original_output_md, - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast<int>(pool_params.pad_top), - static_cast<int>(pool_params.pad_left)}), - memory::dims({static_cast<int>(pool_params.pad_bottom), - static_cast<int>(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_fwd_prim_desc = - pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); - - auto pool_bkwd_desc = pooling_backward::desc( - algorithm::pooling_max, output_dnn_data.GetUsrMemDesc(), - target_diff_dst_md, - memory::dims({pool_params.row_stride, pool_params.col_stride}), - memory::dims({pool_params.window_rows, pool_params.window_cols}), - memory::dims({static_cast<int>(pool_params.pad_top), - static_cast<int>(pool_params.pad_left)}), - memory::dims({static_cast<int>(pool_params.pad_bottom), - static_cast<int>(pool_params.pad_right)}), - TFPaddingToMklDnnPadding(this->padding_)); - auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( - pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); - - this->AllocateOutputTensor(context, pool_bkwd_prim_desc, + TensorShape orig_input_shape = orig_input_tensor.shape(); + this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape, + orig_input_shape); + + memory::dims filter_dims, strides, padding_left, padding_right; + this->PoolParamsToDims(&pool_params, &filter_dims, &strides, + &padding_left, &padding_right); + + memory::dims diff_dst_dims = + grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(), + this->data_format_tf_); + memory::dims orig_input_dims_mkl_order = + orig_input_mkl_shape.IsMklTensor() + ? orig_input_mkl_shape.GetSizesAsMklDnnDims() + : TFShapeToMklDnnDimsInNCHW(orig_input_shape, + this->data_format_tf_); + + memory::dims output_dims_mkl_order; + this->GetOutputDims(pool_params, &output_dims_mkl_order); + + MklPoolingParams bwdParams( + orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, + strides, padding_left, padding_right, algorithm::pooling_max); + MklPoolingBwdPrimitive<T>* pooling_bwd = + MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); + + // allocate output tensor and memory primitive + Tensor* output_tensor = nullptr; + this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), orig_input_dims_mkl_order, this->data_format_mkldnn_, &output_tensor); - output_dnn_data.SetUsrMemDataHandle(output_tensor); - - ConfigureWorkspace(workspace_tensor, - pool_fwd_prim_desc.workspace_primitive_desc(), - &workspace_dnn_data); - this->PrepareAndExecuteNet( - pool_bkwd_prim_desc, &grad_dnn_data, &output_dnn_data, - memory::primitive_desc(target_diff_dst_md, cpu_engine), - &workspace_dnn_data); + // get diff_dst mem desc + memory::desc diff_dst_md = + grad_mkl_shape.IsMklTensor() + ? grad_mkl_shape.GetMklLayout() + : memory::desc(diff_dst_dims, MklDnnType<T>(), + this->data_format_mkldnn_); + // check if diff_dst needs to be reordered + const T* diff_dst_data = grad_tensor.flat<T>().data(); + if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) { + auto target_diff_dst = memory::primitive_desc( + {{diff_dst_dims}, MklDnnType<T>(), pooling_bwd->GetDiffDstFormat()}, + cpu_engine); + grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor); + grad_dnn_data.CheckReorderToOpMem(target_diff_dst); + diff_dst_data = const_cast<T*>( + reinterpret_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle())); + } + + void* ws_data = static_cast<void*>( + const_cast<uint8*>(workspace_tensor.flat<uint8>().data())); + ; + auto ws_md = + pooling_bwd->GetPoolingFwdPd()->workspace_primitive_desc().desc(); + if (ws_md.data.format != pooling_bwd->GetWorkspaceFormat()) { + memory::dims ws_dims; + ws_dims.assign(ws_md.data.dims, ws_md.data.dims + ws_md.data.ndims); + auto target_ws = + memory::primitive_desc({{ws_dims}, + pooling_bwd->GetWorkspaceDataType(), + pooling_bwd->GetWorkspaceFormat()}, + cpu_engine); + workspace_dnn_data.SetUsrMem(ws_md, &workspace_tensor); + workspace_dnn_data.CheckReorderToOpMem(target_ws); + ws_data = workspace_dnn_data.GetOpMem().get_data_handle(); + } + + T* diff_src_data = output_tensor->flat<T>().data(); + + // execute pooling + pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data); } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + + string error_msg = "Status:" + std::to_string(e.status) + + ", message: " + string(e.message) + ". in file " + string(__FILE__) + ":" + std::to_string(__LINE__); OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", error_msg)); } - } // Compute + } private: // .Input("orig_input: T") @@ -718,18 +753,6 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { const int kInputTensorIndexOrigOutput = 1; const int kInputTensorIndexGradient = 2; const int kInputTensorIndexWorkspace = 3; - // Output("output: T") in Base Class - - memory::desc ConfigureOriginalInput( - OpKernelContext* context, const Tensor& tensor_original_input, - const MklDnnShape& original_input_mkl_shape, - memory::dims* original_input_dims_mkl_order, - MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { - *input_tensor_shape = tensor_original_input.shape(); - return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput( - context, tensor_original_input, original_input_mkl_shape, - original_input_dims_mkl_order, pool_params, *input_tensor_shape); - } void ConfigureWorkspace(const Tensor& workspace_tensor, memory::primitive_desc workspace_pd, diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index 5ef6ce2a57..915878d9ea 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -24,6 +24,187 @@ limitations under the License. namespace tensorflow { +#ifndef INTEL_MKL_ML + +using mkldnn::pooling_avg; +using mkldnn::pooling_avg_exclude_padding; +using mkldnn::pooling_avg_include_padding; +using mkldnn::pooling_max; +using mkldnn::prop_kind; + +template <typename T> +void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { + if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg && + fwdParams.alg_kind != pooling_avg_include_padding && + fwdParams.alg_kind != pooling_avg_exclude_padding) { + assert("Pooling algorithm kind is not supported\n"); + } + + context_.alg_kind = fwdParams.alg_kind; + // create memory desc + // FIXME: Pooling doesn't expose to get the src_primitive_desc, + // so src format is currently hard-coded. + // A utility function is used to do this, + // which may be broken with future CPU architectures + context_.src_md.reset( + new memory::desc({fwdParams.src_dims}, MklDnnType<T>(), + get_desired_format(fwdParams.src_dims[1]))); + context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(), + memory::format::any)); + + // create a pooling descriptor + context_.fwd_desc.reset(new pooling_forward::desc( + prop_kind::forward_training, fwdParams.alg_kind, *context_.src_md, + *context_.dst_md, fwdParams.strides, fwdParams.filter_dims, + fwdParams.padding_left, fwdParams.padding_right, padding_kind::zero)); + context_.fwd_pd.reset( + new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_)); + + // store expected primitive format + context_.src_fmt = get_desired_format(fwdParams.src_dims[1]); + context_.dst_fmt = static_cast<mkldnn::memory::format>( + context_.fwd_pd.get()->dst_primitive_desc().desc().data.format); + + // create MKL-DNN internal memory object with dummy data + context_.src_mem.reset(new memory( + {{{fwdParams.src_dims}, MklDnnType<T>(), context_.src_fmt}, cpu_engine_}, + DummyData)); + context_.dst_mem.reset( + new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); + + // for max pooling, need to return workspace(ws) for backward computing + if (fwdParams.alg_kind == pooling_max) { + auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data; + // store workspace's dims and format to create workspace tensor + context_.ws_fmt = static_cast<mkldnn::memory::format>(ws_pd.format); + context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims); + context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type); + context_.ws_size = + context_.fwd_pd.get()->workspace_primitive_desc().get_size(); + context_.ws_mem.reset(new memory( + context_.fwd_pd.get()->workspace_primitive_desc(), DummyData)); + context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem, + *context_.dst_mem, + *context_.ws_mem)); + } else { + context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem, + *context_.dst_mem)); + } + + context_.fwd_primitives.push_back(*context_.fwd); +} + +template <typename T> +void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, + void* ws_data) { + context_.src_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(src_data))); + context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); + if (context_.alg_kind == pooling_max) { // max pooling must have ws + assert(ws_data != nullptr); + context_.ws_mem->set_data_handle(ws_data); + } + context_.fwd_stream->submit(context_.fwd_primitives); + + // set back data handle + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + if (context_.alg_kind == pooling_max) { // max pooling must have ws + assert(ws_data != nullptr); + context_.ws_mem->set_data_handle(DummyData); + } +} + +template class MklPoolingFwdPrimitive<float>; + +template <typename T> +void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) { + if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg && + bwdParams.alg_kind != pooling_avg_include_padding && + bwdParams.alg_kind != pooling_avg_exclude_padding) { + assert("Pooling algorithm kind is not supported\n"); + } + context_.alg_kind = bwdParams.alg_kind; + + // Create memory desc + context_.diff_src_md.reset(new memory::desc( + {bwdParams.src_dims}, MklDnnType<T>(), memory::format::any)); + context_.diff_dst_md.reset( + new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(), + get_desired_format(bwdParams.dst_dims[1]))); + context_.bwd_desc.reset(new pooling_backward::desc( + bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md, + bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left, + bwdParams.padding_right, padding_kind::zero)); + + // create a forward primitive, + // which will be used as a hint for creating backward primitive + context_.fwd_desc.reset(new pooling_forward::desc( + prop_kind::forward_training, bwdParams.alg_kind, *context_.diff_src_md, + *context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims, + bwdParams.padding_left, bwdParams.padding_right, padding_kind::zero)); + context_.fwd_pd.reset( + new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine)); + context_.bwd_pd.reset(new pooling_backward::primitive_desc( + *context_.bwd_desc, cpu_engine, *context_.fwd_pd)); + + // store expected primitive format + context_.diff_src_fmt = static_cast<mkldnn::memory::format>( + context_.bwd_pd.get()->diff_src_primitive_desc().desc().data.format); + context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1]); + + // create MKL-DNN internal memory object with dummy data + context_.diff_src_mem.reset( + new memory(context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData)); + context_.diff_dst_mem.reset(new memory( + {{{bwdParams.dst_dims}, MklDnnType<T>(), context_.diff_dst_fmt}, + cpu_engine}, + DummyData)); + + // for max pooling, need to return workspace for backward + if (bwdParams.alg_kind == pooling_max) { + auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data; + context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims); + context_.ws_fmt = get_desired_format(context_.ws_dims[1]); + context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type); + context_.ws_mem.reset(new memory( + {{{context_.ws_dims}, context_.ws_dt, context_.ws_fmt}, cpu_engine}, + DummyData)); + context_.bwd.reset( + new pooling_backward(*context_.bwd_pd, *context_.diff_dst_mem, + *context_.ws_mem, *context_.diff_src_mem)); + } else { + context_.bwd.reset(new pooling_backward( + *context_.bwd_pd, *context_.diff_dst_mem, *context_.diff_src_mem)); + } + context_.bwd_primitives.push_back(*context_.bwd); +} + +template <typename T> +void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, + T* diff_src_data, const void* ws_data) { + context_.diff_dst_mem->set_data_handle( + static_cast<void*>(const_cast<T*>(diff_dst_data))); + context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); + if (context_.alg_kind == pooling_max) { + assert(ws_data != nullptr); + context_.ws_mem->set_data_handle(const_cast<void*>(ws_data)); + } + + context_.bwd_stream->submit(context_.bwd_primitives); + // set back data handle + context_.diff_dst_mem->set_data_handle(DummyData); + context_.diff_src_mem->set_data_handle(DummyData); + if (context_.alg_kind == pooling_max) { + assert(ws_data != nullptr); + context_.ws_mem->set_data_handle(DummyData); + } +} + +template class MklPoolingBwdPrimitive<float>; + +#endif + // Initialization for TensorFlow format void MklPoolParameters::Init(OpKernelContext* context, const std::vector<int32>& ksize, diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index c0dfed7d7d..9c516afbd0 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ #ifdef INTEL_MKL -#include <string> +#include <memory> #include <vector> #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" @@ -32,6 +32,326 @@ using mkldnn::stream; namespace tensorflow { +#ifndef INTEL_MKL_ML + +using mkldnn::memory; +using mkldnn::pooling_avg; +using mkldnn::pooling_avg_exclude_padding; +using mkldnn::pooling_avg_include_padding; +using mkldnn::pooling_max; +using mkldnn::prop_kind; + +struct MklPoolingParams { + memory::dims src_dims; + memory::dims dst_dims; + memory::dims filter_dims; + memory::dims strides; + memory::dims padding_left; + memory::dims padding_right; + mkldnn::algorithm alg_kind; + + MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, + memory::dims filter_dims, memory::dims strides, + memory::dims padding_left, memory::dims padding_right, + mkldnn::algorithm alg_kind) + : src_dims(src_dims), + dst_dims(dst_dims), + filter_dims(filter_dims), + strides(strides), + padding_left(padding_left), + padding_right(padding_right), + alg_kind(alg_kind) {} +}; + +template <typename T> +class MklPoolingFwdPrimitive : public MklPrimitive { + public: + explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) + : cpu_engine_(engine::cpu, 0) { + context_.fwd_stream.reset(new stream(stream::kind::eager)); + if (context_.fwd == nullptr) Setup(fwdParams); + } + + ~MklPoolingFwdPrimitive() {} + + // Pooling forward execute + // src_data: input data buffer of src + // ws_data: output data buffer of workspace + // dst_data: output data buffer of dst + void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr); + + std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd() + const { + return context_.fwd_pd; + } + + memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } + + memory::format GetDstMemoryFormat() const { return context_.dst_fmt; } + + private: + void Setup(const MklPoolingParams& fwdParams); + + struct PoolingFwdContext { + // algorithm + mkldnn::algorithm alg_kind; + + // expected memory format + memory::format src_fmt; + memory::format dst_fmt; + memory::format ws_fmt; + + // workspace shape + memory::dims ws_dims; + memory::data_type ws_dt; + size_t ws_size; + + // MKL-DNN memory, just dummy data + std::shared_ptr<mkldnn::memory> ws_mem; + std::shared_ptr<mkldnn::memory> src_mem; + std::shared_ptr<mkldnn::memory> dst_mem; + + // desc & primitive desc + std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc; + std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd; + + // memory desc + std::shared_ptr<mkldnn::memory::desc> src_md; + std::shared_ptr<mkldnn::memory::desc> dst_md; + + // Pooling primitive + std::shared_ptr<mkldnn::pooling_forward> fwd; + std::shared_ptr<mkldnn::stream> fwd_stream; + std::vector<mkldnn::primitive> fwd_primitives; + + PoolingFwdContext() + : src_fmt(memory::format::any), + dst_fmt(memory::format::any), + ws_fmt(memory::format::any), + ws_mem(nullptr), + src_mem(nullptr), + dst_mem(nullptr), + fwd_desc(nullptr), + fwd_pd(nullptr), + src_md(nullptr), + dst_md(nullptr), + fwd(nullptr), + fwd_stream(nullptr) {} + }; + + struct PoolingFwdContext context_; + engine cpu_engine_; +}; + +template <typename T> +class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> { + public: + static MklPoolingFwdPrimitive<T>* Get(const MklPoolingParams& fwdParams) { + MklPoolingFwdPrimitive<T>* pooling_forward = nullptr; + + // Get pooling primitive from the pool + pooling_forward = static_cast<MklPoolingFwdPrimitive<T>*>( + MklPoolingFwdPrimitiveFactory<T>::GetInstance().GetPoolingFwd( + fwdParams)); + + if (pooling_forward == nullptr) { + pooling_forward = new MklPoolingFwdPrimitive<T>(fwdParams); + MklPoolingFwdPrimitiveFactory<T>::GetInstance().SetPoolingFwd( + fwdParams, pooling_forward); + } + return pooling_forward; + } + + static MklPoolingFwdPrimitiveFactory& GetInstance() { + static MklPoolingFwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklPoolingFwdPrimitiveFactory() {} + ~MklPoolingFwdPrimitiveFactory() {} + + // The key to be created will be used to get/set pooling + // primitive op from reuse perspective. + // A pooling key is a string which concates key parameters + // as well as algorithm kind (max versus avg). + static std::string CreateKey(const MklPoolingParams& fwdParams) { + std::string prefix = "pooling_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey(fwdParams.dst_dims); + key_creator.AddAsKey(fwdParams.filter_dims); + key_creator.AddAsKey(fwdParams.strides); + key_creator.AddAsKey(fwdParams.padding_left); + key_creator.AddAsKey(fwdParams.padding_right); + key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind)); + return key_creator.GetKey(); + } + + MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) { + std::string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) { + std::string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +}; + +template <typename T> +class MklPoolingBwdPrimitive : public MklPrimitive { + public: + explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) + : cpu_engine(engine::cpu, 0) { + context_.bwd_stream.reset(new stream(stream::kind::eager)); + if (context_.bwd == nullptr) Setup(bwdParams); + } + + ~MklPoolingBwdPrimitive() {} + + // Pooling backward execute + // diff_dst_data: input data buffer of diff_dst + // diff_src_data: output data buffer of diff_src + // ws_data: input data buffer of workspace + void Execute(const T* diff_dst_data, T* diff_src_data, + const void* ws_data = nullptr); + + public: + std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd() + const { + return context_.fwd_pd; + } + std::shared_ptr<mkldnn::pooling_backward::primitive_desc> GetPoolingBwdPd() + const { + return context_.bwd_pd; + } + + memory::format GetDiffDstFormat() const { return context_.diff_dst_fmt; } + + mkldnn::memory::data_type GetWorkspaceDataType() const { + return context_.ws_dt; + } + memory::format GetWorkspaceFormat() const { return context_.ws_fmt; } + + private: + void Setup(const MklPoolingParams& bwdParams); + + // Primitive reuse context for pooling bwd ops + struct PoolingBwdContext { + // algorithm + mkldnn::algorithm alg_kind; + + // expected memory format + mkldnn::memory::format diff_src_fmt; + mkldnn::memory::format diff_dst_fmt; + mkldnn::memory::format ws_fmt; + + // workspace attribute + mkldnn::memory::dims ws_dims; + mkldnn::memory::data_type ws_dt; + + // MKL-DNN memory + std::shared_ptr<mkldnn::memory> ws_mem; + std::shared_ptr<mkldnn::memory> diff_src_mem; + std::shared_ptr<mkldnn::memory> diff_dst_mem; + + // memory desc + std::shared_ptr<mkldnn::memory::desc> diff_src_md; + std::shared_ptr<mkldnn::memory::desc> diff_dst_md; + + // desc & primitive desc + std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc; + std::shared_ptr<mkldnn::pooling_backward::desc> bwd_desc; + std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd; + std::shared_ptr<mkldnn::pooling_backward::primitive_desc> bwd_pd; + + // pooling primitive + std::shared_ptr<mkldnn::pooling_backward> bwd; + std::shared_ptr<mkldnn::stream> bwd_stream; + + std::vector<mkldnn::primitive> bwd_primitives; + + PoolingBwdContext() + : diff_src_fmt(memory::format::any), + diff_dst_fmt(memory::format::any), + ws_fmt(memory::format::any), + ws_mem(nullptr), + diff_src_mem(nullptr), + diff_dst_mem(nullptr), + diff_src_md(nullptr), + diff_dst_md(nullptr), + fwd_desc(nullptr), + bwd_desc(nullptr), + fwd_pd(nullptr), + bwd_pd(nullptr), + bwd(nullptr), + bwd_stream(nullptr) {} + }; + + struct PoolingBwdContext context_; + engine cpu_engine; +}; + +template <typename T> +class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> { + public: + static MklPoolingBwdPrimitive<T>* Get(const MklPoolingParams& bwdParams) { + MklPoolingBwdPrimitive<T>* pooling_backward = nullptr; + + // Find a pooling backward primitive from the pool + // If it does not exist, create a new one + pooling_backward = static_cast<MklPoolingBwdPrimitive<T>*>( + MklPoolingBwdPrimitiveFactory<T>::GetInstance().GetPoolingBwd( + bwdParams)); + if (pooling_backward == nullptr) { + pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams); + MklPoolingBwdPrimitiveFactory<T>::GetInstance().SetPoolingBwd( + bwdParams, pooling_backward); + } + return pooling_backward; + } + + static MklPoolingBwdPrimitiveFactory& GetInstance() { + static MklPoolingBwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklPoolingBwdPrimitiveFactory() {} + ~MklPoolingBwdPrimitiveFactory() {} + + // The key to be created will be used to get/set pooling + // primitive op from reuse perspective. + // A pooling key is a string which concates key parameters + // as well as algorithm kind (max versus avg). + static std::string CreateKey(const MklPoolingParams& bwdParams) { + std::string prefix = "pooling_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(bwdParams.dst_dims); + key_creator.AddAsKey(bwdParams.filter_dims); + key_creator.AddAsKey(bwdParams.strides); + key_creator.AddAsKey(bwdParams.padding_left); + key_creator.AddAsKey(bwdParams.padding_right); + key_creator.AddAsKey<int>(static_cast<int>(bwdParams.alg_kind)); + return key_creator.GetKey(); + } + + MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) { + std::string key = CreateKey(bwdParams); + return this->GetOp(key); + } + + void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) { + std::string key = CreateKey(bwdParams); + this->SetOp(key, op); + } +}; +#endif + typedef Eigen::ThreadPoolDevice CPUDevice; struct MklPoolParameters { @@ -163,6 +483,41 @@ class MklPoolingOpBase : public OpKernel { } } + void PoolParamsToDims(const MklPoolParameters* pool_params, + memory::dims* filter_dims, memory::dims* strides, + memory::dims* padding_left, + memory::dims* padding_right) { + *filter_dims = {pool_params->window_rows, pool_params->window_cols}; + *strides = {pool_params->row_stride, pool_params->col_stride}; + *padding_left = {static_cast<int>(pool_params->pad_top), + static_cast<int>(pool_params->pad_left)}; + *padding_right = {static_cast<int>(pool_params->pad_bottom), + static_cast<int>(pool_params->pad_right)}; + } + + void AllocateEmptyOutputTensor(OpKernelContext* context, + const int kOutputIndex, + MklPoolParameters* pool_params, + const memory::dims output_dims_mkl_order, + Tensor** output_tensor) { + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(false); + TensorShape output_tf_shape; + if (pool_params->data_format == TensorFormat::FORMAT_NCHW) { + output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); + } else { + memory::dims output_dims_NHWC_order; + output_dims_NHWC_order = {pool_params->tensor_in_batch, + static_cast<int>(pool_params->out_height), + static_cast<int>(pool_params->out_width), + pool_params->out_depth}; + output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); + } + AllocateOutputSetMklShape(context, kOutputIndex, output_tensor, + output_tf_shape, output_mkl_shape); + CHECK_NOTNULL(output_tensor); + } + // Checks to make sure that the memory we need to allocate // is a multiple of sizeof(T) // returns the number of elements @@ -235,23 +590,6 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> { CHECK_NOTNULL(*output_tensor); } - void PrepareAndExecuteNet( - const pooling_forward::primitive_desc& pool_fwd_desc, - const MklDnnData<T>* src, MklDnnData<T>* dst, - MklDnnData<uint8>* wksp = nullptr) { - std::vector<primitive> net; - - // Create pooling primitive and add it to net - if (wksp != nullptr) { - net.push_back(pooling_forward(pool_fwd_desc, src->GetOpMem(), - dst->GetOpMem(), wksp->GetOpMem())); - } else { - net.push_back( - pooling_forward(pool_fwd_desc, src->GetOpMem(), dst->GetOpMem())); - } - stream(stream::kind::eager).submit(net).wait(); - } - void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor, const MklDnnShape& input_mkl_shape) { if (!input_mkl_shape.IsMklTensor()) { @@ -301,67 +639,6 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> { CHECK_NOTNULL(*output_tensor); } - void PrepareAndExecuteNet( - const pooling_backward::primitive_desc& pool_bkwd_desc, - MklDnnData<T>* input_gradient_diff_dst, MklDnnData<T>* output_diff_src, - const memory::primitive_desc& target_diff_dst_pd, - const MklDnnData<uint8>* workspace = nullptr) { - std::vector<primitive> net; - - // If the input gradient isn't in the same format as the output - // reorder it to the same format as the output - input_gradient_diff_dst->CheckReorderToOpMem(target_diff_dst_pd, &net); - - // Create pooling primitive and add it to net - if (nullptr == workspace) { - net.push_back(pooling_backward(pool_bkwd_desc, - input_gradient_diff_dst->GetOpMem(), - output_diff_src->GetOpMem())); - } else { - net.push_back( - pooling_backward(pool_bkwd_desc, input_gradient_diff_dst->GetOpMem(), - workspace->GetOpMem(), output_diff_src->GetOpMem())); - } - stream(stream::kind::eager).submit(net).wait(); - } - - // Max Pooling and Avg Pooling have slightly different implementations - // Takes the Tensor containing original input data and the original - // mkl Dnn Shape and populates other data - memory::desc ConfigureOriginalInput( - OpKernelContext* context, const Tensor& tensor_original_input_shape, - const MklDnnShape& original_input_mkl_shape, - memory::dims* original_input_dims_nchw, MklPoolParameters* pool_params, - const TensorShape& input_tensor_shape) { - CHECK_NOTNULL(original_input_dims_nchw); - CHECK_NOTNULL(pool_params); - this->InitMklPoolParameters(context, pool_params, original_input_mkl_shape, - input_tensor_shape); - - *original_input_dims_nchw = - original_input_mkl_shape.IsMklTensor() - ? original_input_mkl_shape.GetSizesAsMklDnnDims() - : TFShapeToMklDnnDimsInNCHW(input_tensor_shape, - this->data_format_tf_); - - return original_input_mkl_shape.IsMklTensor() - ? original_input_mkl_shape.GetMklLayout() - : memory::desc(*original_input_dims_nchw, MklDnnType<T>(), - this->data_format_mkldnn_); - } - - memory::desc ConfigureOriginalOutput( - const MklPoolParameters& pool_params, - const MklDnnShape& original_output_mkl_shape, - memory::dims output_dims_mkl_order) { - this->GetOutputDims(pool_params, &output_dims_mkl_order); - - return original_output_mkl_shape.IsMklTensor() - ? original_output_mkl_shape.GetMklLayout() - : memory::desc(output_dims_mkl_order, MklDnnType<T>(), - this->data_format_mkldnn_); - } - memory::desc ConfigureInputGradient( const MklDnnShape& input_gradient_mkl_shape, const Tensor& input_gradient_tensor, diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index 02ea9fc068..9c536df215 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -152,8 +152,12 @@ class MklReshapeOp : public OpKernel { // If Tensorflow's data format and the underlying format maintained by // MKLDNN are equivalent (both are NHWC or both are NCHW), then we can // safely return true. + // @todo: Future do not force skip reorder for all blocked format. Use + // blocking_desc_is_equal() for checking all the stride arrays in + // mkl-dnn/blob/master/src/common/type_helpers.hpp auto input_mkl_md = mkl_shape_input.GetMklLayout(); - if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format) { + if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format && + mkl_shape_input.GetTfDataFormat() != memory::format::blocked) { ret = true; } diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index f59843a07a..c7d0d4de0d 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -121,10 +121,11 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn( std::placeholders::_1, std::placeholders::_2, threshold); } -void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores, - int num_boxes, const Tensor& max_output_size, - const float score_threshold, - std::function<bool(int, int)> suppress_check_fn) { +void DoNonMaxSuppressionOp( + OpKernelContext* context, const Tensor& scores, int num_boxes, + const Tensor& max_output_size, const float score_threshold, + const std::function<bool(int, int)>& suppress_check_fn, + bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) { const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes); std::vector<float> scores_data(num_boxes); @@ -172,6 +173,15 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores, } } + int num_valid_outputs = selected.size(); + if (pad_to_max_output_size) { + selected.resize(output_size, 0); + selected_scores.resize(output_size, 0); + } + if (ptr_num_valid_outputs) { + *ptr_num_valid_outputs = num_valid_outputs; + } + // Allocate output tensors Tensor* output_indices = nullptr; TensorShape output_shape({static_cast<int>(selected.size())}); @@ -262,54 +272,106 @@ class NonMaxSuppressionV2Op : public OpKernel { } }; -template <typename Device> -class NonMaxSuppressionV3Op : public OpKernel { +class NonMaxSuppressionV3V4Base : public OpKernel { public: - explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) + explicit NonMaxSuppressionV3V4Base(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // boxes: [num_boxes, 4] - const Tensor& boxes = context->input(0); + boxes_ = context->input(0); // scores: [num_boxes] - const Tensor& scores = context->input(1); + scores_ = context->input(1); // max_output_size: scalar - const Tensor& max_output_size = context->input(2); + max_output_size_ = context->input(2); OP_REQUIRES( - context, TensorShapeUtils::IsScalar(max_output_size.shape()), + context, TensorShapeUtils::IsScalar(max_output_size_.shape()), errors::InvalidArgument("max_output_size must be 0-D, got shape ", - max_output_size.shape().DebugString())); + max_output_size_.shape().DebugString())); // iou_threshold: scalar const Tensor& iou_threshold = context->input(3); OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), errors::InvalidArgument("iou_threshold must be 0-D, got shape ", iou_threshold.shape().DebugString())); - const float iou_threshold_val = iou_threshold.scalar<float>()(); - + iou_threshold_val_ = iou_threshold.scalar<float>()(); + OP_REQUIRES(context, iou_threshold_val_ >= 0 && iou_threshold_val_ <= 1, + errors::InvalidArgument("iou_threshold must be in [0, 1]")); // score_threshold: scalar const Tensor& score_threshold = context->input(4); OP_REQUIRES( context, TensorShapeUtils::IsScalar(score_threshold.shape()), errors::InvalidArgument("score_threshold must be 0-D, got shape ", score_threshold.shape().DebugString())); - const float score_threshold_val = score_threshold.scalar<float>()(); + score_threshold_val_ = score_threshold.scalar<float>()(); - OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1, - errors::InvalidArgument("iou_threshold must be in [0, 1]")); - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, &num_boxes); - CheckScoreSizes(context, num_boxes, scores); + num_boxes_ = 0; + ParseAndCheckBoxSizes(context, boxes_, &num_boxes_); + CheckScoreSizes(context, num_boxes_, scores_); if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoComputeAndPostProcess(context); + } + + protected: + virtual void DoComputeAndPostProcess(OpKernelContext* context) = 0; + + Tensor boxes_; + Tensor scores_; + Tensor max_output_size_; + int num_boxes_; + float iou_threshold_val_; + float score_threshold_val_; +}; + +template <typename Device> +class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { + public: + explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) + : NonMaxSuppressionV3V4Base(context) {} + + protected: + void DoComputeAndPostProcess(OpKernelContext* context) override { + auto suppress_check_fn = + CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + + DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn); } }; template <typename Device> +class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { + public: + explicit NonMaxSuppressionV4Op(OpKernelConstruction* context) + : NonMaxSuppressionV3V4Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size", + &pad_to_max_output_size_)); + } + + protected: + void DoComputeAndPostProcess(OpKernelContext* context) override { + auto suppress_check_fn = + CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + int num_valid_outputs; + + DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn, + pad_to_max_output_size_, &num_valid_outputs); + + // Allocate scalar output tensor for number of indices computed. + Tensor* num_outputs_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 1, tensorflow::TensorShape{}, &num_outputs_t)); + num_outputs_t->scalar<int32>().setConstant(num_valid_outputs); + } + + private: + bool pad_to_max_output_size_; +}; + +template <typename Device> class NonMaxSuppressionWithOverlapsOp : public OpKernel { public: explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context) @@ -365,6 +427,9 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU), NonMaxSuppressionV3Op<CPUDevice>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU), + NonMaxSuppressionV4Op<CPUDevice>); + REGISTER_KERNEL_BUILDER( Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU), NonMaxSuppressionWithOverlapsOp<CPUDevice>); diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc index 055161a35f..c321849f40 100644 --- a/tensorflow/core/kernels/non_max_suppression_op_test.cc +++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc @@ -570,6 +570,61 @@ TEST_F(NonMaxSuppressionV3OpTest, TestEmptyInput) { } // +// NonMaxSuppressionV4Op Tests +// + +class NonMaxSuppressionV4OpTest : public OpsTestBase { + protected: + void MakeOp() { + TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV4") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("pad_to_max_output_size", true) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFive) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {5}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + const auto expected_indices = test::AsTensor<int>({3, 0, 5, 0, 0}); + test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0)); + Tensor expected_num_valid = test::AsScalar<int>(3); + test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1)); +} + +TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFiveScoreThr) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {6}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.4f}); + TF_ASSERT_OK(RunOpKernel()); + + const auto expected_indices = test::AsTensor<int>({3, 0, 0, 0, 0, 0}); + test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0)); + Tensor expected_num_valid = test::AsScalar<int>(2); + test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1)); +} + +// // NonMaxSuppressionWithOverlapsOp Tests // diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index a7a9609c21..33ed044dae 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -114,8 +114,16 @@ class PartitionedCallOp : public AsyncOpKernel { // The FunctionLibraryRuntime's library cannot be mutated from within // an OpKernel, so functions are instantiated in an overlay library. - overlay_lib_.reset(new FunctionLibraryDefinition( - *lib->GetFunctionLibraryDefinition())); + OP_REQUIRES_ASYNC( + ctx, overlay_libs_.find(lib) == overlay_libs_.end(), + errors::Internal("Found an overlay library but did not " + "find cached function partitions; " + "this indicates a bug."), + done); + FunctionLibraryDefinition* overlay_lib = + new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition()); + overlay_libs_.emplace(lib, overlay_lib); + auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>(); for (const auto& pair : subgraphs) { // TODO(akshayka): Fail gracefully if the set of devices corresponds @@ -125,13 +133,13 @@ class PartitionedCallOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC( ctx, UpdateArgAndRetMetadata(target, subgraph.get()), done); FunctionDef shard; - string unique_name = UniquifyFunctionName(func_.name()); + string unique_name = UniquifyFunctionName(overlay_lib, func_.name()); OP_REQUIRES_OK_ASYNC( ctx, GraphToFunctionDef(*subgraph, unique_name, &shard), done); - OP_REQUIRES_OK_ASYNC(ctx, overlay_lib_->AddFunctionDef(shard), done); + OP_REQUIRES_OK_ASYNC(ctx, overlay_lib->AddFunctionDef(shard), done); FunctionLibraryRuntime::InstantiateOptions opts; opts.target = target; - opts.overlay_lib = overlay_lib_.get(); + opts.overlay_lib = overlay_lib; FHandle handle; OP_REQUIRES_OK_ASYNC( ctx, @@ -399,10 +407,11 @@ class PartitionedCallOp : public AsyncOpKernel { } } - string UniquifyFunctionName(const string& name) { + string UniquifyFunctionName(const FunctionLibraryDefinition* function_library, + const string& name) { for (;; ++suffix_) { const string candidate = strings::StrCat(name, "_", suffix_); - if (overlay_lib_->Find(candidate) == nullptr) { + if (function_library->Find(candidate) == nullptr) { return candidate; } } @@ -410,14 +419,16 @@ class PartitionedCallOp : public AsyncOpKernel { NameAttrList func_; string local_device_name_; - // Function shards are added to `overlay_lib_`. - std::unique_ptr<FunctionLibraryDefinition> overlay_lib_; - // Contains maps from device names to handles of function shards, keyed by + // Contains maps from device names to handles of function partitions, keyed by // FunctionLibraryRuntime pointers. (Because this kernel may be instantiated // for a stateful op, different invocations of it may use different FLRs.) gtl::FlatMap<FunctionLibraryRuntime*, std::unique_ptr<gtl::FlatMap<string, FHandle>>> function_handles_ GUARDED_BY(mu_); + // Function partitions are added to overlay libraries. + gtl::FlatMap<FunctionLibraryRuntime*, + std::unique_ptr<FunctionLibraryDefinition>> + overlay_libs_ GUARDED_BY(mu_); // Map from device name to the indices of the arguments and return values // placed on that device. Read-only after the first invocation. gtl::FlatMap<string, ArgAndRetIndices> arg_and_ret_indices_; @@ -427,7 +438,7 @@ class PartitionedCallOp : public AsyncOpKernel { mutex mu_; - // Used to uniquify function names in `overlay_lib_`. + // Used to uniquify function names in `overlay_libs_`. uint32 suffix_ = 0; }; REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.h b/tensorflow/core/kernels/quantize_and_dequantize_op.h index 782263e4e9..6b0c5e5a46 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op.h +++ b/tensorflow/core/kernels/quantize_and_dequantize_op.h @@ -19,6 +19,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/cwise_ops.h" namespace tensorflow { namespace functor { @@ -89,17 +90,14 @@ struct QuantizeAndDequantizeOneScaleImpl { // min_range and max_range - because we may have changed either min_range // or max_range. out.device(d) = - ((input.cwiseMin(max_range).cwiseMax(min_range) - min_range) * scale + - T(0.5)) - .floor() * - inverse_scale + - min_range; + (input.cwiseMin(max_range).cwiseMax(min_range) * scale) + .unaryExpr(Eigen::internal::scalar_round_op_google<T>()) * + inverse_scale; } else { - // No need to clamp to min_range and max_range in this case as they were - // measured from the tensor. out.device(d) = - ((input - min_range) * scale + T(0.5)).floor() * inverse_scale + - min_range; + (input * scale) + .unaryExpr(Eigen::internal::scalar_round_op_google<T>()) * + inverse_scale; } } }; diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc index 629c698503..cddabf8a99 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc @@ -226,13 +226,13 @@ TEST_F(QuantizeAndDequantizeTest, Convert_2D_tensor_with_int8_range_given) { AddInputFromArray<float>(TensorShape({}), {1.0}); // Max // Note that the range is given as [-1, 1]. - // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -128, + // With int8, the tensor is quantized to {-102, -64, 0, 38, 102, 70, -128, // 127}. // Scale is: 1/127 TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 4})); test::FillValues<float>( - &expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127, + &expected, {-102.0 / 127, -64.0 / 127, 0, 38.0 / 127, 102.0 / 127, 70.0 / 127, -128.0 / 127, 1}); test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5); } @@ -257,13 +257,13 @@ TEST_F(QuantizeAndDequantizeTest, Convert_2D_tensor_with_int8_range_given_V3) { AddInputFromArray<int32>(TensorShape({}), {8}); // num_bits // Note that the range is given as [-1, 1]. - // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -128, + // With int8, the tensor is quantized to {-102, -64, 0, 38, 102, 70, -128, // 127}. // Scale is: 1/127 TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 4})); test::FillValues<float>( - &expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127, + &expected, {-102.0 / 127, -64.0 / 127, 0, 38.0 / 127, 102.0 / 127, 70.0 / 127, -128.0 / 127, 1}); test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5); } @@ -285,11 +285,11 @@ TEST_F(QuantizeAndDequantizeTest, Convert_4D_tensor_with_uint8_range_given) { AddInputFromArray<float>(TensorShape({}), {1.0}); // Max // Note that the range is given as [0, 1]. - // With int8, the tensor is quantized to {0, 0, 77, 204} + // With int8, the tensor is quantized to {0, 0, 76, 204} // Scale is: 1/255 TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 1, 1})); - test::FillValues<float>(&expected, {0, 0, 77.0 / 255, 204.0 / 255}); + test::FillValues<float>(&expected, {0, 0, 76.0 / 255, 204.0 / 255}); test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5); } @@ -311,11 +311,11 @@ TEST_F(QuantizeAndDequantizeTest, Convert_4D_tensor_with_uint8_range_given_V3) { AddInputFromArray<int32>(TensorShape({}), {8}); // num_bits // Note that the range is given as [0, 1]. - // With int8, the tensor is quantized to {0, 0, 77, 204} + // With int8, the tensor is quantized to {0, 0, 76, 204} // Scale is: 1/255 TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 1, 1})); - test::FillValues<float>(&expected, {0, 0, 77.0 / 255, 204.0 / 255}); + test::FillValues<float>(&expected, {0, 0, 76.0 / 255, 204.0 / 255}); test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5); } diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index c5292e1ae1..cab9eb729d 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -213,64 +213,32 @@ class AssignVariableOp : public OpKernel { "Variable and value dtypes don't match; respectively, ", dtype_, " and ", context->input(1).dtype())); Var* variable = nullptr; - OP_REQUIRES_OK( - context, - LookupOrCreateResource<Var>( - context, HandleFromInput(context, 0), &variable, - [this, context](Var** ptr) { - *ptr = new Var(dtype_); - PersistentTensor unused; - Tensor* tmp; - AllocatorAttributes attr; - if (!relax_constraints_) { - attr.set_gpu_compatible(true); - attr.set_nic_compatible(true); - } - TF_RETURN_IF_ERROR(context->allocate_persistent( - dtype_, context->input(1).shape(), &unused, &tmp, attr)); - *(*ptr)->tensor() = *tmp; - return Status::OK(); - })); + const Tensor& value = context->input(1); + // Note: every resource-variable-manipulating op assumes copy-on-write + // semantics, and creates a copy of the variable's Tensor if its refcount is + // bigger than 1 when we try to modify it. This means we never need to copy + // the original tensor for AssignVariableOp; even if there are other live + // users of it we know none can modify it so this is always safe (even in + // esoteric cases where the same tensor is used to initialize multiple + // variables or the tensor is a constant this is safe, as future writes will + // trigger copies). + OP_REQUIRES_OK(context, LookupOrCreateResource<Var>( + context, HandleFromInput(context, 0), &variable, + [this, &value](Var** ptr) { + *ptr = new Var(dtype_); + *(*ptr)->tensor() = value; + (*ptr)->is_initialized = true; + return Status::OK(); + })); core::ScopedUnref s(variable); - OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, errors::InvalidArgument( "Trying to assign variable with wrong dtype. Expected ", DataTypeString(variable->tensor()->dtype()), " got ", DataTypeString(dtype_))); - - const Tensor& value = context->input(1); - AllocatorAttributes attr; - if (!relax_constraints_) { - attr.set_gpu_compatible(true); - attr.set_nic_compatible(true); - } - - // Copying is unnecessary if we are the last user of the value - // tensor, we can just adopt the input tensor's buffer instead. - std::unique_ptr<Tensor> input_alias = context->forward_input( - 1, OpKernelContext::Params::kNoReservation /*output_index*/, dtype_, - value.shape(), DEVICE_MEMORY, attr); mutex_lock ml(*variable->mu()); variable->is_initialized = true; - if (input_alias) { - *variable->tensor() = *input_alias; - return; - } - - // Need to copy, but maybe we can re-use variable's buffer? - if (!variable->tensor()->RefCountIsOne() || - !variable->tensor()->shape().IsSameSize(value.shape())) { - // Copy to new buffer - PersistentTensor unused; - Tensor* tmp; - OP_REQUIRES_OK(context, context->allocate_persistent( - dtype_, value.shape(), &unused, &tmp, attr)); - *variable->tensor() = *tmp; - } - functor::DenseUpdate<Device, T, ASSIGN> copy_functor; - copy_functor(context->eigen_device<Device>(), variable->tensor()->flat<T>(), - value.flat<T>()); + *variable->tensor() = value; } private: diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index 990bd2bff9..e335e38bdc 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -23,7 +23,9 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -95,7 +97,7 @@ void SaveTensors( return tensor_names_flat(a) < tensor_names_flat(b); }); - for (size_t i : sorted_name_idx) { + for (const size_t i : sorted_name_idx) { const string& name = tensor_names_flat(i); const Tensor& input = context->input(i + kFixedInputs); TensorShape shape(input.shape()); @@ -226,43 +228,53 @@ void RestoreTensor(OpKernelContext* context, #undef READER_COPY } -Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, - const Tensor& tensor_names, - const Tensor& shape_and_slices, - gtl::ArraySlice<DataType> dtypes) { - const string& prefix_string = prefix.scalar<string>()(); +namespace { - const auto& tensor_names_flat = tensor_names.flat<string>(); - const auto& shape_and_slices_flat = shape_and_slices.flat<string>(); +// Tensors larger than this threshold will be restored from a thread-pool. +const int64 kLargeShapeThreshold = 16 << 20; // 16M - // Sort lookup keys to improve locality when reading multiple tensors. - std::vector<size_t> sorted_name_idx(tensor_names_flat.size()); - std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0); - std::sort(sorted_name_idx.begin(), sorted_name_idx.end(), - [&tensor_names_flat](size_t a, size_t b) { - return tensor_names_flat(a) < tensor_names_flat(b); - }); +// A restore operation for a single tensor. Small tensors may be restored +// directly from the op thread to improve read locality. Large tensors can be +// restored from a thread pool: this requires creating a separate BundleReader +// for each restore. +struct RestoreOp { + RestoreOp& operator=(const RestoreOp&) = delete; - BundleReader reader(Env::Default(), prefix_string); - TF_RETURN_IF_ERROR(reader.status()); + bool should_run_in_pool(BundleReader* reader) const { + TensorShape restored_full_shape; - // TODO(zongheng): potential optimization: one Seek() in first lookup. - // TODO(zongheng): consider measuring speed and issuing concurrent lookups - // within a fixed memory budget. - TensorShape restored_full_shape; - Tensor* restored_tensor = nullptr; - for (auto i : sorted_name_idx) { - const string& tensor_name = tensor_names_flat(i); - const string& shape_and_slice = shape_and_slices_flat(i); + // Ignore status here; we'll catch the error later. + if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) { + return false; + } + + return restored_full_shape.num_elements() > kLargeShapeThreshold; + } + + // Run this restore operation using a new BundleReader. + void run_with_new_reader() { + BundleReader reader(Env::Default(), reader_prefix); + if (!reader.status().ok()) { + status = reader.status(); + return; + } + + status = run(&reader); + } + Status run(BundleReader* reader) { + TensorShape restored_full_shape; TF_RETURN_IF_ERROR( - reader.LookupTensorShape(tensor_name, &restored_full_shape)); + reader->LookupTensorShape(tensor_name, &restored_full_shape)); + VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : " + << restored_full_shape.num_elements(); + Tensor* restored_tensor; if (shape_and_slice.empty()) { // Lookup the full tensor. TF_RETURN_IF_ERROR( - context->allocate_output(i, restored_full_shape, &restored_tensor)); - TF_RETURN_IF_ERROR(reader.Lookup(tensor_name, restored_tensor)); + context->allocate_output(idx, restored_full_shape, &restored_tensor)); + TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor)); } else { // Lookup the slice. TensorShape parsed_full_shape; @@ -272,6 +284,7 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, TF_RETURN_IF_ERROR( checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape, &parsed_slice, &parsed_slice_shape)); + if (!restored_full_shape.IsSameSize(parsed_full_shape)) { return errors::InvalidArgument( "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ", @@ -279,19 +292,113 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, " does not match the shape stored in checkpoint: ", restored_full_shape.DebugString()); } - TF_RETURN_IF_ERROR( - context->allocate_output(i, parsed_slice_shape, &restored_tensor)); + context->allocate_output(idx, parsed_slice_shape, &restored_tensor)); TF_RETURN_IF_ERROR( - reader.LookupSlice(tensor_name, parsed_slice, restored_tensor)); + reader->LookupSlice(tensor_name, parsed_slice, restored_tensor)); + } + return Status::OK(); + } + + OpKernelContext* context; + size_t idx; + string tensor_name; + string shape_and_slice; + string reader_prefix; + + ::tensorflow::Status status; +}; + +} // namespace + +Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, + const Tensor& tensor_names, + const Tensor& shape_and_slices, + gtl::ArraySlice<DataType> dtypes) { + const string& prefix_string = prefix.scalar<string>()(); + + const auto& tensor_names_flat = tensor_names.flat<string>(); + const auto& shape_and_slices_flat = shape_and_slices.flat<string>(); + + // Sort lookup keys to improve locality when reading multiple tensors. + std::vector<size_t> sorted_name_idx(tensor_names_flat.size()); + std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0); + std::sort(sorted_name_idx.begin(), sorted_name_idx.end(), + [&tensor_names_flat](size_t a, size_t b) { + return tensor_names_flat(a) < tensor_names_flat(b); + }); + + std::vector<std::unique_ptr<RestoreOp> > pool_restore_ops; + std::vector<std::unique_ptr<RestoreOp> > direct_restore_ops; + + BundleReader default_reader(Env::Default(), prefix_string); + TF_RETURN_IF_ERROR(default_reader.status()); + + std::vector<string> mismatched_errors; + for (const size_t i : sorted_name_idx) { + TensorShape restored_full_shape; + DataType original_dtype; + const string& tensor_name = tensor_names_flat(i); + TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape( + tensor_name, &original_dtype, &restored_full_shape)); + if (dtypes[i] != original_dtype) { + string error_msg = strings::StrCat( + "tensor_name = ", tensor_name, "; expected dtype ", + DataTypeString(dtypes[i]), " does not equal original dtype ", + DataTypeString(original_dtype)); + mismatched_errors.emplace_back(error_msg); + } + } + if (!mismatched_errors.empty()) { + const string error_msg = str_util::Join(mismatched_errors, "\n"); + return errors::InvalidArgument(error_msg); + } + + for (auto i : sorted_name_idx) { + const string& tensor_name = tensor_names_flat(i); + const string& shape_and_slice = shape_and_slices_flat(i); + auto op = + new RestoreOp{context, i, tensor_name, shape_and_slice, prefix_string}; + if (op->should_run_in_pool(&default_reader)) { + pool_restore_ops.emplace_back(op); + } else { + direct_restore_ops.emplace_back(op); + } + } + + { + // Schedule any threaded operations first, skipping thread pool creation if + // we don't have any expensive operations. + std::unique_ptr<thread::ThreadPool> reader_pool; + if (!pool_restore_ops.empty()) { + reader_pool.reset( + new thread::ThreadPool(Env::Default(), "restore_tensors", 8)); + for (auto& op : pool_restore_ops) { + reader_pool->Schedule([&op]() { op->run_with_new_reader(); }); + } } - if (dtypes[i] != restored_tensor->dtype()) { + + // Read small tensors from the op thread + for (auto& op : direct_restore_ops) { + TF_RETURN_IF_ERROR(op->run(&default_reader)); + } + } + + // Check status of pool ops; this must come after the pool shuts down. + for (auto& op : pool_restore_ops) { + TF_RETURN_IF_ERROR(op->status); + } + + for (auto i : sorted_name_idx) { + const string& tensor_name = tensor_names_flat(i); + if (dtypes[i] != context->mutable_output(i)->dtype()) { return errors::InvalidArgument( "tensor_name = ", tensor_name, "; expected dtype ", DataTypeString(dtypes[i]), " does not equal restored dtype ", - DataTypeString(restored_tensor->dtype())); + DataTypeString(context->mutable_output(i)->dtype())); } } + return Status::OK(); } diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc index e72608945b..93a753787a 100644 --- a/tensorflow/core/kernels/softmax_op.cc +++ b/tensorflow/core/kernels/softmax_op.cc @@ -61,15 +61,16 @@ class SoftmaxOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& logits_in = context->input(0); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), - errors::InvalidArgument("logits must be 2-dimensional")); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in.shape()), + errors::InvalidArgument("logits must have >= 1 dimension, got ", + logits_in.shape().DebugString())); Tensor* softmax_out = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 0, logits_in.shape(), &softmax_out)); if (logits_in.NumElements() > 0) { functor::SoftmaxFunctor<Device, T> functor; - functor(context->eigen_device<Device>(), logits_in.matrix<T>(), - softmax_out->matrix<T>(), log_); + functor(context->eigen_device<Device>(), logits_in.flat_inner_dims<T>(), + softmax_out->flat_inner_dims<T>(), log_); } } diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc index b63dcbb163..d1e677feb0 100644 --- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc +++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc @@ -134,11 +134,12 @@ class SoftmaxOpGPU : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& logits_in_ = context->input(0); - auto logits_in = logits_in_.matrix<T>(); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in_.shape()), + errors::InvalidArgument("logits must have >= 1 dimension, got ", + logits_in_.shape().DebugString())); + auto logits_in = logits_in_.flat_inner_dims<T>(); const int rows = logits_in.dimension(0); const int cols = logits_in.dimension(1); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in_.shape()), - errors::InvalidArgument("logits must be 2-dimensional")); Tensor* softmax_out = nullptr; OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( {0}, 0, logits_in_.shape(), &softmax_out)); diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc index fdc08ec8e3..64f1b0d661 100644 --- a/tensorflow/core/kernels/spacetobatch_op.cc +++ b/tensorflow/core/kernels/spacetobatch_op.cc @@ -42,29 +42,29 @@ typedef Eigen::GpuDevice GPUDevice; namespace { template <typename Device, typename T> -void SpaceToBatchOpCompute(OpKernelContext* context, - const Tensor& orig_input_tensor, - const Tensor& orig_block_shape, - const Tensor& orig_paddings) { +Status SpaceToBatchOpCompute(OpKernelContext* context, + const Tensor& orig_input_tensor, + const Tensor& orig_block_shape, + const Tensor& orig_paddings) { const int input_dims = orig_input_tensor.dims(); - OP_REQUIRES( - context, TensorShapeUtils::IsVector(orig_block_shape.shape()), - errors::InvalidArgument("block_shape rank should be 1 instead of ", - orig_block_shape.dims())); + if (!TensorShapeUtils::IsVector(orig_block_shape.shape())) { + return errors::InvalidArgument("block_shape rank should be 1 instead of ", + orig_block_shape.dims()); + } const int block_dims = orig_block_shape.dim_size(0); - OP_REQUIRES( - context, orig_input_tensor.dims() >= 1 + block_dims, - errors::InvalidArgument("input rank should be >= ", 1 + block_dims, - " instead of ", orig_input_tensor.dims())); - - OP_REQUIRES(context, - TensorShapeUtils::IsMatrix(orig_paddings.shape()) && - block_dims == orig_paddings.dim_size(0) && - 2 == orig_paddings.dim_size(1), - errors::InvalidArgument("paddings should have shape [", - block_dims, ", 2] instead of ", - orig_paddings.shape().DebugString())); + if (orig_input_tensor.dims() < 1 + block_dims) { + return errors::InvalidArgument("input rank should be >= ", 1 + block_dims, + " instead of ", orig_input_tensor.dims()); + } + + if (!(TensorShapeUtils::IsMatrix(orig_paddings.shape()) && + block_dims == orig_paddings.dim_size(0) && + 2 == orig_paddings.dim_size(1))) { + return errors::InvalidArgument("paddings should have shape [", block_dims, + ", 2] instead of ", + orig_paddings.shape().DebugString()); + } // To avoid out-of-bounds access in the case that the block_shape and/or // paddings tensors are concurrently modified, we must copy the values. @@ -101,22 +101,23 @@ void SpaceToBatchOpCompute(OpKernelContext* context, for (int block_dim = 0; block_dim < block_dims; ++block_dim) { block_shape_product *= block_shape[block_dim]; } - OP_REQUIRES( - context, block_shape_product > 0, - errors::InvalidArgument("Product of block sizes must be positive, got ", - block_shape_product)); + if (block_shape_product <= 0) { + return errors::InvalidArgument( + "Product of block sizes must be positive, got ", block_shape_product); + } const int internal_block_dims = block_dims - removed_prefix_block_dims - removed_suffix_block_dims; - OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims, - errors::InvalidArgument( - "Maximum number of non-combined block dimensions is ", - internal_block_dims, " but must not exceed ", - kMaxSpaceToBatchBlockDims)); + if (internal_block_dims > kMaxSpaceToBatchBlockDims) { + return errors::InvalidArgument( + "Maximum number of non-combined block dimensions is ", + internal_block_dims, " but must not exceed ", + kMaxSpaceToBatchBlockDims); + } if (internal_block_dims == 0) { context->set_output(0, orig_input_tensor); - return; + return Status::OK(); } // For the purpose of computing the result, the input will be treated as @@ -146,16 +147,18 @@ void SpaceToBatchOpCompute(OpKernelContext* context, block_dim < block_dims - removed_suffix_block_dims; ++block_dim) { const int64 pad_start = paddings[2 * block_dim], pad_end = paddings[2 * block_dim + 1]; - OP_REQUIRES(context, pad_start >= 0 && pad_end >= 0, - errors::InvalidArgument("Paddings must be non-negative")); + if (pad_start < 0 || pad_end < 0) { + return errors::InvalidArgument("Paddings must be non-negative"); + } const int64 input_size = orig_input_tensor.dim_size(block_dim + 1); const int64 block_shape_value = block_shape[block_dim]; const int64 padded_size = input_size + pad_start + pad_end; - OP_REQUIRES( - context, padded_size % block_shape_value == 0, - errors::InvalidArgument("padded_shape[", block_dim, "]=", padded_size, - " is not divisible by block_shape[", block_dim, - "]=", block_shape_value)); + if (padded_size % block_shape_value != 0) { + return errors::InvalidArgument("padded_shape[", block_dim, + "]=", padded_size, + " is not divisible by block_shape[", + block_dim, "]=", block_shape_value); + } internal_input_shape.AddDim(input_size); const int64 output_size = padded_size / block_shape_value; internal_output_shape.AddDim(output_size); @@ -174,29 +177,29 @@ void SpaceToBatchOpCompute(OpKernelContext* context, // Allocate output tensor. Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape, - &output_tensor)); + TF_RETURN_IF_ERROR( + context->allocate_output(0, external_output_shape, &output_tensor)); const int64* internal_paddings = &paddings[2 * removed_prefix_block_dims]; const int64* internal_block_shape = &block_shape[removed_prefix_block_dims]; switch (internal_block_dims) { -#define TF_SPACETOBATCH_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \ - case NUM_BLOCK_DIMS: { \ - OP_REQUIRES_OK( \ - context, \ - (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, false>()( \ - context->eigen_device<Device>(), \ - orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \ - internal_input_shape.dim_sizes()), \ - internal_block_shape, internal_paddings, \ - output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \ - internal_output_shape.dim_sizes())))); \ - } break; \ +#define TF_SPACETOBATCH_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \ + case NUM_BLOCK_DIMS: { \ + TF_RETURN_IF_ERROR( \ + functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, false>()( \ + context->eigen_device<Device>(), \ + orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \ + internal_input_shape.dim_sizes()), \ + internal_block_shape, internal_paddings, \ + output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \ + internal_output_shape.dim_sizes()))); \ + } break; \ /**/ TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_SPACETOBATCH_BLOCK_DIMS_CASE) #undef TF_SPACETOBATCH_BLOCK_DIMS_CASE } + return Status::OK(); } } // namespace @@ -211,8 +214,9 @@ class SpaceToBatchNDOp : public OpKernel { const Tensor& orig_input_tensor = context->input(0); const Tensor& orig_block_shape = context->input(1); const Tensor& orig_paddings = context->input(2); - SpaceToBatchOpCompute<Device, T>(context, orig_input_tensor, - orig_block_shape, orig_paddings); + OP_REQUIRES_OK(context, SpaceToBatchOpCompute<Device, T>( + context, orig_input_tensor, orig_block_shape, + orig_paddings)); } }; @@ -241,7 +245,8 @@ class SpaceToBatchOp : public OpKernel { OP_REQUIRES(context, kRequiredDims == dims, errors::InvalidArgument("Input rank should be: ", kRequiredDims, "instead of: ", dims)); - SpaceToBatchOpCompute<Device, T>(context, in0, block_shape_, in1); + OP_REQUIRES_OK(context, SpaceToBatchOpCompute<Device, T>( + context, in0, block_shape_, in1)); } private: diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 1e3e92a68a..59fdc2262a 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -304,6 +305,9 @@ class StridedSliceAssignOp : public OpKernel { Var* v; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &v)); + mutex_lock ml(*v->mu()); + OP_REQUIRES_OK(context, + PrepareToUpdateVariable<Device, T>(context, v->tensor())); old_lhs = *v->tensor(); OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value, errors::InvalidArgument( diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc index f288e124ee..d3c4f62071 100644 --- a/tensorflow/core/kernels/training_op_helpers.cc +++ b/tensorflow/core/kernels/training_op_helpers.cc @@ -39,8 +39,15 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) { // GetInputTensor which will signal a failure. std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) { + bool any_resource = false; + for (auto i : input_ids) { + if (ctx->input_dtype(i) == DT_RESOURCE) { + any_resource = true; + break; + } + } std::vector<mutex_lock> locks; - if (!do_lock) { + if (!do_lock && !any_resource) { return locks; } std::vector<mutex*> mutexes; diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h index 7e56e15450..765335d3a0 100644 --- a/tensorflow/core/kernels/training_op_helpers.h +++ b/tensorflow/core/kernels/training_op_helpers.h @@ -80,18 +80,8 @@ Status GetInputTensorFromVariable(OpKernelContext* ctx, int input, Var* var; TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var)); core::ScopedUnref unref_var(var); - if (lock_held) { - TF_RETURN_IF_ERROR( - PrepareToUpdateVariable<Device, T>(ctx, var->tensor())); - *out = *var->tensor(); - } else { - mutex_lock ml(*var->mu()); - if (!sparse) { - TF_RETURN_IF_ERROR( - PrepareToUpdateVariable<Device, T>(ctx, var->tensor())); - } - *out = *var->tensor(); - } + TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(ctx, var->tensor())); + *out = *var->tensor(); return Status::OK(); } *out = ctx->mutable_input(input, lock_held); |