diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-07-10 08:26:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-10 08:31:01 -0700 |
commit | b6becb44dd1a4494aa7c3acbb6961d23900f614b (patch) | |
tree | 97b5ee3ca6be56d3cfe402644f7071544006079a /tensorflow/contrib/tensor_forest/ops | |
parent | 5b9bff4a04a331aa551ba5fc5bd713d9cb9dc684 (diff) |
Fix tensorforest for using sparse-only data.
PiperOrigin-RevId: 161396592
Diffstat (limited to 'tensorflow/contrib/tensor_forest/ops')
-rw-r--r-- | tensorflow/contrib/tensor_forest/ops/model_ops.cc | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc index c9acc4a6ae..168f079f52 100644 --- a/tensorflow/contrib/tensor_forest/ops/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc @@ -20,6 +20,7 @@ namespace tensorflow { using shape_inference::DimensionHandle; using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; namespace tensorforest { @@ -93,7 +94,8 @@ REGISTER_OP("TreePredictionsV4") .SetShapeFn([](InferenceContext* c) { DimensionHandle num_points = c->UnknownDim(); - if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0) { + if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0 && + c->Value(c->Dim(c->input(1), 0)) > 0) { num_points = c->Dim(c->input(1), 0); } |