aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/model_ops.cc3
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/stats_ops.cc3
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc4
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.h6
-rw-r--r--tensorflow/contrib/tensor_forest/ops/model_ops.cc4
5 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
index 5b265a869d..0f92f05e2c 100644
--- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
@@ -165,9 +165,10 @@ class TreePredictionsV4Op : public OpKernel {
const Tensor& input_data = context->input(1);
const Tensor& sparse_input_indices = context->input(2);
const Tensor& sparse_input_values = context->input(3);
+ const Tensor& sparse_input_shape = context->input(4);
data_set_->set_input_tensors(input_data, sparse_input_indices,
- sparse_input_values);
+ sparse_input_values, sparse_input_shape);
DecisionTreeResource* decision_tree_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
index 1fd85fd81f..260e03df26 100644
--- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
@@ -258,11 +258,12 @@ class ProcessInputOp : public OpKernel {
const Tensor& input_data = context->input(2);
const Tensor& sparse_input_indices = context->input(3);
const Tensor& sparse_input_values = context->input(4);
+ const Tensor& sparse_input_shape = context->input(5);
const Tensor& input_labels = context->input(6);
const Tensor& input_weights = context->input(7);
data_set_->set_input_tensors(input_data, sparse_input_indices,
- sparse_input_values);
+ sparse_input_values, sparse_input_shape);
FertileStatsResource* fertile_stats_resource;
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
index f5f07bea5c..14cb19d36f 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
@@ -105,7 +105,8 @@ float TensorDataSet::GetExampleValue(int example, int32 feature_id) const {
void TensorDataSet::set_input_tensors(const Tensor& dense,
const Tensor& sparse_indices,
- const Tensor& sparse_values) {
+ const Tensor& sparse_values,
+ const Tensor& sparse_shape) {
if (dense.shape().dims() == 2) {
dense_data_.reset(new DenseStorageType(dense.tensor<float, 2>()));
}
@@ -114,6 +115,7 @@ void TensorDataSet::set_input_tensors(const Tensor& dense,
sparse_indices.tensor<int64, 2>()));
sparse_values_.reset(new SparseValuesStorageType(
sparse_values.tensor<float, 1>()));
+ sparse_batch_size_ = sparse_shape.tensor<int64, 1>()(0);
}
original_dense_tensor_ = dense;
}
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
index 261a1f2d5e..e3d4edbf8a 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
@@ -67,7 +67,8 @@ class TensorDataSet {
virtual ~TensorDataSet() {}
void set_input_tensors(const Tensor& dense, const Tensor& sparse_indices,
- const Tensor& sparse_values);
+ const Tensor& sparse_values,
+ const Tensor& sparse_shape);
float get_input_value(int offset, int col) {
return (*dense_data_)(offset, col);
@@ -77,7 +78,7 @@ class TensorDataSet {
if (dense_data_ != nullptr) {
return dense_data_->dimensions()[0];
} else if (sparse_indices_ != nullptr) {
- return sparse_indices_->dimensions()[0];
+ return sparse_batch_size_;
} else {
return 0;
}
@@ -109,6 +110,7 @@ class TensorDataSet {
std::unique_ptr<DenseStorageType> dense_data_;
std::unique_ptr<SparseIndicesStorageType> sparse_indices_;
std::unique_ptr<SparseValuesStorageType> sparse_values_;
+ int sparse_batch_size_;
Tensor original_dense_tensor_;
const tensorforest::TensorForestDataSpec input_spec_;
diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
index c9acc4a6ae..168f079f52 100644
--- a/tensorflow/contrib/tensor_forest/ops/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
@@ -20,6 +20,7 @@
namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
namespace tensorforest {
@@ -93,7 +94,8 @@ REGISTER_OP("TreePredictionsV4")
.SetShapeFn([](InferenceContext* c) {
DimensionHandle num_points = c->UnknownDim();
- if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0) {
+ if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0 &&
+ c->Value(c->Dim(c->input(1), 0)) > 0) {
num_points = c->Dim(c->input(1), 0);
}