aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/bcast.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/bcast.h')
-rw-r--r--tensorflow/core/util/bcast.h99
1 files changed, 99 insertions, 0 deletions
diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h
new file mode 100644
index 0000000000..9f0233e415
--- /dev/null
+++ b/tensorflow/core/util/bcast.h
@@ -0,0 +1,99 @@
+#ifndef TENSORFLOW_UTIL_BCAST_H_
+#define TENSORFLOW_UTIL_BCAST_H_
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+// BCast is a helper for broadcasting binary tensor operation.
+// TensorFlow's broadcasting rule follows that of numpy (See
+// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
+//
+// The rule has the following properties:
+//
+// 1. suffix matching: the rule starts with the right-most
+// dimension, and works towards the left-most dimension. Since
+// TensorFlow is row-major, the right-most dimension (the last
+// element in the shape of a tensor) is the inner-most, a.k.a.
+// the fastest changing, dimension.
+//
+// 2. Two dimensions are compatible for broadcasting if both are the
+// same or either is 1.
+//
+// BCast takes the shape of two tensors and computes a few vectors of
+// int32 that are useful for the caller to reshape the tensors, apply
+// the right broadcasts to them, compute the broadcasted operation,
+// and possibly the gradients. In a nutshell, the caller is expected
+// to compute the broadcasted operation as following:
+//
+// BCast b(x.shape(), y.shape());
+// output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
+// _op_
+// y.reshape(b.y_reshape()).broadcast(b.y_bcast())
+//
+// For the gradient computation,
+// grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
+// .reshape(x.shape())
+// grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
+// .reshape(y.shape())
+// backprop_x and backprop_y are functionals of the binary function "op",
+// e.g.,
+// for +, backprop_x(x, y) = backprop_y(x, y) = 1;
+// for *, backprop_x(x, y) = y, backprop_y(x, y) = x;
+// for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
+//
+// The multiplication in the grad * backprop_x itself is also
+// broadcasting following the same rule.
+//
+// TODO(zhifengc): Adds support for n-ary (n >= 2).
+class BCast {
+ public:
+ // A vector of int32 representing the shape of tensor. The 0-th
+ // element is the outer-most dimension and the last element is the
+ // inner-most dimension. Note that we do not use TensorShape since
+ // it's more convenient to manipulate Vec directly for this module.
+ typedef std::vector<int64> Vec;
+
+ BCast(const Vec& x, const Vec& y);
+ ~BCast() {}
+
+ // Returns true iff two operands are compatible according to the
+ // broadcasting rule.
+ bool IsValid() const { return valid_; }
+
+ // If and only if IsValid(), the following fields can be used in
+ // implementing a broadcasted binary tensor operation according to
+ // the broadcasting rule.
+ const Vec& x_reshape() const { return x_reshape_; }
+ const Vec& x_bcast() const { return x_bcast_; }
+ const Vec& y_reshape() const { return y_reshape_; }
+ const Vec& y_bcast() const { return y_bcast_; }
+ const Vec& result_shape() const { return result_; }
+ const Vec& output_shape() const { return output_; }
+ const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; }
+ const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; }
+
+ private:
+ bool valid_ = true;
+ Vec x_reshape_;
+ Vec x_bcast_;
+ Vec y_reshape_;
+ Vec y_bcast_;
+ Vec result_;
+ Vec output_;
+ Vec grad_x_reduce_idx_;
+ Vec grad_y_reduce_idx_;
+
+ static void Reverse(Vec* shape);
+ static bool HasZero(const Vec& shape);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BCast);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_BCAST_H_