diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-11-14 13:22:48 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-14 13:43:16 -0800 |
commit | 0c2018764516add343acc9ee56c487715bc98c2a (patch) | |
tree | 4a61a5235b6c339bedf9d97fc048570fd72cec22 /tensorflow/core/framework/common_shape_fns.cc | |
parent | 750c98508c4ea51bf45694d32773b075eb2c8c8d (diff) |
Move ValidateSparseTensor to common_shape_fns.h.
Change: 139112567
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns.cc')
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 2434127acc..794f9c37cf 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -860,5 +860,46 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) { return Status::OK(); } +Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, + ShapeHandle values_shape, ShapeHandle shape_shape) { + // Validate ranks. + ShapeHandle unused_shape; + TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape)); + + // Number of elements in indices and values must match. + DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0); + if (c->ValueKnown(num_index_elements_dim)) { + DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0); + if (c->ValueKnown(num_values_elements_dim)) { + int64 num_index_elements = c->Value(num_index_elements_dim); + int64 num_values_elements = c->Value(num_values_elements_dim); + if (num_index_elements != num_values_elements) { + return errors::InvalidArgument("Number of elements in index (", + num_index_elements, ") and values (", + num_values_elements, ") do not match."); + } + } + } + + // Rank embedded in indices must match shape. + DimensionHandle index_rank_dim = c->Dim(indices_shape, 1); + if (c->ValueKnown(index_rank_dim)) { + DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0); + if (c->ValueKnown(shape_rank_dim)) { + int64 index_rank = c->Value(index_rank_dim); + int32 shape_rank = c->Value(shape_rank_dim); + if (index_rank != shape_rank) { + return errors::InvalidArgument("Index rank (", index_rank, + ") and shape rank (", shape_rank, + ") do not match."); + } + } + } + + return Status::OK(); +} + } // namespace shape_inference } // namespace tensorflow |