diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-02-27 13:01:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-27 13:54:00 -0800 |
commit | 203a4d98d696c44214854df68b43f7bd7c89ca5f (patch) | |
tree | daa1b77c271ee64c204c25ee0aa4efa0a13c0cbf /tensorflow/core/framework/op_kernel.h | |
parent | 1c707ac780313f48a6733dc3beedf4b8a2b3df77 (diff) |
Refactor the buffer forwarding code:
* Moves the core forwarding logic to a new function forward_input that is not restricted to forwarding to an output slot, but also allows, e.g., reusing input buffers as temporaries or variables.
* Gets rid of forward_input_to_output_with_same_shape that is now unused.
* Adds convenience methods input_memory_type(), output_memory_type, and (private) input_is_ref() to OpKernelContext.
* Misc. small cleanups.
Change: 148683484
Diffstat (limited to 'tensorflow/core/framework/op_kernel.h')
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 58 |
1 files changed, 41 insertions, 17 deletions
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index faafe86ade..b6e302c492 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -569,8 +569,11 @@ class OpKernelContext { int num_inputs() const { return params_->inputs->size(); } DataType input_dtype(int index) const; Status input_dtype(StringPiece name, DataType* dtype) const; + MemoryType input_memory_type(int index) const; + int num_outputs() const { return outputs_.size(); } DataType expected_output_dtype(int index) const; + MemoryType output_memory_type(int index) const; // Input @@ -669,16 +672,6 @@ class OpKernelContext { // REQUIRES: IsRefType(output_dtype(output_index)). void forward_ref_input_to_ref_output(int input_index, int output_index); - // Returns true when an alias to input[input_index] that is safe to use for - // in-place computation was written to *output. Returns false if - // input[input_index] has a refcount greater than or if its type does not - // match the expected output type of output[output_index]. - bool forward_input_to_output_with_same_shape( - int input_index, int output_index, Tensor** output) TF_MUST_USE_RESULT; - Status forward_input_to_output_with_same_shape( - StringPiece input_name, StringPiece output_name, - Tensor** output) TF_MUST_USE_RESULT; - // Returns true when an alias to input[input_index], reshaped to output_shape, // which is is safe to use for in-place computation was written to *output. // Returns false if input[input_index] has a refcount greater than one, or if @@ -693,6 +686,19 @@ class OpKernelContext { const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; + // Returns a pointer to a Tensor aliasing the underlying buffer backing + // input[input_index] iff + // * input[input_index] is not a ref, + // * the data type, shape, memory type, and allocator attributes of + // input[input_index] are compatible with those given in dtype, shape, + // memory_type, and attr, + // * refcount on the underlying buffer is one. + // Otherwise returns nullptr. + std::unique_ptr<Tensor> forward_input( + int input_index, DataType dtype, const TensorShape& shape, + MemoryType memory_type, + const AllocatorAttributes& attr) TF_MUST_USE_RESULT; + // Tries to forward one of the inputs given in input_indices to // output[output_index]. If none of the given inputs can be forwarded, calls // allocate_output() to allocate a new output buffer. @@ -999,6 +1005,8 @@ class OpKernelContext { TensorValue release_output(int index); private: + bool input_is_ref(int index) const; + Allocator* get_allocator(AllocatorAttributes attr); // Internal method to add a tensor's buffer to the list of buffers @@ -1187,7 +1195,7 @@ Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const { inline DataType OpKernelContext::input_dtype(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); + DCHECK_LT(index, num_inputs()); const TensorValue& value((*params_->inputs)[index]); if (value.is_ref()) { return MakeRefType(value->dtype()); @@ -1196,12 +1204,28 @@ inline DataType OpKernelContext::input_dtype(int index) const { } } +inline MemoryType OpKernelContext::input_memory_type(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + return op_kernel().input_memory_types()[index]; +} + inline DataType OpKernelContext::expected_output_dtype(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->op_kernel->output_types().size()); + DCHECK_LT(index, num_outputs()); return params_->op_kernel->output_type(index); } +inline MemoryType OpKernelContext::output_memory_type(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + return op_kernel().output_memory_types()[index]; +} + +inline bool OpKernelContext::input_is_ref(int index) const { + return IsRefType(input_dtype(index)); +} + inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) { DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(), params_->record_tensor_accesses); @@ -1221,14 +1245,14 @@ inline void OpKernelContext::retrieve_accessed_tensors( // no input if tensor == nullptr. inline bool OpKernelContext::has_input(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); + DCHECK_LT(index, num_inputs()); return (*params_->inputs)[index].tensor != nullptr; } inline mutex* OpKernelContext::input_ref_mutex(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); - DCHECK((*params_->inputs)[index].is_ref()); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); return (*params_->inputs)[index].mutex_if_ref; } @@ -1240,7 +1264,7 @@ inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { inline Tensor* OpKernelContext::mutable_output(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, outputs_.size()); + DCHECK_LT(index, num_outputs()); // No need to record_tensor_reference since the output must already // have been set by a call that did so. return outputs_[index].tensor; @@ -1248,7 +1272,7 @@ inline Tensor* OpKernelContext::mutable_output(int index) { inline TensorValue OpKernelContext::release_output(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, outputs_.size()); + DCHECK_LT(index, num_outputs()); TensorValue value = outputs_[index]; outputs_[index] = TensorValue(); return value; |