aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/tools
diff options
context:
space:
mode:
authorGravatar Shashi Shekhar <shashishekhar@google.com>2018-08-28 15:12:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 15:24:13 -0700
commit730b530e7b41faadf0ff477b20bd74bf2b8e89fe (patch)
tree422e4ba644151edd7e1af65f5211f6ed76e0f7d7 /tensorflow/contrib/lite/tools
parent85631dce2b91585a3d44f7b78db85ed3eba55a48 (diff)
Automated rollback of commit 683e21314a80ac6cb89eb959465ded41e381d23c
PiperOrigin-RevId: 210615521
Diffstat (limited to 'tensorflow/contrib/lite/tools')
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md2
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py105
2 files changed, 106 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
index 3c6a0d85b3..9b3b99451d 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
@@ -47,7 +47,7 @@ category labels. The `validation_ground_truth.txt` can be converted by the follo
ILSVRC_2012_DEVKIT_DIR=[set to path to ILSVRC 2012 devkit]
VALIDATION_LABELS=[set to path to output]
-python generate_validation_labels -- \
+python generate_validation_labels.py -- \
--ilsvrc_devkit_dir=${ILSVRC_2012_DEVKIT_DIR} \
--validation_labels_output=${VALIDATION_LABELS}
```
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py
new file mode 100644
index 0000000000..c32a41e50d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py
@@ -0,0 +1,105 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tool to convert ILSVRC devkit validation ground truth to synset labels."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+from os import path
+import sys
+import scipy.io
+
+_SYNSET_ARRAYS_RELATIVE_PATH = 'data/meta.mat'
+_VALIDATION_FILE_RELATIVE_PATH = 'data/ILSVRC2012_validation_ground_truth.txt'
+
+
+def _synset_to_word(filepath):
+ """Returns synset to word dictionary by reading sysnset arrays."""
+ mat = scipy.io.loadmat(filepath)
+ entries = mat['synsets']
+ # These fields are listed in devkit readme.txt
+ fields = [
+ 'synset_id', 'WNID', 'words', 'gloss', 'num_children', 'children',
+ 'wordnet_height', 'num_train_images'
+ ]
+ synset_index = fields.index('synset_id')
+ words_index = fields.index('words')
+ synset_to_word = {}
+ for entry in entries:
+ entry = entry[0]
+ synset_id = int(entry[synset_index][0])
+ first_word = entry[words_index][0].split(',')[0]
+ synset_to_word[synset_id] = first_word
+ return synset_to_word
+
+
+def _validation_file_path(ilsvrc_dir):
+ return path.join(ilsvrc_dir, _VALIDATION_FILE_RELATIVE_PATH)
+
+
+def _synset_array_path(ilsvrc_dir):
+ return path.join(ilsvrc_dir, _SYNSET_ARRAYS_RELATIVE_PATH)
+
+
+def _generate_validation_labels(ilsvrc_dir, output_file):
+ synset_to_word = _synset_to_word(_synset_array_path(ilsvrc_dir))
+ with open(_validation_file_path(ilsvrc_dir), 'r') as synset_id_file, open(
+ output_file, 'w') as output:
+ for synset_id in synset_id_file:
+ synset_id = int(synset_id)
+ output.write('%s\n' % synset_to_word[synset_id])
+
+
+def _check_arguments(args):
+ if not args.validation_labels_output:
+ raise ValueError('Invalid path to output file.')
+ ilsvrc_dir = args.ilsvrc_devkit_dir
+ if not ilsvrc_dir or not path.isdir(ilsvrc_dir):
+ raise ValueError('Invalid path to ilsvrc_dir')
+ if not path.exists(_validation_file_path(ilsvrc_dir)):
+ raise ValueError('Invalid path to ilsvrc_dir, cannot find validation file.')
+ if not path.exists(_synset_array_path(ilsvrc_dir)):
+ raise ValueError(
+ 'Invalid path to ilsvrc_dir, cannot find synset arrays file.')
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Converts ILSVRC devkit validation_ground_truth.txt to synset'
+ ' labels file that can be used by the accuracy script.')
+ parser.add_argument(
+ '--validation_labels_output',
+ type=str,
+ help='Full path for outputting validation labels.')
+ parser.add_argument(
+ '--ilsvrc_devkit_dir',
+ type=str,
+ help='Full path to ILSVRC 2012 devikit directory.')
+ args = parser.parse_args()
+ try:
+ _check_arguments(args)
+ except ValueError as e:
+ parser.print_usage()
+ file_name = path.basename(sys.argv[0])
+ sys.stderr.write('{0}: error: {1}\n'.format(file_name, str(e)))
+ sys.exit(1)
+ _generate_validation_labels(args.ilsvrc_devkit_dir,
+ args.validation_labels_output)
+
+
+if __name__ == '__main__':
+ main()