aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/bcast.cc
blob: 4e70b78751a04df293c22aa12d12642892e27050 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include "tensorflow/core/util/bcast.h"

#include "tensorflow/core/platform/logging.h"
namespace tensorflow {

/* static */
void BCast::Reverse(Vec* shape) { std::reverse(shape->begin(), shape->end()); }

BCast::BCast(const Vec& sx, const Vec& sy) {
  // Reverse the shape of x and y for convenience.
  // After the reverse, 0-th is the inner-most dimension.
  Vec x = sx;
  Reverse(&x);
  Vec y = sy;
  Reverse(&y);

  // 1-extend and align x and y so that they are the same size.
  if (x.size() > y.size()) {
    y.resize(x.size(), 1);
  } else {
    x.resize(y.size(), 1);
  }

  // Going through each dimension starting from the inner-most
  // dimension, compares dimension of x and y. They are compatible if
  // they are equal or either is 1.
  enum State {
    UNKNOWN,
    SAME,
    X_ONE,
    Y_ONE,
  };
  State prev = UNKNOWN;
  const int64 n = x.size();
  for (int i = 0; i < n; ++i) {
    // Output shape.
    State curr = UNKNOWN;
    const int64 x_i = x[i];  // i-th dimension of x.
    CHECK_GE(x_i, 0);
    const int64 y_i = y[i];  // i-th dimension of y.
    CHECK_GE(y_i, 0);
    int64 o_i;   // i-th dimension of the output.
    int64 bx_i;  // i-th broadcast for x.
    int64 by_i;  // i-th broadcast for y.
    // Invariant:
    //   o_i = x_i * bx_i = y_i * by_i
    if (x_i == y_i) {
      // No broadcast.
      o_i = x_i;
      bx_i = 1;
      by_i = 1;
      curr = SAME;
    } else if (x_i == 1) {
      // x broadcast to y on this dimension.
      o_i = y_i;
      bx_i = y_i;
      by_i = 1;
      grad_x_reduce_idx_.push_back(n - 1 - i);
      curr = X_ONE;
    } else if (y_i == 1) {
      // y broadcast to x on this dimension.
      o_i = x_i;
      bx_i = 1;
      by_i = x_i;
      grad_y_reduce_idx_.push_back(n - 1 - i);
      curr = Y_ONE;
    } else {
      valid_ = false;
      return;
    }
    output_.push_back(o_i);
    // Reshape/broadcast.
    // Invariant:
    //  result[i] == x_reshape[i] * x_bcast[i] == y_reshape_[i] * y_bcast_[i]
    if (curr == SAME && x_i == 1) {
      // Both side are 1s.
      grad_x_reduce_idx_.push_back(n - 1 - i);
      grad_y_reduce_idx_.push_back(n - 1 - i);
      continue;
    } else if (prev == curr) {
      // It is a run of the same cases (no broadcast, x broadcast to
      // y, y broadcast to x). We can reshape the input so that fewer
      // dimensions are involved in the intermediate computation.
      result_.back() *= o_i;
      x_reshape_.back() *= x_i;
      x_bcast_.back() *= bx_i;
      y_reshape_.back() *= y_i;
      y_bcast_.back() *= by_i;
    } else {
      result_.push_back(o_i);
      x_reshape_.push_back(x_i);
      x_bcast_.push_back(bx_i);
      y_reshape_.push_back(y_i);
      y_bcast_.push_back(by_i);
    }
    prev = curr;
  }

  if (result_.empty()) {
    // Can happen when both x and y are effectively scalar.
    result_.push_back(1);
    x_reshape_.push_back(1);
    x_bcast_.push_back(1);
    y_reshape_.push_back(1);
    y_bcast_.push_back(1);
  }

  // Reverse all vectors since x and y were reversed at very
  // beginning.
  Reverse(&x_reshape_);
  Reverse(&x_bcast_);
  Reverse(&y_reshape_);
  Reverse(&y_bcast_);
  Reverse(&result_);
  Reverse(&output_);
  Reverse(&grad_x_reduce_idx_);
  Reverse(&grad_y_reduce_idx_);
}

}  // end namespace tensorflow