From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/core/kernels/dynamic_stitch_op.cc | 158 +++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 tensorflow/core/kernels/dynamic_stitch_op.cc (limited to 'tensorflow/core/kernels/dynamic_stitch_op.cc') diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc new file mode 100644 index 0000000000..a5623685fb --- /dev/null +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -0,0 +1,158 @@ +// See docs in ../ops/data_flow_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +template +class DynamicStitchOp : public OpKernel { + public: + explicit DynamicStitchOp(OpKernelConstruction* c) : OpKernel(c) { + // Compute expected input signature + const DataType dt = DataTypeToEnum::v(); + const int n = c->num_inputs() / 2; + DataTypeVector expected; + for (int i = 0; i < n; i++) { + expected.push_back(DT_INT32); + } + for (int i = 0; i < n; i++) { + expected.push_back(dt); + } + OP_REQUIRES_OK(c, c->MatchSignature(expected, {dt})); + OP_REQUIRES( + c, c->num_inputs() > 0, + errors::InvalidArgument("DynamicStitchOp: Must have some inputs")); + OP_REQUIRES(c, c->num_inputs() % 2 == 0, + errors::InvalidArgument( + "DynamicStitchOp: Must have even number of arguments")); + } + + void Compute(OpKernelContext* c) override { + // Find maximum index in the indices vectors + OpInputList indices_inputs; + OP_REQUIRES_OK(c, c->input_list("indices", &indices_inputs)); + + int32 max_index = -1; + for (const Tensor& indices : indices_inputs) { + Eigen::Tensor m = + indices.flat().maximum(); + max_index = std::max(m(), max_index); + } + const int first_dim_size = max_index + 1; + + // Validate that data[i].shape = indices[i].shape + constant + OpInputList data_inputs; + OP_REQUIRES_OK(c, c->input_list("data", &data_inputs)); + const Tensor& data0 = data_inputs[0]; + const Tensor& indices0 = indices_inputs[0]; + for (int input_num = 0; input_num < indices_inputs.size(); input_num++) { + const Tensor& indices = indices_inputs[input_num]; + const Tensor& data = data_inputs[input_num]; + OP_REQUIRES( + c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()), + errors::InvalidArgument( + "data[", input_num, "].shape = ", data.shape().ShortDebugString(), + " does not start with indices[", input_num, "].shape = ", + indices.shape().ShortDebugString())); + OP_REQUIRES( + c, input_num == 0 || SameExtraShape(data0, indices0, data, indices), + errors::InvalidArgument( + "Need data[0].shape[", indices0.dims(), ":] = data[", input_num, + "].shape[", indices.dims(), ":], got data[0].shape = ", + data0.shape().ShortDebugString(), ", data[", input_num, + "].shape = ", data.shape().ShortDebugString(), + ", indices[0].shape = ", indices0.shape().ShortDebugString(), + ", indices[", input_num, "].shape = ", + indices.shape().ShortDebugString())); + } + + // Allocate result tensor of shape + // [first_dim_size] + data.shape[indices.dims:] + TensorShape result_shape; + result_shape.AddDim(first_dim_size); + for (int d = indices0.dims(); d < data0.dims(); d++) { + result_shape.AddDim(data0.dim_size(d)); + } + Tensor* merged = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &merged)); + + // TODO(jeff): Currently we leave uninitialized any portions of + // merged that aren't covered by an index in indices. What should we do? + if (first_dim_size > 0) { + auto merged_flat = merged->flat_outer_dims(); + const int slice_size = merged_flat.dimension(1); + for (int input_num = 0; input_num < indices_inputs.size(); input_num++) { + const Tensor& indices = indices_inputs[input_num]; + auto indices_vec = indices.flat(); + const Tensor& data = data_inputs[input_num]; + auto data_flat = + data.shaped({indices_vec.dimension(0), slice_size}); + + if (DataTypeCanUseMemcpy(DataTypeToEnum::v())) { + T* merged_base = &merged_flat(0, 0); + const T* data_base = &data_flat(0, 0); + const size_t slice_bytes = slice_size * sizeof(T); + for (int i = 0; i < indices_vec.size(); i++) { + memcpy(merged_base + indices_vec(i) * slice_size, + data_base + i * slice_size, slice_bytes); + } + } else { + Eigen::DSizes sizes(1, slice_size); + for (int i = 0; i < indices_vec.size(); i++) { + // Copy slice data[i] to merged[indices[i]] + Eigen::DSizes data_indices(i, 0); + Eigen::DSizes merged_indices(indices_vec(i), + 0); + merged_flat.slice(merged_indices, sizes) = + data_flat.slice(data_indices, sizes); + } + } + } + } + } + + private: + // Check if data0.shape[indices0.dims():] == data1.shape[indices1.dims():] + static bool SameExtraShape(const Tensor& data0, const Tensor& indices0, + const Tensor& data1, const Tensor& indices1) { + const int extra0 = data0.dims() - indices0.dims(); + const int extra1 = data1.dims() - indices1.dims(); + if (extra0 != extra1) return false; + for (int i = 0; i < extra0; i++) { + if (data0.dim_size(indices0.dims() + i) != + data1.dim_size(indices1.dims() + i)) { + return false; + } + } + return true; + } +}; + +#define REGISTER_DYNAMIC_STITCH(type) \ + REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("indices"), \ + DynamicStitchOp) + +TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH); +#undef REGISTER_DYNAMIC_STITCH + +#if GOOGLE_CUDA +#define REGISTER_DYNAMIC_STITCH_GPU(type) \ + REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("indices") \ + .HostMemory("data") \ + .HostMemory("merged"), \ + DynamicStitchOp) + +TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH_GPU); +#undef REGISTER_DYNAMIC_STITCH_GPU + +#endif // GOOGLE_CUDA + +} // namespace tensorflow -- cgit v1.2.3