1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
#include "tensorflow/core/util/bcast.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
/* static */
void BCast::Reverse(Vec* shape) { std::reverse(shape->begin(), shape->end()); }
BCast::BCast(const Vec& sx, const Vec& sy) {
// Reverse the shape of x and y for convenience.
// After the reverse, 0-th is the inner-most dimension.
Vec x = sx;
Reverse(&x);
Vec y = sy;
Reverse(&y);
// 1-extend and align x and y so that they are the same size.
if (x.size() > y.size()) {
y.resize(x.size(), 1);
} else {
x.resize(y.size(), 1);
}
// Going through each dimension starting from the inner-most
// dimension, compares dimension of x and y. They are compatible if
// they are equal or either is 1.
enum State {
UNKNOWN,
SAME,
X_ONE,
Y_ONE,
};
State prev = UNKNOWN;
const int64 n = x.size();
for (int i = 0; i < n; ++i) {
// Output shape.
State curr = UNKNOWN;
const int64 x_i = x[i]; // i-th dimension of x.
CHECK_GE(x_i, 0);
const int64 y_i = y[i]; // i-th dimension of y.
CHECK_GE(y_i, 0);
int64 o_i; // i-th dimension of the output.
int64 bx_i; // i-th broadcast for x.
int64 by_i; // i-th broadcast for y.
// Invariant:
// o_i = x_i * bx_i = y_i * by_i
if (x_i == y_i) {
// No broadcast.
o_i = x_i;
bx_i = 1;
by_i = 1;
curr = SAME;
} else if (x_i == 1) {
// x broadcast to y on this dimension.
o_i = y_i;
bx_i = y_i;
by_i = 1;
grad_x_reduce_idx_.push_back(n - 1 - i);
curr = X_ONE;
} else if (y_i == 1) {
// y broadcast to x on this dimension.
o_i = x_i;
bx_i = 1;
by_i = x_i;
grad_y_reduce_idx_.push_back(n - 1 - i);
curr = Y_ONE;
} else {
valid_ = false;
return;
}
output_.push_back(o_i);
// Reshape/broadcast.
// Invariant:
// result[i] == x_reshape[i] * x_bcast[i] == y_reshape_[i] * y_bcast_[i]
if (curr == SAME && x_i == 1) {
// Both side are 1s.
grad_x_reduce_idx_.push_back(n - 1 - i);
grad_y_reduce_idx_.push_back(n - 1 - i);
continue;
} else if (prev == curr) {
// It is a run of the same cases (no broadcast, x broadcast to
// y, y broadcast to x). We can reshape the input so that fewer
// dimensions are involved in the intermediate computation.
result_.back() *= o_i;
x_reshape_.back() *= x_i;
x_bcast_.back() *= bx_i;
y_reshape_.back() *= y_i;
y_bcast_.back() *= by_i;
} else {
result_.push_back(o_i);
x_reshape_.push_back(x_i);
x_bcast_.push_back(bx_i);
y_reshape_.push_back(y_i);
y_bcast_.push_back(by_i);
}
prev = curr;
}
if (result_.empty()) {
// Can happen when both x and y are effectively scalar.
result_.push_back(1);
x_reshape_.push_back(1);
x_bcast_.push_back(1);
y_reshape_.push_back(1);
y_bcast_.push_back(1);
}
// Reverse all vectors since x and y were reversed at very
// beginning.
Reverse(&x_reshape_);
Reverse(&x_bcast_);
Reverse(&y_reshape_);
Reverse(&y_bcast_);
Reverse(&result_);
Reverse(&output_);
Reverse(&grad_x_reduce_idx_);
Reverse(&grad_y_reduce_idx_);
}
} // end namespace tensorflow
|