aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/in_topk_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/in_topk_op.cc')
-rw-r--r--tensorflow/core/kernels/in_topk_op.cc52
1 files changed, 47 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc
index 13890e5b7f..e2861ae090 100644
--- a/tensorflow/core/kernels/in_topk_op.cc
+++ b/tensorflow/core/kernels/in_topk_op.cc
@@ -17,11 +17,11 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
@@ -29,12 +29,29 @@ template <typename T, typename TARGET_T>
class InTopK : public OpKernel {
public:
explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
+ if (context->num_inputs() == 2) {
+ OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
+ }
}
void Compute(OpKernelContext* context) override {
const auto& predictions_in = context->input(0);
const auto& targets_in = context->input(1);
+ int64 k_val = k_;
+ if (context->num_inputs() == 3) {
+ const auto& k_in = context->input(2);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()),
+ errors::InvalidArgument("k must be 0-D, got shape ",
+ k_in.shape().DebugString()));
+
+ if (k_in.dtype() == DT_INT32) {
+ k_val = k_in.scalar<int32>()();
+ } else {
+ k_val = k_in.scalar<int64>()();
+ }
+ }
+
OP_REQUIRES(context, predictions_in.dims() == 2,
errors::InvalidArgument("predictions must be 2-dimensional"));
OP_REQUIRES(context, targets_in.dims() == 1,
@@ -73,7 +90,7 @@ class InTopK : public OpKernel {
}
}
}
- out(b) = cannot_say ? false : (more_probable_classes < k_);
+ out(b) = cannot_say ? false : (more_probable_classes < k_val);
}
}
@@ -82,10 +99,35 @@ class InTopK : public OpKernel {
};
REGISTER_KERNEL_BUILDER(
- Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int32>("T"),
+ Name("InTopK").Device(DEVICE_CPU)
+ .HostMemory("predictions")
+ .HostMemory("targets")
+ .HostMemory("precision")
+ .TypeConstraint<int32>("T"),
+ InTopK<float, int32>);
+REGISTER_KERNEL_BUILDER(
+ Name("InTopK").Device(DEVICE_CPU)
+ .HostMemory("predictions")
+ .HostMemory("targets")
+ .HostMemory("precision")
+ .TypeConstraint<int64>("T"),
+ InTopK<float, int64>);
+
+REGISTER_KERNEL_BUILDER(
+ Name("InTopKV2").Device(DEVICE_CPU)
+ .HostMemory("predictions")
+ .HostMemory("targets")
+ .HostMemory("k")
+ .HostMemory("precision")
+ .TypeConstraint<int32>("T"),
InTopK<float, int32>);
REGISTER_KERNEL_BUILDER(
- Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int64>("T"),
+ Name("InTopKV2").Device(DEVICE_CPU)
+ .HostMemory("predictions")
+ .HostMemory("targets")
+ .HostMemory("k")
+ .HostMemory("precision")
+ .TypeConstraint<int64>("T"),
InTopK<float, int64>);
} // namespace tensorflow