diff options
author | 2016-07-21 16:23:48 -0800 | |
---|---|---|
committer | 2016-07-21 17:38:15 -0700 | |
commit | f3b61805997f2a08fab35bd02f0f6be18e6cd8cc (patch) | |
tree | 9542431121d251af4e7147ecba2464671a4a3c31 /tensorflow/core/ops/sparse_ops.cc | |
parent | 924b7d9d0ddcc94c2980b3febec3e646c5c788b0 (diff) |
Translate shape inference functions for sparse_ops to C++.
Change shape_inference::InferenceContext::MakeShapeFromShapetensor to handle
the case where the shape tensor is not known but the rank of the shape tensor
is.
Add shape_inference::NumElements().
Change: 128124939
Diffstat (limited to 'tensorflow/core/ops/sparse_ops.cc')
-rw-r--r-- | tensorflow/core/ops/sparse_ops.cc | 195 |
1 files changed, 195 insertions, 0 deletions
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc index a39f2f70cb..1ad9f7175f 100644 --- a/tensorflow/core/ops/sparse_ops.cc +++ b/tensorflow/core/ops/sparse_ops.cc @@ -13,10 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +using shape_inference::Dimension; +using shape_inference::InferenceContext; +using shape_inference::Shape; + +namespace { + +Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) { + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // a_shape + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused)); // b_indices + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused)); // b_values + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused)); // b_shape + c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); + c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); +} + +} // namespace + REGISTER_OP("SparseAddGrad") .Input("backprop_val_grad: T") .Input("a_indices: int64") @@ -25,6 +49,15 @@ REGISTER_OP("SparseAddGrad") .Output("a_val_grad: T") .Output("b_val_grad: T") .Attr("T: numbertype") + .SetShapeFn([](InferenceContext* c) { + const Shape* a_indices; + const Shape* b_indices; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices)); + c->set_output(0, c->Vector(c->Dim(a_indices, 0))); + c->set_output(1, c->Vector(c->Dim(b_indices, 0))); + return Status::OK(); + }) .Doc(R"doc( The gradient operator for the SparseAdd op. @@ -58,6 +91,15 @@ REGISTER_OP("SparseAdd") .Output("sum_shape: int64") .Attr("T: numbertype") .Attr("Treal: realnumbertype") + .SetShapeFn([](InferenceContext* c) { + const Shape* a_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape)); + c->set_output( + 0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0))); + c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(2, a_shape); + return Status::OK(); + }) .Doc(R"doc( Adds two `SparseTensor` objects to produce another `SparseTensor`. @@ -94,6 +136,26 @@ REGISTER_OP("SparseTensorDenseMatMul") .Attr("T: type") .Attr("adjoint_a: bool = false") .Attr("adjoint_b: bool = false") + .SetShapeFn([](InferenceContext* c) { + const Dimension* unused_dim; + const Shape* unused; + const Shape* b; + const Shape* a_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(a_shape, 0), 2, &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b)); + + bool adjoint_b; + TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b)); + + // TODO(zongheng): 1) incorporate adjoint_a. 2) When both attrs are + // considered, check the inner dimensions match. + const Dimension* output_right = c->Dim(b, adjoint_b ? 0 : 1); + c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, output_right)); + return Status::OK(); + }) .Doc(R"doc( Multiply SparseTensor (of rank 2) "A" by dense matrix "B". @@ -123,6 +185,14 @@ REGISTER_OP("SerializeSparse") .Input("sparse_shape: int64") .Attr("T: type") .Output("serialized_sparse: string") + .SetShapeFn([](InferenceContext* c) { + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + c->set_output(0, c->Vector(3)); + return Status::OK(); + }) .Doc(R"doc( Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object. @@ -137,6 +207,14 @@ REGISTER_OP("SerializeManySparse") .Input("sparse_shape: int64") .Attr("T: type") .Output("serialized_sparse: string") + .SetShapeFn([](InferenceContext* c) { + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3)); + return Status::OK(); + }) .Doc(R"doc( Serialize an `N`-minibatch `SparseTensor` into an `[N, 3]` string `Tensor`. @@ -159,6 +237,20 @@ REGISTER_OP("DeserializeManySparse") .Output("sparse_indices: int64") .Output("sparse_values: dtype") .Output("sparse_shape: int64") + .SetShapeFn([](InferenceContext* c) { + // serialized sparse is [?,3] matrix. + const Shape* serialized_sparse; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse)); + const Dimension* unused; + TF_RETURN_IF_ERROR( + c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused)); + + c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); + c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + }) .Doc(R"doc( Deserialize and concatenate `SparseTensors` from a serialized minibatch. @@ -218,6 +310,12 @@ REGISTER_OP("SparseToDense") .Attr("T: type") .Output("dense: T") .Attr("Tindices: {int32, int64}") + .SetShapeFn([](InferenceContext* c) { + const Shape* out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); + c->set_output(0, out); + return Status::OK(); + }) .Doc(R"doc( Converts a sparse representation into a dense tensor. @@ -263,6 +361,40 @@ REGISTER_OP("SparseConcat") .Attr("concat_dim: int >= 0") .Attr("N: int >= 2") .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + // These accumulates the sum. + const Dimension* output_row_count = c->MakeDim(0); + + // These are only merged. + const Dimension* output_ind_cols = c->UnknownDim(); + const Shape* output_shape = c->UnknownShape(); + + const int n = c->num_inputs() / 3; + for (int i = 0; i < n; i++) { + const Shape* ind; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind)); + const Shape* val; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val)); + const Shape* shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape)); + + // Add to output_ind_rows. + const Dimension* num_dim; + TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim)); + TF_RETURN_IF_ERROR( + c->Add(output_row_count, num_dim, &output_row_count)); + + // Merge into output_ind_cols and output_shape. + TF_RETURN_IF_ERROR( + c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols)); + TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape)); + } + + c->set_output(0, c->Matrix(output_row_count, output_ind_cols)); + c->set_output(1, c->Vector(output_row_count)); + c->set_output(2, output_shape); + return Status::OK(); + }) .Doc(R"doc( Concatenates a list of `SparseTensor` along the specified dimension. @@ -327,6 +459,24 @@ REGISTER_OP("SparseSplit") .Output("output_shape: num_split * int64") .Attr("num_split: int >= 1") .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + const Shape* input_shape = c->input(3); + const Shape* output_indices = + c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape)); + const Shape* output_values = c->Vector(InferenceContext::kUnknownDim); + const Shape* output_shape = input_shape; + + // Copy the outputs into the output ranges. + int num_splits = c->num_outputs() / 3; + int out_idx = 0; + for (int i = 0; i < num_splits; ++i) + c->set_output(out_idx++, output_indices); + for (int i = 0; i < num_splits; ++i) + c->set_output(out_idx++, output_values); + for (int i = 0; i < num_splits; ++i) + c->set_output(out_idx++, output_shape); + return Status::OK(); + }) .Doc(R"doc( Split a `SparseTensor` into `num_split` tensors along one dimension. @@ -369,6 +519,19 @@ REGISTER_OP("SparseReorder") .Output("output_indices: int64") .Output("output_values: T") .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + const Shape* indices; + const Shape* values; + const Shape* unused; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + + c->set_output(0, indices); + c->set_output(1, values); + return Status::OK(); + }) .Doc(R"doc( Reorders a SparseTensor into the canonical, row-major ordering. @@ -396,6 +559,19 @@ REGISTER_OP("SparseReshape") .Input("new_shape: int64") .Output("output_indices: int64") .Output("output_shape: int64") + .SetShapeFn([](InferenceContext* c) { + const Shape* indices; + const Shape* unused; + const Shape* new_shape; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape)); + + c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0))); + c->set_output(1, new_shape); + return Status::OK(); + }) .Doc(R"doc( Reshapes a SparseTensor to represent values in a new dense shape. @@ -434,6 +610,10 @@ REGISTER_OP("SparseTensorDenseAdd") .Output("output: T") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(3)); + return Status::OK(); + }) .Doc(R"doc( Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. @@ -453,6 +633,10 @@ REGISTER_OP("SparseReduceSum") .Attr("keep_dims: bool = False") .Output("output: T") .Attr("T: numbertype") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }) .Doc(R"doc( Computes the sum of elements across dimensions of a SparseTensor. @@ -544,6 +728,15 @@ REGISTER_OP("SparseSoftmax") .Input("sp_shape: int64") .Output("output: T") .Attr("T: {float, double}") + .SetShapeFn([](InferenceContext* c) { + const Shape* unused; + const Shape* values; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // sp_indices + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values)); // sp_values + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + c->set_output(0, values); + return Status::OK(); + }) .Doc(R"doc( Applies softmax to a batched N-D `SparseTensor`. @@ -580,6 +773,7 @@ REGISTER_OP("SparseSparseMaximum") .Output("output_indices: int64") .Output("output_values: T") .Attr("T: realnumbertype") + .SetShapeFn(SparseSparseMinOrMaxShapeFn) .Doc(R"doc( Returns the element-wise max of two SparseTensors. @@ -607,6 +801,7 @@ REGISTER_OP("SparseSparseMinimum") .Output("output_indices: int64") .Output("output_values: T") .Attr("T: numbertype") + .SetShapeFn(SparseSparseMinOrMaxShapeFn) .Doc(R"doc( Returns the element-wise min of two SparseTensors. |