From b6becb44dd1a4494aa7c3acbb6961d23900f614b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jul 2017 08:26:10 -0700 Subject: Fix tensorforest for using sparse-only data. PiperOrigin-RevId: 161396592 --- tensorflow/contrib/tensor_forest/ops/model_ops.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tensorflow/contrib/tensor_forest/ops') diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc index c9acc4a6ae..168f079f52 100644 --- a/tensorflow/contrib/tensor_forest/ops/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc @@ -20,6 +20,7 @@ namespace tensorflow { using shape_inference::DimensionHandle; using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; namespace tensorforest { @@ -93,7 +94,8 @@ REGISTER_OP("TreePredictionsV4") .SetShapeFn([](InferenceContext* c) { DimensionHandle num_points = c->UnknownDim(); - if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0) { + if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0 && + c->Value(c->Dim(c->input(1), 0)) > 0) { num_points = c->Dim(c->input(1), 0); } -- cgit v1.2.3