aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-19 12:08:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-19 12:13:53 -0800
commit83b751621439cc2b8a85450972414cf2f92a58cf (patch)
tree860f7f40a104388113121c79f7c6cb51cf7b1198 /tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
parent39ae44e9822ed76639bae3ccf800b36039d1da55 (diff)
Add support for time_major shape format to the sequential RNN Op in TF Lite.
This option, if set, changes the shape format of the inputs and outputs to [max_time, batch_size, depth]. If false, it uses [batch_size, max_time, depth]. By default, it is set to false. PiperOrigin-RevId: 182569507
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc102
1 files changed, 92 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
index a1c1eda160..82c680ec3d 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Unit test for TFLite RNN op.
+// Unit test for TFLite Sequential RNN op.
#include <vector>
#include <iomanip>
@@ -125,7 +125,8 @@ static float rnn_golden_output[] = {
class UnidirectionalRNNOpModel : public SingleOpModel {
public:
- UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size)
+ UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size,
+ bool time_major)
: batches_(batches),
sequence_len_(sequence_len),
units_(units),
@@ -136,13 +137,22 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
bias_ = AddInput(TensorType_FLOAT32);
hidden_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(
- BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_RNNOptions,
- CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
- BuildInterpreter({{batches_, sequence_len_, input_size_},
- {units_, input_size_},
- {units_, units_},
- {units_}});
+ SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
+ BuiltinOptions_SequenceRNNOptions,
+ CreateSequenceRNNOptions(builder_, time_major,
+ ActivationFunctionType_RELU)
+ .Union());
+ if (time_major) {
+ BuildInterpreter({{sequence_len_, batches_, input_size_},
+ {units_, input_size_},
+ {units_, units_},
+ {units_}});
+ } else {
+ BuildInterpreter({{batches_, sequence_len_, input_size_},
+ {units_, input_size_},
+ {units_, units_},
+ {units_}});
+ }
}
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -195,7 +205,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
// TODO(mirkov): add another test which directly compares to TF once TOCO
// supports the conversion from dynamic_rnn with BasicRNNCell.
TEST(FullyConnectedOpTest, BlackBoxTest) {
- UnidirectionalRNNOpModel rnn(2, 16, 16, 8);
+ UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*units=*/16, /*size=*/8, /*time_major=*/false);
rnn.SetWeights(
{0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
@@ -260,6 +271,77 @@ TEST(FullyConnectedOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
}
+TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) {
+ UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*units=*/16, /*size=*/8, /*time_major=*/true);
+ rnn.SetWeights(
+ {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
+ 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
+ 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
+ -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
+ -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
+ -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
+ -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
+ 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
+ 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
+ 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
+ -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
+ 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
+ -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
+ -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
+ 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
+ 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
+ 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
+ -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
+ 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
+ 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
+ -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
+ 0.277308, 0.415818});
+
+ rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
+ -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
+ 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
+ -0.37609905});
+
+ rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1});
+
+ rnn.ResetHiddenState();
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ // The two batches are identical.
+ rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
+ rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
+ }
+
+ rnn.Invoke();
+
+ std::vector<float> expected;
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* golden_batch_start = rnn_golden_output + i * rnn.num_units();
+ float* golden_batch_end = golden_batch_start + rnn.num_units();
+ expected.insert(expected.end(), golden_batch_start, golden_batch_end);
+ expected.insert(expected.end(), golden_batch_start, golden_batch_end);
+ }
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+}
+
} // namespace
} // namespace tflite