aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-22 11:15:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-22 12:18:42 -0700
commit2729f5c661360ab31513056ccae285f8d3e584dd (patch)
tree2978877cd2fdaaa7ea8057a8bfd87910639b9f97
parent2f7ebd6baa34df19141368358f07694c5a72dff9 (diff)
Factors out the interface for control flow ops Enter, Exit, Switch, Merge and NextIteration. Registers the Gather op for quantized parameters.
Change: 125593594
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc184
-rw-r--r--tensorflow/core/kernels/control_flow_ops.h65
-rw-r--r--tensorflow/core/kernels/gather_op.cc1
3 files changed, 124 insertions, 126 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index bca686f046..ac81b29860 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -23,36 +23,21 @@ limitations under the License.
namespace tensorflow {
-// A switch op has two inputs and two outputs. It forwards the value of
-// Input:0 to the output specified by input:1. Input:1 is a boolean tensor.
-// Input:0 is forwarded to output:0 if input:1 is false, otherwise to
-// output:1.
-class SwitchOp : public OpKernel {
- public:
- explicit SwitchOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- const Tensor& outputPorts = context->input(1);
- OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()),
- errors::InvalidArgument("The second input must be a scalar, "
- "but it has shape ",
- outputPorts.shape().DebugString()));
-
- bool pred = outputPorts.scalar<bool>()();
- int port = (pred) ? 1 : 0;
- if (IsRefType(context->input_dtype(0))) {
- context->forward_ref_input_to_ref_output(0, port);
- } else {
- context->set_output(port, context->input(0));
- }
+void SwitchOp::Compute(OpKernelContext* context) {
+ const Tensor& outputPorts = context->input(1);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(outputPorts.shape()),
+ errors::InvalidArgument("The second input must be a scalar, "
+ "but it has shape ",
+ outputPorts.shape().DebugString()));
+
+ bool pred = outputPorts.scalar<bool>()();
+ int port = (pred) ? 1 : 0;
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, port);
+ } else {
+ context->set_output(port, context->input(0));
}
-
- bool IsExpensive() override { return false; }
-
- ~SwitchOp() override {}
-
- TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp);
-};
+}
#define REGISTER_CPU_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \
@@ -167,48 +152,36 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT);
#undef REGISTER_CPU_REF_SWITCH
-// A merge op has n inputs and two outputs. It forwards the value of the
-// first input that becomes available to its first output, and the
-// index of the first input to its second output.
-class MergeOp : public OpKernel {
- public:
- explicit MergeOp(OpKernelConstruction* context) : OpKernel(context) {
- const DataType dt = context->input_type(0);
- const int num_in = context->num_inputs();
- OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt),
- {dt, DT_INT32}));
- }
+MergeOp::MergeOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = context->input_type(0);
+ const int num_in = context->num_inputs();
+ OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt),
+ {dt, DT_INT32}));
+}
+
+void MergeOp::Compute(OpKernelContext* context) {
+ bool input_seen = false;
+ for (int i = 0; i < context->num_inputs(); ++i) {
+ if (context->has_input(i)) {
+ if (input_seen) {
+ context->SetStatus(
+ errors::Internal("Merge can not have more than one valid input."));
+ return;
+ }
+ input_seen = true;
- void Compute(OpKernelContext* context) override {
- bool input_seen = false;
- for (int i = 0; i < context->num_inputs(); ++i) {
- if (context->has_input(i)) {
- if (input_seen) {
- context->SetStatus(errors::Internal(
- "Merge can not have more than one valid input."));
- return;
- }
- input_seen = true;
-
- if (IsRefType(context->input_dtype(i))) {
- context->forward_ref_input_to_ref_output(i, 0);
- } else {
- context->set_output(0, context->input(i));
- }
- Tensor* value_index = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}),
- &value_index));
- value_index->scalar<int32>()() = i;
+ if (IsRefType(context->input_dtype(i))) {
+ context->forward_ref_input_to_ref_output(i, 0);
+ } else {
+ context->set_output(0, context->input(i));
}
+ Tensor* value_index = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(1, TensorShape({}), &value_index));
+ value_index->scalar<int32>()() = i;
}
}
-
- bool IsExpensive() override { return false; }
-
- ~MergeOp() override {}
-
- TF_DISALLOW_COPY_AND_ASSIGN(MergeOp);
-};
+}
REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);
@@ -259,27 +232,13 @@ REGISTER_GPU_HOST_KERNEL(string);
#undef REGISTER_GPU_HOST_KERNEL
-// An enter op has one input and one output. It creates or finds
-// the child frame that is uniquely identified by the frame_name,
-// and makes its input available to the child frame.
-class EnterOp : public OpKernel {
- public:
- explicit EnterOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- if (IsRefType(context->input_dtype(0))) {
- context->forward_ref_input_to_ref_output(0, 0);
- } else {
- context->set_output(0, context->input(0));
- }
+void EnterOp::Compute(OpKernelContext* context) {
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ context->set_output(0, context->input(0));
}
-
- bool IsExpensive() override { return false; }
-
- ~EnterOp() override {}
-
- TF_DISALLOW_COPY_AND_ASSIGN(EnterOp);
-};
+}
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp);
REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
@@ -326,27 +285,13 @@ REGISTER_GPU_HOST_REF_KERNEL(string);
#undef REGISTER_GPU_HOST_KERNEL
#undef REGISTER_GPU_HOST_REF_KERNEL
-// An exit op has one input and one output. It exits the current
-// frame to its parent frame, and makes its input available to the
-// parent frame.
-class ExitOp : public OpKernel {
- public:
- explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- if (IsRefType(context->input_dtype(0))) {
- context->forward_ref_input_to_ref_output(0, 0);
- } else {
- context->set_output(0, context->input(0));
- }
+void ExitOp::Compute(OpKernelContext* context) {
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ context->set_output(0, context->input(0));
}
-
- bool IsExpensive() override { return false; }
-
- ~ExitOp() override {}
-
- TF_DISALLOW_COPY_AND_ASSIGN(ExitOp);
-};
+}
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp);
REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);
@@ -386,26 +331,13 @@ REGISTER_GPU_HOST_KERNEL(string);
#undef REGISTER_GPU_HOST_KERNEL
-// A next_iteration op has one input and one output. It makes its input
-// available to the next iteration.
-class NextIterationOp : public OpKernel {
- public:
- explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- if (IsRefType(context->input_dtype(0))) {
- context->forward_ref_input_to_ref_output(0, 0);
- } else {
- context->set_output(0, context->input(0));
- }
+void NextIterationOp::Compute(OpKernelContext* context) {
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ context->set_output(0, context->input(0));
}
-
- bool IsExpensive() override { return false; }
-
- ~NextIterationOp() override {}
-
- TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp);
-};
+}
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU),
NextIterationOp);
diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h
index 0a1c68992b..4838f2e2bf 100644
--- a/tensorflow/core/kernels/control_flow_ops.h
+++ b/tensorflow/core/kernels/control_flow_ops.h
@@ -32,6 +32,71 @@ class ControlTriggerOp : public OpKernel {
bool IsExpensive() override { return false; }
};
+// A switch op has two inputs and two outputs. It forwards the value of
+// Input:0 to the output specified by input:1. Input:1 is a boolean tensor.
+// Input:0 is forwarded to output:0 if input:1 is false, otherwise to
+// output:1.
+class SwitchOp : public OpKernel {
+ public:
+ explicit SwitchOp(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override;
+ bool IsExpensive() override { return false; }
+ ~SwitchOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp);
+};
+
+// A merge op has n inputs and two outputs. It forwards the value of the
+// first input that becomes available to its first output, and the
+// index of the first input to its second output.
+class MergeOp : public OpKernel {
+ public:
+ explicit MergeOp(OpKernelConstruction* context);
+ void Compute(OpKernelContext* context) override;
+ bool IsExpensive() override { return false; }
+ ~MergeOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MergeOp);
+};
+
+// An enter op has one input and one output. It creates or finds
+// the child frame that is uniquely identified by the frame_name,
+// and makes its input available to the child frame.
+class EnterOp : public OpKernel {
+ public:
+ explicit EnterOp(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override;
+ bool IsExpensive() override { return false; }
+ ~EnterOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EnterOp);
+};
+
+// An exit op has one input and one output. It exits the current
+// frame to its parent frame, and makes its input available to the
+// parent frame.
+class ExitOp : public OpKernel {
+ public:
+ explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override;
+ bool IsExpensive() override { return false; }
+ ~ExitOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExitOp);
+};
+
+// A next_iteration op has one input and one output. It makes its input
+// available to the next iteration.
+class NextIterationOp : public OpKernel {
+ public:
+ explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override;
+ bool IsExpensive() override { return false; }
+ ~NextIterationOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp);
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index f73761a0a7..edce9a3197 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -185,6 +185,7 @@ struct Gather<CPUDevice, T, Index> {
#define REGISTER_GATHER_CPU(type) REGISTER_GATHER_ALL_INDICES(CPU, type)
TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
+TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
#undef REGISTER_GATHER_CPU