aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/fusion_utils.h
blob: 41f13f6cb824eb9b7bd7800ec9b4cef94fe974e2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_

#include <functional>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/protobuf.h"

namespace tensorflow {
namespace grappler {
namespace fusion_utils {

// These functions are invoked with first and second function signature,
// should set a signature of fused second_function.
using SetFunctionSignatureFn = std::function<void(
    const OpDef& first_function_signature,
    const OpDef& second_function_signature, OpDef* fused_function_signature)>;

using StringCollection = gtl::InlinedVector<string, 2>;

// These functions are invoked with nodes from second function that were
// previously taking arguments as input. The `arg_num` tells which
// function argument node was using as an input, e.g:
// node(arg_1, other_node, arg_4)
// would be called on the first and third input with arg_num equal 1 and 4.
// It should set up inputs based on first function inputs or outputs or
// second function inputs.
using SetInputFn =
    std::function<string(const StringCollection& first_function_inputs,
                         const StringCollection& second_function_inputs,
                         const StringCollection& parent_outputs, int arg_num)>;

// This function is invoked with first function ret. It is used to set up
// returns of fused function.  If you need to combine outputs
// of first and second function, then this is a right place to create a new
// nodes.
using SetOutputFn =
    std::function<void(const protobuf::Map<string, string>& parent_ret,
                       const protobuf::Map<string, string>& second_function_ret,
                       FunctionDef* fused_function)>;

// Returns true if functions can be composed.
bool CanCompose(const OpDef& first_signature, const OpDef& second_signature);

void ComposeSignature(const OpDef& first_signature,
                      const OpDef& second_signature, OpDef* fused_signature);

string ComposeInput(const StringCollection& first_inputs,
                    const StringCollection& second_inputs,
                    const StringCollection& first_outputs, int arg_num);

// Sets output to the composition of first and second function:
// second_function(first_function(args...)).
void ComposeOutput(const protobuf::Map<string, string>& first_ret,
                   const protobuf::Map<string, string>& second_ret,
                   FunctionDef* fused_function);

// Set input signature to `first_function_signature` and output signature
// to `first_function_signature` + `second_function_signature`
void CombineSignature(const OpDef& first_signature,
                      const OpDef& second_signature, OpDef* fused_signature);

// Apart from first function returns, return values from second function as
// extra returns like:
// return *first_function(...), *second_function(...)
void CombineOutput(const protobuf::Map<string, string>& first_ret,
                   const protobuf::Map<string, string>& second_ret,
                   FunctionDef* fused_function);

// Fuse `first_function` with `second_function`, setting `fused_name_prefix` as
// a name prefix.  The nodes from `first_function` are copied unmodified.  All
// of the setup functions are called with a copy of second function having names
// that are not conflicting with first function.  This means that copied nodes
// from  second function can end up having different names.  For explanation of
// set up functions see the documentation of the functions types.
FunctionDef* FuseFunctions(const FunctionDef& first_function,
                           const FunctionDef& second_function,
                           StringPiece fused_name_prefix,
                           const SetFunctionSignatureFn& set_signature,
                           const SetInputFn& set_input,
                           const SetOutputFn& set_output,
                           FunctionDefLibrary* library);

}  // namespace fusion_utils
}  // namespace grappler
}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_