diff options
author | 2018-06-08 16:46:20 -0700 | |
---|---|---|
committer | 2018-06-08 16:49:10 -0700 | |
commit | f5a1a38a831e9db5a822351f3a3b138ab1cb83b3 (patch) | |
tree | f6352dff2100c563c9e684251beceb39d54c92f0 | |
parent | 00a4d11ac6d60f486b32c317ffddeae9a056cf38 (diff) |
Created a ThreadPoolDevice wrapper to make each op run with the number of threads stored in NodeDef.
PiperOrigin-RevId: 199870879
-rw-r--r-- | tensorflow/core/framework/device_base.h | 4 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 8 |
3 files changed, 25 insertions, 3 deletions
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index ec26d92a61..b59ced869d 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -186,6 +186,10 @@ class DeviceBase { virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; } + const bool has_eigen_cpu_device() const { + return (eigen_cpu_device_ != nullptr); + } + virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() { CHECK(eigen_cpu_device_ != nullptr); return eigen_cpu_device_; diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index ce213a63be..a0f449d64f 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/core/framework/op_kernel.h" #include <unordered_map> #include <utility> #include <vector> +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/graph.pb_text.h" @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -270,6 +273,19 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs) if (params_->record_tensor_accesses) { referenced_tensors_.Init(); } + if (params->device->has_eigen_cpu_device()) { + int64 block_size = -1, output_size = -1, num_threads = 1; + const Eigen::ThreadPoolDevice* thread_pool = + params_->device->eigen_cpu_device(); + AttrSlice attributes(op_kernel().def()); + if (GetNodeAttr(attributes, "_block_size", &block_size) == Status::OK() && + GetNodeAttr(attributes, "_output_size", &output_size) == Status::OK()) { + num_threads = std::min(Eigen::divup(output_size, block_size), + static_cast<int64>(thread_pool->numThreads())); + eigen_cpu_device_ = MakeUnique<Eigen::ThreadPoolDevice>( + thread_pool->getPool(), num_threads); + } + } } OpKernelContext::~OpKernelContext() { diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 5ebe6976fd..d307078e63 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ -#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ #include <functional> @@ -1004,6 +1004,7 @@ class OpKernelContext { // OpKernels can use these eigen devices to carry out their // numerical computation. const Eigen::ThreadPoolDevice& eigen_cpu_device() const { + if (eigen_cpu_device_ != nullptr) return *eigen_cpu_device_; return *device()->eigen_cpu_device(); } const Eigen::GpuDevice& eigen_gpu_device() const { @@ -1139,6 +1140,7 @@ class OpKernelContext { mutable mutex mu_; // mutable so const accessors can acquire the lock gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_); gtl::InlinedVector<TensorValue, 4> outputs_; + std::unique_ptr<Eigen::ThreadPoolDevice> eigen_cpu_device_; // Constructed only if <params->record_tensor_accesses>. ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_); @@ -1576,4 +1578,4 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { } // namespace tensorflow -#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ |