aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates/eager/test_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/delegates/eager/test_util.cc')
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.cc43
1 files changed, 23 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc
index b8c9e2652a..8584999ace 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc
@@ -25,19 +25,6 @@ namespace testing {
bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
-void EagerModelTest::SetValues(int tensor_index,
- const std::vector<float>& values) {
- float* v = interpreter_->typed_tensor<float>(tensor_index);
- for (float f : values) {
- *v++ = f;
- }
-}
-
-std::vector<float> EagerModelTest::GetValues(int tensor_index) {
- TfLiteTensor* o = interpreter_->tensor(tensor_index);
- return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float));
-}
-
void EagerModelTest::SetShape(int tensor_index,
const std::vector<int>& values) {
ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
@@ -54,13 +41,21 @@ std::vector<int> EagerModelTest::GetShape(int tensor_index) {
return result;
}
+TfLiteType EagerModelTest::GetType(int tensor_index) {
+ return interpreter_->tensor(tensor_index)->type;
+}
+
void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
const std::vector<int>& outputs,
- const TfLiteType& type,
- const std::vector<int>& dims) {
+ TfLiteType type, const std::vector<int>& dims) {
interpreter_->AddTensors(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteQuantizationParams quant;
+ // Suppress explicit output type specification to ensure type inference
+ // works properly.
+ if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
+ type = kTfLiteFloat32;
+ }
CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
/*name=*/"",
/*dims=*/dims, quant),
@@ -101,18 +96,26 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
return " attr{ key: '" + key + "' value {" + value + "}}";
};
+ // Crude type attribution, will need fleshing out as more tests are added.
+ // TODO(b/113613439): Use nodedef string utilities to properly handle
+ // all types.
+ string type_attribute = attr("T", "type: DT_FLOAT");
+ if (interpreter_->tensor(inputs[0])->type == kTfLiteInt32) {
+ type_attribute = attr("T", "type: DT_INT32");
+ }
+
if (op == kUnpack) {
- string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") +
- attr("axis", "i: 0");
+ string attributes =
+ type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs);
} else if (op == kIdentity) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs);
} else if (op == kAdd) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerAdd", "Add", attributes, inputs, outputs);
} else if (op == kMul) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerMul", "Mul", attributes, inputs, outputs);
} else if (op == kNonExistent) {
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);