diff options
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 159 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 58 |
2 files changed, 107 insertions, 110 deletions
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 2b0488d944..186a0c104c 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -199,8 +199,8 @@ OpKernelContext::OpKernelContext(Params* params) : OpKernelContext( params, static_cast<int>(params->op_kernel->output_types().size())) {} -OpKernelContext::OpKernelContext(Params* params, int noutputs) - : params_(params), outputs_(noutputs) { +OpKernelContext::OpKernelContext(Params* params, int num_outputs) + : params_(params), outputs_(num_outputs) { Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); params_->ensure_eigen_gpu_device(); params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device, @@ -258,9 +258,9 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { "' when single-valued input was " "expected"); } - if ((*params_->inputs)[start].is_ref()) { + if (input_is_ref(start)) { return errors::InvalidArgument("OpKernel used ref input name '", name, - "' when immutable input was expected"); + "' when non-ref input was expected"); } *tensor = (*params_->inputs)[start].tensor; record_tensor_reference(**tensor); @@ -299,8 +299,8 @@ Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { const Tensor& OpKernelContext::input(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); - DCHECK(!(*params_->inputs)[index].is_ref()); + DCHECK_LT(index, num_inputs()); + DCHECK(!input_is_ref(index)); const Tensor& tensor = *((*params_->inputs)[index].tensor); record_tensor_reference(tensor); return tensor; @@ -308,8 +308,8 @@ const Tensor& OpKernelContext::input(int index) { Tensor OpKernelContext::mutable_input(int index, bool lock_held) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); - DCHECK((*params_->inputs)[index].is_ref()); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); // return a copy of the Ref acquired while holding the mutex if (lock_held) { Tensor& tensor = *((*params_->inputs)[index].tensor); @@ -326,8 +326,8 @@ Tensor OpKernelContext::mutable_input(int index, bool lock_held) { void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, bool lock_held) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); - DCHECK((*params_->inputs)[index].is_ref()); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); // should only modify the tensor while holding the mutex if (lock_held) { *(*params_->inputs)[index].tensor = tensor; @@ -341,27 +341,34 @@ void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, void OpKernelContext::forward_ref_input_to_ref_output(int input_index, int output_index) { DCHECK_GE(input_index, 0); - DCHECK_LT(input_index, params_->inputs->size()); - DCHECK((*params_->inputs)[input_index].is_ref()); + DCHECK_LT(input_index, num_inputs()); + DCHECK(input_is_ref(input_index)); set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref, (*params_->inputs)[input_index].tensor); } -bool OpKernelContext::forward_input_to_output_with_same_shape(int input_index, - int output_index, - Tensor** output) { - DCHECK_GE(input_index, 0); - DCHECK_LT(input_index, params_->inputs->size()); - const TensorValue& input = (*params_->inputs)[input_index]; - if (input.tensor == nullptr) { +bool OpKernelContext::forward_input_to_output_with_shape( + int input_index, int output_index, const TensorShape& output_shape, + Tensor** output) { + const auto output_attr = params_->output_attr_array == nullptr + ? AllocatorAttributes() + : output_alloc_attr(output_index); + std::unique_ptr<Tensor> new_tensor = forward_input( + input_index, expected_output_dtype(output_index), output_shape, + output_memory_type(output_index), output_attr); + if (new_tensor != nullptr) { + // Transfer ownership to the output slot in OpKernelContext. + outputs_[output_index] = TensorValue(new_tensor.release()); + *output = outputs_[output_index].tensor; + return true; + } else { return false; } - return forward_input_to_output_with_shape(input_index, output_index, - input.tensor->shape(), output); } -Status OpKernelContext::forward_input_to_output_with_same_shape( - StringPiece input_name, StringPiece output_name, Tensor** output) { +Status OpKernelContext::forward_input_to_output_with_shape( + StringPiece input_name, StringPiece output_name, + const TensorShape& output_shape, Tensor** output) { int input_index, output_index, stop; TF_RETURN_IF_ERROR( params_->op_kernel->InputRange(input_name, &input_index, &stop)); @@ -379,102 +386,68 @@ Status OpKernelContext::forward_input_to_output_with_same_shape( "' when single-valued output was " "expected"); } - if (!forward_input_to_output_with_same_shape(input_index, output_index, - output)) { + if (!forward_input_to_output_with_shape(input_index, output_index, + output_shape, output)) { return errors::FailedPrecondition("OpKernel could not forward input '", input_name, "' to output '", output_name); } return Status::OK(); } -bool OpKernelContext::forward_input_to_output_with_shape( - int input_index, int output_index, const TensorShape& output_shape, - Tensor** output) { +std::unique_ptr<Tensor> OpKernelContext::forward_input( + int input_index, DataType output_dtype, const TensorShape& output_shape, + MemoryType output_memory_type, const AllocatorAttributes& output_attr) { + // TODO(rmlarsen,zhengxq): Re-enable for GPU memory once kernels have been + // made forwarding aware or decorated to expose which inputs they rely on + // to access via the read-only texture cache. + // TODO(rmlarsen): Short term, move disabling logic into the kernels + // themselves for fine-grained control. + DCHECK(params_->device != nullptr); + if (output_memory_type == DEVICE_MEMORY && + params_->device->attributes().device_type() == DEVICE_GPU) { + return nullptr; + } + DCHECK_GE(input_index, 0); - DCHECK_LT(input_index, params_->inputs->size()); + DCHECK_LT(input_index, num_inputs()); const TensorValue& input = (*params_->inputs)[input_index]; - // Check that input tensor exists, is not a ref, and have no other consumers. + // Check that input tensor exists, is not a ref, and has no other consumers. if (input.tensor == nullptr || input.is_ref() || !input->RefCountIsOne()) { - return false; + return nullptr; } - DCHECK_GE(output_index, 0); - DCHECK_LT(output_index, num_outputs()); - // Check that input and output types match. - if (expected_output_dtype(output_index) != input_dtype(input_index)) { - return false; + // Check that input type matches. + if (input_dtype(input_index) != output_dtype) { + return nullptr; } // Check that the input and output sizes are compatible. if (input.tensor->shape().num_elements() != output_shape.num_elements()) { - return false; + return nullptr; } // Check that input and output memory types match, i.e. // that they either both live in host or both live in device memmory. - if (op_kernel().output_memory_types()[output_index] != - op_kernel().input_memory_types()[input_index]) { - return false; - } - - // TODO(rmlarsen,zhengxq): Re-enable for GPU memory once kernels have been - // made forwarding aware or decorated to expose which inputs they rely on - // to access via the read-only texture cache. - // TODO(rmlarsen): Short term, move disabling logic into the kernels - // themselves for fine-grained control. - DCHECK(params_->device != nullptr); - if (op_kernel().output_memory_types()[output_index] == DEVICE_MEMORY && - params_->device->attributes().device_type() == DEVICE_GPU) { - return false; + if (input_memory_type(input_index) != output_memory_type) { + return nullptr; } - // Check that output allocator attributes are not more restrictive than // input allocator attributes. const auto input_attr = params_->input_alloc_attrs == nullptr ? AllocatorAttributes() : input_alloc_attr(input_index); - const auto output_attr = params_->output_attr_array == nullptr - ? AllocatorAttributes() - : output_alloc_attr(output_index); if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) { - return false; + return nullptr; } - Tensor* output_tensor = new Tensor(); + // TODO(rmlarsen): Use MakeUnique here. There is already a copy in + // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of + // general cleanup of ownership in this code. + std::unique_ptr<Tensor> output_tensor(new Tensor()); CHECK(output_tensor->CopyFrom(*input.tensor, output_shape)); - outputs_[output_index] = TensorValue(output_tensor); - *output = outputs_[output_index].tensor; - return true; -} - -Status OpKernelContext::forward_input_to_output_with_shape( - StringPiece input_name, StringPiece output_name, - const TensorShape& output_shape, Tensor** output) { - int input_index, output_index, stop; - TF_RETURN_IF_ERROR( - params_->op_kernel->InputRange(input_name, &input_index, &stop)); - if (stop != input_index + 1) { - return errors::InvalidArgument("OpKernel used list-valued input name '", - input_name, - "' when single-valued input was " - "expected"); - } - TF_RETURN_IF_ERROR( - params_->op_kernel->OutputRange(output_name, &output_index, &stop)); - if (stop != output_index + 1) { - return errors::InvalidArgument("OpKernel used list-valued output name '", - output_name, - "' when single-valued output was " - "expected"); - } - if (!forward_input_to_output_with_shape(input_index, output_index, - output_shape, output)) { - return errors::FailedPrecondition("OpKernel could not forward input '", - input_name, "' to output '", output_name); - } - return Status::OK(); + return output_tensor; } void OpKernelContext::delete_ref_input(int index, bool lock_held) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); - DCHECK((*params_->inputs)[index].is_ref()); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); // should only modify the tensor while holding the mutex if (lock_held) { delete (*params_->inputs)[index].tensor; @@ -493,8 +466,8 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, name, "' when single-valued input was expected"); } - if (!(*params_->inputs)[start].is_ref()) { - return errors::InvalidArgument("OpKernel used immutable input name '", name, + if (!input_is_ref(start)) { + return errors::InvalidArgument("OpKernel used non-ref input name '", name, "' when ref input was expected"); } // return a copy of the Ref acquired while holding the mutex @@ -518,7 +491,7 @@ Status OpKernelContext::replace_ref_input(StringPiece name, name, "' when single-valued input was expected"); } - if (!(*params_->inputs)[start].is_ref()) { + if (!input_is_ref(start)) { return errors::InvalidArgument("OpKernel used immutable input name '", name, "' when ref input was expected"); } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index faafe86ade..b6e302c492 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -569,8 +569,11 @@ class OpKernelContext { int num_inputs() const { return params_->inputs->size(); } DataType input_dtype(int index) const; Status input_dtype(StringPiece name, DataType* dtype) const; + MemoryType input_memory_type(int index) const; + int num_outputs() const { return outputs_.size(); } DataType expected_output_dtype(int index) const; + MemoryType output_memory_type(int index) const; // Input @@ -669,16 +672,6 @@ class OpKernelContext { // REQUIRES: IsRefType(output_dtype(output_index)). void forward_ref_input_to_ref_output(int input_index, int output_index); - // Returns true when an alias to input[input_index] that is safe to use for - // in-place computation was written to *output. Returns false if - // input[input_index] has a refcount greater than or if its type does not - // match the expected output type of output[output_index]. - bool forward_input_to_output_with_same_shape( - int input_index, int output_index, Tensor** output) TF_MUST_USE_RESULT; - Status forward_input_to_output_with_same_shape( - StringPiece input_name, StringPiece output_name, - Tensor** output) TF_MUST_USE_RESULT; - // Returns true when an alias to input[input_index], reshaped to output_shape, // which is is safe to use for in-place computation was written to *output. // Returns false if input[input_index] has a refcount greater than one, or if @@ -693,6 +686,19 @@ class OpKernelContext { const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; + // Returns a pointer to a Tensor aliasing the underlying buffer backing + // input[input_index] iff + // * input[input_index] is not a ref, + // * the data type, shape, memory type, and allocator attributes of + // input[input_index] are compatible with those given in dtype, shape, + // memory_type, and attr, + // * refcount on the underlying buffer is one. + // Otherwise returns nullptr. + std::unique_ptr<Tensor> forward_input( + int input_index, DataType dtype, const TensorShape& shape, + MemoryType memory_type, + const AllocatorAttributes& attr) TF_MUST_USE_RESULT; + // Tries to forward one of the inputs given in input_indices to // output[output_index]. If none of the given inputs can be forwarded, calls // allocate_output() to allocate a new output buffer. @@ -999,6 +1005,8 @@ class OpKernelContext { TensorValue release_output(int index); private: + bool input_is_ref(int index) const; + Allocator* get_allocator(AllocatorAttributes attr); // Internal method to add a tensor's buffer to the list of buffers @@ -1187,7 +1195,7 @@ Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const { inline DataType OpKernelContext::input_dtype(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); + DCHECK_LT(index, num_inputs()); const TensorValue& value((*params_->inputs)[index]); if (value.is_ref()) { return MakeRefType(value->dtype()); @@ -1196,12 +1204,28 @@ inline DataType OpKernelContext::input_dtype(int index) const { } } +inline MemoryType OpKernelContext::input_memory_type(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + return op_kernel().input_memory_types()[index]; +} + inline DataType OpKernelContext::expected_output_dtype(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->op_kernel->output_types().size()); + DCHECK_LT(index, num_outputs()); return params_->op_kernel->output_type(index); } +inline MemoryType OpKernelContext::output_memory_type(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + return op_kernel().output_memory_types()[index]; +} + +inline bool OpKernelContext::input_is_ref(int index) const { + return IsRefType(input_dtype(index)); +} + inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) { DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(), params_->record_tensor_accesses); @@ -1221,14 +1245,14 @@ inline void OpKernelContext::retrieve_accessed_tensors( // no input if tensor == nullptr. inline bool OpKernelContext::has_input(int index) const { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); + DCHECK_LT(index, num_inputs()); return (*params_->inputs)[index].tensor != nullptr; } inline mutex* OpKernelContext::input_ref_mutex(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, params_->inputs->size()); - DCHECK((*params_->inputs)[index].is_ref()); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); return (*params_->inputs)[index].mutex_if_ref; } @@ -1240,7 +1264,7 @@ inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { inline Tensor* OpKernelContext::mutable_output(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, outputs_.size()); + DCHECK_LT(index, num_outputs()); // No need to record_tensor_reference since the output must already // have been set by a call that did so. return outputs_[index].tensor; @@ -1248,7 +1272,7 @@ inline Tensor* OpKernelContext::mutable_output(int index) { inline TensorValue OpKernelContext::release_output(int index) { DCHECK_GE(index, 0); - DCHECK_LT(index, outputs_.size()); + DCHECK_LT(index, num_outputs()); TensorValue value = outputs_[index]; outputs_[index] = TensorValue(); return value; |