diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/activations_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/activations_test.cc | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index b9a96e3f79..50a84edd47 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -383,6 +383,49 @@ TEST(FloatActivationsOpTest, LogSoftmax) { }))); } +class PReluOpModel : public SingleOpModel { + public: + PReluOpModel(const TensorData& input, const TensorData& alpha) { + input_ = AddInput(input); + alpha_ = AddInput(alpha); + output_ = AddOutput(input); + SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_), GetShape(alpha_)}); + } + void SetInput(std::initializer_list<float> data) { + PopulateTensor(input_, data); + } + void SetAlpha(std::initializer_list<float> data) { + PopulateTensor(alpha_, data); + } + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + + protected: + int input_; + int alpha_; + int output_; +}; + +TEST(FloatActivationsOpTest, PRelu) { + PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}}, + {TensorType_FLOAT32, {1, 1, 3}}); + + m.SetInput({ + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + -1.0f, -1.0f, -1.0f, // Row 2, Column 1 + -2.0f, -2.0f, -2.0f, // Row 1, Column 2 + }); + m.SetAlpha({0.0f, 1.0f, 2.0f}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 1.0f, 1.0f, 1.0f, // Row 1, Column 2 + 0.0f, -1.0f, -2.0f, // Row 2, Column 1 + 0.0f, -2.0f, -4.0f, // Row 1, Column 2 + })); +} + } // namespace } // namespace tflite |