aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/op_kernel.cc159
-rw-r--r--tensorflow/core/framework/op_kernel.h58
2 files changed, 107 insertions, 110 deletions
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 2b0488d944..186a0c104c 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -199,8 +199,8 @@ OpKernelContext::OpKernelContext(Params* params)
: OpKernelContext(
params, static_cast<int>(params->op_kernel->output_types().size())) {}
-OpKernelContext::OpKernelContext(Params* params, int noutputs)
- : params_(params), outputs_(noutputs) {
+OpKernelContext::OpKernelContext(Params* params, int num_outputs)
+ : params_(params), outputs_(num_outputs) {
Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
params_->ensure_eigen_gpu_device();
params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
@@ -258,9 +258,9 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
"' when single-valued input was "
"expected");
}
- if ((*params_->inputs)[start].is_ref()) {
+ if (input_is_ref(start)) {
return errors::InvalidArgument("OpKernel used ref input name '", name,
- "' when immutable input was expected");
+ "' when non-ref input was expected");
}
*tensor = (*params_->inputs)[start].tensor;
record_tensor_reference(**tensor);
@@ -299,8 +299,8 @@ Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
const Tensor& OpKernelContext::input(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK(!(*params_->inputs)[index].is_ref());
+ DCHECK_LT(index, num_inputs());
+ DCHECK(!input_is_ref(index));
const Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor;
@@ -308,8 +308,8 @@ const Tensor& OpKernelContext::input(int index) {
Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK((*params_->inputs)[index].is_ref());
+ DCHECK_LT(index, num_inputs());
+ DCHECK(input_is_ref(index));
// return a copy of the Ref acquired while holding the mutex
if (lock_held) {
Tensor& tensor = *((*params_->inputs)[index].tensor);
@@ -326,8 +326,8 @@ Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
bool lock_held) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK((*params_->inputs)[index].is_ref());
+ DCHECK_LT(index, num_inputs());
+ DCHECK(input_is_ref(index));
// should only modify the tensor while holding the mutex
if (lock_held) {
*(*params_->inputs)[index].tensor = tensor;
@@ -341,27 +341,34 @@ void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
int output_index) {
DCHECK_GE(input_index, 0);
- DCHECK_LT(input_index, params_->inputs->size());
- DCHECK((*params_->inputs)[input_index].is_ref());
+ DCHECK_LT(input_index, num_inputs());
+ DCHECK(input_is_ref(input_index));
set_output_ref(output_index, (*params_->inputs)[input_index].mutex_if_ref,
(*params_->inputs)[input_index].tensor);
}
-bool OpKernelContext::forward_input_to_output_with_same_shape(int input_index,
- int output_index,
- Tensor** output) {
- DCHECK_GE(input_index, 0);
- DCHECK_LT(input_index, params_->inputs->size());
- const TensorValue& input = (*params_->inputs)[input_index];
- if (input.tensor == nullptr) {
+bool OpKernelContext::forward_input_to_output_with_shape(
+ int input_index, int output_index, const TensorShape& output_shape,
+ Tensor** output) {
+ const auto output_attr = params_->output_attr_array == nullptr
+ ? AllocatorAttributes()
+ : output_alloc_attr(output_index);
+ std::unique_ptr<Tensor> new_tensor = forward_input(
+ input_index, expected_output_dtype(output_index), output_shape,
+ output_memory_type(output_index), output_attr);
+ if (new_tensor != nullptr) {
+ // Transfer ownership to the output slot in OpKernelContext.
+ outputs_[output_index] = TensorValue(new_tensor.release());
+ *output = outputs_[output_index].tensor;
+ return true;
+ } else {
return false;
}
- return forward_input_to_output_with_shape(input_index, output_index,
- input.tensor->shape(), output);
}
-Status OpKernelContext::forward_input_to_output_with_same_shape(
- StringPiece input_name, StringPiece output_name, Tensor** output) {
+Status OpKernelContext::forward_input_to_output_with_shape(
+ StringPiece input_name, StringPiece output_name,
+ const TensorShape& output_shape, Tensor** output) {
int input_index, output_index, stop;
TF_RETURN_IF_ERROR(
params_->op_kernel->InputRange(input_name, &input_index, &stop));
@@ -379,102 +386,68 @@ Status OpKernelContext::forward_input_to_output_with_same_shape(
"' when single-valued output was "
"expected");
}
- if (!forward_input_to_output_with_same_shape(input_index, output_index,
- output)) {
+ if (!forward_input_to_output_with_shape(input_index, output_index,
+ output_shape, output)) {
return errors::FailedPrecondition("OpKernel could not forward input '",
input_name, "' to output '", output_name);
}
return Status::OK();
}
-bool OpKernelContext::forward_input_to_output_with_shape(
- int input_index, int output_index, const TensorShape& output_shape,
- Tensor** output) {
+std::unique_ptr<Tensor> OpKernelContext::forward_input(
+ int input_index, DataType output_dtype, const TensorShape& output_shape,
+ MemoryType output_memory_type, const AllocatorAttributes& output_attr) {
+ // TODO(rmlarsen,zhengxq): Re-enable for GPU memory once kernels have been
+ // made forwarding aware or decorated to expose which inputs they rely on
+ // to access via the read-only texture cache.
+ // TODO(rmlarsen): Short term, move disabling logic into the kernels
+ // themselves for fine-grained control.
+ DCHECK(params_->device != nullptr);
+ if (output_memory_type == DEVICE_MEMORY &&
+ params_->device->attributes().device_type() == DEVICE_GPU) {
+ return nullptr;
+ }
+
DCHECK_GE(input_index, 0);
- DCHECK_LT(input_index, params_->inputs->size());
+ DCHECK_LT(input_index, num_inputs());
const TensorValue& input = (*params_->inputs)[input_index];
- // Check that input tensor exists, is not a ref, and have no other consumers.
+ // Check that input tensor exists, is not a ref, and has no other consumers.
if (input.tensor == nullptr || input.is_ref() || !input->RefCountIsOne()) {
- return false;
+ return nullptr;
}
- DCHECK_GE(output_index, 0);
- DCHECK_LT(output_index, num_outputs());
- // Check that input and output types match.
- if (expected_output_dtype(output_index) != input_dtype(input_index)) {
- return false;
+ // Check that input type matches.
+ if (input_dtype(input_index) != output_dtype) {
+ return nullptr;
}
// Check that the input and output sizes are compatible.
if (input.tensor->shape().num_elements() != output_shape.num_elements()) {
- return false;
+ return nullptr;
}
// Check that input and output memory types match, i.e.
// that they either both live in host or both live in device memmory.
- if (op_kernel().output_memory_types()[output_index] !=
- op_kernel().input_memory_types()[input_index]) {
- return false;
- }
-
- // TODO(rmlarsen,zhengxq): Re-enable for GPU memory once kernels have been
- // made forwarding aware or decorated to expose which inputs they rely on
- // to access via the read-only texture cache.
- // TODO(rmlarsen): Short term, move disabling logic into the kernels
- // themselves for fine-grained control.
- DCHECK(params_->device != nullptr);
- if (op_kernel().output_memory_types()[output_index] == DEVICE_MEMORY &&
- params_->device->attributes().device_type() == DEVICE_GPU) {
- return false;
+ if (input_memory_type(input_index) != output_memory_type) {
+ return nullptr;
}
-
// Check that output allocator attributes are not more restrictive than
// input allocator attributes.
const auto input_attr = params_->input_alloc_attrs == nullptr
? AllocatorAttributes()
: input_alloc_attr(input_index);
- const auto output_attr = params_->output_attr_array == nullptr
- ? AllocatorAttributes()
- : output_alloc_attr(output_index);
if (!output_attr.IsEqualOrLessRestrictiveThan(input_attr)) {
- return false;
+ return nullptr;
}
- Tensor* output_tensor = new Tensor();
+ // TODO(rmlarsen): Use MakeUnique here. There is already a copy in
+ // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
+ // general cleanup of ownership in this code.
+ std::unique_ptr<Tensor> output_tensor(new Tensor());
CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
- outputs_[output_index] = TensorValue(output_tensor);
- *output = outputs_[output_index].tensor;
- return true;
-}
-
-Status OpKernelContext::forward_input_to_output_with_shape(
- StringPiece input_name, StringPiece output_name,
- const TensorShape& output_shape, Tensor** output) {
- int input_index, output_index, stop;
- TF_RETURN_IF_ERROR(
- params_->op_kernel->InputRange(input_name, &input_index, &stop));
- if (stop != input_index + 1) {
- return errors::InvalidArgument("OpKernel used list-valued input name '",
- input_name,
- "' when single-valued input was "
- "expected");
- }
- TF_RETURN_IF_ERROR(
- params_->op_kernel->OutputRange(output_name, &output_index, &stop));
- if (stop != output_index + 1) {
- return errors::InvalidArgument("OpKernel used list-valued output name '",
- output_name,
- "' when single-valued output was "
- "expected");
- }
- if (!forward_input_to_output_with_shape(input_index, output_index,
- output_shape, output)) {
- return errors::FailedPrecondition("OpKernel could not forward input '",
- input_name, "' to output '", output_name);
- }
- return Status::OK();
+ return output_tensor;
}
void OpKernelContext::delete_ref_input(int index, bool lock_held) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK((*params_->inputs)[index].is_ref());
+ DCHECK_LT(index, num_inputs());
+ DCHECK(input_is_ref(index));
// should only modify the tensor while holding the mutex
if (lock_held) {
delete (*params_->inputs)[index].tensor;
@@ -493,8 +466,8 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
name,
"' when single-valued input was expected");
}
- if (!(*params_->inputs)[start].is_ref()) {
- return errors::InvalidArgument("OpKernel used immutable input name '", name,
+ if (!input_is_ref(start)) {
+ return errors::InvalidArgument("OpKernel used non-ref input name '", name,
"' when ref input was expected");
}
// return a copy of the Ref acquired while holding the mutex
@@ -518,7 +491,7 @@ Status OpKernelContext::replace_ref_input(StringPiece name,
name,
"' when single-valued input was expected");
}
- if (!(*params_->inputs)[start].is_ref()) {
+ if (!input_is_ref(start)) {
return errors::InvalidArgument("OpKernel used immutable input name '", name,
"' when ref input was expected");
}
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index faafe86ade..b6e302c492 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -569,8 +569,11 @@ class OpKernelContext {
int num_inputs() const { return params_->inputs->size(); }
DataType input_dtype(int index) const;
Status input_dtype(StringPiece name, DataType* dtype) const;
+ MemoryType input_memory_type(int index) const;
+
int num_outputs() const { return outputs_.size(); }
DataType expected_output_dtype(int index) const;
+ MemoryType output_memory_type(int index) const;
// Input
@@ -669,16 +672,6 @@ class OpKernelContext {
// REQUIRES: IsRefType(output_dtype(output_index)).
void forward_ref_input_to_ref_output(int input_index, int output_index);
- // Returns true when an alias to input[input_index] that is safe to use for
- // in-place computation was written to *output. Returns false if
- // input[input_index] has a refcount greater than or if its type does not
- // match the expected output type of output[output_index].
- bool forward_input_to_output_with_same_shape(
- int input_index, int output_index, Tensor** output) TF_MUST_USE_RESULT;
- Status forward_input_to_output_with_same_shape(
- StringPiece input_name, StringPiece output_name,
- Tensor** output) TF_MUST_USE_RESULT;
-
// Returns true when an alias to input[input_index], reshaped to output_shape,
// which is is safe to use for in-place computation was written to *output.
// Returns false if input[input_index] has a refcount greater than one, or if
@@ -693,6 +686,19 @@ class OpKernelContext {
const TensorShape& output_shape,
Tensor** output) TF_MUST_USE_RESULT;
+ // Returns a pointer to a Tensor aliasing the underlying buffer backing
+ // input[input_index] iff
+ // * input[input_index] is not a ref,
+ // * the data type, shape, memory type, and allocator attributes of
+ // input[input_index] are compatible with those given in dtype, shape,
+ // memory_type, and attr,
+ // * refcount on the underlying buffer is one.
+ // Otherwise returns nullptr.
+ std::unique_ptr<Tensor> forward_input(
+ int input_index, DataType dtype, const TensorShape& shape,
+ MemoryType memory_type,
+ const AllocatorAttributes& attr) TF_MUST_USE_RESULT;
+
// Tries to forward one of the inputs given in input_indices to
// output[output_index]. If none of the given inputs can be forwarded, calls
// allocate_output() to allocate a new output buffer.
@@ -999,6 +1005,8 @@ class OpKernelContext {
TensorValue release_output(int index);
private:
+ bool input_is_ref(int index) const;
+
Allocator* get_allocator(AllocatorAttributes attr);
// Internal method to add a tensor's buffer to the list of buffers
@@ -1187,7 +1195,7 @@ Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
inline DataType OpKernelContext::input_dtype(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
+ DCHECK_LT(index, num_inputs());
const TensorValue& value((*params_->inputs)[index]);
if (value.is_ref()) {
return MakeRefType(value->dtype());
@@ -1196,12 +1204,28 @@ inline DataType OpKernelContext::input_dtype(int index) const {
}
}
+inline MemoryType OpKernelContext::input_memory_type(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, num_inputs());
+ return op_kernel().input_memory_types()[index];
+}
+
inline DataType OpKernelContext::expected_output_dtype(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->op_kernel->output_types().size());
+ DCHECK_LT(index, num_outputs());
return params_->op_kernel->output_type(index);
}
+inline MemoryType OpKernelContext::output_memory_type(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, num_outputs());
+ return op_kernel().output_memory_types()[index];
+}
+
+inline bool OpKernelContext::input_is_ref(int index) const {
+ return IsRefType(input_dtype(index));
+}
+
inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
params_->record_tensor_accesses);
@@ -1221,14 +1245,14 @@ inline void OpKernelContext::retrieve_accessed_tensors(
// no input if tensor == nullptr.
inline bool OpKernelContext::has_input(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
+ DCHECK_LT(index, num_inputs());
return (*params_->inputs)[index].tensor != nullptr;
}
inline mutex* OpKernelContext::input_ref_mutex(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK((*params_->inputs)[index].is_ref());
+ DCHECK_LT(index, num_inputs());
+ DCHECK(input_is_ref(index));
return (*params_->inputs)[index].mutex_if_ref;
}
@@ -1240,7 +1264,7 @@ inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
inline Tensor* OpKernelContext::mutable_output(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, outputs_.size());
+ DCHECK_LT(index, num_outputs());
// No need to record_tensor_reference since the output must already
// have been set by a call that did so.
return outputs_[index].tensor;
@@ -1248,7 +1272,7 @@ inline Tensor* OpKernelContext::mutable_output(int index) {
inline TensorValue OpKernelContext::release_output(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, outputs_.size());
+ DCHECK_LT(index, num_outputs());
TensorValue value = outputs_[index];
outputs_[index] = TensorValue();
return value;