diff options
author | 2018-08-28 15:12:30 -0700 | |
---|---|---|
committer | 2018-08-28 15:24:13 -0700 | |
commit | 730b530e7b41faadf0ff477b20bd74bf2b8e89fe (patch) | |
tree | 422e4ba644151edd7e1af65f5211f6ed76e0f7d7 /tensorflow/contrib/lite/tools | |
parent | 85631dce2b91585a3d44f7b78db85ed3eba55a48 (diff) |
Automated rollback of commit 683e21314a80ac6cb89eb959465ded41e381d23c
PiperOrigin-RevId: 210615521
Diffstat (limited to 'tensorflow/contrib/lite/tools')
-rw-r--r-- | tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py | 105 |
2 files changed, 106 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md index 3c6a0d85b3..9b3b99451d 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md @@ -47,7 +47,7 @@ category labels. The `validation_ground_truth.txt` can be converted by the follo ILSVRC_2012_DEVKIT_DIR=[set to path to ILSVRC 2012 devkit] VALIDATION_LABELS=[set to path to output] -python generate_validation_labels -- \ +python generate_validation_labels.py -- \ --ilsvrc_devkit_dir=${ILSVRC_2012_DEVKIT_DIR} \ --validation_labels_output=${VALIDATION_LABELS} ``` diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py new file mode 100644 index 0000000000..c32a41e50d --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py @@ -0,0 +1,105 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tool to convert ILSVRC devkit validation ground truth to synset labels.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +from os import path +import sys +import scipy.io + +_SYNSET_ARRAYS_RELATIVE_PATH = 'data/meta.mat' +_VALIDATION_FILE_RELATIVE_PATH = 'data/ILSVRC2012_validation_ground_truth.txt' + + +def _synset_to_word(filepath): + """Returns synset to word dictionary by reading sysnset arrays.""" + mat = scipy.io.loadmat(filepath) + entries = mat['synsets'] + # These fields are listed in devkit readme.txt + fields = [ + 'synset_id', 'WNID', 'words', 'gloss', 'num_children', 'children', + 'wordnet_height', 'num_train_images' + ] + synset_index = fields.index('synset_id') + words_index = fields.index('words') + synset_to_word = {} + for entry in entries: + entry = entry[0] + synset_id = int(entry[synset_index][0]) + first_word = entry[words_index][0].split(',')[0] + synset_to_word[synset_id] = first_word + return synset_to_word + + +def _validation_file_path(ilsvrc_dir): + return path.join(ilsvrc_dir, _VALIDATION_FILE_RELATIVE_PATH) + + +def _synset_array_path(ilsvrc_dir): + return path.join(ilsvrc_dir, _SYNSET_ARRAYS_RELATIVE_PATH) + + +def _generate_validation_labels(ilsvrc_dir, output_file): + synset_to_word = _synset_to_word(_synset_array_path(ilsvrc_dir)) + with open(_validation_file_path(ilsvrc_dir), 'r') as synset_id_file, open( + output_file, 'w') as output: + for synset_id in synset_id_file: + synset_id = int(synset_id) + output.write('%s\n' % synset_to_word[synset_id]) + + +def _check_arguments(args): + if not args.validation_labels_output: + raise ValueError('Invalid path to output file.') + ilsvrc_dir = args.ilsvrc_devkit_dir + if not ilsvrc_dir or not path.isdir(ilsvrc_dir): + raise ValueError('Invalid path to ilsvrc_dir') + if not path.exists(_validation_file_path(ilsvrc_dir)): + raise ValueError('Invalid path to ilsvrc_dir, cannot find validation file.') + if not path.exists(_synset_array_path(ilsvrc_dir)): + raise ValueError( + 'Invalid path to ilsvrc_dir, cannot find synset arrays file.') + + +def main(): + parser = argparse.ArgumentParser( + description='Converts ILSVRC devkit validation_ground_truth.txt to synset' + ' labels file that can be used by the accuracy script.') + parser.add_argument( + '--validation_labels_output', + type=str, + help='Full path for outputting validation labels.') + parser.add_argument( + '--ilsvrc_devkit_dir', + type=str, + help='Full path to ILSVRC 2012 devikit directory.') + args = parser.parse_args() + try: + _check_arguments(args) + except ValueError as e: + parser.print_usage() + file_name = path.basename(sys.argv[0]) + sys.stderr.write('{0}: error: {1}\n'.format(file_name, str(e))) + sys.exit(1) + _generate_validation_labels(args.ilsvrc_devkit_dir, + args.validation_labels_output) + + +if __name__ == '__main__': + main() |