diff options
author | 2018-03-22 11:25:49 -0700 | |
---|---|---|
committer | 2018-03-22 11:29:59 -0700 | |
commit | 7c4cdb8bae0e8760ebe4793d49ea5aee68768655 (patch) | |
tree | d3adb4214eecc995845adf5d4f32331b60b8313a /tensorflow/contrib/lite/testing | |
parent | cfdd61585769188789280e768fc43fdbba799619 (diff) |
Supports PReLU in TFLite & Toco.
PiperOrigin-RevId: 190097557
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r-- | tensorflow/contrib/lite/testing/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 49 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/generated_examples_zip_test.cc | 4 |
3 files changed, 54 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index f1b18ad30f..555ea90034 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -39,6 +39,7 @@ gen_zipped_test_files( "mean.zip", "mul.zip", "pad.zip", + "prelu.zip", "relu.zip", "relu1.zip", "relu6.zip", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 420bdb41f1..38de9dcf2c 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -617,6 +617,54 @@ def make_relu6_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_prelu_tests(zip_path): + """Make a set of tests to do PReLU.""" + + test_parameters = [{ + # The canonical case for image processing is having a 4D `input` (NHWC) + # and `shared_axes`=[1, 2], so the alpha parameter is per channel. + "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]], + "shared_axes": [[1, 2], [1]], + }] + + def build_graph(parameters): + """Build the graph for the test case.""" + + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"]) + out = prelu(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Build the inputs for the test case.""" + + input_shape = parameters["input_shape"] + input_values = create_tensor_data( + np.float32, input_shape, min_value=-10, max_value=10) + shared_axes = parameters["shared_axes"] + + alpha_shape = [] + for dim in range(1, len(input_shape)): + alpha_shape.append(1 if dim in shared_axes else input_shape[dim]) + + alpha_values = create_tensor_data(np.float32, alpha_shape) + + with tf.variable_scope("", reuse=True): + alpha = tf.get_variable("p_re_lu/alpha") + sess.run(alpha.assign(alpha_values)) + + return [input_values], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_values]))) + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + use_frozen_graph=True) + + # This function tests various TensorFLow functions that generates Const op, # including `tf.ones`, `tf.zeros` and random functions. def make_constant_tests(zip_path): @@ -1911,6 +1959,7 @@ def main(unused_args): "relu.zip": make_relu_tests, "relu1.zip": make_relu1_tests, "relu6.zip": make_relu6_tests, + "prelu.zip": make_prelu_tests, "l2_pool.zip": make_pool_tests(make_l2_pool), "avg_pool.zip": make_pool_tests(tf.nn.avg_pool), "max_pool.zip": make_pool_tests(tf.nn.max_pool), diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 5e76e7c510..ba2d259462 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -88,6 +88,9 @@ std::map<string, string> kBrokenTests = { // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, + + // PRelu only supports 4D input with (1, 1, channels) 3D alpha now. + {R"(^\/prelu.*shared_axes=\[1\])", "75975192"}, }; // Allows test data to be unzipped into a temporary directory and makes @@ -253,6 +256,7 @@ INSTANTIATE_TESTS(mul) INSTANTIATE_TESTS(pad) INSTANTIATE_TESTS(relu) INSTANTIATE_TESTS(relu1) +INSTANTIATE_TESTS(prelu) INSTANTIATE_TESTS(relu6) INSTANTIATE_TESTS(reshape) INSTANTIATE_TESTS(resize_bilinear) |