diff options
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/test_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/delegates/eager/test_util.cc | 43 |
1 files changed, 23 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc index b8c9e2652a..8584999ace 100644 --- a/tensorflow/contrib/lite/delegates/eager/test_util.cc +++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc @@ -25,19 +25,6 @@ namespace testing { bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; } -void EagerModelTest::SetValues(int tensor_index, - const std::vector<float>& values) { - float* v = interpreter_->typed_tensor<float>(tensor_index); - for (float f : values) { - *v++ = f; - } -} - -std::vector<float> EagerModelTest::GetValues(int tensor_index) { - TfLiteTensor* o = interpreter_->tensor(tensor_index); - return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float)); -} - void EagerModelTest::SetShape(int tensor_index, const std::vector<int>& values) { ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk); @@ -54,13 +41,21 @@ std::vector<int> EagerModelTest::GetShape(int tensor_index) { return result; } +TfLiteType EagerModelTest::GetType(int tensor_index) { + return interpreter_->tensor(tensor_index)->type; +} + void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs, const std::vector<int>& outputs, - const TfLiteType& type, - const std::vector<int>& dims) { + TfLiteType type, const std::vector<int>& dims) { interpreter_->AddTensors(num_tensors); for (int i = 0; i < num_tensors; ++i) { TfLiteQuantizationParams quant; + // Suppress explicit output type specification to ensure type inference + // works properly. + if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) { + type = kTfLiteFloat32; + } CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type, /*name=*/"", /*dims=*/dims, quant), @@ -101,18 +96,26 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs, return " attr{ key: '" + key + "' value {" + value + "}}"; }; + // Crude type attribution, will need fleshing out as more tests are added. + // TODO(b/113613439): Use nodedef string utilities to properly handle + // all types. + string type_attribute = attr("T", "type: DT_FLOAT"); + if (interpreter_->tensor(inputs[0])->type == kTfLiteInt32) { + type_attribute = attr("T", "type: DT_INT32"); + } + if (op == kUnpack) { - string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") + - attr("axis", "i: 0"); + string attributes = + type_attribute + attr("num", "i: 2") + attr("axis", "i: 0"); AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs); } else if (op == kIdentity) { - string attributes = attr("T", "type: DT_FLOAT"); + string attributes = type_attribute; AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs); } else if (op == kAdd) { - string attributes = attr("T", "type: DT_FLOAT"); + string attributes = type_attribute; AddTfOp("EagerAdd", "Add", attributes, inputs, outputs); } else if (op == kMul) { - string attributes = attr("T", "type: DT_FLOAT"); + string attributes = type_attribute; AddTfOp("EagerMul", "Mul", attributes, inputs, outputs); } else if (op == kNonExistent) { AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); |