aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/serialize_sparse_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/serialize_sparse_op.cc')
-rw-r--r--tensorflow/core/kernels/serialize_sparse_op.cc273
1 files changed, 7 insertions, 266 deletions
diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc
index 4ad653601a..577e327809 100644
--- a/tensorflow/core/kernels/serialize_sparse_op.cc
+++ b/tensorflow/core/kernels/serialize_sparse_op.cc
@@ -36,6 +36,8 @@ limitations under the License.
namespace tensorflow {
+namespace {
+
using sparse::SparseTensor;
template <typename T>
@@ -188,8 +190,10 @@ class SerializeManySparseOp : public SerializeManySparseOpBase<U> {
TensorShape tensor_input_shape(input_shape->vec<int64>());
gtl::InlinedVector<int64, 8> std_order(rank);
std::iota(std_order.begin(), std_order.end(), 0);
- SparseTensor input_st(*input_indices, *input_values, tensor_input_shape,
- std_order);
+ SparseTensor input_st;
+ OP_REQUIRES_OK(context, SparseTensor::Create(*input_indices, *input_values,
+ tensor_input_shape, std_order,
+ &input_st));
auto input_shape_t = input_shape->vec<int64>();
const int64 N = input_shape_t(0);
@@ -306,269 +310,6 @@ Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input,
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
-template <typename T>
-class DeserializeSparseOp : public OpKernel {
- public:
- explicit DeserializeSparseOp(OpKernelConstruction* context)
- : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
- }
-
- void Compute(OpKernelContext* context) override {
- const Tensor& serialized_sparse = context->input(0);
- const int ndims = serialized_sparse.shape().dims();
-
- OP_REQUIRES(
- context, ndims > 0,
- errors::InvalidArgument("Serialized sparse should have non-zero rank ",
- serialized_sparse.shape().DebugString()));
-
- OP_REQUIRES(context, serialized_sparse.shape().dim_size(ndims - 1) == 3,
- errors::InvalidArgument(
- "Serialized sparse should have 3 as the last dimension ",
- serialized_sparse.shape().DebugString()));
-
- int num_sparse_tensors = 1;
- for (int i = 0; i < ndims - 1; ++i) {
- num_sparse_tensors *= serialized_sparse.shape().dim_size(i);
- }
-
- OP_REQUIRES(
- context, num_sparse_tensors > 0,
- errors::InvalidArgument(
- "Serialized sparse should have at least 1 serialized tensor, "
- "but has a zero dimension ",
- serialized_sparse.shape().DebugString()));
-
- if (num_sparse_tensors == 1 && ndims == 1) {
- // Special case with a single sparse tensor. We can avoid data
- // motion in the Concat and Reshape.
- const auto& serialized_sparse_t = serialized_sparse.vec<T>();
-
- Tensor output_indices;
- Tensor output_values;
- Tensor output_shape;
- OP_REQUIRES_OK(context,
- this->GetAndValidateSparseTensor(
- serialized_sparse_t(0), serialized_sparse_t(1),
- serialized_sparse_t(2), dtype_, 0 /* index */,
- &output_indices, &output_values, &output_shape));
- context->set_output(0, output_indices);
- context->set_output(1, output_values);
- context->set_output(2, output_shape);
- return;
- }
-
- std::vector<Tensor> indices;
- std::vector<Tensor> values;
- TensorShape shape;
- indices.reserve(num_sparse_tensors);
- values.reserve(num_sparse_tensors);
-
- const auto& serialized_sparse_t = serialized_sparse.flat_inner_dims<T, 2>();
- for (int i = 0; i < num_sparse_tensors; ++i) {
- Tensor output_indices;
- Tensor output_values;
- Tensor output_shape;
- OP_REQUIRES_OK(context,
- this->GetAndValidateSparseTensor(
- serialized_sparse_t(i, 0), serialized_sparse_t(i, 1),
- serialized_sparse_t(i, 2), dtype_, i, &output_indices,
- &output_values, &output_shape));
- int64 num_entries = output_indices.dim_size(0);
- int rank = output_indices.dim_size(1);
-
- // Now we expand each SparseTensors' indices and shape by
- // prefixing a dimension
- Tensor expanded_indices(DT_INT64, TensorShape({num_entries, 1 + rank}));
- const auto& output_indices_t = output_indices.matrix<int64>();
- auto expanded_indices_t = expanded_indices.matrix<int64>();
- expanded_indices_t.chip<1>(0).setZero();
- if (rank > 0) {
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
- expanded_indices_t.slice(indices_start, indices_sizes) =
- output_indices_t;
- }
- Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
- const auto& output_shape_t = output_shape.vec<int64>();
- auto expanded_shape_t = expanded_shape.vec<int64>();
- expanded_shape_t(0) = 1;
- std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
-
- TensorShape expanded_tensor_shape(expanded_shape.vec<int64>());
-
- indices.push_back(expanded_indices);
- values.push_back(output_values);
- if (i == 0) {
- shape = expanded_tensor_shape;
- } else {
- OP_REQUIRES(
- context, shape.dims() == expanded_tensor_shape.dims(),
- errors::InvalidArgument(
- "Inconsistent shape across SparseTensors: rank prior to "
- "SparseTensor[",
- i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
- "] is: ", expanded_tensor_shape.dims() - 1));
- for (int j = 1; j < shape.dims(); ++j) {
- // NOTE(mrry): For compatibility with the implementations of
- // DeserializeManySparse, and many ops that generate
- // SparseTensors to batch that do not have a fixed
- // dense_shape (e.g. `tf.parse_single_example()`), we
- // compute the maximum in each dimension to find the
- // smallest dense_shape that bounds all of the input
- // SparseTensors.
- shape.set_dim(j, std::max(shape.dim_size(j),
- expanded_tensor_shape.dim_size(j)));
- }
- }
- }
-
- // Dimension 0 is the primary dimension.
- int rank = shape.dims();
- gtl::InlinedVector<int64, 8> std_order(rank);
- std::iota(std_order.begin(), std_order.end(), 0);
-
- std::vector<SparseTensor> tensors;
- tensors.reserve(num_sparse_tensors);
- for (int i = 0; i < num_sparse_tensors; ++i) {
- tensors.emplace_back(indices[i], values[i], shape, std_order);
- }
-
- gtl::optional<SparseTensor> maybe_output;
-#define HANDLE_TYPE(T) \
- case DataTypeToEnum<T>::value: { \
- maybe_output = SparseTensor::Concat<T>(tensors); \
- break; \
- }
-
- switch (dtype_) {
- TF_CALL_ALL_TYPES(HANDLE_TYPE);
- TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
-#undef HANDLE_TYPE
- default:
- OP_REQUIRES(context, false,
- errors::Unimplemented(
- "DeserializeSparse Unhandled data type: ", dtype_));
- }
- DCHECK(maybe_output);
- SparseTensor& output = maybe_output.value();
-
- // Compute the input shape for the reshape operation.
- Tensor input_shape(DT_INT64, TensorShape({output.dims()}));
- std::copy_n(output.shape().data(), output.dims(),
- input_shape.vec<int64>().data());
-
- // Compute the target shape for the reshape operation.
- Tensor target_shape(DT_INT64, TensorShape({ndims + output.dims() - 2}));
- for (int i = 0; i < ndims - 1; ++i) {
- target_shape.vec<int64>()(i) = serialized_sparse.shape().dim_size(i);
- }
- for (int i = 0; i < output.dims() - 1; ++i) {
- target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
- }
-
- Tensor output_indices;
- Tensor output_shape;
- Reshape(context, output.indices(), input_shape, target_shape,
- 0 /* output indices index */, 2 /* output shape index */);
- context->set_output(1, output.values());
- }
-
- protected:
- Status Deserialize(const T& serialized, Tensor* result);
-
- Status GetAndValidateSparseTensor(
- const T& serialized_indices, const T& serialized_values,
- const T& serialized_shape, DataType values_dtype, int index,
- Tensor* output_indices, Tensor* output_values, Tensor* output_shape) {
- // Deserialize and validate the indices.
- TF_RETURN_IF_ERROR(this->Deserialize(serialized_indices, output_indices));
- if (!TensorShapeUtils::IsMatrix(output_indices->shape())) {
- return errors::InvalidArgument(
- "Expected serialized_sparse[", index,
- ", 0] to represent an index matrix but received shape ",
- output_indices->shape().DebugString());
- }
- int64 num_entries = output_indices->dim_size(0);
- int rank = output_indices->dim_size(1);
-
- // Deserialize and validate the values.
- TF_RETURN_IF_ERROR(this->Deserialize(serialized_values, output_values));
- if (!TensorShapeUtils::IsVector(output_values->shape())) {
- return errors::InvalidArgument(
- "Expected serialized_sparse[", index,
- ", 1] to represent a values vector but received shape ",
- output_values->shape().DebugString());
- }
- if (values_dtype != output_values->dtype()) {
- return errors::InvalidArgument(
- "Requested SparseTensor of type ", DataTypeString(values_dtype),
- " but SparseTensor[", index,
- "].values.dtype() == ", DataTypeString(output_values->dtype()));
- }
- if (num_entries != output_values->dim_size(0)) {
- return errors::InvalidArgument(
- "Expected row counts of SparseTensor[", index,
- "].indices and SparseTensor[", index,
- "].values to match but they do not: ", num_entries, " vs. ",
- output_values->dim_size(0));
- }
-
- // Deserialize and validate the shape.
- TF_RETURN_IF_ERROR(this->Deserialize(serialized_shape, output_shape));
- if (!TensorShapeUtils::IsVector(output_shape->shape())) {
- return errors::InvalidArgument(
- "Expected serialized_sparse[", index,
- ", 1] to be a shape vector but its shape is ",
- output_shape->shape().DebugString());
- }
- if (rank != output_shape->dim_size(0)) {
- return errors::InvalidArgument("Expected column counts of SparseTensor[",
- index,
- "].indices to match size of SparseTensor[",
- index, "].shape but they do not: ", rank,
- " vs. ", output_shape->dim_size(0));
- }
- return Status::OK();
- }
-
- DataType dtype_;
-};
-
-template <>
-Status DeserializeSparseOp<string>::Deserialize(const string& serialized,
- Tensor* result) {
- TensorProto proto;
- if (!ParseProtoUnlimited(&proto, serialized)) {
- return errors::InvalidArgument("Could not parse serialized proto");
- }
- Tensor tensor;
- if (!tensor.FromProto(proto)) {
- return errors::InvalidArgument("Could not construct tensor from proto");
- }
- *result = tensor;
- return Status::OK();
-}
-
-REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
- .Device(DEVICE_CPU)
- .TypeConstraint<string>("Tserialized"),
- DeserializeSparseOp<string>)
-
-REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU),
- DeserializeSparseOp<string>)
-
-template <>
-Status DeserializeSparseOp<Variant>::Deserialize(const Variant& serialized,
- Tensor* result) {
- *result = *serialized.get<Tensor>();
- return Status::OK();
-}
-
-REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
- .Device(DEVICE_CPU)
- .TypeConstraint<Variant>("Tserialized"),
- DeserializeSparseOp<Variant>)
+} // namespace
} // namespace tensorflow