aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/models/smartreply/predictor.h
blob: d17323a3f9a0ea80ad5e215b0a4700e625d0c590 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_

#include <string>
#include <vector>

#include "tensorflow/contrib/lite/model.h"

namespace tflite {
namespace custom {
namespace smartreply {

const int kDefaultNumResponse = 10;
const float kDefaultBackoffConfidence = 1e-4;

class PredictorResponse;
struct SmartReplyConfig;

// With a given string as input, predict the response with a Tflite model.
// When config.backoff_response is not empty, predictor_responses will be filled
// with messagees from backoff response.
void GetSegmentPredictions(const std::vector<std::string>& input,
                           const ::tflite::FlatBufferModel& model,
                           const SmartReplyConfig& config,
                           std::vector<PredictorResponse>* predictor_responses);

// Data object used to hold a single predictor response.
// It includes messages, and confidence.
class PredictorResponse {
 public:
  PredictorResponse(const std::string& response_text, float score) {
    response_text_ = response_text;
    prediction_score_ = score;
  }

  // Accessor methods.
  const std::string& GetText() const { return response_text_; }
  float GetScore() const { return prediction_score_; }

 private:
  std::string response_text_ = "";
  float prediction_score_ = 0.0;
};

// Configurations for SmartReply.
struct SmartReplyConfig {
  // Maximum responses to return.
  int num_response;
  // Default confidence for backoff responses.
  float backoff_confidence;
  // Backoff responses are used when predicted responses cannot fulfill the
  // list.
  const std::vector<std::string>& backoff_responses;

  SmartReplyConfig(std::vector<std::string> backoff_responses)
      : num_response(kDefaultNumResponse),
        backoff_confidence(kDefaultBackoffConfidence),
        backoff_responses(backoff_responses) {}
};

}  // namespace smartreply
}  // namespace custom
}  // namespace tflite

#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODELS_SMARTREPLY_PREDICTOR_H_