From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/core/kernels/gather_op.cc | 136 +++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 tensorflow/core/kernels/gather_op.cc (limited to 'tensorflow/core/kernels/gather_op.cc') diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc new file mode 100644 index 0000000000..8bd48f26d6 --- /dev/null +++ b/tensorflow/core/kernels/gather_op.cc @@ -0,0 +1,136 @@ +// See docs in ../ops/array_ops.cc. + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +namespace { +template +void HandleCopies(const Tensor& Tparams, + typename TTypes::ConstVec& Tindices, int slice_elems, + typename TTypes::Matrix Tout) { + const int N = Tindices.dimension(0); + const auto& Tparams_flat = Tparams.flat_outer_dims(); + T* Tout_base = &Tout(0, 0); + const T* Tparams_base = &Tparams_flat(0, 0); + const size_t slice_bytes = slice_elems * sizeof(T); + if (static_slice_elems >= 0) { + // Give compiler static knowledge of the number of elements/bytes + CHECK_EQ(static_slice_elems, slice_elems); + slice_elems = static_slice_elems; + } + for (int i = 0; i < N; i++) { + int j = i + 1; + if (j < N) { + port::prefetch(&Tparams_flat(Tindices(j), 0)); + port::prefetch(&Tout(j, 0)); + } + memcpy(Tout_base + i * slice_elems, + Tparams_base + Tindices(i) * slice_elems, slice_bytes); + } +} + +} // anonymous namespace + +template +class GatherOp : public OpKernel { + public: + // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, + // etc. here for the type of the second input argument. Should + // we have the framework do some sort of integer promotion + // automatically, or should that be something that users have to + // do explicitly with a conversion operator in the graph? + explicit GatherOp(OpKernelConstruction* c) : OpKernel(c) { + const DataType dt = DataTypeToEnum::v(); + const DataType index_t = DataTypeToEnum::v(); + OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt})); + } + + void Compute(OpKernelContext* c) override { + const Tensor& Tparams = c->input(0); + const Tensor& Tindices = c->input(1); + OP_REQUIRES( + c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()), + errors::InvalidArgument("params must be at least 1 dimensional")); + const int64 N = Tindices.NumElements(); + const int64 first_dim_size = Tparams.dim_size(0); + + // Validate all the indices are in range + auto Tindices_vec = Tindices.flat(); + for (int64 i = 0; i < N; i++) { + const Index index = Tindices_vec(i); + OP_REQUIRES(c, index >= 0 && index < first_dim_size, + errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in Tindices is out of range"))); + } + + // The result shape is indices.shape + params.shape[1:]. + TensorShape result_shape = Tindices.shape(); + for (int i = 1; i < Tparams.dims(); i++) { + result_shape.AddDim(Tparams.dim_size(i)); + } + + Tensor* Tout = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &Tout)); + const auto& Tparams_flat = Tparams.flat_outer_dims(); + if (N > 0) { + auto Tindices_flat = Tindices.flat(); + auto Tout_flat = Tout->shaped({N, Tout->NumElements() / N}); + if (DataTypeCanUseMemcpy(DataTypeToEnum::v())) { + const int64 slice_size = Tout->NumElements() / N; +#define SPECIALIZE(elems) \ + do { \ + if (slice_size == elems) { \ + HandleCopies(Tparams, Tindices_flat, slice_size, \ + Tout_flat); \ + return; \ + } \ + } while (0) + + SPECIALIZE(10); + SPECIALIZE(20); + +#undef SPECIALIZE + + HandleCopies(Tparams, Tindices_flat, slice_size, + Tout_flat); + } else { + for (int i = 0; i < N; i++) { + int j = i + 1; + if (j < N) { + port::prefetch( + &Tparams_flat(Tindices_vec(j), 0)); + port::prefetch(&Tout_flat(j, 0)); + } + // Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i] + Tout_flat.template chip<0>(i) = + Tparams_flat.template chip<0>(Tindices_vec(i)); + } + } + } + } +}; + +#define REGISTER_GATHER(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("Gather") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Tparams") \ + .TypeConstraint("Tindices"), \ + GatherOp) + +#define REGISTER_GATHER_INT32(type) REGISTER_GATHER(type, int32) +#define REGISTER_GATHER_INT64(type) REGISTER_GATHER(type, int64) + +TF_CALL_ALL_TYPES(REGISTER_GATHER_INT32); +TF_CALL_ALL_TYPES(REGISTER_GATHER_INT64); + +#undef REGISTER_GATHER_INT32 +#undef REGISTER_GATHER_INT64 +#undef REGISTER_GATHER + +} // namespace tensorflow -- cgit v1.2.3