aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-06-05 16:22:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 16:24:51 -0700
commit135a25971bfbac86b0aed2cf0433608966015c22 (patch)
treec4343e0d13592d463a43136f64702b2fdbde9d17 /tensorflow
parent677c83e6ba6fdc4d23f8c26bfc84209be4371631 (diff)
Support uint8, int32 and int64 for SpaceToDepth in TOCO.
PiperOrigin-RevId: 199376731
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py13
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc9
2 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 9bb7a4600d..351187f520 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -58,10 +58,11 @@ from tensorflow.python.ops import rnn
parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
parser.add_argument("output_path",
help="Directory where the outputs will be go.")
-parser.add_argument("--zip_to_output",
- type=str,
- help="Particular zip to output.",
- required=False)
+parser.add_argument(
+ "--zip_to_output",
+ type=str,
+ help="Particular zip to output.",
+ required=True)
parser.add_argument("--toco",
type=str,
help="Path to toco tool.",
@@ -97,8 +98,6 @@ KNOWN_BUGS = {
r"fully_connected.*transpose_.=True": "67586970",
# Softmax graphs are too complex.
r"softmax.*dim=0": "67749831",
- # SpaceToDepth only supports float32.
- r"space_to_depth.*(float16|int32|uint8|int64)": "68018134",
# BatchToSpaceND only supports 4D tensors.
r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
# Div will use floordiv.
@@ -1621,7 +1620,7 @@ def make_space_to_depth_tests(zip_path):
"""Make a set of tests to do space_to_depth."""
test_parameters = [{
- "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64],
+ "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64],
"input_shape": [[2, 12, 24, 1]],
"block_size": [2, 3, 4],
}]
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 0a57015d29..b9ebf66ff2 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -614,7 +614,14 @@ void ConvertSpaceToDepthOperator(const NodeDef& node,
CHECK_EQ(node.op(), "SpaceToDepth");
CheckInputsCount(node, tf_import_flags, 1);
- CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
+ if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
+ dtype != DT_INT64) {
+ const auto* enum_descriptor = tensorflow::DataType_descriptor();
+ LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:"
+ << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
+ << "T must be one of {DT_FLOAT, DT_INT8, DT_INT32, DT_INT64}.";
+ }
auto* op = new SpaceToDepthOperator;
op->inputs.push_back(node.input(0));
op->outputs.push_back(node.name());