aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/executor.cc9
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc64
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h11
-rw-r--r--tensorflow/core/framework/device_base.h22
-rw-r--r--tensorflow/core/framework/op_kernel.cc55
-rw-r--r--tensorflow/core/framework/op_kernel.h150
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc4
-rw-r--r--tensorflow/core/kernels/core_ops_test.cc16
-rw-r--r--tensorflow/core/kernels/ops_testutil.h22
-rw-r--r--tensorflow/core/kernels/restore_op_test.cc4
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops_test.cc2
-rw-r--r--tensorflow/core/kernels/sparse_to_dense_op_test.cc2
12 files changed, 215 insertions, 146 deletions
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index ef9de781c0..c6733b03bc 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -823,6 +823,9 @@ namespace {
OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) {
OpKernelContext::Params* ret = new OpKernelContext::Params;
*ret = p;
+ // Ensure the copy of Params will make a new eigen GPU device if
+ // necessary.
+ ret->eigen_gpu_device = nullptr;
ret->inputs = new TensorValueVec(*p.inputs);
ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts);
ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs);
@@ -831,6 +834,8 @@ OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) {
// Helpers to delete 'p' and copies made by CopyParams.
void DeleteParams(OpKernelContext::Params* p) {
+ // No need to delete p->eigen_gpu_device since that is deleted in
+ // p's destructor
delete p->inputs;
delete p->input_device_contexts;
delete p->input_alloc_attrs;
@@ -929,7 +934,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
if (async) {
// Asynchronous computes.
auto pcopy = CopyParams(params);
- auto ctx = new OpKernelContext(*pcopy);
+ auto ctx = new OpKernelContext(pcopy);
auto done = [this, tagged_node, item, first_input, ctx, stats, pcopy,
device]() {
VLOG(2) << this << " Async kernel done: "
@@ -967,7 +972,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
device->ComputeAsync(async, ctx, done);
} else {
// Synchronous computes.
- OpKernelContext ctx(params);
+ OpKernelContext ctx(&params);
if (stats_collector_) nodestats::SetOpStart(stats);
device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
if (stats_collector_) nodestats::SetOpEnd(stats);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 44874c5189..1efe7b7251 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -73,9 +73,14 @@ namespace tensorflow {
#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
class EigenAllocator : public ::Eigen::Allocator {
public:
- explicit EigenAllocator(gpu::Stream* stream, ::tensorflow::Allocator* alloc,
- EventMgr* em)
- : stream_(stream), allocator_(alloc), em_(em) {}
+ EigenAllocator() {}
+
+ void Reinitialize(gpu::Stream* stream, ::tensorflow::Allocator* alloc,
+ EventMgr* em) {
+ stream_ = stream;
+ allocator_ = alloc;
+ em_ = em;
+ }
void* allocate(size_t num_bytes) const override {
void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes);
@@ -103,10 +108,12 @@ class EigenAllocator : public ::Eigen::Allocator {
#else
class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
public:
- EigenCudaStreamDevice(const cudaStream_t* cuda_stream, int gpu_id,
- ::tensorflow::Allocator* alloc)
- : stream_(cuda_stream), allocator_(alloc) {
- Eigen::initializeDeviceProp();
+ EigenCudaStreamDevice() { Eigen::initializeDeviceProp(); }
+
+ void Reinitialize(const cudaStream_t* cuda_stream, int gpu_id,
+ ::tensorflow::Allocator* alloc) {
+ stream_ = cuda_stream;
+ allocator_ = alloc;
device_prop_ = &Eigen::m_deviceProperties[gpu_id];
}
@@ -391,10 +398,11 @@ namespace {
#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
class ConcretePerOpGpuDevice : public PerOpGpuDevice {
public:
- explicit ConcretePerOpGpuDevice(gpu::Stream* stream,
- Allocator* base_allocator,
- ::tensorflow::EventMgr* em)
- : allocator_(stream, base_allocator, em), device_(stream, &allocator_) {}
+ void Reinitialize(gpu::Stream* stream, Allocator* base_allocator,
+ ::tensorflow::EventMgr* em) {
+ allocator_.Reinitialize(stream, base_allocator, em);
+ device_.Reinitialize(stream, &allocator_);
+ }
const Eigen::GpuDevice& device() const override { return device_; }
@@ -405,10 +413,12 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
#else
class ConcretePerOpGpuDevice : public PerOpGpuDevice {
public:
- explicit ConcretePerOpGpuDevice(const cudaStream_t* cuda_stream, int gpu_id,
- Allocator* base_allocator)
- : stream_device_(cuda_stream, gpu_id, base_allocator),
- device_(&stream_device_) {}
+ ConcretePerOpGpuDevice() : device_(&stream_device_) {}
+
+ void Reinitialize(const cudaStream_t* cuda_stream, int gpu_id,
+ Allocator* base_allocator) {
+ stream_device_.Reinitialize(cuda_stream, gpu_id, base_allocator);
+ }
const Eigen::GpuDevice& device() const override { return device_; }
@@ -419,28 +429,36 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
#endif
} // namespace
-const PerOpGpuDevice* BaseGPUDevice::NewDevice(int stream_id,
- Allocator* allocator) {
+void BaseGPUDevice::ReinitializeDevice(PerOpGpuDevice* device, int stream_id,
+ Allocator* allocator) {
+ ConcretePerOpGpuDevice* concrete_device =
+ dynamic_cast<ConcretePerOpGpuDevice*>(device);
+ DCHECK(concrete_device);
#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
- return new ConcretePerOpGpuDevice(streams_[stream_id], allocator, em_.get());
+ concrete_device->Reinitialize(streams_[stream_id], allocator, em_.get());
#else
const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
streams_[stream_id]->implementation()->CudaStreamMemberHack());
- return new ConcretePerOpGpuDevice(cuda_stream, gpu_id_, allocator);
+ concrete_device->Reinitialize(cuda_stream, gpu_id_, allocator);
#endif
}
-const PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice(DeviceContext* dc,
- Allocator* allocator) {
+PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() {
+ return new ConcretePerOpGpuDevice();
+}
+
+void BaseGPUDevice::ReinitializeGpuDevice(PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) {
if (dc) {
const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
const int stream_id = gpu_dc->stream_id();
VLOG(1) << " eigen_gpu_device(" << dc << ") => stream[" << stream_id
<< "]";
CHECK_LT(stream_id, streams_.size());
- return NewDevice(stream_id, allocator);
+ ReinitializeDevice(device, stream_id, allocator);
} else {
- return NewDevice(0, allocator);
+ ReinitializeDevice(device, 0, allocator);
}
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 374bd73517..02a2279174 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -38,8 +38,6 @@ limitations under the License.
namespace tensorflow {
-class EigenAllocator;
-
class BaseGPUDevice : public LocalDevice {
public:
BaseGPUDevice(const SessionOptions& options, const string& name,
@@ -74,8 +72,10 @@ class BaseGPUDevice : public LocalDevice {
Tensor* tensor) override;
// The caller owns the returned device.
- const PerOpGpuDevice* MakeGpuDevice(DeviceContext* dc,
- Allocator* allocator) override;
+ PerOpGpuDevice* MakeGpuDevice() override;
+
+ void ReinitializeGpuDevice(PerOpGpuDevice* device, DeviceContext* dc,
+ Allocator* allocator) override;
protected:
Allocator* gpu_allocator_; // not owned
@@ -90,7 +90,8 @@ class BaseGPUDevice : public LocalDevice {
const bool sync_every_op_ = false;
std::unique_ptr<EventMgr> em_;
- const PerOpGpuDevice* NewDevice(int stream_id, Allocator* allocator);
+ void ReinitializeDevice(PerOpGpuDevice* device, int stream_id,
+ Allocator* allocator);
};
class BaseGPUDeviceFactory : public DeviceFactory {
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 23944a5f43..50b85e2c17 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -48,7 +48,9 @@ namespace thread {
class ThreadPool;
}
-// A wrapper for an Eigen Gpu Device that includes per-op state
+// A wrapper for an Eigen Gpu Device that includes per-op state. The
+// class is defined even for non-GPU devices since the
+// OpKernelContext::Params structure wants to fill it in.
class PerOpGpuDevice {
public:
virtual ~PerOpGpuDevice() {}
@@ -161,14 +163,16 @@ class DeviceBase {
return eigen_cpu_device_;
}
- // The caller owns the returned device and must free it by calling
- // DisposeGpuDevice below
- virtual const PerOpGpuDevice* MakeGpuDevice(DeviceContext* /*dc*/,
- Allocator* /*allocator*/) {
- // The OpKernelContext calls this even for devices that do not
- // implement an eigen_gpu_device
- return nullptr;
- }
+ // Caller owns the return value. The OpKernelContext calls this even
+ // for devices that do not implement an eigen_gpu_device. Overridden
+ // by GPU devices to return a derived type.
+ virtual PerOpGpuDevice* MakeGpuDevice() { return nullptr; }
+
+ // This is overridden by GPU devices to reinitialize the derived
+ // type returned by MakeGpuDevice.
+ virtual void ReinitializeGpuDevice(PerOpGpuDevice* /*device*/,
+ DeviceContext* /*dc*/,
+ Allocator* /*allocator*/) {}
virtual const DeviceAttributes& attributes() const {
LOG(FATAL) << "Device does not implement attributes()";
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index f2b9a9e295..fcc426af87 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -211,11 +211,13 @@ Status OpKernelConstruction::allocate_persistent(
// OpKernelContext -----------------------------------------------------------
-OpKernelContext::OpKernelContext(const Params& params)
- : params_(params), outputs_(params.op_kernel->output_types().size()) {
+OpKernelContext::OpKernelContext(Params* params)
+ : params_(params), outputs_(params_->op_kernel->output_types().size()) {
Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
- eigen_gpu_device_ = params_.device->MakeGpuDevice(params_.op_device_context,
- eigen_gpu_allocator);
+ params_->ensure_eigen_gpu_device();
+ params_->device->ReinitializeGpuDevice(params_->eigen_gpu_device,
+ params_->op_device_context,
+ eigen_gpu_allocator);
}
OpKernelContext::~OpKernelContext() {
@@ -224,30 +226,29 @@ OpKernelContext::~OpKernelContext() {
delete value.tensor;
}
}
- delete eigen_gpu_device_;
}
Status OpKernelContext::input(const string& name, const Tensor** tensor) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was "
"expected");
}
- if ((*params_.inputs)[start].is_ref()) {
+ if ((*params_->inputs)[start].is_ref()) {
return errors::InvalidArgument("OpKernel used ref input name '", name,
"' when immutable input was expected");
}
- *tensor = (*params_.inputs)[start].tensor;
+ *tensor = (*params_->inputs)[start].tensor;
record_tensor_reference(**tensor);
return Status::OK();
}
Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
@@ -260,22 +261,22 @@ Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) {
Status OpKernelContext::mutable_input(const string& name, Tensor* tensor,
bool lock_held) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was expected");
}
- if (!(*params_.inputs)[start].is_ref()) {
+ if (!(*params_->inputs)[start].is_ref()) {
return errors::InvalidArgument("OpKernel used immutable input name '", name,
"' when ref input was expected");
}
// return a copy of the Ref acquired while holding the mutex
if (lock_held) {
- *tensor = *(*params_.inputs)[start].tensor;
+ *tensor = *(*params_->inputs)[start].tensor;
} else {
mutex_lock l(*input_ref_mutex(start));
- *tensor = *(*params_.inputs)[start].tensor;
+ *tensor = *(*params_->inputs)[start].tensor;
}
record_tensor_reference(*tensor);
return Status::OK();
@@ -285,13 +286,13 @@ Status OpKernelContext::replace_ref_input(const string& name,
const Tensor& tensor,
bool lock_held) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was expected");
}
- if (!(*params_.inputs)[start].is_ref()) {
+ if (!(*params_->inputs)[start].is_ref()) {
return errors::InvalidArgument("OpKernel used immutable input name '", name,
"' when ref input was expected");
}
@@ -301,7 +302,7 @@ Status OpKernelContext::replace_ref_input(const string& name,
Status OpKernelContext::input_list(const string& name, OpInputList* list) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
*list = OpInputList(this, start, stop);
return Status::OK();
}
@@ -309,14 +310,14 @@ Status OpKernelContext::input_list(const string& name, OpInputList* list) {
Status OpKernelContext::mutable_input_list(const string& name,
OpMutableInputList* list) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
*list = OpMutableInputList(this, start, stop);
return Status::OK();
}
Status OpKernelContext::output_list(const string& name, OpOutputList* list) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
*list = OpOutputList(this, start, stop);
return Status::OK();
}
@@ -325,7 +326,7 @@ Status OpKernelContext::allocate_output(const string& name,
const TensorShape& shape,
Tensor** tensor) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
@@ -340,7 +341,7 @@ Status OpKernelContext::allocate_output(const string& name,
Tensor** tensor,
AllocatorAttributes attr) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
@@ -352,7 +353,7 @@ Status OpKernelContext::allocate_output(const string& name,
Status OpKernelContext::set_output(const string& name, const Tensor& tensor) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
@@ -366,7 +367,7 @@ Status OpKernelContext::set_output(const string& name, const Tensor& tensor) {
Status OpKernelContext::set_output_ref(const string& name, mutex* mu,
Tensor* tensor_for_ref) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
@@ -379,7 +380,7 @@ Status OpKernelContext::set_output_ref(const string& name, mutex* mu,
Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
@@ -392,7 +393,7 @@ Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) {
Status OpKernelContext::release_output(const string& name, TensorValue* value) {
int start, stop;
- TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
@@ -404,7 +405,7 @@ Status OpKernelContext::release_output(const string& name, TensorValue* value) {
}
bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
- const auto& inputs = *params_.inputs;
+ const auto& inputs = *params_->inputs;
for (size_t i = 1; i < inputs.size(); ++i) {
if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
SetStatus(errors::InvalidArgument(
@@ -421,10 +422,10 @@ bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
const DataTypeSlice expected_outputs) {
DataTypeVector inputs;
- for (const TensorValue& t : *params_.inputs) {
+ for (const TensorValue& t : *params_->inputs) {
inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype());
}
- DataTypeVector outputs = params_.op_kernel->output_types();
+ DataTypeVector outputs = params_->op_kernel->output_types();
return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
outputs);
}
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 5234b4cef3..671c6de04f 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -430,13 +430,43 @@ class OpKernelContext {
typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
// TODO(zhifengc): Do some cleanup of Params.
+ // The Params struct is passed in to initialize an OpKernelContext,
+ // and must outlive the OpKernelContext.
struct Params {
+ ~Params() { delete eigen_gpu_device; }
+
// The op kernel being computed.
OpKernel* op_kernel = nullptr;
// The device on which the kernel is running.
DeviceBase* device = nullptr;
+ // The Eigen GPU device wrapper, which may include a per-op
+ // wrapped allocator. The concrete type of this object depends on
+ // the type of this->device, so eigen_gpu_device can't be an
+ // inline member and must be heap allocated. However, we don't
+ // want to allocate a new eigen_gpu_device for every Op that is
+ // executed. Instead this member is allocated on first use using
+ // ensure_eigen_gpu_device, and then if the Params structure is
+ // re-used for subsequent Ops, the eigen_gpu_device is
+ // ReInitialized in the OpKernelContext constructor. Unlike the
+ // other pointers in Params, this one is owned by Params.
+ PerOpGpuDevice* eigen_gpu_device = nullptr;
+
+ inline void ensure_eigen_gpu_device() {
+ DCHECK(device);
+ if (nullptr == eigen_gpu_device) {
+ // Surprisingly, MakeGpuDevice will return nullptr if the
+ // device is not a GPU device. This is ok, since those devices
+ // will never use eigen_gpu_device. It seems better to have
+ // ensure_eigen_gpu_device fall through and regenerate the
+ // nullptr every time an OpKernelContext is instantiated, than
+ // to do an unneccessary allocation of a dummy eigen GPU
+ // device for CPU device Ops.
+ eigen_gpu_device = device->MakeGpuDevice();
+ }
+ }
+
bool track_allocations = false;
// Array indexed by output number for this node
@@ -478,16 +508,18 @@ class OpKernelContext {
// TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
};
- explicit OpKernelContext(const Params& params);
+
+ // params must outlive the OpKernelContext.
+ explicit OpKernelContext(Params* params);
~OpKernelContext();
- Env* env() const { return params_.device->env(); }
+ Env* env() const { return params_->device->env(); }
- const OpKernel& op_kernel() const { return *params_.op_kernel; }
+ const OpKernel& op_kernel() const { return *params_->op_kernel; }
// Input/output signature.
- int num_inputs() const { return params_.inputs->size(); }
+ int num_inputs() const { return params_->inputs->size(); }
DataType input_dtype(int index) const;
int num_outputs() const { return outputs_.size(); }
DataType expected_output_dtype(int index) const;
@@ -756,7 +788,7 @@ class OpKernelContext {
template <typename T>
T* op_device_context();
DeviceContext* op_device_context() {
- DeviceContext* ret = params_.op_device_context;
+ DeviceContext* ret = params_->op_device_context;
if (ret == nullptr) {
auto* dev_info = device()->tensorflow_gpu_device_info();
if (dev_info) ret = dev_info->default_context;
@@ -766,12 +798,12 @@ class OpKernelContext {
AllocatorAttributes input_alloc_attr(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.input_alloc_attrs->size());
- return (*params_.input_alloc_attrs)[index];
+ DCHECK_LT(index, params_->input_alloc_attrs->size());
+ return (*params_->input_alloc_attrs)[index];
}
AllocatorAttributes output_alloc_attr(int index) const {
- return params_.output_attr_array[index];
+ return params_->output_attr_array[index];
}
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const {
@@ -784,25 +816,25 @@ class OpKernelContext {
//
// An op kernel communicates with outside environment through
// Rendezvous Send() and Recv().
- Rendezvous* rendezvous() const { return params_.rendezvous; }
+ Rendezvous* rendezvous() const { return params_->rendezvous; }
// Function call support.
//
// If this kernel invocation is within a function execution,
// call_frame() returns the call frame for the function call.
- FunctionCallFrame* call_frame() const { return params_.call_frame; }
+ FunctionCallFrame* call_frame() const { return params_->call_frame; }
// If not nullptr, the kernel invoke functions defined in the
// library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
FunctionLibraryRuntime* function_library() const {
- return params_.function_library;
+ return params_->function_library;
}
// Shared resources accessible to this kernel.
- ResourceMgr* resource_manager() const { return params_.resource_manager; }
+ ResourceMgr* resource_manager() const { return params_->resource_manager; }
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
- return params_.slice_reader_cache;
+ return params_->slice_reader_cache;
}
// Execution.
@@ -813,7 +845,7 @@ class OpKernelContext {
return *device()->eigen_cpu_device();
}
const Eigen::GpuDevice& eigen_gpu_device() const {
- return eigen_gpu_device_->device();
+ return params_->eigen_gpu_device->device();
}
template <typename EigenDeviceType>
const EigenDeviceType& eigen_device() const;
@@ -837,19 +869,19 @@ class OpKernelContext {
// EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an
// example of how to use this API.
CancellationManager* cancellation_manager() const {
- return params_.cancellation_manager;
+ return params_->cancellation_manager;
}
// Other accessors.
// For control flow.
- FrameAndIter frame_iter() const { return params_.frame_iter; }
- bool is_input_dead() const { return params_.is_input_dead; }
+ FrameAndIter frame_iter() const { return params_->frame_iter; }
+ bool is_input_dead() const { return params_->is_input_dead; }
bool* is_output_dead() { return &is_output_dead_; }
// May be used, e.g., to get GPU handles, etc.
// TODO(tucker): Add example usage.
- DeviceBase* device() const { return params_.device; }
+ DeviceBase* device() const { return params_->device; }
// Retrieve list of referenced tensors in out_vector. Once this is
// called, it is not legal to reference any more tensors. Should
@@ -858,14 +890,14 @@ class OpKernelContext {
// Per-step resource manager for use by white-listed internal ops.
ResourceMgr* step_resource_manager() const {
- return params_.step_resource_manager;
+ return params_->step_resource_manager;
}
private:
Allocator* get_allocator(AllocatorAttributes attr) {
Allocator* allocator =
- params_.device->GetStepAllocator(attr, step_resource_manager());
- if (params_.track_allocations) {
+ params_->device->GetStepAllocator(attr, step_resource_manager());
+ if (params_->track_allocations) {
mutex_lock lock(mu_);
for (const auto& wrapped : wrapped_allocators_) {
if (wrapped.first == allocator) {
@@ -908,9 +940,7 @@ class OpKernelContext {
void NotifyUseOfPersistentTensor(const Tensor& tensor);
Status status_;
- Params params_; // immutable after construction.
- const PerOpGpuDevice* eigen_gpu_device_; // owned, with a per-op
- // wrapped allocator
+ Params* params_; // not owned
mutable mutex mu_; // mutable so const accessors can acquire the lock
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
gtl::InlinedVector<TensorValue, 4> outputs_;
@@ -1035,8 +1065,8 @@ class OpKernelRegistrar {
inline DataType OpKernelContext::input_dtype(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.inputs->size());
- const TensorValue& value((*params_.inputs)[index]);
+ DCHECK_LT(index, params_->inputs->size());
+ const TensorValue& value((*params_->inputs)[index]);
if (value.is_ref()) {
return MakeRefType(value->dtype());
} else {
@@ -1046,12 +1076,12 @@ inline DataType OpKernelContext::input_dtype(int index) const {
inline DataType OpKernelContext::expected_output_dtype(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.op_kernel->output_types().size());
- return params_.op_kernel->output_type(index);
+ DCHECK_LT(index, params_->op_kernel->output_types().size());
+ return params_->op_kernel->output_type(index);
}
inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
- if (params_.device->RequiresRecordingAccessedTensors()) {
+ if (params_->device->RequiresRecordingAccessedTensors()) {
mutex_lock l(mu_);
// Keep a reference to the underlying memory around.
referenced_tensors_.Add(tensor);
@@ -1066,25 +1096,25 @@ inline void OpKernelContext::retrieve_accessed_tensors(
inline const Tensor& OpKernelContext::input(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.inputs->size());
- DCHECK(!(*params_.inputs)[index].is_ref());
- const Tensor& tensor = *((*params_.inputs)[index].tensor);
+ DCHECK_LT(index, params_->inputs->size());
+ DCHECK(!(*params_->inputs)[index].is_ref());
+ const Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor;
}
inline Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.inputs->size());
- DCHECK((*params_.inputs)[index].is_ref());
+ DCHECK_LT(index, params_->inputs->size());
+ DCHECK((*params_->inputs)[index].is_ref());
// return a copy of the Ref acquired while holding the mutex
if (lock_held) {
- Tensor& tensor = *((*params_.inputs)[index].tensor);
+ Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor;
} else {
mutex_lock l(*input_ref_mutex(index));
- Tensor& tensor = *((*params_.inputs)[index].tensor);
+ Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor;
}
@@ -1093,14 +1123,14 @@ inline Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
inline void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
bool lock_held) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.inputs->size());
- DCHECK((*params_.inputs)[index].is_ref());
+ DCHECK_LT(index, params_->inputs->size());
+ DCHECK((*params_->inputs)[index].is_ref());
// should only modify the tensor while holding the mutex
if (lock_held) {
- *(*params_.inputs)[index].tensor = tensor;
+ *(*params_->inputs)[index].tensor = tensor;
} else {
mutex_lock l(*input_ref_mutex(index));
- *(*params_.inputs)[index].tensor = tensor;
+ *(*params_->inputs)[index].tensor = tensor;
}
record_tensor_reference(tensor);
}
@@ -1108,37 +1138,37 @@ inline void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
inline void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
int output_index) {
DCHECK_GE(input_index, 0);
- DCHECK_LT(input_index, params_.inputs->size());
- DCHECK((*params_.inputs)[input_index].is_ref());
- set_output_ref(output_index, (*params_.inputs)[input_index].mutex_if_ref,
- (*params_.inputs)[input_index].tensor);
+ DCHECK_LT(input_index, params_->inputs->size());
+ DCHECK((*params_->inputs)[input_index].is_ref());
+ set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
+ (*params_->inputs)[input_index].tensor);
}
inline void OpKernelContext::delete_ref_input(int index, bool lock_held) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.inputs->size());
- DCHECK((*params_.inputs)[index].is_ref());
+ DCHECK_LT(index, params_->inputs->size());
+ DCHECK((*params_->inputs)[index].is_ref());
// should only modify the tensor while holding the mutex
if (lock_held) {
- delete (*params_.inputs)[index].tensor;
+ delete (*params_->inputs)[index].tensor;
} else {
mutex_lock l(*input_ref_mutex(index));
- delete (*params_.inputs)[index].tensor;
+ delete (*params_->inputs)[index].tensor;
}
}
// no input if tensor == nullptr.
inline bool OpKernelContext::has_input(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.inputs->size());
- return (*params_.inputs)[index].tensor != nullptr;
+ DCHECK_LT(index, params_->inputs->size());
+ return (*params_->inputs)[index].tensor != nullptr;
}
inline mutex* OpKernelContext::input_ref_mutex(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.inputs->size());
- DCHECK((*params_.inputs)[index].is_ref());
- return (*params_.inputs)[index].mutex_if_ref;
+ DCHECK_LT(index, params_->inputs->size());
+ DCHECK((*params_->inputs)[index].is_ref());
+ return (*params_->inputs)[index].mutex_if_ref;
}
inline Status OpKernelContext::allocate_output(int index,
@@ -1171,7 +1201,7 @@ inline Status OpKernelContext::allocate_output(int index,
AllocatorAttributes attr) {
DCHECK_GE(index, 0);
DCHECK_LT(index, outputs_.size());
- const DataType type = params_.op_kernel->output_type(index);
+ const DataType type = params_->op_kernel->output_type(index);
DCHECK(!IsRefType(type));
DCHECK(mutable_output(index) == nullptr);
Tensor* output_tensor = new Tensor();
@@ -1216,7 +1246,7 @@ inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
inline void OpKernelContext::set_output(int index, const Tensor& tensor) {
DCHECK_GE(index, 0);
DCHECK_LT(index, outputs_.size());
- DCHECK(!IsRefType(params_.op_kernel->output_type(index)));
+ DCHECK(!IsRefType(params_->op_kernel->output_type(index)));
DCHECK_EQ(mutable_output(index), nullptr);
record_tensor_reference(tensor);
outputs_[index] = TensorValue(new Tensor(tensor));
@@ -1226,7 +1256,7 @@ inline void OpKernelContext::set_output_ref(int index, mutex* mu,
Tensor* tensor_for_ref) {
DCHECK_GE(index, 0);
DCHECK_LT(index, outputs_.size());
- DCHECK(IsRefType(params_.op_kernel->output_type(index)));
+ DCHECK(IsRefType(params_->op_kernel->output_type(index)));
record_tensor_reference(*tensor_for_ref);
outputs_[index] = TensorValue(mu, tensor_for_ref);
}
@@ -1257,16 +1287,16 @@ T* OpKernelContext::op_device_context() {
template <typename T>
T* OpKernelContext::input_device_context(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.input_device_contexts->size());
+ DCHECK_LT(index, params_->input_device_contexts->size());
static_assert(std::is_base_of<DeviceContext, T>::value,
"T is not a subclass of DeviceContext");
- return static_cast<T*>((*params_.input_device_contexts)[index]);
+ return static_cast<T*>((*params_->input_device_contexts)[index]);
}
inline DeviceContext* OpKernelContext::input_device_context(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_.input_device_contexts->size());
- return (*params_.input_device_contexts)[index];
+ DCHECK_LT(index, params_->input_device_contexts->size());
+ return (*params_->input_device_contexts)[index];
}
inline const Tensor& OpInputList::operator[](int i) const {
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
index 6e9eebd376..cf64a77902 100644
--- a/tensorflow/core/framework/op_kernel_test.cc
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -292,7 +292,7 @@ TEST_F(OpKernelTest, SaveTempFalse) {
TF_GRAPH_DEF_VERSION, &status));
EXPECT_TRUE(status.ok());
params.op_kernel = op.get();
- OpKernelContext* ctx = new OpKernelContext(params);
+ OpKernelContext* ctx = new OpKernelContext(&params);
Tensor t;
EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
@@ -316,7 +316,7 @@ TEST_F(OpKernelTest, SaveTempTrue) {
TF_GRAPH_DEF_VERSION, &status));
EXPECT_TRUE(status.ok());
params.op_kernel = op.get();
- OpKernelContext* ctx = new OpKernelContext(params);
+ OpKernelContext* ctx = new OpKernelContext(&params);
Tensor t;
EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
diff --git a/tensorflow/core/kernels/core_ops_test.cc b/tensorflow/core/kernels/core_ops_test.cc
index 007df30014..52b485f6be 100644
--- a/tensorflow/core/kernels/core_ops_test.cc
+++ b/tensorflow/core/kernels/core_ops_test.cc
@@ -447,7 +447,7 @@ static void BM_LRNFloat(int iters, int depth, int cols, int rows,
std::vector<AllocatorAttributes> attrs;
test::SetOutputAttrs(&params, &attrs);
- std::unique_ptr<OpKernelContext> context(new OpKernelContext(params));
+ std::unique_ptr<OpKernelContext> context(new OpKernelContext(&params));
op->Compute(context.get());
tensorflow::testing::StartTiming();
@@ -527,7 +527,8 @@ static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
std::vector<AllocatorAttributes> attrs;
test::SetOutputAttrs(&params, &attrs);
- std::unique_ptr<OpKernelContext> avgpool_context(new OpKernelContext(params));
+ std::unique_ptr<OpKernelContext> avgpool_context(
+ new OpKernelContext(&params));
op->Compute(avgpool_context.get());
tensorflow::testing::StartTiming();
@@ -631,7 +632,8 @@ static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
std::vector<AllocatorAttributes> attrs;
test::SetOutputAttrs(&params, &attrs);
- std::unique_ptr<OpKernelContext> avgpool_context(new OpKernelContext(params));
+ std::unique_ptr<OpKernelContext> avgpool_context(
+ new OpKernelContext(&params));
op->Compute(avgpool_context.get());
tensorflow::testing::StartTiming();
@@ -717,7 +719,8 @@ static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
std::vector<AllocatorAttributes> attrs;
test::SetOutputAttrs(&params, &attrs);
- std::unique_ptr<OpKernelContext> maxpool_context(new OpKernelContext(params));
+ std::unique_ptr<OpKernelContext> maxpool_context(
+ new OpKernelContext(&params));
op->Compute(maxpool_context.get());
tensorflow::testing::StartTiming();
@@ -891,7 +894,7 @@ static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
std::vector<AllocatorAttributes> attrs;
test::SetOutputAttrs(&params, &attrs);
- std::unique_ptr<OpKernelContext> relu_context(new OpKernelContext(params));
+ std::unique_ptr<OpKernelContext> relu_context(new OpKernelContext(&params));
op->Compute(relu_context.get());
tensorflow::testing::StartTiming();
@@ -959,7 +962,8 @@ static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth,
std::vector<AllocatorAttributes> attrs;
test::SetOutputAttrs(&params, &attrs);
- std::unique_ptr<OpKernelContext> softmax_context(new OpKernelContext(params));
+ std::unique_ptr<OpKernelContext> softmax_context(
+ new OpKernelContext(&params));
op->Compute(softmax_context.get());
tensorflow::testing::StartTiming();
diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h
index a0d51667b2..a90cd7218c 100644
--- a/tensorflow/core/kernels/ops_testutil.h
+++ b/tensorflow/core/kernels/ops_testutil.h
@@ -82,6 +82,7 @@ class OpsTestBase : public ::testing::Test {
~OpsTestBase() override {
gtl::STLDeleteElements(&tensors_);
context_.reset(nullptr);
+ params_.reset(nullptr);
}
void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
@@ -150,17 +151,21 @@ class OpsTestBase : public ::testing::Test {
//
// Returns the context's status after running the operation.
Status RunOpKernel() {
- OpKernelContext::Params params;
- params.device = device_.get();
- params.frame_iter = FrameAndIter(0, 0);
- params.inputs = &inputs_;
- params.op_kernel = kernel_.get();
+ // Make sure the old OpKernelContext is deleted before the Params
+ // it was using.
+ context_.reset(nullptr);
+
+ params_.reset(new OpKernelContext::Params);
+ params_.get()->device = device_.get();
+ params_.get()->frame_iter = FrameAndIter(0, 0);
+ params_.get()->inputs = &inputs_;
+ params_.get()->op_kernel = kernel_.get();
std::vector<AllocatorAttributes> attrs;
- test::SetOutputAttrs(&params, &attrs);
+ test::SetOutputAttrs(params_.get(), &attrs);
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
- params.slice_reader_cache = &slice_reader_cache_wrapper;
+ params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
- context_.reset(new OpKernelContext(params));
+ context_.reset(new OpKernelContext(params_.get()));
device_->Compute(kernel_.get(), context_.get());
return context_->status();
}
@@ -206,6 +211,7 @@ class OpsTestBase : public ::testing::Test {
// Owns Tensors.
std::vector<Tensor*> tensors_;
+ std::unique_ptr<OpKernelContext::Params> params_;
std::unique_ptr<OpKernelContext> context_;
private:
diff --git a/tensorflow/core/kernels/restore_op_test.cc b/tensorflow/core/kernels/restore_op_test.cc
index 1dbd452843..9c5e52c574 100644
--- a/tensorflow/core/kernels/restore_op_test.cc
+++ b/tensorflow/core/kernels/restore_op_test.cc
@@ -168,7 +168,7 @@ TEST_F(RestoreOpTest, RestoreSimple) {
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
params.slice_reader_cache = &slice_reader_cache_wrapper;
- OpKernelContext ctx(params);
+ OpKernelContext ctx(&params);
op->Compute(&ctx);
EXPECT_OK(ctx.status());
}
@@ -392,7 +392,7 @@ TEST_F(RestoreSliceOpTest, RestoreInt) {
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
params.slice_reader_cache = &slice_reader_cache_wrapper;
- OpKernelContext ctx(params);
+ OpKernelContext ctx(&params);
op->Compute(&ctx);
EXPECT_OK(ctx.status());
}
diff --git a/tensorflow/core/kernels/segment_reduction_ops_test.cc b/tensorflow/core/kernels/segment_reduction_ops_test.cc
index 88ec897801..23f14376c8 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_test.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_test.cc
@@ -77,7 +77,7 @@ static void BM_SegmentReduction(int iters, string reduction, Index num_rows,
test::SetOutputAttrs(&params, &attrs);
std::unique_ptr<OpKernelContext> reduction_context(
- new OpKernelContext(params));
+ new OpKernelContext(&params));
reduction_op->Compute(reduction_context.get());
TF_CHECK_OK(reduction_context->status());
diff --git a/tensorflow/core/kernels/sparse_to_dense_op_test.cc b/tensorflow/core/kernels/sparse_to_dense_op_test.cc
index 39a532e053..ed7f199804 100644
--- a/tensorflow/core/kernels/sparse_to_dense_op_test.cc
+++ b/tensorflow/core/kernels/sparse_to_dense_op_test.cc
@@ -258,7 +258,7 @@ static void BM_SparseToDense(int iters, const int bm_arg) {
std::vector<AllocatorAttributes> attrs;
test::SetOutputAttrs(&params, &attrs);
- std::unique_ptr<OpKernelContext> sparse_context(new OpKernelContext(params));
+ std::unique_ptr<OpKernelContext> sparse_context(new OpKernelContext(&params));
op->Compute(sparse_context.get());
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {