aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-03-22 11:25:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 11:29:59 -0700
commit7c4cdb8bae0e8760ebe4793d49ea5aee68768655 (patch)
treed3adb4214eecc995845adf5d4f32331b60b8313a /tensorflow/contrib/lite/testing
parentcfdd61585769188789280e768fc43fdbba799619 (diff)
Supports PReLU in TFLite & Toco.
PiperOrigin-RevId: 190097557
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r--tensorflow/contrib/lite/testing/BUILD1
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py49
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc4
3 files changed, 54 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index f1b18ad30f..555ea90034 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -39,6 +39,7 @@ gen_zipped_test_files(
"mean.zip",
"mul.zip",
"pad.zip",
+ "prelu.zip",
"relu.zip",
"relu1.zip",
"relu6.zip",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 420bdb41f1..38de9dcf2c 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -617,6 +617,54 @@ def make_relu6_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_prelu_tests(zip_path):
+ """Make a set of tests to do PReLU."""
+
+ test_parameters = [{
+ # The canonical case for image processing is having a 4D `input` (NHWC)
+ # and `shared_axes`=[1, 2], so the alpha parameter is per channel.
+ "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]],
+ "shared_axes": [[1, 2], [1]],
+ }]
+
+ def build_graph(parameters):
+ """Build the graph for the test case."""
+
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"])
+ out = prelu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Build the inputs for the test case."""
+
+ input_shape = parameters["input_shape"]
+ input_values = create_tensor_data(
+ np.float32, input_shape, min_value=-10, max_value=10)
+ shared_axes = parameters["shared_axes"]
+
+ alpha_shape = []
+ for dim in range(1, len(input_shape)):
+ alpha_shape.append(1 if dim in shared_axes else input_shape[dim])
+
+ alpha_values = create_tensor_data(np.float32, alpha_shape)
+
+ with tf.variable_scope("", reuse=True):
+ alpha = tf.get_variable("p_re_lu/alpha")
+ sess.run(alpha.assign(alpha_values))
+
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(
+ zip_path,
+ test_parameters,
+ build_graph,
+ build_inputs,
+ use_frozen_graph=True)
+
+
# This function tests various TensorFLow functions that generates Const op,
# including `tf.ones`, `tf.zeros` and random functions.
def make_constant_tests(zip_path):
@@ -1911,6 +1959,7 @@ def main(unused_args):
"relu.zip": make_relu_tests,
"relu1.zip": make_relu1_tests,
"relu6.zip": make_relu6_tests,
+ "prelu.zip": make_prelu_tests,
"l2_pool.zip": make_pool_tests(make_l2_pool),
"avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
"max_pool.zip": make_pool_tests(tf.nn.max_pool),
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 5e76e7c510..ba2d259462 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -88,6 +88,9 @@ std::map<string, string> kBrokenTests = {
// Transpose only supports 1D-4D input tensors.
{R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
+
+ // PRelu only supports 4D input with (1, 1, channels) 3D alpha now.
+ {R"(^\/prelu.*shared_axes=\[1\])", "75975192"},
};
// Allows test data to be unzipped into a temporary directory and makes
@@ -253,6 +256,7 @@ INSTANTIATE_TESTS(mul)
INSTANTIATE_TESTS(pad)
INSTANTIATE_TESTS(relu)
INSTANTIATE_TESTS(relu1)
+INSTANTIATE_TESTS(prelu)
INSTANTIATE_TESTS(relu6)
INSTANTIATE_TESTS(reshape)
INSTANTIATE_TESTS(resize_bilinear)