aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/decode_raw_op_test.py
blob: abd50a7527e247f785f5cbbdb886ad80a46c121a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Tests for DecodeRaw op from parsing_ops."""

import tensorflow.python.platform

import tensorflow as tf


class DecodeRawOpTest(tf.test.TestCase):

  def testToUint8(self):
    with self.test_session():
      in_bytes = tf.placeholder(tf.string, shape=[2])
      decode = tf.decode_raw(in_bytes, out_type=tf.uint8)
      self.assertEqual([2, None], decode.get_shape().as_list())

      result = decode.eval(feed_dict={in_bytes: ["A", "a"]})
      self.assertAllEqual([[ord("A")], [ord("a")]], result)

      result = decode.eval(feed_dict={in_bytes: ["wer", "XYZ"]})
      self.assertAllEqual([[ord("w"), ord("e"), ord("r")],
                           [ord("X"), ord("Y"), ord("Z")]], result)

      with self.assertRaisesOpError(
          "DecodeRaw requires input strings to all be the same size, but "
          "element 1 has size 5 != 6"):
        decode.eval(feed_dict={in_bytes: ["short", "longer"]})

  def testToInt16(self):
    with self.test_session():
      in_bytes = tf.placeholder(tf.string, shape=[None])
      decode = tf.decode_raw(in_bytes, out_type=tf.int16)
      self.assertEqual([None, None], decode.get_shape().as_list())

      result = decode.eval(feed_dict={in_bytes: ["AaBC"]})
      self.assertAllEqual([[ord("A") + ord("a") * 256,
                            ord("B") + ord("C") * 256]], result)

      with self.assertRaisesOpError(
          "Input to DecodeRaw has length 3 that is not a multiple of 2, the "
          "size of int16"):
        decode.eval(feed_dict={in_bytes: ["123", "456"]})

if __name__ == "__main__":
  tf.test.main()