aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-13 04:28:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-13 04:32:08 -0700
commit30bea6a1eb7cfc68fa926a96a48f22d8fabb350f (patch)
tree5f400e550007a0802eb39517d7a52879fc861251
parent6ac3efd42902d48d45d59128926110e6d5121a08 (diff)
Minor modernizations; use unique_ptr interfaces, simplify template use, use pointer rather than reference arguments, etc.
PiperOrigin-RevId: 158830894
-rw-r--r--tensorflow/core/kernels/cwise_ops_test.cc72
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.cc100
-rw-r--r--tensorflow/core/kernels/stage_op.cc61
3 files changed, 104 insertions, 129 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc
index 92018ec871..bca0f1004d 100644
--- a/tensorflow/core/kernels/cwise_ops_test.cc
+++ b/tensorflow/core/kernels/cwise_ops_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
+namespace {
// Creates a Graph which applies a unary "func" on a 3D tensor of
// type T with "num" elements.
@@ -35,14 +36,14 @@ static Graph* Unary(const string& func, int num, DataType dtype) {
return g;
}
-static int kRows = 100000;
+const int kRows = 100000;
-static int RowsAndColsArg(int r, int c) { return r * kRows + c; }
-static int RowsFromArg(int arg) { return (arg / kRows); }
-static int ColsFromArg(int arg) { return (arg % kRows); }
+int RowsAndColsArg(int r, int c) { return r * kRows + c; }
+int RowsFromArg(int arg) { return (arg / kRows); }
+int ColsFromArg(int arg) { return (arg % kRows); }
#define BM_UNARY(DEVICE, FUNC, T, TYPE) \
- static void BM_##DEVICE##_##FUNC##_##TYPE(int iters, int num) { \
+ void BM_##DEVICE##_##FUNC##_##TYPE(int iters, int num) { \
const int64 tot = static_cast<int64>(iters) * num; \
testing::ItemsProcessed(tot); \
testing::BytesProcessed(tot * sizeof(T)); \
@@ -85,7 +86,7 @@ BM_UNARY(gpu, Rint, float, DT_FLOAT);
#endif // GOOGLE_CUDA
// data func scalar.
-static Graph* BinaryScalar(int num, const string& func) {
+Graph* BinaryScalar(int num, const string& func) {
Graph* g = new Graph(OpRegistry::Global());
Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
lhs.flat<float>().setRandom();
@@ -97,7 +98,7 @@ static Graph* BinaryScalar(int num, const string& func) {
}
#define BM_BINARY_SCALAR(DEVICE, FUNC) \
- static void BM_##DEVICE##_##FUNC##_scalar(int iters, int num) { \
+ void BM_##DEVICE##_##FUNC##_scalar(int iters, int num) { \
const int64 tot = static_cast<int64>(iters) * num; \
testing::ItemsProcessed(tot); \
testing::BytesProcessed(tot * sizeof(float)); \
@@ -127,7 +128,7 @@ BM_BINARY_SCALAR(sycl, Add);
#undef BM_BINARY_SCALAR
template <class T>
-static Graph* BiasAdd(int rows, int cols, DataType type) {
+Graph* BiasAdd(int rows, int cols, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
Tensor lhs(type, TensorShape({rows, cols}));
lhs.template flat<T>().setRandom();
@@ -141,8 +142,7 @@ static Graph* BiasAdd(int rows, int cols, DataType type) {
}
#define BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, R, C) \
- static void BM_##DEVICE##_##C_TYPE##_BiasAdd_R##R##_C##C(int iters, \
- int arg) { \
+ void BM_##DEVICE##_##C_TYPE##_BiasAdd_R##R##_C##C(int iters, int arg) { \
const int rows = RowsFromArg(arg); \
const int cols = ColsFromArg(arg); \
const int64 tot = static_cast<int64>(iters) * rows * cols; \
@@ -172,8 +172,8 @@ BM_BIAS_ADD_ALL(gpu, half, DT_HALF);
#undef BM_BIAS_ADD
template <class T>
-static Graph* BiasAddGrad(int rows, int cols, int channels, DataType type,
- TensorFormat format) {
+Graph* BiasAddGrad(int rows, int cols, int channels, DataType type,
+ TensorFormat format) {
Graph* g = new Graph(OpRegistry::Global());
TensorShape lhs_shape;
if (format == FORMAT_NCHW) {
@@ -186,15 +186,14 @@ static Graph* BiasAddGrad(int rows, int cols, int channels, DataType type,
Node* n;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAddGrad")
.Attr("data_format", ToString(format))
- .Input(test::graph::Constant(g, lhs), /*index=*/0)
+ .Input(test::graph::Constant(g, lhs), /*src_index=*/0)
.Finalize(g, &n));
return g;
}
#define BM_BIAS_ADD_GRAD(DEVICE, FMT, C_TYPE, TF_TYPE, R, C, CH) \
- static void \
- BM_##DEVICE##_##FMT##_##C_TYPE##_BiasAddGrad_R##R##_C##C##_CH##CH( \
- int iters, int arg, int channels) { \
+ void BM_##DEVICE##_##FMT##_##C_TYPE##_BiasAddGrad_R##R##_C##C##_CH##CH( \
+ int iters, int arg, int channels) { \
const int rows = RowsFromArg(arg); \
const int cols = ColsFromArg(arg); \
const int64 tot = static_cast<int64>(iters) * rows * cols * channels; \
@@ -230,7 +229,7 @@ BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, half, DT_HALF);
#undef BM_BIAS_ADD_GRAD_ALL
#undef BM_BIAS_ADD_GRAD
-static Graph* BcastAdd(int rows, int cols, int dim) {
+Graph* BcastAdd(int rows, int cols, int dim) {
Graph* g = new Graph(OpRegistry::Global());
Tensor lhs(DT_FLOAT, TensorShape({rows, cols}));
lhs.flat<float>().setRandom();
@@ -247,15 +246,15 @@ static Graph* BcastAdd(int rows, int cols, int dim) {
return g;
}
-#define BM_BCAST_ADD_ROW(DEVICE, R, C) \
- static void BM_##DEVICE##_BcastAddRow_R##R##_C##C(int iters, int arg) { \
- const int rows = RowsFromArg(arg); \
- const int cols = ColsFromArg(arg); \
- const int64 tot = static_cast<int64>(iters) * rows * cols; \
- testing::ItemsProcessed(tot); \
- testing::BytesProcessed(tot * sizeof(float)); \
- test::Benchmark(#DEVICE, BcastAdd(rows, cols, 0)).Run(iters); \
- } \
+#define BM_BCAST_ADD_ROW(DEVICE, R, C) \
+ void BM_##DEVICE##_BcastAddRow_R##R##_C##C(int iters, int arg) { \
+ const int rows = RowsFromArg(arg); \
+ const int cols = ColsFromArg(arg); \
+ const int64 tot = static_cast<int64>(iters) * rows * cols; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(float)); \
+ test::Benchmark(#DEVICE, BcastAdd(rows, cols, 0)).Run(iters); \
+ } \
BENCHMARK(BM_##DEVICE##_BcastAddRow_R##R##_C##C)->Arg(RowsAndColsArg(R, C));
#define BM_BCAST_ADD_ROW_ALL(DEVICE) \
@@ -273,15 +272,15 @@ BM_BCAST_ADD_ROW_ALL(sycl);
#undef BM_BCAST_ADD_ROW_ALL
#undef BM_BCAST_ADD_ROW
-#define BM_BCAST_ADD_COL(DEVICE, R, C) \
- static void BM_##DEVICE##_BcastAddCol_R##R##_C##C(int iters, int arg) { \
- const int rows = RowsFromArg(arg); \
- const int cols = ColsFromArg(arg); \
- const int64 tot = static_cast<int64>(iters) * rows * cols; \
- testing::ItemsProcessed(tot); \
- testing::BytesProcessed(tot * sizeof(float)); \
- test::Benchmark(#DEVICE, BcastAdd(rows, cols, 1)).Run(iters); \
- } \
+#define BM_BCAST_ADD_COL(DEVICE, R, C) \
+ void BM_##DEVICE##_BcastAddCol_R##R##_C##C(int iters, int arg) { \
+ const int rows = RowsFromArg(arg); \
+ const int cols = ColsFromArg(arg); \
+ const int64 tot = static_cast<int64>(iters) * rows * cols; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(float)); \
+ test::Benchmark(#DEVICE, BcastAdd(rows, cols, 1)).Run(iters); \
+ } \
BENCHMARK(BM_##DEVICE##_BcastAddCol_R##R##_C##C)->Arg(RowsAndColsArg(R, C));
#define BM_BCAST_ADD_COL_ALL(DEVICE) \
@@ -299,4 +298,5 @@ BM_BCAST_ADD_COL_ALL(sycl);
#undef BM_BCAST_ADD_COL_ALL
#undef BM_BCAST_ADD_COL
-} // end namespace tensorflow
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc
index d109543494..0bbb52bc32 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op.cc
@@ -19,7 +19,10 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse_matmul_op.h"
+#include <map>
+#include <memory>
#include <vector>
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bfloat16.h"
@@ -42,26 +45,27 @@ limitations under the License.
#endif
namespace tensorflow {
-
namespace {
using Eigen::operator==;
-typedef Eigen::Tensor<float, 2, Eigen::RowMajor> Matrix;
-typedef Eigen::DSizes<Eigen::DenseIndex, 2> DSizes;
-typedef Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- ConstMatrixMap;
-typedef Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- MatrixMap;
-typedef Eigen::ThreadPoolDevice CPUDevice;
+template <typename T>
+using BasicMatrix = Eigen::Tensor<T, 2, Eigen::RowMajor>;
+
+template <typename T>
+using BasicMatrixMap =
+ Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>;
+
+using Matrix = BasicMatrix<float>;
+using MatrixMap = BasicMatrixMap<float>;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using DSizes = Eigen::DSizes<Eigen::DenseIndex, 2>;
// Blocksizes
// TODO(agarwal): compute these sizes based on cache sizes.
-static const int K = 64;
-static const int M = 64;
-static const int N = 128;
+const int K = 64;
+const int M = 64;
+const int N = 128;
// This stores a sparse representation of a slice of a matrix with size
// (num_rows, num_cols). The slice is represented as a series of blocks of size
@@ -86,9 +90,7 @@ static const int N = 128;
// index_offset.
template <typename T>
struct SparseSlice {
- typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- ConstMatrixMap;
+ using ConstMatrixMap = BasicMatrixMap<const T>;
public:
// Indices of three elements on the same row.
@@ -244,8 +246,8 @@ void SparseSlice<T>::Clear() {
data.clear();
}
-typedef Eigen::internal::packet_traits<float>::type Packet;
-static const int kNumOperands = (sizeof(Packet) / sizeof(float));
+using Packet = Eigen::internal::packet_traits<float>::type;
+const int kNumOperands = (sizeof(Packet) / sizeof(float));
#define LOAD(x) Eigen::internal::pload<Packet>(x);
#define EXPAND_BFLOAT_L(x, y) \
const auto y = Eigen::internal::pexpand_bf16_l<Packet>(x);
@@ -752,17 +754,11 @@ inline void GEPP(
template <typename TL, typename TR>
class SparseMatMul {
- typedef Eigen::Tensor<TL, 2, Eigen::RowMajor> MatrixL;
- typedef Eigen::Tensor<TR, 2, Eigen::RowMajor> MatrixR;
- typedef Eigen::TensorMap<Eigen::Tensor<const TL, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- ConstMatrixMapL;
- typedef Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- ConstMatrixMapR;
- typedef Eigen::TensorMap<Eigen::Tensor<TR, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- MatrixMapR;
+ using MatrixL = BasicMatrix<TL>;
+ using MatrixR = BasicMatrix<TR>;
+ using ConstMatrixMapL = BasicMatrixMap<const TL>;
+ using ConstMatrixMapR = BasicMatrixMap<const TR>;
+ using MatrixMapR = BasicMatrixMap<TR>;
public:
// Not used; added to match interface of LibxsmmSparseMatMul
@@ -792,7 +788,7 @@ class SparseMatMul {
// "slice_num_cols", each grid element is converted into a SparseSlice and
// stored in mat_slices. "slice_block_size" is used to perform further column
// blocking of each slice.
- static inline BlockingCounter* CreateSparseSlices(
+ static inline std::unique_ptr<BlockingCounter> CreateSparseSlices(
const ConstMatrixMapL& mat, bool transpose, int slice_num_rows,
int slice_block_size, int slice_num_cols,
std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
@@ -802,7 +798,7 @@ class SparseMatMul {
// columns, and concatenates the pieces one after the other in "buffer". It
// returns the list of the pieces in "slices". It returns a BlockingCounter
// which should be used to wait for the shuffle operations to complete.
- static inline BlockingCounter* CreateDenseSlices(
+ static inline std::unique_ptr<BlockingCounter> CreateDenseSlices(
const ConstMatrixMapR& mat, int row_start, int num_rows, int col_start,
int num_cols, const DeviceBase::CpuWorkerThreads* thread_pool,
MatrixR* buffer, std::vector<ConstMatrixMapR*>* slices);
@@ -839,17 +835,11 @@ class SparseMatMul {
#ifdef TENSORFLOW_USE_LIBXSMM
template <typename TL, typename TR>
class LibxsmmSparseMatMul {
- typedef Eigen::Tensor<TL, 2, Eigen::RowMajor> MatrixL;
- typedef Eigen::Tensor<TR, 2, Eigen::RowMajor> MatrixR;
- typedef Eigen::TensorMap<Eigen::Tensor<const TL, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- ConstMatrixMapL;
- typedef Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- ConstMatrixMapR;
- typedef Eigen::TensorMap<Eigen::Tensor<TR, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- MatrixMapR;
+ using MatrixL = BasicMatrix<TL>;
+ using MatrixR = BasicMatrix<TR>;
+ using ConstMatrixMapL = BasicMatrixMap<const TL>;
+ using ConstMatrixMapR = BasicMatrixMap<const TR>;
+ using MatrixMapR = BasicMatrixMap<TR>;
public:
// This structure contains a set of libxsmm kernels for sizes that have been
@@ -939,10 +929,8 @@ class LibxsmmSparseMatMul {
template <typename TL, typename TR,
template <typename TL2, typename TR2> class DoMatMul>
class SparseMatMulOp : public OpKernel {
- typedef Eigen::Tensor<TR, 2, Eigen::RowMajor> MatrixR;
- typedef Eigen::TensorMap<Eigen::Tensor<const TR, 2, Eigen::RowMajor>,
- Eigen::Aligned>
- ConstMatrixMapR;
+ using MatrixR = BasicMatrix<TR>;
+ using ConstMatrixMapR = BasicMatrixMap<const TR>;
public:
explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -1122,7 +1110,8 @@ inline void SparseMatMul<TL, TR>::ComputeOutputBlock(
}
template <typename TL, typename TR>
-inline BlockingCounter* SparseMatMul<TL, TR>::CreateSparseSlices(
+inline std::unique_ptr<BlockingCounter>
+SparseMatMul<TL, TR>::CreateSparseSlices(
const typename SparseMatMul<TL, TR>::ConstMatrixMapL& mat, bool transpose,
int slice_num_rows, int slice_block_size, int slice_num_cols,
std::vector<std::vector<SparseSlice<TL>*>>* mat_slices,
@@ -1170,7 +1159,7 @@ inline BlockingCounter* SparseMatMul<TL, TR>::CreateSparseSlices(
[=]() { work(sparse_slice, slice, slice_num_cols * j); });
}
}
- return counter;
+ return std::unique_ptr<BlockingCounter>(counter);
}
#define LOAD(x) Eigen::internal::ploadu<Packet>((x));
#define INTERLEAVE(x) Eigen::internal::pinterleave4x64<Packet>(x);
@@ -1286,13 +1275,13 @@ inline void SparseMatMul<TL, TR>::SliceMatrix(
}
template <typename TL, typename TR>
-inline BlockingCounter* SparseMatMul<TL, TR>::CreateDenseSlices(
+inline std::unique_ptr<BlockingCounter> SparseMatMul<TL, TR>::CreateDenseSlices(
const typename SparseMatMul<TL, TR>::ConstMatrixMapR& mat, int row_start,
int num_rows, int col_start, int num_cols,
const DeviceBase::CpuWorkerThreads* thread_pool, MatrixR* buffer,
std::vector<typename SparseMatMul<TL, TR>::ConstMatrixMapR*>* slices) {
- BlockingCounter* shuffle_counter = ShuffleMatrix(
- mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer);
+ std::unique_ptr<BlockingCounter> shuffle_counter(ShuffleMatrix(
+ mat, row_start, num_rows, col_start, num_cols, N, thread_pool, buffer));
const int num_slices = (num_cols + N - 1) / N;
SliceMatrix(*buffer, num_rows, num_slices, slices);
return shuffle_counter;
@@ -1554,10 +1543,9 @@ inline void SparseMatMul<TL, TR>::Compute(
&JB, &IB);
// Slice the left matrix
std::vector<std::vector<SparseSlice<TL>*>> left_slices;
- std::unique_ptr<BlockingCounter> sparse_slice_counter;
- sparse_slice_counter.reset(
+ std::unique_ptr<BlockingCounter> sparse_slice_counter =
CreateSparseSlices(ConstMatrixMapL(left.data(), left.dimensions()),
- transpose_left, M, K, KL, &left_slices, thread_pool));
+ transpose_left, M, K, KL, &left_slices, thread_pool);
const int num_left_slices = left_slices.size();
const int right_dim0 = right.dimension(0);
@@ -1583,9 +1571,9 @@ inline void SparseMatMul<TL, TR>::Compute(
for (int kb = 0; kb < num_k_blocks; ++kb) {
const int right_num_rows =
std::min(KR, static_cast<int>(right_dim0 - KR * kb));
- dense_slice_counter.reset(CreateDenseSlices(
+ dense_slice_counter = CreateDenseSlices(
right, kb * KR, right_num_rows, nb * NR, right_num_cols, thread_pool,
- &buffer, &right_slices));
+ &buffer, &right_slices);
const int num_right_slices = right_slices.size();
tasks.reserve(num_left_slices * num_right_slices);
for (int j_outer = 0; j_outer < num_right_slices; j_outer += JB) {
diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc
index 49352ff4d1..1717428adf 100644
--- a/tensorflow/core/kernels/stage_op.cc
+++ b/tensorflow/core/kernels/stage_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstddef>
#include <deque>
#include <mutex>
#include <numeric>
@@ -27,13 +28,12 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
-
namespace {
class Buffer : public ResourceBase {
public:
// public types
- typedef std::vector<Tensor> Tuple;
+ using Tuple = std::vector<Tensor>;
private:
// private variables
@@ -45,32 +45,25 @@ class Buffer : public ResourceBase {
std::condition_variable full_cond_var_;
std::deque<Tuple> buf_;
-
private:
// private methods
// If the buffer is configured for bounded capacity, notify
// waiting inserters that space is now available
- void notify_inserters_if_bounded(std::unique_lock<std::mutex> & l)
- {
- if(IsBounded())
- {
- l.unlock();
+ void notify_inserters_if_bounded(std::unique_lock<std::mutex>* lock) {
+ if (IsBounded()) {
+ lock->unlock();
full_cond_var_.notify_one();
}
}
// Are there a limit number of elements or a memory limit
// configued on this buffer?
- bool IsBounded() {
- return capacity_ > 0 || memory_limit_ > 0;
- }
+ bool IsBounded() const { return capacity_ > 0 || memory_limit_ > 0; }
- bool IsCapacityFull() {
- return buf_.size() >= capacity_;
- }
+ bool IsCapacityFull() const { return buf_.size() >= capacity_; }
- bool WouldExceedMemoryLimit(std::size_t bytes) {
+ bool WouldExceedMemoryLimit(std::size_t bytes) const {
return bytes + current_bytes_ > memory_limit_;
}
@@ -84,14 +77,12 @@ class Buffer : public ResourceBase {
public:
// public methods
- explicit Buffer(std::size_t capacity, std::size_t memory_limit) :
- capacity_(capacity),
- memory_limit_(memory_limit),
- current_bytes_(0) {}
+ explicit Buffer(std::size_t capacity, std::size_t memory_limit)
+ : capacity_(capacity), memory_limit_(memory_limit), current_bytes_(0) {}
// the Buffer takes ownership of the Tuple
Status Put(Tuple* tuple) {
- std::unique_lock<std::mutex> l(mu_);
+ std::unique_lock<std::mutex> lock(mu_);
std::size_t tuple_bytes = GetTupleBytes(*tuple);
@@ -105,7 +96,7 @@ class Buffer : public ResourceBase {
// If buffer capacity is bounded wait until elements have been removed
if(IsBounded()) {
- full_cond_var_.wait(l, [tuple_bytes, this]() {
+ full_cond_var_.wait(lock, [tuple_bytes, this]() {
// If there's a memory limit, check if there's space for insertion
bool memory_limit_valid = memory_limit_ > 0 ?
!WouldExceedMemoryLimit(tuple_bytes) : true;
@@ -123,7 +114,7 @@ class Buffer : public ResourceBase {
// Store tuple
buf_.push_back(std::move(*tuple));
- l.unlock();
+ lock.unlock();
// maybe possible to optimize by reducing
// how often this signal is sent
non_empty_cond_var_.notify_one();
@@ -133,12 +124,10 @@ class Buffer : public ResourceBase {
// Get tuple at front of the buffer
void Get(Tuple* tuple) { // TODO(zhifengc): Support cancellation.
- std::unique_lock<std::mutex> l(mu_);
+ std::unique_lock<std::mutex> lock(mu_);
// Wait for data if the buffer is empty
- non_empty_cond_var_.wait(l, [this]() {
- return !buf_.empty();
- });
+ non_empty_cond_var_.wait(lock, [this]() { return !buf_.empty(); });
// Move data into the output tuple
*tuple = std::move(buf_.front());
@@ -147,20 +136,19 @@ class Buffer : public ResourceBase {
// Update bytes in the Staging Area
current_bytes_ -= GetTupleBytes(*tuple);
- notify_inserters_if_bounded(l);
+ notify_inserters_if_bounded(&lock);
}
// Return tuple at index
Status Peek(std::size_t index, Tuple* tuple) {
- std::unique_lock<std::mutex> l(mu_);
+ std::unique_lock<std::mutex> lock(mu_);
// Wait if the requested index is not available
- non_empty_cond_var_.wait(l, [index, this]() {
- return index < this->buf_.size();
- });
+ non_empty_cond_var_.wait(
+ lock, [index, this]() { return index < this->buf_.size(); });
// Place tensors in the output tuple
- for(const auto & tensor: buf_[index]) {
+ for (const auto& tensor : buf_[index]) {
tuple->push_back(tensor);
}
@@ -169,23 +157,22 @@ class Buffer : public ResourceBase {
// Buffer size
size_t Size() {
- std::unique_lock<std::mutex> l(mu_);
+ std::unique_lock<std::mutex> lock(mu_);
return buf_.size();
}
void Clear() {
- std::unique_lock<std::mutex> l(mu_);
+ std::unique_lock<std::mutex> lock(mu_);
buf_.clear();
current_bytes_ = 0;
- notify_inserters_if_bounded(l);
+ notify_inserters_if_bounded(&lock);
}
string DebugString() override {
- std::unique_lock<std::mutex> l(mu_);
+ std::unique_lock<std::mutex> lock(mu_);
return strings::StrCat("Staging size: ", buf_.size());
}
-
};
Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) {