diff options
5 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index 5b265a869d..0f92f05e2c 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -165,9 +165,10 @@ class TreePredictionsV4Op : public OpKernel { const Tensor& input_data = context->input(1); const Tensor& sparse_input_indices = context->input(2); const Tensor& sparse_input_values = context->input(3); + const Tensor& sparse_input_shape = context->input(4); data_set_->set_input_tensors(input_data, sparse_input_indices, - sparse_input_values); + sparse_input_values, sparse_input_shape); DecisionTreeResource* decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index 1fd85fd81f..260e03df26 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -258,11 +258,12 @@ class ProcessInputOp : public OpKernel { const Tensor& input_data = context->input(2); const Tensor& sparse_input_indices = context->input(3); const Tensor& sparse_input_values = context->input(4); + const Tensor& sparse_input_shape = context->input(5); const Tensor& input_labels = context->input(6); const Tensor& input_weights = context->input(7); data_set_->set_input_tensors(input_data, sparse_input_indices, - sparse_input_values); + sparse_input_values, sparse_input_shape); FertileStatsResource* fertile_stats_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc index f5f07bea5c..14cb19d36f 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc @@ -105,7 +105,8 @@ float TensorDataSet::GetExampleValue(int example, int32 feature_id) const { void TensorDataSet::set_input_tensors(const Tensor& dense, const Tensor& sparse_indices, - const Tensor& sparse_values) { + const Tensor& sparse_values, + const Tensor& sparse_shape) { if (dense.shape().dims() == 2) { dense_data_.reset(new DenseStorageType(dense.tensor<float, 2>())); } @@ -114,6 +115,7 @@ void TensorDataSet::set_input_tensors(const Tensor& dense, sparse_indices.tensor<int64, 2>())); sparse_values_.reset(new SparseValuesStorageType( sparse_values.tensor<float, 1>())); + sparse_batch_size_ = sparse_shape.tensor<int64, 1>()(0); } original_dense_tensor_ = dense; } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h index 261a1f2d5e..e3d4edbf8a 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h @@ -67,7 +67,8 @@ class TensorDataSet { virtual ~TensorDataSet() {} void set_input_tensors(const Tensor& dense, const Tensor& sparse_indices, - const Tensor& sparse_values); + const Tensor& sparse_values, + const Tensor& sparse_shape); float get_input_value(int offset, int col) { return (*dense_data_)(offset, col); @@ -77,7 +78,7 @@ class TensorDataSet { if (dense_data_ != nullptr) { return dense_data_->dimensions()[0]; } else if (sparse_indices_ != nullptr) { - return sparse_indices_->dimensions()[0]; + return sparse_batch_size_; } else { return 0; } @@ -109,6 +110,7 @@ class TensorDataSet { std::unique_ptr<DenseStorageType> dense_data_; std::unique_ptr<SparseIndicesStorageType> sparse_indices_; std::unique_ptr<SparseValuesStorageType> sparse_values_; + int sparse_batch_size_; Tensor original_dense_tensor_; const tensorforest::TensorForestDataSpec input_spec_; diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc index c9acc4a6ae..168f079f52 100644 --- a/tensorflow/contrib/tensor_forest/ops/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc @@ -20,6 +20,7 @@ namespace tensorflow { using shape_inference::DimensionHandle; using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; namespace tensorforest { @@ -93,7 +94,8 @@ REGISTER_OP("TreePredictionsV4") .SetShapeFn([](InferenceContext* c) { DimensionHandle num_points = c->UnknownDim(); - if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0) { + if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0 && + c->Value(c->Dim(c->input(1), 0)) > 0) { num_points = c->Dim(c->input(1), 0); } |