diff options
author | 2016-06-22 11:15:23 -0800 | |
---|---|---|
committer | 2016-06-22 12:18:42 -0700 | |
commit | 2729f5c661360ab31513056ccae285f8d3e584dd (patch) | |
tree | 2978877cd2fdaaa7ea8057a8bfd87910639b9f97 | |
parent | 2f7ebd6baa34df19141368358f07694c5a72dff9 (diff) |
Factors out the interface for control flow ops Enter, Exit, Switch, Merge and NextIteration. Registers the Gather op for quantized parameters.
Change: 125593594
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.cc | 184 | ||||
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.h | 65 | ||||
-rw-r--r-- | tensorflow/core/kernels/gather_op.cc | 1 |
3 files changed, 124 insertions, 126 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index bca686f046..ac81b29860 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -23,36 +23,21 @@ limitations under the License. namespace tensorflow { -// A switch op has two inputs and two outputs. It forwards the value of -// Input:0 to the output specified by input:1. Input:1 is a boolean tensor. -// Input:0 is forwarded to output:0 if input:1 is false, otherwise to -// output:1. -class SwitchOp : public OpKernel { - public: - explicit SwitchOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - const Tensor& outputPorts = context->input(1); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()), - errors::InvalidArgument("The second input must be a scalar, " - "but it has shape ", - outputPorts.shape().DebugString())); - - bool pred = outputPorts.scalar<bool>()(); - int port = (pred) ? 1 : 0; - if (IsRefType(context->input_dtype(0))) { - context->forward_ref_input_to_ref_output(0, port); - } else { - context->set_output(port, context->input(0)); - } +void SwitchOp::Compute(OpKernelContext* context) { + const Tensor& outputPorts = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()), + errors::InvalidArgument("The second input must be a scalar, " + "but it has shape ", + outputPorts.shape().DebugString())); + + bool pred = outputPorts.scalar<bool>()(); + int port = (pred) ? 1 : 0; + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, port); + } else { + context->set_output(port, context->input(0)); } - - bool IsExpensive() override { return false; } - - ~SwitchOp() override {} - - TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp); -}; +} #define REGISTER_CPU_SWITCH(type) \ REGISTER_KERNEL_BUILDER(Name("Switch") \ @@ -167,48 +152,36 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT); #undef REGISTER_CPU_REF_SWITCH -// A merge op has n inputs and two outputs. It forwards the value of the -// first input that becomes available to its first output, and the -// index of the first input to its second output. -class MergeOp : public OpKernel { - public: - explicit MergeOp(OpKernelConstruction* context) : OpKernel(context) { - const DataType dt = context->input_type(0); - const int num_in = context->num_inputs(); - OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt), - {dt, DT_INT32})); - } +MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = context->input_type(0); + const int num_in = context->num_inputs(); + OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt), + {dt, DT_INT32})); +} + +void MergeOp::Compute(OpKernelContext* context) { + bool input_seen = false; + for (int i = 0; i < context->num_inputs(); ++i) { + if (context->has_input(i)) { + if (input_seen) { + context->SetStatus( + errors::Internal("Merge can not have more than one valid input.")); + return; + } + input_seen = true; - void Compute(OpKernelContext* context) override { - bool input_seen = false; - for (int i = 0; i < context->num_inputs(); ++i) { - if (context->has_input(i)) { - if (input_seen) { - context->SetStatus(errors::Internal( - "Merge can not have more than one valid input.")); - return; - } - input_seen = true; - - if (IsRefType(context->input_dtype(i))) { - context->forward_ref_input_to_ref_output(i, 0); - } else { - context->set_output(0, context->input(i)); - } - Tensor* value_index = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), - &value_index)); - value_index->scalar<int32>()() = i; + if (IsRefType(context->input_dtype(i))) { + context->forward_ref_input_to_ref_output(i, 0); + } else { + context->set_output(0, context->input(i)); } + Tensor* value_index = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(1, TensorShape({}), &value_index)); + value_index->scalar<int32>()() = i; } } - - bool IsExpensive() override { return false; } - - ~MergeOp() override {} - - TF_DISALLOW_COPY_AND_ASSIGN(MergeOp); -}; +} REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp); REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp); @@ -259,27 +232,13 @@ REGISTER_GPU_HOST_KERNEL(string); #undef REGISTER_GPU_HOST_KERNEL -// An enter op has one input and one output. It creates or finds -// the child frame that is uniquely identified by the frame_name, -// and makes its input available to the child frame. -class EnterOp : public OpKernel { - public: - explicit EnterOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - if (IsRefType(context->input_dtype(0))) { - context->forward_ref_input_to_ref_output(0, 0); - } else { - context->set_output(0, context->input(0)); - } +void EnterOp::Compute(OpKernelContext* context) { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); } - - bool IsExpensive() override { return false; } - - ~EnterOp() override {} - - TF_DISALLOW_COPY_AND_ASSIGN(EnterOp); -}; +} REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp); REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp); @@ -326,27 +285,13 @@ REGISTER_GPU_HOST_REF_KERNEL(string); #undef REGISTER_GPU_HOST_KERNEL #undef REGISTER_GPU_HOST_REF_KERNEL -// An exit op has one input and one output. It exits the current -// frame to its parent frame, and makes its input available to the -// parent frame. -class ExitOp : public OpKernel { - public: - explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - if (IsRefType(context->input_dtype(0))) { - context->forward_ref_input_to_ref_output(0, 0); - } else { - context->set_output(0, context->input(0)); - } +void ExitOp::Compute(OpKernelContext* context) { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); } - - bool IsExpensive() override { return false; } - - ~ExitOp() override {} - - TF_DISALLOW_COPY_AND_ASSIGN(ExitOp); -}; +} REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp); REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp); @@ -386,26 +331,13 @@ REGISTER_GPU_HOST_KERNEL(string); #undef REGISTER_GPU_HOST_KERNEL -// A next_iteration op has one input and one output. It makes its input -// available to the next iteration. -class NextIterationOp : public OpKernel { - public: - explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - if (IsRefType(context->input_dtype(0))) { - context->forward_ref_input_to_ref_output(0, 0); - } else { - context->set_output(0, context->input(0)); - } +void NextIterationOp::Compute(OpKernelContext* context) { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); } - - bool IsExpensive() override { return false; } - - ~NextIterationOp() override {} - - TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp); -}; +} REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU), NextIterationOp); diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h index 0a1c68992b..4838f2e2bf 100644 --- a/tensorflow/core/kernels/control_flow_ops.h +++ b/tensorflow/core/kernels/control_flow_ops.h @@ -32,6 +32,71 @@ class ControlTriggerOp : public OpKernel { bool IsExpensive() override { return false; } }; +// A switch op has two inputs and two outputs. It forwards the value of +// Input:0 to the output specified by input:1. Input:1 is a boolean tensor. +// Input:0 is forwarded to output:0 if input:1 is false, otherwise to +// output:1. +class SwitchOp : public OpKernel { + public: + explicit SwitchOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~SwitchOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp); +}; + +// A merge op has n inputs and two outputs. It forwards the value of the +// first input that becomes available to its first output, and the +// index of the first input to its second output. +class MergeOp : public OpKernel { + public: + explicit MergeOp(OpKernelConstruction* context); + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~MergeOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(MergeOp); +}; + +// An enter op has one input and one output. It creates or finds +// the child frame that is uniquely identified by the frame_name, +// and makes its input available to the child frame. +class EnterOp : public OpKernel { + public: + explicit EnterOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~EnterOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(EnterOp); +}; + +// An exit op has one input and one output. It exits the current +// frame to its parent frame, and makes its input available to the +// parent frame. +class ExitOp : public OpKernel { + public: + explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~ExitOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(ExitOp); +}; + +// A next_iteration op has one input and one output. It makes its input +// available to the next iteration. +class NextIterationOp : public OpKernel { + public: + explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~NextIterationOp() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp); +}; + } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_ diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index f73761a0a7..edce9a3197 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -185,6 +185,7 @@ struct Gather<CPUDevice, T, Index> { #define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type) TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU); +TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU); #undef REGISTER_GATHER_CPU |