aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/activations_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc43
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index b9a96e3f79..50a84edd47 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -383,6 +383,49 @@ TEST(FloatActivationsOpTest, LogSoftmax) {
})));
}
+class PReluOpModel : public SingleOpModel {
+ public:
+ PReluOpModel(const TensorData& input, const TensorData& alpha) {
+ input_ = AddInput(input);
+ alpha_ = AddInput(alpha);
+ output_ = AddOutput(input);
+ SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0);
+ BuildInterpreter({GetShape(input_), GetShape(alpha_)});
+ }
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ void SetAlpha(std::initializer_list<float> data) {
+ PopulateTensor(alpha_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int alpha_;
+ int output_;
+};
+
+TEST(FloatActivationsOpTest, PRelu) {
+ PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
+ {TensorType_FLOAT32, {1, 1, 3}});
+
+ m.SetInput({
+ 0.0f, 0.0f, 0.0f, // Row 1, Column 1
+ 1.0f, 1.0f, 1.0f, // Row 1, Column 2
+ -1.0f, -1.0f, -1.0f, // Row 2, Column 1
+ -2.0f, -2.0f, -2.0f, // Row 1, Column 2
+ });
+ m.SetAlpha({0.0f, 1.0f, 2.0f});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0.0f, 0.0f, 0.0f, // Row 1, Column 1
+ 1.0f, 1.0f, 1.0f, // Row 1, Column 2
+ 0.0f, -1.0f, -2.0f, // Row 2, Column 1
+ 0.0f, -2.0f, -4.0f, // Row 1, Column 2
+ }));
+}
+
} // namespace
} // namespace tflite