aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/bcast.h
blob: 9f0233e415cc611bae18422d35c2f4d5fdd2338e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#ifndef TENSORFLOW_UTIL_BCAST_H_
#define TENSORFLOW_UTIL_BCAST_H_

#include <algorithm>
#include <vector>

#include "tensorflow/core/platform/port.h"

#include "tensorflow/core/platform/logging.h"
namespace tensorflow {

// BCast is a helper for broadcasting binary tensor operation.
// TensorFlow's broadcasting rule follows that of numpy (See
// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
//
// The rule has the following properties:
//
//   1. suffix matching: the rule starts with the right-most
//      dimension, and works towards the left-most dimension. Since
//      TensorFlow is row-major, the right-most dimension (the last
//      element in the shape of a tensor) is the inner-most, a.k.a.
//      the fastest changing, dimension.
//
//   2. Two dimensions are compatible for broadcasting if both are the
//      same or either is 1.
//
// BCast takes the shape of two tensors and computes a few vectors of
// int32 that are useful for the caller to reshape the tensors, apply
// the right broadcasts to them, compute the broadcasted operation,
// and possibly the gradients. In a nutshell, the caller is expected
// to compute the broadcasted operation as following:
//
//   BCast b(x.shape(), y.shape());
//   output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
//            _op_
//            y.reshape(b.y_reshape()).broadcast(b.y_bcast())
//
// For the gradient computation,
//   grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
//            .reshape(x.shape())
//   grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
//            .reshape(y.shape())
// backprop_x and backprop_y are functionals of the binary function "op",
// e.g.,
//   for +, backprop_x(x, y) = backprop_y(x, y) = 1;
//   for *, backprop_x(x, y) =  y, backprop_y(x, y) = x;
//   for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
//
// The multiplication in the grad * backprop_x itself is also
// broadcasting following the same rule.
//
// TODO(zhifengc): Adds support for n-ary (n >= 2).
class BCast {
 public:
  // A vector of int32 representing the shape of tensor. The 0-th
  // element is the outer-most dimension and the last element is the
  // inner-most dimension. Note that we do not use TensorShape since
  // it's more convenient to manipulate Vec directly for this module.
  typedef std::vector<int64> Vec;

  BCast(const Vec& x, const Vec& y);
  ~BCast() {}

  // Returns true iff two operands are compatible according to the
  // broadcasting rule.
  bool IsValid() const { return valid_; }

  // If and only if IsValid(), the following fields can be used in
  // implementing a broadcasted binary tensor operation according to
  // the broadcasting rule.
  const Vec& x_reshape() const { return x_reshape_; }
  const Vec& x_bcast() const { return x_bcast_; }
  const Vec& y_reshape() const { return y_reshape_; }
  const Vec& y_bcast() const { return y_bcast_; }
  const Vec& result_shape() const { return result_; }
  const Vec& output_shape() const { return output_; }
  const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; }
  const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; }

 private:
  bool valid_ = true;
  Vec x_reshape_;
  Vec x_bcast_;
  Vec y_reshape_;
  Vec y_bcast_;
  Vec result_;
  Vec output_;
  Vec grad_x_reduce_idx_;
  Vec grad_y_reduce_idx_;

  static void Reverse(Vec* shape);
  static bool HasZero(const Vec& shape);

  TF_DISALLOW_COPY_AND_ASSIGN(BCast);
};

}  // end namespace tensorflow

#endif  // TENSORFLOW_UTIL_BCAST_H_