diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/util/bcast.h |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/util/bcast.h')
-rw-r--r-- | tensorflow/core/util/bcast.h | 99 |
1 files changed, 99 insertions, 0 deletions
diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h new file mode 100644 index 0000000000..9f0233e415 --- /dev/null +++ b/tensorflow/core/util/bcast.h @@ -0,0 +1,99 @@ +#ifndef TENSORFLOW_UTIL_BCAST_H_ +#define TENSORFLOW_UTIL_BCAST_H_ + +#include <algorithm> +#include <vector> + +#include "tensorflow/core/platform/port.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { + +// BCast is a helper for broadcasting binary tensor operation. +// TensorFlow's broadcasting rule follows that of numpy (See +// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). +// +// The rule has the following properties: +// +// 1. suffix matching: the rule starts with the right-most +// dimension, and works towards the left-most dimension. Since +// TensorFlow is row-major, the right-most dimension (the last +// element in the shape of a tensor) is the inner-most, a.k.a. +// the fastest changing, dimension. +// +// 2. Two dimensions are compatible for broadcasting if both are the +// same or either is 1. +// +// BCast takes the shape of two tensors and computes a few vectors of +// int32 that are useful for the caller to reshape the tensors, apply +// the right broadcasts to them, compute the broadcasted operation, +// and possibly the gradients. In a nutshell, the caller is expected +// to compute the broadcasted operation as following: +// +// BCast b(x.shape(), y.shape()); +// output = x.reshape(b.x_reshape()).broadcast(b.x_bcast()) +// _op_ +// y.reshape(b.y_reshape()).broadcast(b.y_bcast()) +// +// For the gradient computation, +// grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx) +// .reshape(x.shape()) +// grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx) +// .reshape(y.shape()) +// backprop_x and backprop_y are functionals of the binary function "op", +// e.g., +// for +, backprop_x(x, y) = backprop_y(x, y) = 1; +// for *, backprop_x(x, y) = y, backprop_y(x, y) = x; +// for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2; +// +// The multiplication in the grad * backprop_x itself is also +// broadcasting following the same rule. +// +// TODO(zhifengc): Adds support for n-ary (n >= 2). +class BCast { + public: + // A vector of int32 representing the shape of tensor. The 0-th + // element is the outer-most dimension and the last element is the + // inner-most dimension. Note that we do not use TensorShape since + // it's more convenient to manipulate Vec directly for this module. + typedef std::vector<int64> Vec; + + BCast(const Vec& x, const Vec& y); + ~BCast() {} + + // Returns true iff two operands are compatible according to the + // broadcasting rule. + bool IsValid() const { return valid_; } + + // If and only if IsValid(), the following fields can be used in + // implementing a broadcasted binary tensor operation according to + // the broadcasting rule. + const Vec& x_reshape() const { return x_reshape_; } + const Vec& x_bcast() const { return x_bcast_; } + const Vec& y_reshape() const { return y_reshape_; } + const Vec& y_bcast() const { return y_bcast_; } + const Vec& result_shape() const { return result_; } + const Vec& output_shape() const { return output_; } + const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; } + const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; } + + private: + bool valid_ = true; + Vec x_reshape_; + Vec x_bcast_; + Vec y_reshape_; + Vec y_bcast_; + Vec result_; + Vec output_; + Vec grad_x_reduce_idx_; + Vec grad_y_reduce_idx_; + + static void Reverse(Vec* shape); + static bool HasZero(const Vec& shape); + + TF_DISALLOW_COPY_AND_ASSIGN(BCast); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_BCAST_H_ |