diff options
author | 2018-06-05 16:22:14 -0700 | |
---|---|---|
committer | 2018-06-05 16:24:51 -0700 | |
commit | 135a25971bfbac86b0aed2cf0433608966015c22 (patch) | |
tree | c4343e0d13592d463a43136f64702b2fdbde9d17 /tensorflow | |
parent | 677c83e6ba6fdc4d23f8c26bfc84209be4371631 (diff) |
Support uint8, int32 and int64 for SpaceToDepth in TOCO.
PiperOrigin-RevId: 199376731
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 13 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 9 |
2 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 9bb7a4600d..351187f520 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -58,10 +58,11 @@ from tensorflow.python.ops import rnn parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") parser.add_argument("output_path", help="Directory where the outputs will be go.") -parser.add_argument("--zip_to_output", - type=str, - help="Particular zip to output.", - required=False) +parser.add_argument( + "--zip_to_output", + type=str, + help="Particular zip to output.", + required=True) parser.add_argument("--toco", type=str, help="Path to toco tool.", @@ -97,8 +98,6 @@ KNOWN_BUGS = { r"fully_connected.*transpose_.=True": "67586970", # Softmax graphs are too complex. r"softmax.*dim=0": "67749831", - # SpaceToDepth only supports float32. - r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", # BatchToSpaceND only supports 4D tensors. r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", # Div will use floordiv. @@ -1621,7 +1620,7 @@ def make_space_to_depth_tests(zip_path): """Make a set of tests to do space_to_depth.""" test_parameters = [{ - "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64], + "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64], "input_shape": [[2, 12, 24, 1]], "block_size": [2, 3, 4], }] diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 0a57015d29..b9ebf66ff2 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -614,7 +614,14 @@ void ConvertSpaceToDepthOperator(const NodeDef& node, CHECK_EQ(node.op(), "SpaceToDepth"); CheckInputsCount(node, tf_import_flags, 1); - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + tensorflow::DataType dtype = GetDataTypeAttr(node, "T"); + if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 && + dtype != DT_INT64) { + const auto* enum_descriptor = tensorflow::DataType_descriptor(); + LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:" + << enum_descriptor->FindValueByNumber(dtype)->name() << ". " + << "T must be one of {DT_FLOAT, DT_INT8, DT_INT32, DT_INT64}."; + } auto* op = new SpaceToDepthOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); |