diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/attention_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/attention_ops_test.py | 166 |
1 files changed, 166 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py new file mode 100644 index 0000000000..5541c541b2 --- /dev/null +++ b/tensorflow/python/kernel_tests/attention_ops_test.py @@ -0,0 +1,166 @@ +"""Tests for tensorflow.ops.attention_ops.""" + +import tensorflow.python.platform + +import tensorflow as tf + +from tensorflow.python.ops import attention_ops + + +class ExtractGlimpseTest(tf.test.TestCase): + + def _VerifyValues( + self, tensor_in_sizes, glimpse_sizes, offsets, expected_rows, + expected_cols): + """Verifies the output values of the glimpse extraction kernel. + + Args: + tensor_in_sizes: Input tensor dimensions in [input_rows, input_cols]. + glimpse_sizes: Dimensions of the glimpse in [glimpse_rows, glimpse_cols]. + offsets: Relative location of the center of the glimpse in the input + image expressed as [row_offset, col_offset]. + expected_rows: A list containing the expected row numbers (None for + out of bound entries that are expected to be replaced by uniform + random entries in [0,1) ). + expected_cols: Same as expected_rows, but for column numbers. + """ + + rows = tensor_in_sizes[0] + cols = tensor_in_sizes[1] + # Row Tensor with entries by row. + # [[ 1 1 1 ... ] + # [ 2 2 2 ... ] + # [ 3 3 3 ... ] + # [ ... + # ] + t_rows = tf.tile( + [[1.0 * r] for r in range(1, rows + 1)], [1, cols], + name='tile_rows') + + # Shuffle to switch to a convention of (batch_size, height, width, depth). + t_rows_4d = tf.transpose( + tf.expand_dims( + tf.expand_dims(t_rows, 0), 3), [0, 2, 1, 3]) + + # Column Tensor with entries by column. + # [[ 1 2 3 4 ... ] + # [ 1 2 3 4 ... ] + # [ 1 2 3 4 ... ] + # [ ... ] + # ] + t_cols = tf.tile( + [[1.0 * r for r in range(1, cols + 1)]], + [rows, 1], name='tile_cols') + + # Shuffle to switch to a convention of (batch_size, height, width, depth). + t_cols_4d = tf.transpose( + tf.expand_dims( + tf.expand_dims(t_cols, 0), 3), [0, 2, 1, 3]) + + # extract_glimpses from Row and Column Tensor, respectively. + # Switch order for glimpse_sizes and offsets to switch from (row, col) + # convention to tensorflows (height, width) convention. + t1 = tf.constant([glimpse_sizes[1], glimpse_sizes[0]], shape=[2]) + t2 = tf.constant([offsets[1], offsets[0]], shape=[1, 2]) + glimpse_rows = (tf.transpose( + attention_ops.extract_glimpse(t_rows_4d, t1, t2), [0, 2, 1, 3])) + glimpse_cols = (tf.transpose( + attention_ops.extract_glimpse(t_cols_4d, t1, t2), [0, 2, 1, 3])) + + # Evaluate the Tensorflow Graph. + with self.test_session() as sess: + value_rows, value_cols = sess.run([glimpse_rows, glimpse_cols]) + + # Check dimensions of returned glimpse. + self.assertEqual(value_rows.shape[1], glimpse_sizes[0]) + self.assertEqual(value_rows.shape[2], glimpse_sizes[1]) + self.assertEqual(value_cols.shape[1], glimpse_sizes[0]) + self.assertEqual(value_cols.shape[2], glimpse_sizes[1]) + + # Check entries. + min_random_val = 0 + max_random_val = max(rows, cols) + for i in range(0, glimpse_sizes[0]): + for j in range(0, glimpse_sizes[1]): + if expected_rows[i] is None or expected_cols[j] is None: + self.assertGreaterEqual(value_rows[0][i][j][0], min_random_val) + self.assertLessEqual(value_rows[0][i][j][0], max_random_val) + self.assertGreaterEqual(value_cols[0][i][j][0], min_random_val) + self.assertLessEqual(value_cols[0][i][j][0], max_random_val) + else: + self.assertEqual(value_rows[0][i][j][0], expected_rows[i]) + self.assertEqual(value_cols[0][i][j][0], expected_cols[j]) + + def testCenterGlimpse(self): + self._VerifyValues(tensor_in_sizes=[41, 61], + glimpse_sizes=[3, 5], + offsets=[0.0, 0.0], + expected_rows=[20, 21, 22], + expected_cols=[29, 30, 31, 32, 33]) + + def testLargeCenterGlimpse(self): + self._VerifyValues(tensor_in_sizes=[41, 61], + glimpse_sizes=[41, 61], + offsets=[0.0, 0.0], + expected_rows=range(1, 42), + expected_cols=range(1, 62)) + + def testTooLargeCenterGlimpse(self): + self._VerifyValues(tensor_in_sizes=[41, 61], + glimpse_sizes=[43, 63], + offsets=[0.0, 0.0], + expected_rows=[None] + range(1, 42) + [None], + expected_cols=[None] + range(1, 62) + [None]) + + def testGlimpseFullOverlap(self): + self._VerifyValues(tensor_in_sizes=[41, 61], + glimpse_sizes=[3, 5], + offsets=[0.1, 0.3], + expected_rows=[22, 23, 24], + expected_cols=[38, 39, 40, 41, 42]) + + def testGlimpseFullOverlap2(self): + self._VerifyValues(tensor_in_sizes=[41, 61], + glimpse_sizes=[11, 3], + offsets=[-0.7, -0.7], + expected_rows=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + expected_cols=[8, 9, 10]) + + def testGlimpseBeforeLeftMargin(self): + self._VerifyValues(tensor_in_sizes=[41, 61], + glimpse_sizes=[11, 5], + offsets=[-0.7, -0.9], + expected_rows=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + expected_cols=[1, 2, 3, 4, 5]) + + def testGlimpseLowerRightCorner(self): + self._VerifyValues(tensor_in_sizes=[41, 61], + glimpse_sizes=[7, 5], + offsets=[1.0, 1.0], + expected_rows=[38, 39, 40, 41, None, None, None], + expected_cols=[59, 60, 61, None, None]) + + def testGlimpseNoOverlap(self): + self._VerifyValues(tensor_in_sizes=[20, 30], + glimpse_sizes=[3, 3], + offsets=[-2.0, 2.0], + expected_rows=[None, None, None], + expected_cols=[None, None, None]) + + def testGlimpseOnLeftMargin(self): + self._VerifyValues(tensor_in_sizes=[41, 61], + glimpse_sizes=[11, 7], + offsets=[-0.7, -1.0], + expected_rows=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + expected_cols=[None, None, None, 1, 2, 3, 4]) + + def testGlimpseUpperMargin(self): + self._VerifyValues(tensor_in_sizes=[41, 61], + glimpse_sizes=[7, 5], + offsets=[-1, 0.9], + expected_rows=[None, None, None, 1, 2, 3, 4], + expected_cols=[56, 57, 58, 59, 60]) + + +if __name__ == '__main__': + tf.test.main() |