aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-08 16:46:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 16:49:10 -0700
commitf5a1a38a831e9db5a822351f3a3b138ab1cb83b3 (patch)
treef6352dff2100c563c9e684251beceb39d54c92f0
parent00a4d11ac6d60f486b32c317ffddeae9a056cf38 (diff)
Created a ThreadPoolDevice wrapper to make each op run with the number of threads stored in NodeDef.
PiperOrigin-RevId: 199870879
-rw-r--r--tensorflow/core/framework/device_base.h4
-rw-r--r--tensorflow/core/framework/op_kernel.cc16
-rw-r--r--tensorflow/core/framework/op_kernel.h8
3 files changed, 25 insertions, 3 deletions
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index ec26d92a61..b59ced869d 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -186,6 +186,10 @@ class DeviceBase {
virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; }
+ const bool has_eigen_cpu_device() const {
+ return (eigen_cpu_device_ != nullptr);
+ }
+
virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
CHECK(eigen_cpu_device_ != nullptr);
return eigen_cpu_device_;
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index ce213a63be..a0f449d64f 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/op_kernel.h"
#include <unordered_map>
#include <utility>
#include <vector>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb_text.h"
@@ -40,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -270,6 +273,19 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
}
+ if (params->device->has_eigen_cpu_device()) {
+ int64 block_size = -1, output_size = -1, num_threads = 1;
+ const Eigen::ThreadPoolDevice* thread_pool =
+ params_->device->eigen_cpu_device();
+ AttrSlice attributes(op_kernel().def());
+ if (GetNodeAttr(attributes, "_block_size", &block_size) == Status::OK() &&
+ GetNodeAttr(attributes, "_output_size", &output_size) == Status::OK()) {
+ num_threads = std::min(Eigen::divup(output_size, block_size),
+ static_cast<int64>(thread_pool->numThreads()));
+ eigen_cpu_device_ = MakeUnique<Eigen::ThreadPoolDevice>(
+ thread_pool->getPool(), num_threads);
+ }
+ }
}
OpKernelContext::~OpKernelContext() {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 5ebe6976fd..d307078e63 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
-#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
#include <functional>
@@ -1004,6 +1004,7 @@ class OpKernelContext {
// OpKernels can use these eigen devices to carry out their
// numerical computation.
const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
+ if (eigen_cpu_device_ != nullptr) return *eigen_cpu_device_;
return *device()->eigen_cpu_device();
}
const Eigen::GpuDevice& eigen_gpu_device() const {
@@ -1139,6 +1140,7 @@ class OpKernelContext {
mutable mutex mu_; // mutable so const accessors can acquire the lock
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
gtl::InlinedVector<TensorValue, 4> outputs_;
+ std::unique_ptr<Eigen::ThreadPoolDevice> eigen_cpu_device_;
// Constructed only if <params->record_tensor_accesses>.
ManualConstructor<UniqueTensorReferences> referenced_tensors_ GUARDED_BY(mu_);
@@ -1576,4 +1578,4 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_