aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-03-22 11:25:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 11:29:59 -0700
commit7c4cdb8bae0e8760ebe4793d49ea5aee68768655 (patch)
treed3adb4214eecc995845adf5d4f32331b60b8313a /tensorflow/contrib
parentcfdd61585769188789280e768fc43fdbba799619 (diff)
Supports PReLU in TFLite & Toco.
PiperOrigin-RevId: 190097557
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc64
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc43
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc1
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs1
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h9
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py49
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc4
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc119
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc1
-rw-r--r--tensorflow/contrib/lite/toco/model.h13
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc2
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc1
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
19 files changed, 312 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index e4652a3e70..d7993e60cc 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -78,6 +78,7 @@ typedef enum {
kTfLiteBuiltinDelegate = 51,
kTfLiteBuiltinBidirectionalSequenceLstm = 52,
kTfLiteBuiltinCast = 53,
+ kTfLiteBuiltinPrelu = 54,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 093761c43c..39a54c9396 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -150,6 +150,34 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayCopy(input->dims));
}
+TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ TfLiteTensor* alpha = GetInput(context, node, 1);
+
+ output->type = input->type;
+
+ // Currently only Float32 is supported
+ // TODO(ycling): Support other data types.
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, alpha->type, kTfLiteFloat32);
+
+ // Currently, only support 4D `input` and 3D `alpha` with shape
+ // (1, 1, channels).
+ // TODO(impjdi): Support other cases where `alpha` is broadcastable
+ // to `input`.
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+ TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3);
+ TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1);
+ TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1);
+ TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], input->dims->data[3]);
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
@@ -388,6 +416,35 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
}
}
+TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* alpha = GetInput(context, node, 1);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+
+ if (input->type != kTfLiteFloat32) {
+ context->ReportError(context, "Only float32 supported currently.");
+ return kTfLiteError;
+ }
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 4);
+ const int batches = input->dims->data[0];
+ const int height = input->dims->data[1];
+ const int width = input->dims->data[2];
+ const int channels = input->dims->data[3];
+
+ TF_LITE_ENSURE_EQ(context, alpha->dims->size, 3);
+ TF_LITE_ENSURE_EQ(context, alpha->dims->data[0], 1);
+ TF_LITE_ENSURE_EQ(context, alpha->dims->data[1], 1);
+ TF_LITE_ENSURE_EQ(context, alpha->dims->data[2], channels);
+
+ const int n = batches * height * width * channels;
+ for (int i = 0; i < n; ++i) {
+ const float x = input->data.f[i];
+ output->data.f[i] = x >= 0.0f ? x : alpha->data.f[i % channels] * x;
+ }
+
+ return kTfLiteOk;
+}
+
} // namespace activations
TfLiteRegistration* Register_RELU() {
@@ -439,6 +496,13 @@ TfLiteRegistration* Register_LOG_SOFTMAX() {
return &r;
}
+TfLiteRegistration* Register_PRELU() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ activations::PreluPrepare,
+ activations::PreluEval};
+ return &r;
+}
+
} // namespace builtin
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index b9a96e3f79..50a84edd47 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -383,6 +383,49 @@ TEST(FloatActivationsOpTest, LogSoftmax) {
})));
}
+class PReluOpModel : public SingleOpModel {
+ public:
+ PReluOpModel(const TensorData& input, const TensorData& alpha) {
+ input_ = AddInput(input);
+ alpha_ = AddInput(alpha);
+ output_ = AddOutput(input);
+ SetBuiltinOp(BuiltinOperator_PRELU, BuiltinOptions_NONE, 0);
+ BuildInterpreter({GetShape(input_), GetShape(alpha_)});
+ }
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ void SetAlpha(std::initializer_list<float> data) {
+ PopulateTensor(alpha_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int alpha_;
+ int output_;
+};
+
+TEST(FloatActivationsOpTest, PRelu) {
+ PReluOpModel m({TensorType_FLOAT32, {1, 2, 2, 3}},
+ {TensorType_FLOAT32, {1, 1, 3}});
+
+ m.SetInput({
+ 0.0f, 0.0f, 0.0f, // Row 1, Column 1
+ 1.0f, 1.0f, 1.0f, // Row 1, Column 2
+ -1.0f, -1.0f, -1.0f, // Row 2, Column 1
+ -2.0f, -2.0f, -2.0f, // Row 1, Column 2
+ });
+ m.SetAlpha({0.0f, 1.0f, 2.0f});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0.0f, 0.0f, 0.0f, // Row 1, Column 1
+ 1.0f, 1.0f, 1.0f, // Row 1, Column 2
+ 0.0f, -1.0f, -2.0f, // Row 2, Column 1
+ 0.0f, -2.0f, -4.0f, // Row 1, Column 2
+ }));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 369d3b9886..62045f0a4d 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -75,6 +75,7 @@ TfLiteRegistration* Register_TOPK_V2();
TfLiteRegistration* Register_LOG_SOFTMAX();
TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_DEQUANTIZE();
+TfLiteRegistration* Register_PRELU();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -131,6 +132,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE());
+ AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 9c619f88e0..b7ccdf070b 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -309,6 +309,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_LOG_SOFTMAX:
case BuiltinOperator_CAST:
case BuiltinOperator_DEQUANTIZE:
+ case BuiltinOperator_PRELU:
break;
case BuiltinOperator_LSH_PROJECTION: {
TfLiteLSHProjectionParams* params =
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 9d00d965d3..e31b7c03a5 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -349,6 +349,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_DEQUANTIZE:
case tflite::BuiltinOperator_DELEGATE:
case tflite::BuiltinOperator_CAST:
+ case tflite::BuiltinOperator_PRELU:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 04387fed33..e1075971e9 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -130,6 +130,7 @@ enum BuiltinOperator : byte {
DELEGATE = 51,
BIDIRECTIONAL_SEQUENCE_LSTM = 52,
CAST = 53,
+ PRELU = 54,
}
// Options for the builtin operators.
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index b922de2081..86daeaf5cc 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -254,11 +254,12 @@ enum BuiltinOperator {
BuiltinOperator_DELEGATE = 51,
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52,
BuiltinOperator_CAST = 53,
+ BuiltinOperator_PRELU = 54,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_CAST
+ BuiltinOperator_MAX = BuiltinOperator_PRELU
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[52] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[53] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -311,7 +312,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[52] {
BuiltinOperator_LOG_SOFTMAX,
BuiltinOperator_DELEGATE,
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
- BuiltinOperator_CAST
+ BuiltinOperator_CAST,
+ BuiltinOperator_PRELU
};
return values;
}
@@ -372,6 +374,7 @@ inline const char **EnumNamesBuiltinOperator() {
"DELEGATE",
"BIDIRECTIONAL_SEQUENCE_LSTM",
"CAST",
+ "PRELU",
nullptr
};
return names;
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index f1b18ad30f..555ea90034 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -39,6 +39,7 @@ gen_zipped_test_files(
"mean.zip",
"mul.zip",
"pad.zip",
+ "prelu.zip",
"relu.zip",
"relu1.zip",
"relu6.zip",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 420bdb41f1..38de9dcf2c 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -617,6 +617,54 @@ def make_relu6_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_prelu_tests(zip_path):
+ """Make a set of tests to do PReLU."""
+
+ test_parameters = [{
+ # The canonical case for image processing is having a 4D `input` (NHWC)
+ # and `shared_axes`=[1, 2], so the alpha parameter is per channel.
+ "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]],
+ "shared_axes": [[1, 2], [1]],
+ }]
+
+ def build_graph(parameters):
+ """Build the graph for the test case."""
+
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"])
+ out = prelu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Build the inputs for the test case."""
+
+ input_shape = parameters["input_shape"]
+ input_values = create_tensor_data(
+ np.float32, input_shape, min_value=-10, max_value=10)
+ shared_axes = parameters["shared_axes"]
+
+ alpha_shape = []
+ for dim in range(1, len(input_shape)):
+ alpha_shape.append(1 if dim in shared_axes else input_shape[dim])
+
+ alpha_values = create_tensor_data(np.float32, alpha_shape)
+
+ with tf.variable_scope("", reuse=True):
+ alpha = tf.get_variable("p_re_lu/alpha")
+ sess.run(alpha.assign(alpha_values))
+
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(
+ zip_path,
+ test_parameters,
+ build_graph,
+ build_inputs,
+ use_frozen_graph=True)
+
+
# This function tests various TensorFLow functions that generates Const op,
# including `tf.ones`, `tf.zeros` and random functions.
def make_constant_tests(zip_path):
@@ -1911,6 +1959,7 @@ def main(unused_args):
"relu.zip": make_relu_tests,
"relu1.zip": make_relu1_tests,
"relu6.zip": make_relu6_tests,
+ "prelu.zip": make_prelu_tests,
"l2_pool.zip": make_pool_tests(make_l2_pool),
"avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
"max_pool.zip": make_pool_tests(tf.nn.max_pool),
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 5e76e7c510..ba2d259462 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -88,6 +88,9 @@ std::map<string, string> kBrokenTests = {
// Transpose only supports 1D-4D input tensors.
{R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
+
+ // PRelu only supports 4D input with (1, 1, channels) 3D alpha now.
+ {R"(^\/prelu.*shared_axes=\[1\])", "75975192"},
};
// Allows test data to be unzipped into a temporary directory and makes
@@ -253,6 +256,7 @@ INSTANTIATE_TESTS(mul)
INSTANTIATE_TESTS(pad)
INSTANTIATE_TESTS(relu)
INSTANTIATE_TESTS(relu1)
+INSTANTIATE_TESTS(prelu)
INSTANTIATE_TESTS(relu6)
INSTANTIATE_TESTS(reshape)
INSTANTIATE_TESTS(resize_bilinear)
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 395abc5326..486ff1edcd 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -193,6 +193,7 @@ cc_library(
"graph_transformations/identify_lstm.cc",
"graph_transformations/identify_lstm_merge_inputs.cc",
"graph_transformations/identify_lstm_split_inputs.cc",
+ "graph_transformations/identify_prelu.cc",
"graph_transformations/identify_relu1.cc",
"graph_transformations/lstm_utils.cc",
"graph_transformations/make_initial_dequantize_operator.cc",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 11e5e19f50..640afc7c74 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -129,6 +129,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
new file mode 100644
index 0000000000..30be4ac0aa
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
@@ -0,0 +1,119 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+// This transformation rule tries to identify the PRelu structure generated by
+// Keras, and convert it to a single op.
+//
+// The formula of PReLU is:
+// f(x) = alpha * x for x < 0, f(x) = x for x >= 0.
+//
+// `x` is the input, and `alpha` is a trainable tensor which can be broadcasted
+// to the shape of `x`.
+//
+// There's no native PRelu op in TensorFlow, so Keras generates the following
+// structure which does the equivalent calculation:
+// f(x) = Relu(x) + (-alpha * Relu(-x))
+//
+// Practically, alpha is always a constant in the inference graph, and Toco have
+// other graph transformations which fold the activation functions to other ops.
+// Therefore, we're looking for the structure:
+//
+// f(x) = Relu(x) + (negative_alpha * Neg(x, activation=Relu))
+
+namespace toco {
+
+bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
+ const auto add_op_it = model->operators.begin() + op_index;
+ const auto* add_op = add_op_it->get();
+ if (add_op == nullptr || add_op->type != OperatorType::kAdd ||
+ add_op->inputs.size() != 2 ||
+ add_op->fused_activation_function != FusedActivationFunctionType::kNone) {
+ return false;
+ }
+
+ const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]);
+ if (relu_input_op == nullptr || relu_input_op->type != OperatorType::kRelu ||
+ relu_input_op->inputs.size() != 1 ||
+ relu_input_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ return false;
+ }
+
+ // TODO(ycling): Both Add and Mul are commutative. Support the case where
+ // the position of operands are exchanged.
+ const auto* mul_op = GetOpWithOutput(*model, add_op->inputs[1]);
+ if (mul_op == nullptr || mul_op->type != OperatorType::kMul ||
+ mul_op->inputs.size() != 2 ||
+ mul_op->fused_activation_function != FusedActivationFunctionType::kNone) {
+ return false;
+ }
+
+ const auto neg_alpha_tensor_name = mul_op->inputs[0];
+
+ const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]);
+
+ if (relu_neg_input_op == nullptr ||
+ relu_neg_input_op->type != OperatorType::kNeg ||
+ relu_neg_input_op->fused_activation_function !=
+ FusedActivationFunctionType::kRelu ||
+ relu_neg_input_op->inputs.size() != 1) {
+ return false;
+ }
+
+ if (relu_input_op->inputs[0] != relu_neg_input_op->inputs[0]) {
+ return false;
+ }
+
+ const auto input_tensor_name = relu_input_op->inputs[0];
+ const auto output_tensor_name = add_op->outputs[0];
+
+ // Construct a tensor for positive alpha (double negative).
+ const auto alpha_tensor_name =
+ AvailableArrayName(*model, neg_alpha_tensor_name + "_neg");
+ model->GetOrCreateArray(alpha_tensor_name);
+
+ auto* neg_neg_alpha_op = new NegOperator;
+ neg_neg_alpha_op->inputs = {neg_alpha_tensor_name};
+ neg_neg_alpha_op->outputs = {alpha_tensor_name};
+ model->operators.emplace(add_op_it, neg_neg_alpha_op);
+
+ auto* prelu_op = new PReluOperator;
+ prelu_op->inputs = {input_tensor_name, alpha_tensor_name};
+ prelu_op->outputs = {output_tensor_name};
+ model->operators.emplace(add_op_it, prelu_op);
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*prelu_op));
+
+ DeleteArrayIfUsedOnce(neg_alpha_tensor_name, model);
+ DeleteArrayIfUsedOnce(add_op->inputs[0], model);
+ DeleteArrayIfUsedOnce(add_op->inputs[1], model);
+ DeleteArrayIfUsedOnce(mul_op->inputs[1], model);
+ // Remove the existing Add op that outputs the final result. If the other
+ // intermediate tensors aren't used by other ops, those will be removed by
+ // other graph transformation rules.
+ model->operators.erase(FindOp(*model, add_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 375848a7d4..676736cfc5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1467,6 +1467,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kRelu:
case OperatorType::kRelu1:
case OperatorType::kRelu6:
+ case OperatorType::kPRelu:
case OperatorType::kSoftmax:
case OperatorType::kLogSoftmax:
case OperatorType::kLogistic:
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 3fa0089cba..5199e292e1 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -65,6 +65,7 @@ enum class OperatorType {
kRelu,
kRelu1,
kRelu6,
+ kPRelu,
kSoftmax,
kLogSoftmax,
kSub,
@@ -566,6 +567,18 @@ struct Relu6Operator : Operator {
Relu6Operator() : Operator(OperatorType::kRelu6) {}
};
+// PRelu
+// f(x) = alpha * x for x < 0, f(x) = x for x >= 0.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: the alpha array
+//
+// Equivalent to keras.layers.PReLU.
+struct PReluOperator : Operator {
+ PReluOperator() : Operator(OperatorType::kPRelu) {}
+};
+
// Element-wise Logistic operator:
// x -> Logistic(x) = 1 / (1 + exp(-x))
//
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index f2cc4ef71f..f23249cfa1 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -854,6 +854,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1));
ops.emplace_back(
new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
+ ops.emplace_back(
+ new SimpleOperator<Relu1Operator>("PRELU", OperatorType::kPRelu));
ops.emplace_back(new SimpleOperator<LogisticOperator>(
"LOGISTIC", OperatorType::kLogistic));
ops.emplace_back(
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index ca66110ba3..30dd6fab9e 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -94,6 +94,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new IdentifyL2Normalization);
transformations->Add(new IdentifyL2Pool);
transformations->Add(new IdentifyRelu1);
+ transformations->Add(new IdentifyPRelu);
transformations->Add(new RemoveTrivialBinaryOperator);
transformations->Add(new ReadFakeQuantMinMax);
transformations->Add(new ResolveSpaceToBatchNDAttributes);
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 2362206a14..ec1770c129 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -300,6 +300,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Relu)
HANDLE_OPERATORTYPENAME_CASE(Relu1)
HANDLE_OPERATORTYPENAME_CASE(Relu6)
+ HANDLE_OPERATORTYPENAME_CASE(PRelu)
HANDLE_OPERATORTYPENAME_CASE(ReorderAxes)
HANDLE_OPERATORTYPENAME_CASE(Softmax)
HANDLE_OPERATORTYPENAME_CASE(LogSoftmax)