aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Scott Tseng <scott.tseng@kikatech.com>2018-04-26 17:30:08 +0800
committerGravatar Scott Tseng <scott.tseng@kikatech.com>2018-04-26 17:32:26 +0800
commitafa4748488c55ba5d9dbc9b6c1586132d28d1759 (patch)
tree1b809a86fcc3c0f5c96c177e5417b1b6e104ae0c
parent968addadfd4e4f5688eedc31f92a9066329ff6a7 (diff)
Fix some issues in official tf.nn.topk() in lite
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2_test.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc4
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc9
6 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index 807e84609f..ad9b744f1a 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -25,8 +25,8 @@ namespace builtin {
namespace topk_v2 {
constexpr int kInputTensor = 0;
constexpr int kInputTopK = 1;
-constexpr int kOutputIndexes = 0;
-constexpr int kOutputValues = 1;
+constexpr int kOutputValues = 0;
+constexpr int kOutputIndexes = 1;
namespace {
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
index 29f2a057cd..212f8acc76 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
@@ -31,8 +31,8 @@ class TopKV2OpModel : public SingleOpModel {
int top_k) {
input_ = AddInput(input_type);
top_k_ = AddInput(TensorType_INT32);
- output_indexes_ = AddOutput(TensorType_INT32);
output_values_ = AddOutput(input_type);
+ output_indexes_ = AddOutput(TensorType_INT32);
SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0);
BuildInterpreter({input_shape, {1}});
PopulateTensor<int32_t>(top_k_, {top_k});
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 89ad58f887..c1cf79f626 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -124,6 +124,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, rand_op->dtype);
break;
}
+ case OperatorType::kTopK_V2: {
+ // topk(values: T, k: int32) -> values: T, indices: int32
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 2);
+ CHECK(model->GetArray(op->inputs[1]).data_type == ArrayDataType::kInt32);
+ model->GetArray(op->outputs[0]).data_type = model->GetArray(op->inputs[0]).data_type;
+ model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32;
+ break;
+ }
case OperatorType::kTensorFlowUnsupported: {
auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
// Some output tensors from the op could be eliminated by optimization.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index ba244cf5ef..a53fc54533 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1110,8 +1110,8 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) {
void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
const auto& input_values = model->GetArray(op->inputs[0]);
const auto& input_k = model->GetArray(op->inputs[1]);
- auto& output_indexes = model->GetArray(op->outputs[0]);
- auto& output_values = model->GetArray(op->outputs[1]);
+ auto& output_values = model->GetArray(op->outputs[0]);
+ auto& output_indexes = model->GetArray(op->outputs[1]);
// Bail if we already know the output shape.
if (output_indexes.has_shape()) {
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 155d890c9f..6495ee7257 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1970,7 +1970,7 @@ void ConvertTopKV2Operator(const NodeDef& node,
op->inputs.push_back(node.input(1));
}
// The op has two outputs.
- op->outputs.push_back(node.name() + ":0");
+ op->outputs.push_back(node.name());
op->outputs.push_back(node.name() + ":1");
model->operators.emplace_back(op.release());
}
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index cf2cbeedc7..5e1e470f25 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -825,11 +825,6 @@ void FixNoOrphanedArray(Model* model) {
void CheckEachArray(const Model& model) {
for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = array_entry.second;
- if (array->has_shape()) {
- for (int d : array->shape().dims()) {
- CHECK_GE(d, 1);
- }
- }
// It's OK to have a buffer or an alloc, but not both.
// (Since allocs are for transient arrays without a buffer).
CHECK(!array->buffer || !array->alloc);
@@ -839,6 +834,10 @@ void CheckEachArray(const Model& model) {
// The presence of a fixed buffer should imply the presence of a fixed
// shape.
CHECK(array->has_shape());
+ // Constant buffer should has a valid shape.
+ for (int d : array->shape().dims()) {
+ CHECK_GE(d, 1);
+ }
// The shape flat-size should agree with the buffer length.
CHECK_EQ(array->buffer->Length(),
RequiredBufferSizeForShape(array->shape()));