diff options
-rw-r--r-- | tensorflow/core/common_runtime/executor.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/gpu/gpu_device.cc | 64 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/gpu/gpu_device.h | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/device_base.h | 22 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 55 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 150 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/core_ops_test.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/kernels/ops_testutil.h | 22 | ||||
-rw-r--r-- | tensorflow/core/kernels/restore_op_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops_test.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/sparse_to_dense_op_test.cc | 2 |
12 files changed, 215 insertions, 146 deletions
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index ef9de781c0..c6733b03bc 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -823,6 +823,9 @@ namespace { OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) { OpKernelContext::Params* ret = new OpKernelContext::Params; *ret = p; + // Ensure the copy of Params will make a new eigen GPU device if + // necessary. + ret->eigen_gpu_device = nullptr; ret->inputs = new TensorValueVec(*p.inputs); ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts); ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs); @@ -831,6 +834,8 @@ OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) { // Helpers to delete 'p' and copies made by CopyParams. void DeleteParams(OpKernelContext::Params* p) { + // No need to delete p->eigen_gpu_device since that is deleted in + // p's destructor delete p->inputs; delete p->input_device_contexts; delete p->input_alloc_attrs; @@ -929,7 +934,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { if (async) { // Asynchronous computes. auto pcopy = CopyParams(params); - auto ctx = new OpKernelContext(*pcopy); + auto ctx = new OpKernelContext(pcopy); auto done = [this, tagged_node, item, first_input, ctx, stats, pcopy, device]() { VLOG(2) << this << " Async kernel done: " @@ -967,7 +972,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) { device->ComputeAsync(async, ctx, done); } else { // Synchronous computes. - OpKernelContext ctx(params); + OpKernelContext ctx(¶ms); if (stats_collector_) nodestats::SetOpStart(stats); device->Compute(CHECK_NOTNULL(op_kernel), &ctx); if (stats_collector_) nodestats::SetOpEnd(stats); diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 44874c5189..1efe7b7251 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -73,9 +73,14 @@ namespace tensorflow { #if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) class EigenAllocator : public ::Eigen::Allocator { public: - explicit EigenAllocator(gpu::Stream* stream, ::tensorflow::Allocator* alloc, - EventMgr* em) - : stream_(stream), allocator_(alloc), em_(em) {} + EigenAllocator() {} + + void Reinitialize(gpu::Stream* stream, ::tensorflow::Allocator* alloc, + EventMgr* em) { + stream_ = stream; + allocator_ = alloc; + em_ = em; + } void* allocate(size_t num_bytes) const override { void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes); @@ -103,10 +108,12 @@ class EigenAllocator : public ::Eigen::Allocator { #else class EigenCudaStreamDevice : public ::Eigen::StreamInterface { public: - EigenCudaStreamDevice(const cudaStream_t* cuda_stream, int gpu_id, - ::tensorflow::Allocator* alloc) - : stream_(cuda_stream), allocator_(alloc) { - Eigen::initializeDeviceProp(); + EigenCudaStreamDevice() { Eigen::initializeDeviceProp(); } + + void Reinitialize(const cudaStream_t* cuda_stream, int gpu_id, + ::tensorflow::Allocator* alloc) { + stream_ = cuda_stream; + allocator_ = alloc; device_prop_ = &Eigen::m_deviceProperties[gpu_id]; } @@ -391,10 +398,11 @@ namespace { #if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) class ConcretePerOpGpuDevice : public PerOpGpuDevice { public: - explicit ConcretePerOpGpuDevice(gpu::Stream* stream, - Allocator* base_allocator, - ::tensorflow::EventMgr* em) - : allocator_(stream, base_allocator, em), device_(stream, &allocator_) {} + void Reinitialize(gpu::Stream* stream, Allocator* base_allocator, + ::tensorflow::EventMgr* em) { + allocator_.Reinitialize(stream, base_allocator, em); + device_.Reinitialize(stream, &allocator_); + } const Eigen::GpuDevice& device() const override { return device_; } @@ -405,10 +413,12 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice { #else class ConcretePerOpGpuDevice : public PerOpGpuDevice { public: - explicit ConcretePerOpGpuDevice(const cudaStream_t* cuda_stream, int gpu_id, - Allocator* base_allocator) - : stream_device_(cuda_stream, gpu_id, base_allocator), - device_(&stream_device_) {} + ConcretePerOpGpuDevice() : device_(&stream_device_) {} + + void Reinitialize(const cudaStream_t* cuda_stream, int gpu_id, + Allocator* base_allocator) { + stream_device_.Reinitialize(cuda_stream, gpu_id, base_allocator); + } const Eigen::GpuDevice& device() const override { return device_; } @@ -419,28 +429,36 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice { #endif } // namespace -const PerOpGpuDevice* BaseGPUDevice::NewDevice(int stream_id, - Allocator* allocator) { +void BaseGPUDevice::ReinitializeDevice(PerOpGpuDevice* device, int stream_id, + Allocator* allocator) { + ConcretePerOpGpuDevice* concrete_device = + dynamic_cast<ConcretePerOpGpuDevice*>(device); + DCHECK(concrete_device); #if defined(__GCUDACC__) || defined(__GCUDACC_HOST__) - return new ConcretePerOpGpuDevice(streams_[stream_id], allocator, em_.get()); + concrete_device->Reinitialize(streams_[stream_id], allocator, em_.get()); #else const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>( streams_[stream_id]->implementation()->CudaStreamMemberHack()); - return new ConcretePerOpGpuDevice(cuda_stream, gpu_id_, allocator); + concrete_device->Reinitialize(cuda_stream, gpu_id_, allocator); #endif } -const PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice(DeviceContext* dc, - Allocator* allocator) { +PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() { + return new ConcretePerOpGpuDevice(); +} + +void BaseGPUDevice::ReinitializeGpuDevice(PerOpGpuDevice* device, + DeviceContext* dc, + Allocator* allocator) { if (dc) { const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc); const int stream_id = gpu_dc->stream_id(); VLOG(1) << " eigen_gpu_device(" << dc << ") => stream[" << stream_id << "]"; CHECK_LT(stream_id, streams_.size()); - return NewDevice(stream_id, allocator); + ReinitializeDevice(device, stream_id, allocator); } else { - return NewDevice(0, allocator); + ReinitializeDevice(device, 0, allocator); } } diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h index 374bd73517..02a2279174 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.h +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -38,8 +38,6 @@ limitations under the License. namespace tensorflow { -class EigenAllocator; - class BaseGPUDevice : public LocalDevice { public: BaseGPUDevice(const SessionOptions& options, const string& name, @@ -74,8 +72,10 @@ class BaseGPUDevice : public LocalDevice { Tensor* tensor) override; // The caller owns the returned device. - const PerOpGpuDevice* MakeGpuDevice(DeviceContext* dc, - Allocator* allocator) override; + PerOpGpuDevice* MakeGpuDevice() override; + + void ReinitializeGpuDevice(PerOpGpuDevice* device, DeviceContext* dc, + Allocator* allocator) override; protected: Allocator* gpu_allocator_; // not owned @@ -90,7 +90,8 @@ class BaseGPUDevice : public LocalDevice { const bool sync_every_op_ = false; std::unique_ptr<EventMgr> em_; - const PerOpGpuDevice* NewDevice(int stream_id, Allocator* allocator); + void ReinitializeDevice(PerOpGpuDevice* device, int stream_id, + Allocator* allocator); }; class BaseGPUDeviceFactory : public DeviceFactory { diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 23944a5f43..50b85e2c17 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -48,7 +48,9 @@ namespace thread { class ThreadPool; } -// A wrapper for an Eigen Gpu Device that includes per-op state +// A wrapper for an Eigen Gpu Device that includes per-op state. The +// class is defined even for non-GPU devices since the +// OpKernelContext::Params structure wants to fill it in. class PerOpGpuDevice { public: virtual ~PerOpGpuDevice() {} @@ -161,14 +163,16 @@ class DeviceBase { return eigen_cpu_device_; } - // The caller owns the returned device and must free it by calling - // DisposeGpuDevice below - virtual const PerOpGpuDevice* MakeGpuDevice(DeviceContext* /*dc*/, - Allocator* /*allocator*/) { - // The OpKernelContext calls this even for devices that do not - // implement an eigen_gpu_device - return nullptr; - } + // Caller owns the return value. The OpKernelContext calls this even + // for devices that do not implement an eigen_gpu_device. Overridden + // by GPU devices to return a derived type. + virtual PerOpGpuDevice* MakeGpuDevice() { return nullptr; } + + // This is overridden by GPU devices to reinitialize the derived + // type returned by MakeGpuDevice. + virtual void ReinitializeGpuDevice(PerOpGpuDevice* /*device*/, + DeviceContext* /*dc*/, + Allocator* /*allocator*/) {} virtual const DeviceAttributes& attributes() const { LOG(FATAL) << "Device does not implement attributes()"; diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index f2b9a9e295..fcc426af87 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -211,11 +211,13 @@ Status OpKernelConstruction::allocate_persistent( // OpKernelContext ----------------------------------------------------------- -OpKernelContext::OpKernelContext(const Params& params) - : params_(params), outputs_(params.op_kernel->output_types().size()) { +OpKernelContext::OpKernelContext(Params* params) + : params_(params), outputs_(params_->op_kernel->output_types().size()) { Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); - eigen_gpu_device_ = params_.device->MakeGpuDevice(params_.op_device_context, - eigen_gpu_allocator); + params_->ensure_eigen_gpu_device(); + params_->device->ReinitializeGpuDevice(params_->eigen_gpu_device, + params_->op_device_context, + eigen_gpu_allocator); } OpKernelContext::~OpKernelContext() { @@ -224,30 +226,29 @@ OpKernelContext::~OpKernelContext() { delete value.tensor; } } - delete eigen_gpu_device_; } Status OpKernelContext::input(const string& name, const Tensor** tensor) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued input name '", name, "' when single-valued input was " "expected"); } - if ((*params_.inputs)[start].is_ref()) { + if ((*params_->inputs)[start].is_ref()) { return errors::InvalidArgument("OpKernel used ref input name '", name, "' when immutable input was expected"); } - *tensor = (*params_.inputs)[start].tensor; + *tensor = (*params_->inputs)[start].tensor; record_tensor_reference(**tensor); return Status::OK(); } Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued input name '", name, @@ -260,22 +261,22 @@ Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) { Status OpKernelContext::mutable_input(const string& name, Tensor* tensor, bool lock_held) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued input name '", name, "' when single-valued input was expected"); } - if (!(*params_.inputs)[start].is_ref()) { + if (!(*params_->inputs)[start].is_ref()) { return errors::InvalidArgument("OpKernel used immutable input name '", name, "' when ref input was expected"); } // return a copy of the Ref acquired while holding the mutex if (lock_held) { - *tensor = *(*params_.inputs)[start].tensor; + *tensor = *(*params_->inputs)[start].tensor; } else { mutex_lock l(*input_ref_mutex(start)); - *tensor = *(*params_.inputs)[start].tensor; + *tensor = *(*params_->inputs)[start].tensor; } record_tensor_reference(*tensor); return Status::OK(); @@ -285,13 +286,13 @@ Status OpKernelContext::replace_ref_input(const string& name, const Tensor& tensor, bool lock_held) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued input name '", name, "' when single-valued input was expected"); } - if (!(*params_.inputs)[start].is_ref()) { + if (!(*params_->inputs)[start].is_ref()) { return errors::InvalidArgument("OpKernel used immutable input name '", name, "' when ref input was expected"); } @@ -301,7 +302,7 @@ Status OpKernelContext::replace_ref_input(const string& name, Status OpKernelContext::input_list(const string& name, OpInputList* list) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); *list = OpInputList(this, start, stop); return Status::OK(); } @@ -309,14 +310,14 @@ Status OpKernelContext::input_list(const string& name, OpInputList* list) { Status OpKernelContext::mutable_input_list(const string& name, OpMutableInputList* list) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); *list = OpMutableInputList(this, start, stop); return Status::OK(); } Status OpKernelContext::output_list(const string& name, OpOutputList* list) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); *list = OpOutputList(this, start, stop); return Status::OK(); } @@ -325,7 +326,7 @@ Status OpKernelContext::allocate_output(const string& name, const TensorShape& shape, Tensor** tensor) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued output name '", name, @@ -340,7 +341,7 @@ Status OpKernelContext::allocate_output(const string& name, Tensor** tensor, AllocatorAttributes attr) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued output name '", name, @@ -352,7 +353,7 @@ Status OpKernelContext::allocate_output(const string& name, Status OpKernelContext::set_output(const string& name, const Tensor& tensor) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued output name '", name, @@ -366,7 +367,7 @@ Status OpKernelContext::set_output(const string& name, const Tensor& tensor) { Status OpKernelContext::set_output_ref(const string& name, mutex* mu, Tensor* tensor_for_ref) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued output name '", name, @@ -379,7 +380,7 @@ Status OpKernelContext::set_output_ref(const string& name, mutex* mu, Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued output name '", name, @@ -392,7 +393,7 @@ Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) { Status OpKernelContext::release_output(const string& name, TensorValue* value) { int start, stop; - TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { return errors::InvalidArgument("OpKernel used list-valued output name '", name, @@ -404,7 +405,7 @@ Status OpKernelContext::release_output(const string& name, TensorValue* value) { } bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { - const auto& inputs = *params_.inputs; + const auto& inputs = *params_->inputs; for (size_t i = 1; i < inputs.size(); ++i) { if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) { SetStatus(errors::InvalidArgument( @@ -421,10 +422,10 @@ bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { DataTypeVector inputs; - for (const TensorValue& t : *params_.inputs) { + for (const TensorValue& t : *params_->inputs) { inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype()); } - DataTypeVector outputs = params_.op_kernel->output_types(); + DataTypeVector outputs = params_->op_kernel->output_types(); return MatchSignatureHelper(expected_inputs, expected_outputs, inputs, outputs); } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 5234b4cef3..671c6de04f 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -430,13 +430,43 @@ class OpKernelContext { typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator; // TODO(zhifengc): Do some cleanup of Params. + // The Params struct is passed in to initialize an OpKernelContext, + // and must outlive the OpKernelContext. struct Params { + ~Params() { delete eigen_gpu_device; } + // The op kernel being computed. OpKernel* op_kernel = nullptr; // The device on which the kernel is running. DeviceBase* device = nullptr; + // The Eigen GPU device wrapper, which may include a per-op + // wrapped allocator. The concrete type of this object depends on + // the type of this->device, so eigen_gpu_device can't be an + // inline member and must be heap allocated. However, we don't + // want to allocate a new eigen_gpu_device for every Op that is + // executed. Instead this member is allocated on first use using + // ensure_eigen_gpu_device, and then if the Params structure is + // re-used for subsequent Ops, the eigen_gpu_device is + // ReInitialized in the OpKernelContext constructor. Unlike the + // other pointers in Params, this one is owned by Params. + PerOpGpuDevice* eigen_gpu_device = nullptr; + + inline void ensure_eigen_gpu_device() { + DCHECK(device); + if (nullptr == eigen_gpu_device) { + // Surprisingly, MakeGpuDevice will return nullptr if the + // device is not a GPU device. This is ok, since those devices + // will never use eigen_gpu_device. It seems better to have + // ensure_eigen_gpu_device fall through and regenerate the + // nullptr every time an OpKernelContext is instantiated, than + // to do an unneccessary allocation of a dummy eigen GPU + // device for CPU device Ops. + eigen_gpu_device = device->MakeGpuDevice(); + } + } + bool track_allocations = false; // Array indexed by output number for this node @@ -478,16 +508,18 @@ class OpKernelContext { // TensorSliceReaderCache support. checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; }; - explicit OpKernelContext(const Params& params); + + // params must outlive the OpKernelContext. + explicit OpKernelContext(Params* params); ~OpKernelContext(); - Env* env() const { return params_.device->env(); } + Env* env() const { return params_->device->env(); } - const OpKernel& op_kernel() const { return *params_.op_kernel; } + const OpKernel& op_kernel() const { return *params_->op_kernel; } // Input/output signature. - int num_inputs() const { return params_.inputs->size(); } + int num_inputs() const { return params_->inputs->size(); } DataType input_dtype(int index) const; int num_outputs() const { return outputs_.size(); } DataType expected_output_dtype(int index) const; @@ -756,7 +788,7 @@ class OpKernelContext { template <typename T> T* op_device_context(); DeviceContext* op_device_context() { - DeviceContext* ret = params_.op_device_context; + DeviceContext* ret = params_->op_device_context; if (ret == nullptr) { auto* dev_info = device()->tensorflow_gpu_device_info(); if (dev_info) ret = dev_info->default_context; @@ -766,12 +798,12 @@ class OpKernelContext { AllocatorAttributes input_alloc_attr(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.input_alloc_attrs->size()); - return (*params_.input_alloc_attrs)[index]; + DCHECK_LT(index, params_->input_alloc_attrs->size()); + return (*params_->input_alloc_attrs)[index]; } AllocatorAttributes output_alloc_attr(int index) const { - return params_.output_attr_array[index]; + return params_->output_attr_array[index]; } gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const { @@ -784,25 +816,25 @@ class OpKernelContext { // // An op kernel communicates with outside environment through // Rendezvous Send() and Recv(). - Rendezvous* rendezvous() const { return params_.rendezvous; } + Rendezvous* rendezvous() const { return params_->rendezvous; } // Function call support. // // If this kernel invocation is within a function execution, // call_frame() returns the call frame for the function call. - FunctionCallFrame* call_frame() const { return params_.call_frame; } + FunctionCallFrame* call_frame() const { return params_->call_frame; } // If not nullptr, the kernel invoke functions defined in the // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...). FunctionLibraryRuntime* function_library() const { - return params_.function_library; + return params_->function_library; } // Shared resources accessible to this kernel. - ResourceMgr* resource_manager() const { return params_.resource_manager; } + ResourceMgr* resource_manager() const { return params_->resource_manager; } checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const { - return params_.slice_reader_cache; + return params_->slice_reader_cache; } // Execution. @@ -813,7 +845,7 @@ class OpKernelContext { return *device()->eigen_cpu_device(); } const Eigen::GpuDevice& eigen_gpu_device() const { - return eigen_gpu_device_->device(); + return params_->eigen_gpu_device->device(); } template <typename EigenDeviceType> const EigenDeviceType& eigen_device() const; @@ -837,19 +869,19 @@ class OpKernelContext { // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an // example of how to use this API. CancellationManager* cancellation_manager() const { - return params_.cancellation_manager; + return params_->cancellation_manager; } // Other accessors. // For control flow. - FrameAndIter frame_iter() const { return params_.frame_iter; } - bool is_input_dead() const { return params_.is_input_dead; } + FrameAndIter frame_iter() const { return params_->frame_iter; } + bool is_input_dead() const { return params_->is_input_dead; } bool* is_output_dead() { return &is_output_dead_; } // May be used, e.g., to get GPU handles, etc. // TODO(tucker): Add example usage. - DeviceBase* device() const { return params_.device; } + DeviceBase* device() const { return params_->device; } // Retrieve list of referenced tensors in out_vector. Once this is // called, it is not legal to reference any more tensors. Should @@ -858,14 +890,14 @@ class OpKernelContext { // Per-step resource manager for use by white-listed internal ops. ResourceMgr* step_resource_manager() const { - return params_.step_resource_manager; + return params_->step_resource_manager; } private: Allocator* get_allocator(AllocatorAttributes attr) { Allocator* allocator = - params_.device->GetStepAllocator(attr, step_resource_manager()); - if (params_.track_allocations) { + params_->device->GetStepAllocator(attr, step_resource_manager()); + if (params_->track_allocations) { mutex_lock lock(mu_); for (const auto& wrapped : wrapped_allocators_) { if (wrapped.first == allocator) { @@ -908,9 +940,7 @@ class OpKernelContext { void NotifyUseOfPersistentTensor(const Tensor& tensor); Status status_; - Params params_; // immutable after construction. - const PerOpGpuDevice* eigen_gpu_device_; // owned, with a per-op - // wrapped allocator + Params* params_; // not owned mutable mutex mu_; // mutable so const accessors can acquire the lock gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_); gtl::InlinedVector<TensorValue, 4> outputs_; @@ -1035,8 +1065,8 @@ class OpKernelRegistrar { inline DataType OpKernelContext::input_dtype(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.inputs->size()); - const TensorValue& value((*params_.inputs)[index]); + DCHECK_LT(index, params_->inputs->size()); + const TensorValue& value((*params_->inputs)[index]); if (value.is_ref()) { return MakeRefType(value->dtype()); } else { @@ -1046,12 +1076,12 @@ inline DataType OpKernelContext::input_dtype(int index) const { inline DataType OpKernelContext::expected_output_dtype(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.op_kernel->output_types().size()); - return params_.op_kernel->output_type(index); + DCHECK_LT(index, params_->op_kernel->output_types().size()); + return params_->op_kernel->output_type(index); } inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) { - if (params_.device->RequiresRecordingAccessedTensors()) { + if (params_->device->RequiresRecordingAccessedTensors()) { mutex_lock l(mu_); // Keep a reference to the underlying memory around. referenced_tensors_.Add(tensor); @@ -1066,25 +1096,25 @@ inline void OpKernelContext::retrieve_accessed_tensors( inline const Tensor& OpKernelContext::input(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.inputs->size()); - DCHECK(!(*params_.inputs)[index].is_ref()); - const Tensor& tensor = *((*params_.inputs)[index].tensor); + DCHECK_LT(index, params_->inputs->size()); + DCHECK(!(*params_->inputs)[index].is_ref()); + const Tensor& tensor = *((*params_->inputs)[index].tensor); record_tensor_reference(tensor); return tensor; } inline Tensor OpKernelContext::mutable_input(int index, bool lock_held) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.inputs->size()); - DCHECK((*params_.inputs)[index].is_ref()); + DCHECK_LT(index, params_->inputs->size()); + DCHECK((*params_->inputs)[index].is_ref()); // return a copy of the Ref acquired while holding the mutex if (lock_held) { - Tensor& tensor = *((*params_.inputs)[index].tensor); + Tensor& tensor = *((*params_->inputs)[index].tensor); record_tensor_reference(tensor); return tensor; } else { mutex_lock l(*input_ref_mutex(index)); - Tensor& tensor = *((*params_.inputs)[index].tensor); + Tensor& tensor = *((*params_->inputs)[index].tensor); record_tensor_reference(tensor); return tensor; } @@ -1093,14 +1123,14 @@ inline Tensor OpKernelContext::mutable_input(int index, bool lock_held) { inline void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, bool lock_held) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.inputs->size()); - DCHECK((*params_.inputs)[index].is_ref()); + DCHECK_LT(index, params_->inputs->size()); + DCHECK((*params_->inputs)[index].is_ref()); // should only modify the tensor while holding the mutex if (lock_held) { - *(*params_.inputs)[index].tensor = tensor; + *(*params_->inputs)[index].tensor = tensor; } else { mutex_lock l(*input_ref_mutex(index)); - *(*params_.inputs)[index].tensor = tensor; + *(*params_->inputs)[index].tensor = tensor; } record_tensor_reference(tensor); } @@ -1108,37 +1138,37 @@ inline void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, inline void OpKernelContext::forward_ref_input_to_ref_output(int input_index, int output_index) { DCHECK_GE(input_index, 0); - DCHECK_LT(input_index, params_.inputs->size()); - DCHECK((*params_.inputs)[input_index].is_ref()); - set_output_ref(output_index, (*params_.inputs)[input_index].mutex_if_ref, - (*params_.inputs)[input_index].tensor); + DCHECK_LT(input_index, params_->inputs->size()); + DCHECK((*params_->inputs)[input_index].is_ref()); + set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref, + (*params_->inputs)[input_index].tensor); } inline void OpKernelContext::delete_ref_input(int index, bool lock_held) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.inputs->size()); - DCHECK((*params_.inputs)[index].is_ref()); + DCHECK_LT(index, params_->inputs->size()); + DCHECK((*params_->inputs)[index].is_ref()); // should only modify the tensor while holding the mutex if (lock_held) { - delete (*params_.inputs)[index].tensor; + delete (*params_->inputs)[index].tensor; } else { mutex_lock l(*input_ref_mutex(index)); - delete (*params_.inputs)[index].tensor; + delete (*params_->inputs)[index].tensor; } } // no input if tensor == nullptr. inline bool OpKernelContext::has_input(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.inputs->size()); - return (*params_.inputs)[index].tensor != nullptr; + DCHECK_LT(index, params_->inputs->size()); + return (*params_->inputs)[index].tensor != nullptr; } inline mutex* OpKernelContext::input_ref_mutex(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.inputs->size()); - DCHECK((*params_.inputs)[index].is_ref()); - return (*params_.inputs)[index].mutex_if_ref; + DCHECK_LT(index, params_->inputs->size()); + DCHECK((*params_->inputs)[index].is_ref()); + return (*params_->inputs)[index].mutex_if_ref; } inline Status OpKernelContext::allocate_output(int index, @@ -1171,7 +1201,7 @@ inline Status OpKernelContext::allocate_output(int index, AllocatorAttributes attr) { DCHECK_GE(index, 0); DCHECK_LT(index, outputs_.size()); - const DataType type = params_.op_kernel->output_type(index); + const DataType type = params_->op_kernel->output_type(index); DCHECK(!IsRefType(type)); DCHECK(mutable_output(index) == nullptr); Tensor* output_tensor = new Tensor(); @@ -1216,7 +1246,7 @@ inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { inline void OpKernelContext::set_output(int index, const Tensor& tensor) { DCHECK_GE(index, 0); DCHECK_LT(index, outputs_.size()); - DCHECK(!IsRefType(params_.op_kernel->output_type(index))); + DCHECK(!IsRefType(params_->op_kernel->output_type(index))); DCHECK_EQ(mutable_output(index), nullptr); record_tensor_reference(tensor); outputs_[index] = TensorValue(new Tensor(tensor)); @@ -1226,7 +1256,7 @@ inline void OpKernelContext::set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref) { DCHECK_GE(index, 0); DCHECK_LT(index, outputs_.size()); - DCHECK(IsRefType(params_.op_kernel->output_type(index))); + DCHECK(IsRefType(params_->op_kernel->output_type(index))); record_tensor_reference(*tensor_for_ref); outputs_[index] = TensorValue(mu, tensor_for_ref); } @@ -1257,16 +1287,16 @@ T* OpKernelContext::op_device_context() { template <typename T> T* OpKernelContext::input_device_context(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.input_device_contexts->size()); + DCHECK_LT(index, params_->input_device_contexts->size()); static_assert(std::is_base_of<DeviceContext, T>::value, "T is not a subclass of DeviceContext"); - return static_cast<T*>((*params_.input_device_contexts)[index]); + return static_cast<T*>((*params_->input_device_contexts)[index]); } inline DeviceContext* OpKernelContext::input_device_context(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_.input_device_contexts->size()); - return (*params_.input_device_contexts)[index]; + DCHECK_LT(index, params_->input_device_contexts->size()); + return (*params_->input_device_contexts)[index]; } inline const Tensor& OpInputList::operator[](int i) const { diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 6e9eebd376..cf64a77902 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -292,7 +292,7 @@ TEST_F(OpKernelTest, SaveTempFalse) { TF_GRAPH_DEF_VERSION, &status)); EXPECT_TRUE(status.ok()); params.op_kernel = op.get(); - OpKernelContext* ctx = new OpKernelContext(params); + OpKernelContext* ctx = new OpKernelContext(¶ms); Tensor t; EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); @@ -316,7 +316,7 @@ TEST_F(OpKernelTest, SaveTempTrue) { TF_GRAPH_DEF_VERSION, &status)); EXPECT_TRUE(status.ok()); params.op_kernel = op.get(); - OpKernelContext* ctx = new OpKernelContext(params); + OpKernelContext* ctx = new OpKernelContext(¶ms); Tensor t; EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); diff --git a/tensorflow/core/kernels/core_ops_test.cc b/tensorflow/core/kernels/core_ops_test.cc index 007df30014..52b485f6be 100644 --- a/tensorflow/core/kernels/core_ops_test.cc +++ b/tensorflow/core/kernels/core_ops_test.cc @@ -447,7 +447,7 @@ static void BM_LRNFloat(int iters, int depth, int cols, int rows, std::vector<AllocatorAttributes> attrs; test::SetOutputAttrs(¶ms, &attrs); - std::unique_ptr<OpKernelContext> context(new OpKernelContext(params)); + std::unique_ptr<OpKernelContext> context(new OpKernelContext(¶ms)); op->Compute(context.get()); tensorflow::testing::StartTiming(); @@ -527,7 +527,8 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth, std::vector<AllocatorAttributes> attrs; test::SetOutputAttrs(¶ms, &attrs); - std::unique_ptr<OpKernelContext> avgpool_context(new OpKernelContext(params)); + std::unique_ptr<OpKernelContext> avgpool_context( + new OpKernelContext(¶ms)); op->Compute(avgpool_context.get()); tensorflow::testing::StartTiming(); @@ -631,7 +632,8 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols, std::vector<AllocatorAttributes> attrs; test::SetOutputAttrs(¶ms, &attrs); - std::unique_ptr<OpKernelContext> avgpool_context(new OpKernelContext(params)); + std::unique_ptr<OpKernelContext> avgpool_context( + new OpKernelContext(¶ms)); op->Compute(avgpool_context.get()); tensorflow::testing::StartTiming(); @@ -717,7 +719,8 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth, std::vector<AllocatorAttributes> attrs; test::SetOutputAttrs(¶ms, &attrs); - std::unique_ptr<OpKernelContext> maxpool_context(new OpKernelContext(params)); + std::unique_ptr<OpKernelContext> maxpool_context( + new OpKernelContext(¶ms)); op->Compute(maxpool_context.get()); tensorflow::testing::StartTiming(); @@ -891,7 +894,7 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols, std::vector<AllocatorAttributes> attrs; test::SetOutputAttrs(¶ms, &attrs); - std::unique_ptr<OpKernelContext> relu_context(new OpKernelContext(params)); + std::unique_ptr<OpKernelContext> relu_context(new OpKernelContext(¶ms)); op->Compute(relu_context.get()); tensorflow::testing::StartTiming(); @@ -959,7 +962,8 @@ static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth, std::vector<AllocatorAttributes> attrs; test::SetOutputAttrs(¶ms, &attrs); - std::unique_ptr<OpKernelContext> softmax_context(new OpKernelContext(params)); + std::unique_ptr<OpKernelContext> softmax_context( + new OpKernelContext(¶ms)); op->Compute(softmax_context.get()); tensorflow::testing::StartTiming(); diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h index a0d51667b2..a90cd7218c 100644 --- a/tensorflow/core/kernels/ops_testutil.h +++ b/tensorflow/core/kernels/ops_testutil.h @@ -82,6 +82,7 @@ class OpsTestBase : public ::testing::Test { ~OpsTestBase() override { gtl::STLDeleteElements(&tensors_); context_.reset(nullptr); + params_.reset(nullptr); } void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); } @@ -150,17 +151,21 @@ class OpsTestBase : public ::testing::Test { // // Returns the context's status after running the operation. Status RunOpKernel() { - OpKernelContext::Params params; - params.device = device_.get(); - params.frame_iter = FrameAndIter(0, 0); - params.inputs = &inputs_; - params.op_kernel = kernel_.get(); + // Make sure the old OpKernelContext is deleted before the Params + // it was using. + context_.reset(nullptr); + + params_.reset(new OpKernelContext::Params); + params_.get()->device = device_.get(); + params_.get()->frame_iter = FrameAndIter(0, 0); + params_.get()->inputs = &inputs_; + params_.get()->op_kernel = kernel_.get(); std::vector<AllocatorAttributes> attrs; - test::SetOutputAttrs(¶ms, &attrs); + test::SetOutputAttrs(params_.get(), &attrs); checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; - params.slice_reader_cache = &slice_reader_cache_wrapper; + params_.get()->slice_reader_cache = &slice_reader_cache_wrapper; - context_.reset(new OpKernelContext(params)); + context_.reset(new OpKernelContext(params_.get())); device_->Compute(kernel_.get(), context_.get()); return context_->status(); } @@ -206,6 +211,7 @@ class OpsTestBase : public ::testing::Test { // Owns Tensors. std::vector<Tensor*> tensors_; + std::unique_ptr<OpKernelContext::Params> params_; std::unique_ptr<OpKernelContext> context_; private: diff --git a/tensorflow/core/kernels/restore_op_test.cc b/tensorflow/core/kernels/restore_op_test.cc index 1dbd452843..9c5e52c574 100644 --- a/tensorflow/core/kernels/restore_op_test.cc +++ b/tensorflow/core/kernels/restore_op_test.cc @@ -168,7 +168,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; params.slice_reader_cache = &slice_reader_cache_wrapper; - OpKernelContext ctx(params); + OpKernelContext ctx(¶ms); op->Compute(&ctx); EXPECT_OK(ctx.status()); } @@ -392,7 +392,7 @@ TEST_F(RestoreSliceOpTest, RestoreInt) { checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; params.slice_reader_cache = &slice_reader_cache_wrapper; - OpKernelContext ctx(params); + OpKernelContext ctx(¶ms); op->Compute(&ctx); EXPECT_OK(ctx.status()); } diff --git a/tensorflow/core/kernels/segment_reduction_ops_test.cc b/tensorflow/core/kernels/segment_reduction_ops_test.cc index 88ec897801..23f14376c8 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_test.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_test.cc @@ -77,7 +77,7 @@ static void BM_SegmentReduction(int iters, string reduction, Index num_rows, test::SetOutputAttrs(¶ms, &attrs); std::unique_ptr<OpKernelContext> reduction_context( - new OpKernelContext(params)); + new OpKernelContext(¶ms)); reduction_op->Compute(reduction_context.get()); TF_CHECK_OK(reduction_context->status()); diff --git a/tensorflow/core/kernels/sparse_to_dense_op_test.cc b/tensorflow/core/kernels/sparse_to_dense_op_test.cc index 39a532e053..ed7f199804 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op_test.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op_test.cc @@ -258,7 +258,7 @@ static void BM_SparseToDense(int iters, const int bm_arg) { std::vector<AllocatorAttributes> attrs; test::SetOutputAttrs(¶ms, &attrs); - std::unique_ptr<OpKernelContext> sparse_context(new OpKernelContext(params)); + std::unique_ptr<OpKernelContext> sparse_context(new OpKernelContext(¶ms)); op->Compute(sparse_context.get()); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { |