From f5a1a38a831e9db5a822351f3a3b138ab1cb83b3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 8 Jun 2018 16:46:20 -0700 Subject: Created a ThreadPoolDevice wrapper to make each op run with the number of threads stored in NodeDef. PiperOrigin-RevId: 199870879 --- tensorflow/core/framework/device_base.h | 4 ++++ tensorflow/core/framework/op_kernel.cc | 16 ++++++++++++++++ tensorflow/core/framework/op_kernel.h | 8 +++++--- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index ec26d92a61..b59ced869d 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -186,6 +186,10 @@ class DeviceBase { virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; } + const bool has_eigen_cpu_device() const { + return (eigen_cpu_device_ != nullptr); + } + virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() { CHECK(eigen_cpu_device_ != nullptr); return eigen_cpu_device_; diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index ce213a63be..a0f449d64f 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/core/framework/op_kernel.h" #include #include #include +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/graph.pb_text.h" @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -270,6 +273,19 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs) if (params_->record_tensor_accesses) { referenced_tensors_.Init(); } + if (params->device->has_eigen_cpu_device()) { + int64 block_size = -1, output_size = -1, num_threads = 1; + const Eigen::ThreadPoolDevice* thread_pool = + params_->device->eigen_cpu_device(); + AttrSlice attributes(op_kernel().def()); + if (GetNodeAttr(attributes, "_block_size", &block_size) == Status::OK() && + GetNodeAttr(attributes, "_output_size", &output_size) == Status::OK()) { + num_threads = std::min(Eigen::divup(output_size, block_size), + static_cast(thread_pool->numThreads())); + eigen_cpu_device_ = MakeUnique( + thread_pool->getPool(), num_threads); + } + } } OpKernelContext::~OpKernelContext() { diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 5ebe6976fd..d307078e63 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ -#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ #include @@ -1004,6 +1004,7 @@ class OpKernelContext { // OpKernels can use these eigen devices to carry out their // numerical computation. const Eigen::ThreadPoolDevice& eigen_cpu_device() const { + if (eigen_cpu_device_ != nullptr) return *eigen_cpu_device_; return *device()->eigen_cpu_device(); } const Eigen::GpuDevice& eigen_gpu_device() const { @@ -1139,6 +1140,7 @@ class OpKernelContext { mutable mutex mu_; // mutable so const accessors can acquire the lock gtl::InlinedVector wrapped_allocators_ GUARDED_BY(mu_); gtl::InlinedVector outputs_; + std::unique_ptr eigen_cpu_device_; // Constructed only if record_tensor_accesses>. ManualConstructor referenced_tensors_ GUARDED_BY(mu_); @@ -1576,4 +1578,4 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { } // namespace tensorflow -#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ -- cgit v1.2.3