aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/common_shape_fns.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-14 13:22:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-14 13:43:16 -0800
commit0c2018764516add343acc9ee56c487715bc98c2a (patch)
tree4a61a5235b6c339bedf9d97fc048570fd72cec22 /tensorflow/core/framework/common_shape_fns.cc
parent750c98508c4ea51bf45694d32773b075eb2c8c8d (diff)
Move ValidateSparseTensor to common_shape_fns.h.
Change: 139112567
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns.cc')
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc41
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 2434127acc..794f9c37cf 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -860,5 +860,46 @@ Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
+ ShapeHandle values_shape, ShapeHandle shape_shape) {
+ // Validate ranks.
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape));
+
+ // Number of elements in indices and values must match.
+ DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0);
+ if (c->ValueKnown(num_index_elements_dim)) {
+ DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0);
+ if (c->ValueKnown(num_values_elements_dim)) {
+ int64 num_index_elements = c->Value(num_index_elements_dim);
+ int64 num_values_elements = c->Value(num_values_elements_dim);
+ if (num_index_elements != num_values_elements) {
+ return errors::InvalidArgument("Number of elements in index (",
+ num_index_elements, ") and values (",
+ num_values_elements, ") do not match.");
+ }
+ }
+ }
+
+ // Rank embedded in indices must match shape.
+ DimensionHandle index_rank_dim = c->Dim(indices_shape, 1);
+ if (c->ValueKnown(index_rank_dim)) {
+ DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0);
+ if (c->ValueKnown(shape_rank_dim)) {
+ int64 index_rank = c->Value(index_rank_dim);
+ int32 shape_rank = c->Value(shape_rank_dim);
+ if (index_rank != shape_rank) {
+ return errors::InvalidArgument("Index rank (", index_rank,
+ ") and shape rank (", shape_rank,
+ ") do not match.");
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace shape_inference
} // namespace tensorflow