aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-31 16:05:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 16:14:33 -0700
commit0fffc1bde7f85c3c5b985bf30500c53ace6c81eb (patch)
tree23631c6082ea8b32c7e302bd2a79b96da510288a
parent2681adc66ceb0530ea2f0f50c11938fba32caa3e (diff)
Fixing Any operator.
PiperOrigin-RevId: 211159438
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc46
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/model.h1
4 files changed, 37 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 6fdf47dedc..b52a79282c 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1701,9 +1701,11 @@ void ConvertReduceOperator(const Model& model, const T& src_op,
*new_op->add_input() = src_op.inputs[0];
*new_op->add_input() = src_op.inputs[1];
- const tensorflow::DataType params_type =
- GetTensorFlowDataType(model, src_op.inputs[0]);
- (*new_op->mutable_attr())["T"].set_type(params_type);
+ if (src_op.type != OperatorType::kAny) {
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ }
const tensorflow::DataType indices_type =
GetTensorFlowDataType(model, src_op.inputs[1]);
(*new_op->mutable_attr())["Tidx"].set_type(indices_type);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 28effc2a67..c25be078ff 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -561,26 +561,38 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
const bool keep_dims = KeepDims(*op);
if (op->inputs.size() == 2) {
// There is a reduction_indices input.
- const auto& reduction_array = model->GetArray(op->inputs[1]);
- if (!reduction_array.buffer) {
+ const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
+ if (!reduction_indices_array.buffer) {
return;
}
- CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
- const auto& reduction_array_vals =
- reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
- auto& output_dims = *output_array.mutable_shape()->mutable_dims();
- output_dims.clear();
- for (int i = 0; i < input_shape.dimensions_count(); i++) {
- bool is_reduction_dim = false;
- for (int r : reduction_array_vals) {
- if (i == r) {
- is_reduction_dim = true;
- }
+ CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
+
+ int input_rank = input_shape.dimensions_count();
+ std::set<int32> true_indices;
+ const auto& reduction_indices =
+ reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < reduction_indices.size(); ++i) {
+ const int32 reduction_index = reduction_indices[i];
+ if (reduction_index < -input_rank || reduction_index >= input_rank) {
+ CHECK(false) << "Invalid reduction dimension " << reduction_index
+ << " for input with " << input_rank << " dimensions";
+ }
+ int32 wrapped_index = reduction_index;
+ if (wrapped_index < 0) {
+ wrapped_index += input_rank;
}
- if (!is_reduction_dim) {
- output_dims.push_back(input_shape.dims(i));
- } else if (keep_dims) {
- output_dims.push_back(1);
+ true_indices.insert(wrapped_index);
+ }
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->clear();
+ for (int i = 0; i < input_rank; ++i) {
+ if (true_indices.count(i) > 0) {
+ if (keep_dims) {
+ mutable_dims->emplace_back(1);
+ }
+ } else {
+ mutable_dims->emplace_back(input_shape.dims(i));
}
}
} else {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
index 7d456af2fb..73198ac7c0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
@@ -52,6 +52,8 @@ bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) {
return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
case OperatorType::kReduceMax:
return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
+ case OperatorType::kAny:
+ return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
default:
return false;
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index fa1c459f0e..2e100e37f6 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -1768,6 +1768,7 @@ struct PowOperator : Operator {
//
// Inputs:
// Inputs[0]: required: A boolean input tensor.
+// Inputs[1]: required: reduction_indices.
//
// TensorFlow equivalent: tf.reduce_any.
struct TensorFlowAnyOperator : Operator {