aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_kernel.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-27 13:01:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-27 13:54:00 -0800
commit203a4d98d696c44214854df68b43f7bd7c89ca5f (patch)
treedaa1b77c271ee64c204c25ee0aa4efa0a13c0cbf /tensorflow/core/framework/op_kernel.h
parent1c707ac780313f48a6733dc3beedf4b8a2b3df77 (diff)
Refactor the buffer forwarding code:
* Moves the core forwarding logic to a new function forward_input that is not restricted to forwarding to an output slot, but also allows, e.g., reusing input buffers as temporaries or variables. * Gets rid of forward_input_to_output_with_same_shape that is now unused. * Adds convenience methods input_memory_type(), output_memory_type, and (private) input_is_ref() to OpKernelContext. * Misc. small cleanups. Change: 148683484
Diffstat (limited to 'tensorflow/core/framework/op_kernel.h')
-rw-r--r--tensorflow/core/framework/op_kernel.h58
1 files changed, 41 insertions, 17 deletions
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index faafe86ade..b6e302c492 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -569,8 +569,11 @@ class OpKernelContext {
int num_inputs() const { return params_->inputs->size(); }
DataType input_dtype(int index) const;
Status input_dtype(StringPiece name, DataType* dtype) const;
+ MemoryType input_memory_type(int index) const;
+
int num_outputs() const { return outputs_.size(); }
DataType expected_output_dtype(int index) const;
+ MemoryType output_memory_type(int index) const;
// Input
@@ -669,16 +672,6 @@ class OpKernelContext {
// REQUIRES: IsRefType(output_dtype(output_index)).
void forward_ref_input_to_ref_output(int input_index, int output_index);
- // Returns true when an alias to input[input_index] that is safe to use for
- // in-place computation was written to *output. Returns false if
- // input[input_index] has a refcount greater than or if its type does not
- // match the expected output type of output[output_index].
- bool forward_input_to_output_with_same_shape(
- int input_index, int output_index, Tensor** output) TF_MUST_USE_RESULT;
- Status forward_input_to_output_with_same_shape(
- StringPiece input_name, StringPiece output_name,
- Tensor** output) TF_MUST_USE_RESULT;
-
// Returns true when an alias to input[input_index], reshaped to output_shape,
// which is is safe to use for in-place computation was written to *output.
// Returns false if input[input_index] has a refcount greater than one, or if
@@ -693,6 +686,19 @@ class OpKernelContext {
const TensorShape& output_shape,
Tensor** output) TF_MUST_USE_RESULT;
+ // Returns a pointer to a Tensor aliasing the underlying buffer backing
+ // input[input_index] iff
+ // * input[input_index] is not a ref,
+ // * the data type, shape, memory type, and allocator attributes of
+ // input[input_index] are compatible with those given in dtype, shape,
+ // memory_type, and attr,
+ // * refcount on the underlying buffer is one.
+ // Otherwise returns nullptr.
+ std::unique_ptr<Tensor> forward_input(
+ int input_index, DataType dtype, const TensorShape& shape,
+ MemoryType memory_type,
+ const AllocatorAttributes& attr) TF_MUST_USE_RESULT;
+
// Tries to forward one of the inputs given in input_indices to
// output[output_index]. If none of the given inputs can be forwarded, calls
// allocate_output() to allocate a new output buffer.
@@ -999,6 +1005,8 @@ class OpKernelContext {
TensorValue release_output(int index);
private:
+ bool input_is_ref(int index) const;
+
Allocator* get_allocator(AllocatorAttributes attr);
// Internal method to add a tensor's buffer to the list of buffers
@@ -1187,7 +1195,7 @@ Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
inline DataType OpKernelContext::input_dtype(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
+ DCHECK_LT(index, num_inputs());
const TensorValue& value((*params_->inputs)[index]);
if (value.is_ref()) {
return MakeRefType(value->dtype());
@@ -1196,12 +1204,28 @@ inline DataType OpKernelContext::input_dtype(int index) const {
}
}
+inline MemoryType OpKernelContext::input_memory_type(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, num_inputs());
+ return op_kernel().input_memory_types()[index];
+}
+
inline DataType OpKernelContext::expected_output_dtype(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->op_kernel->output_types().size());
+ DCHECK_LT(index, num_outputs());
return params_->op_kernel->output_type(index);
}
+inline MemoryType OpKernelContext::output_memory_type(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, num_outputs());
+ return op_kernel().output_memory_types()[index];
+}
+
+inline bool OpKernelContext::input_is_ref(int index) const {
+ return IsRefType(input_dtype(index));
+}
+
inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
params_->record_tensor_accesses);
@@ -1221,14 +1245,14 @@ inline void OpKernelContext::retrieve_accessed_tensors(
// no input if tensor == nullptr.
inline bool OpKernelContext::has_input(int index) const {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
+ DCHECK_LT(index, num_inputs());
return (*params_->inputs)[index].tensor != nullptr;
}
inline mutex* OpKernelContext::input_ref_mutex(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, params_->inputs->size());
- DCHECK((*params_->inputs)[index].is_ref());
+ DCHECK_LT(index, num_inputs());
+ DCHECK(input_is_ref(index));
return (*params_->inputs)[index].mutex_if_ref;
}
@@ -1240,7 +1264,7 @@ inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
inline Tensor* OpKernelContext::mutable_output(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, outputs_.size());
+ DCHECK_LT(index, num_outputs());
// No need to record_tensor_reference since the output must already
// have been set by a call that did so.
return outputs_[index].tensor;
@@ -1248,7 +1272,7 @@ inline Tensor* OpKernelContext::mutable_output(int index) {
inline TensorValue OpKernelContext::release_output(int index) {
DCHECK_GE(index, 0);
- DCHECK_LT(index, outputs_.size());
+ DCHECK_LT(index, num_outputs());
TensorValue value = outputs_[index];
outputs_[index] = TensorValue();
return value;