aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/speech_commands/accuracy_utils.h
blob: eea048365bc9ff53bdd767be436fb657b43793c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_
#define TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_

#include <vector>

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

struct StreamingAccuracyStats {
  StreamingAccuracyStats()
      : how_many_ground_truth_words(0),
        how_many_ground_truth_matched(0),
        how_many_false_positives(0),
        how_many_correct_words(0),
        how_many_wrong_words(0) {}
  int32 how_many_ground_truth_words;
  int32 how_many_ground_truth_matched;
  int32 how_many_false_positives;
  int32 how_many_correct_words;
  int32 how_many_wrong_words;
};

// Takes a file name, and loads a list of expected word labels and times from
// it, as comma-separated variables.
Status ReadGroundTruthFile(const string& file_name,
                           std::vector<std::pair<string, int64>>* result);

// Given ground truth labels and corresponding predictions found by a model,
// figure out how many were correct. Takes a time limit, so that only
// predictions up to a point in time are considered, in case we're evaluating
// accuracy when the model has only been run on part of the stream.
void CalculateAccuracyStats(
    const std::vector<std::pair<string, int64>>& ground_truth_list,
    const std::vector<std::pair<string, int64>>& found_words,
    int64 up_to_time_ms, int64 time_tolerance_ms,
    StreamingAccuracyStats* stats);

// Writes a human-readable description of the statistics to stdout.
void PrintAccuracyStats(const StreamingAccuracyStats& stats);

}  // namespace tensorflow

#endif  // TENSORFLOW_EXAMPLES_SPEECH_COMMANDS_ACCURACY_UTILS_H_