aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/elementwise_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-07 19:31:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 19:34:14 -0700
commit99e6a86480bfb518dea59b4b25f7c9549b227587 (patch)
treefa4176f76f672b86874135dcd6e8e560067aa65f /tensorflow/contrib/lite/kernels/elementwise_test.cc
parenta9ddfe50eee83b2f18293241ab96f0a1e2b4b05b (diff)
Implement Log operator.
PiperOrigin-RevId: 199735191
Diffstat (limited to 'tensorflow/contrib/lite/kernels/elementwise_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise_test.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc
index 412ffb04b9..10e88d5a31 100644
--- a/tensorflow/contrib/lite/kernels/elementwise_test.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc
@@ -24,12 +24,13 @@ namespace {
using ::testing::ElementsAreArray;
-class SinOpModel : public SingleOpModel {
+class ElementWiseOpModel : public SingleOpModel {
public:
- SinOpModel(std::initializer_list<int> input_shape) {
+ ElementWiseOpModel(BuiltinOperator op,
+ std::initializer_list<int> input_shape) {
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(BuiltinOperator_SIN, BuiltinOptions_NONE, 0);
+ SetBuiltinOp(op, BuiltinOptions_NONE, 0);
BuildInterpreter({input_shape});
}
@@ -42,7 +43,7 @@ class SinOpModel : public SingleOpModel {
};
TEST(ElementWise, Sin) {
- SinOpModel m({1, 1, 4, 1});
+ ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
@@ -50,6 +51,15 @@ TEST(ElementWise, Sin) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
+TEST(ElementWise, Log) {
+ ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
} // namespace
} // namespace tflite