diff options
Diffstat (limited to 'tensorflow/core/kernels/in_topk_op.cc')
-rw-r--r-- | tensorflow/core/kernels/in_topk_op.cc | 52 |
1 files changed, 47 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc index 13890e5b7f..e2861ae090 100644 --- a/tensorflow/core/kernels/in_topk_op.cc +++ b/tensorflow/core/kernels/in_topk_op.cc @@ -17,11 +17,11 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { @@ -29,12 +29,29 @@ template <typename T, typename TARGET_T> class InTopK : public OpKernel { public: explicit InTopK(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); + if (context->num_inputs() == 2) { + OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); + } } void Compute(OpKernelContext* context) override { const auto& predictions_in = context->input(0); const auto& targets_in = context->input(1); + int64 k_val = k_; + if (context->num_inputs() == 3) { + const auto& k_in = context->input(2); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_in.shape()), + errors::InvalidArgument("k must be 0-D, got shape ", + k_in.shape().DebugString())); + + if (k_in.dtype() == DT_INT32) { + k_val = k_in.scalar<int32>()(); + } else { + k_val = k_in.scalar<int64>()(); + } + } + OP_REQUIRES(context, predictions_in.dims() == 2, errors::InvalidArgument("predictions must be 2-dimensional")); OP_REQUIRES(context, targets_in.dims() == 1, @@ -73,7 +90,7 @@ class InTopK : public OpKernel { } } } - out(b) = cannot_say ? false : (more_probable_classes < k_); + out(b) = cannot_say ? false : (more_probable_classes < k_val); } } @@ -82,10 +99,35 @@ class InTopK : public OpKernel { }; REGISTER_KERNEL_BUILDER( - Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int32>("T"), + Name("InTopK").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("precision") + .TypeConstraint<int32>("T"), + InTopK<float, int32>); +REGISTER_KERNEL_BUILDER( + Name("InTopK").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("precision") + .TypeConstraint<int64>("T"), + InTopK<float, int64>); + +REGISTER_KERNEL_BUILDER( + Name("InTopKV2").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("k") + .HostMemory("precision") + .TypeConstraint<int32>("T"), InTopK<float, int32>); REGISTER_KERNEL_BUILDER( - Name("InTopK").Device(DEVICE_CPU).TypeConstraint<int64>("T"), + Name("InTopKV2").Device(DEVICE_CPU) + .HostMemory("predictions") + .HostMemory("targets") + .HostMemory("k") + .HostMemory("precision") + .TypeConstraint<int64>("T"), InTopK<float, int64>); } // namespace tensorflow |