diff options
author | Scott Tseng <scott.tseng@kikatech.com> | 2018-04-26 17:30:08 +0800 |
---|---|---|
committer | Scott Tseng <scott.tseng@kikatech.com> | 2018-04-26 17:32:26 +0800 |
commit | afa4748488c55ba5d9dbc9b6c1586132d28d1759 (patch) | |
tree | 1b809a86fcc3c0f5c96c177e5417b1b6e104ae0c | |
parent | 968addadfd4e4f5688eedc31f92a9066329ff6a7 (diff) |
Fix some issues in official tf.nn.topk() in lite
6 files changed, 19 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc index 807e84609f..ad9b744f1a 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -25,8 +25,8 @@ namespace builtin { namespace topk_v2 { constexpr int kInputTensor = 0; constexpr int kInputTopK = 1; -constexpr int kOutputIndexes = 0; -constexpr int kOutputValues = 1; +constexpr int kOutputValues = 0; +constexpr int kOutputIndexes = 1; namespace { TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc index 29f2a057cd..212f8acc76 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc @@ -31,8 +31,8 @@ class TopKV2OpModel : public SingleOpModel { int top_k) { input_ = AddInput(input_type); top_k_ = AddInput(TensorType_INT32); - output_indexes_ = AddOutput(TensorType_INT32); output_values_ = AddOutput(input_type); + output_indexes_ = AddOutput(TensorType_INT32); SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0); BuildInterpreter({input_shape, {1}}); PopulateTensor<int32_t>(top_k_, {top_k}); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 89ad58f887..c1cf79f626 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -124,6 +124,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, rand_op->dtype); break; } + case OperatorType::kTopK_V2: { + // topk(values: T, k: int32) -> values: T, indices: int32 + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 2); + CHECK(model->GetArray(op->inputs[1]).data_type == ArrayDataType::kInt32); + model->GetArray(op->outputs[0]).data_type = model->GetArray(op->inputs[0]).data_type; + model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32; + break; + } case OperatorType::kTensorFlowUnsupported: { auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op); // Some output tensors from the op could be eliminated by optimization. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index ba244cf5ef..a53fc54533 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1110,8 +1110,8 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) { const auto& input_values = model->GetArray(op->inputs[0]); const auto& input_k = model->GetArray(op->inputs[1]); - auto& output_indexes = model->GetArray(op->outputs[0]); - auto& output_values = model->GetArray(op->outputs[1]); + auto& output_values = model->GetArray(op->outputs[0]); + auto& output_indexes = model->GetArray(op->outputs[1]); // Bail if we already know the output shape. if (output_indexes.has_shape()) { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 155d890c9f..6495ee7257 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1970,7 +1970,7 @@ void ConvertTopKV2Operator(const NodeDef& node, op->inputs.push_back(node.input(1)); } // The op has two outputs. - op->outputs.push_back(node.name() + ":0"); + op->outputs.push_back(node.name()); op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op.release()); } diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index cf2cbeedc7..5e1e470f25 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -825,11 +825,6 @@ void FixNoOrphanedArray(Model* model) { void CheckEachArray(const Model& model) { for (const auto& array_entry : model.GetArrayMap()) { const auto& array = array_entry.second; - if (array->has_shape()) { - for (int d : array->shape().dims()) { - CHECK_GE(d, 1); - } - } // It's OK to have a buffer or an alloc, but not both. // (Since allocs are for transient arrays without a buffer). CHECK(!array->buffer || !array->alloc); @@ -839,6 +834,10 @@ void CheckEachArray(const Model& model) { // The presence of a fixed buffer should imply the presence of a fixed // shape. CHECK(array->has_shape()); + // Constant buffer should has a valid shape. + for (int d : array->shape().dims()) { + CHECK_GE(d, 1); + } // The shape flat-size should agree with the buffer length. CHECK_EQ(array->buffer->Length(), RequiredBufferSizeForShape(array->shape())); |