aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-17 14:33:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-17 14:38:38 -0700
commit02f87fee25552e220c8295b58ab8e58b6fbe598b (patch)
tree964569d7ef4a8e369b36a7c7a1852015b67a2e93
parent4143410e1140a553621de5de09c1cad12a5eb4cb (diff)
CPU runtime: Improve the performance of matrix-vector and
vector-matrix products. This change makes the single threaded matrix-vector product explicit so that Eigen will always delegate to an optimized GEMV kernel. This is done by using an Eigen Matrix instead of the Eigen Tensor implementation. This is the same optimization done by TensorFlow's matmul op for GEMV. This is used even in the multi-threaded case because it appears to be faster than the multi-threaded version. This change also expands the scope of the CPU runtime test to test vec-mat and mat-vec on both single threaded and multi threaded modes. PiperOrigin-RevId: 165630063
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD15
-rw-r--r--tensorflow/compiler/xla/service/cpu/build_defs.bzl28
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc111
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matmul.cc19
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matvec.cc110
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matvec.h45
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc17
7 files changed, 292 insertions, 53 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 54d0e1cf01..85fb1bcd11 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -14,7 +14,7 @@ package_group(
],
)
-load(":build_defs.bzl", "runtime_copts")
+load(":build_defs.bzl", "runtime_copts", "runtime_logging_deps")
# Filegroup used to collect source files for dependency checking.
filegroup(
@@ -382,12 +382,23 @@ cc_library(
)
cc_library(
+ name = "runtime_matvec",
+ srcs = ["runtime_matvec.cc"],
+ hdrs = ["runtime_matvec.h"],
+ copts = runtime_copts(),
+ deps = [
+ "//third_party/eigen3",
+ ] + runtime_logging_deps(),
+)
+
+cc_library(
name = "runtime_matmul",
srcs = ["runtime_matmul.cc"],
hdrs = ["runtime_matmul.h"],
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
+ ":runtime_matvec",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
@@ -417,6 +428,7 @@ cc_library(
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
+ ":runtime_matvec",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
@@ -428,6 +440,7 @@ cc_test(
deps = [
":cpu_runtime",
":runtime_matmul",
+ ":runtime_single_threaded_matmul",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/xla/service/cpu/build_defs.bzl b/tensorflow/compiler/xla/service/cpu/build_defs.bzl
index b4b5219751..1440e2c9a8 100644
--- a/tensorflow/compiler/xla/service/cpu/build_defs.bzl
+++ b/tensorflow/compiler/xla/service/cpu/build_defs.bzl
@@ -1,11 +1,25 @@
"""build_defs for service/cpu."""
+
def runtime_copts():
"""Returns copts used for CPU runtime libraries."""
- return (["-DEIGEN_AVOID_STL_ARRAY"] +
- select({
- "//tensorflow:android_arm": ["-mfpu=neon"],
- "//conditions:default": []}) +
- select({
- "//tensorflow:android": ["-O2"],
- "//conditions:default": []}))
+ return (["-DEIGEN_AVOID_STL_ARRAY"] + select({
+ "//tensorflow:android_arm": ["-mfpu=neon"],
+ "//conditions:default": []
+ }) + select({
+ "//tensorflow:android": ["-O2"],
+ "//conditions:default": []
+ }))
+
+
+def runtime_logging_deps():
+ """Returns deps for building CPU runtime libraries with logging functions."""
+ return select({
+ "//tensorflow:android": [
+ # This dependency is smaller than :android_tensorflow_lib
+ "//tensorflow/core:android_tensorflow_lib_selective_registration",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:lib",
+ ],
+ })
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
index 52eed7dbad..f8e260dd90 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include <tuple>
#define EIGEN_USE_THREADS
@@ -25,8 +26,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -75,14 +78,8 @@ void CheckMatrixMultiply(const Array2D<float>& a, const Array2D<float>& b,
std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
const Array2D<float>& b,
bool transpose_lhs,
- bool transpose_rhs) {
- tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
- 2);
- tensorflow::EigenThreadPoolWrapper tp(&pool);
- Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
- ExecutableRunOptions run_options;
- run_options.set_intra_op_thread_pool(&device);
-
+ bool transpose_rhs,
+ bool single_threaded) {
CHECK_EQ(a.width(), b.height());
int64 m = a.height();
int64 n = b.width();
@@ -98,41 +95,81 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
// Since we're going to transpose c before returning it. Swap the order of the
// dimension sizes to ensure the returned array is properly dimensioned.
auto c_transpose = MakeUnique<Array2D<float>>(n, m);
- __xla_cpu_runtime_EigenMatMulF32(&run_options, c_transpose->data(),
- a_transpose->data(), b_transpose->data(), m,
- n, k, transpose_lhs, transpose_rhs);
+ if (single_threaded) {
+ __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
+ nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
+ m, n, k, transpose_lhs, transpose_rhs);
+ } else {
+ tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
+ 2);
+ tensorflow::EigenThreadPoolWrapper tp(&pool);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+ ExecutableRunOptions run_options;
+ run_options.set_intra_op_thread_pool(&device);
+
+ __xla_cpu_runtime_EigenMatMulF32(&run_options, c_transpose->data(),
+ a_transpose->data(), b_transpose->data(),
+ m, n, k, transpose_lhs, transpose_rhs);
+ }
return MaybeTransposeArray2D(*c_transpose, true);
}
-TEST_F(CpuRuntimeTest, SmallEigenMatmul) {
- Array2D<float> a({{1.0f, 2.0f}, {3.0f, 4.0f}});
- Array2D<float> b({{5.0f, -1.0f, 3.0f}, {2.0f, 6.0f, 4.0f}});
-
- for (bool transpose_lhs : {false, true}) {
- for (bool transpose_rhs : {false, true}) {
- auto c = EigenMatrixMultiply(a, b, transpose_lhs, transpose_rhs);
-
- LOG(INFO) << "a = " << a.ToString();
- LOG(INFO) << "b = " << b.ToString();
- LOG(INFO) << "c = " << c->ToString();
-
- CheckMatrixMultiply(a, b, *c);
- }
+struct MatMulShape {
+ int64 m;
+ int64 k;
+ int64 n;
+};
+
+MatMulShape MatMulShapes[] = {
+ MatMulShape{2, 2, 3}, MatMulShape{256, 512, 1024},
+ MatMulShape{128, 128, 1}, MatMulShape{1, 128, 128},
+ MatMulShape{1, 32, 128}, MatMulShape{1, 32, 16},
+ MatMulShape{32, 16, 1}, MatMulShape{32, 128, 1},
+};
+
+// This takes 4 parameters:
+// * shape of the matmul
+// * transpose_lhs
+// * transpose_rhs
+// * single_threaded
+using EigenMatMulTestParam = std::tuple<MatMulShape, bool, bool, bool>;
+
+class EigenMatMulTest
+ : public CpuRuntimeTest,
+ public ::testing::WithParamInterface<EigenMatMulTestParam> {
+ public:
+ static string Name(
+ const ::testing::TestParamInfo<EigenMatMulTestParam>& info) {
+ MatMulShape shape = std::get<0>(info.param);
+ bool transpose_lhs = std::get<1>(info.param);
+ bool transpose_rhs = std::get<2>(info.param);
+ bool single_threaded = std::get<3>(info.param);
+
+ return tensorflow::strings::Printf(
+ "MatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
+ transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
+ single_threaded ? "single" : "multi");
}
+}; // namespace xla
+
+TEST_P(EigenMatMulTest, DoIt) {
+ MatMulShape shape = std::get<0>(GetParam());
+ bool transpose_lhs = std::get<1>(GetParam());
+ bool transpose_rhs = std::get<2>(GetParam());
+ bool single_threaded = std::get<3>(GetParam());
+
+ auto a = MakeLinspaceArray2D(0.0, 1.0, shape.m, shape.k);
+ auto b = MakeLinspaceArray2D(-2.0, 2.0, shape.k, shape.n);
+ auto c = EigenMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs,
+ single_threaded);
+ CheckMatrixMultiply(*a, *b, *c);
}
-TEST_F(CpuRuntimeTest, LargeEigenMatmul) {
- auto a = MakeLinspaceArray2D(0.0, 1.0, 256, 512);
- auto b = MakeLinspaceArray2D(-2.0, 2.0, 512, 1024);
-
- for (bool transpose_lhs : {false, true}) {
- for (bool transpose_rhs : {false, true}) {
- auto c = EigenMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs);
-
- CheckMatrixMultiply(*a, *b, *c);
- }
- }
-}
+INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest,
+ ::testing::Combine(::testing::ValuesIn(MatMulShapes),
+ ::testing::Bool(), ::testing::Bool(),
+ ::testing::Bool()),
+ EigenMatMulTest::Name);
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc
index ee772f5c39..bff57d33ae 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::int32;
@@ -68,14 +69,24 @@ void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out,
float* lhs, float* rhs, int64 m, int64 n,
int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
- MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
- transpose_rhs);
+ if (m == 1 || n == 1) {
+ // Despite being single threaded, this version of matrix * vector is faster.
+ xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+ } else {
+ MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
+ transpose_rhs);
+ }
}
void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out,
double* lhs, double* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs,
int32 transpose_rhs) {
- MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
- transpose_rhs);
+ if (m == 1 || n == 1) {
+ // Despite being single threaded, this version of matrix * vector is faster.
+ xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+ } else {
+ MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
+ transpose_rhs);
+ }
}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc b/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc
new file mode 100644
index 0000000000..a26553bd39
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc
@@ -0,0 +1,110 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::int32;
+using tensorflow::int64;
+
+namespace {
+
+// Does mat * x or mat^T * x.
+template <typename T>
+void MatVec(T* out_buf, T* mat_buf, T* x_buf, int64 rows, int64 cols,
+ int32 transpose) {
+ // Use an Eigen Matrix instead of a Tensor, as the GEMV from Matrix seems to
+ // be faster (b/30223679). See also: the matmul op kernel in TensorFlow,
+ // which implements the same optimization.
+ using Matrix = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
+ using MatrixMap = Eigen::Map<Matrix>;
+
+ using Vector = Eigen::Matrix<T, Eigen::Dynamic, 1>;
+ using VectorMap = Eigen::Map<Vector>;
+
+ auto x = VectorMap(x_buf, cols);
+ auto out = VectorMap(out_buf, rows);
+
+ int64 mat_rows = rows;
+ int64 mat_cols = cols;
+
+ if (transpose) {
+ std::swap(mat_rows, mat_cols);
+ }
+
+ auto mat = MatrixMap(mat_buf, mat_rows, mat_cols);
+
+ if (transpose) {
+ out = mat.transpose() * x;
+ } else {
+ out = mat * x;
+ }
+}
+
+// Converts matmul-style args to matvec.
+template <typename T>
+void DispatchMatVec(T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k,
+ int32 transpose_lhs, int32 transpose_rhs) {
+ // If the input is in the form x * A, where x is the vector, then bring A back
+ // over to the left hand side. We make use of the identity
+ //
+ // (x * A)^T = A^T * x^T
+ //
+ // We do not need to take the transpose of x or of the result since taking
+ // the transpose of a vector does not change the memory layout.
+ const int64 cols = k;
+
+ T* mat;
+ T* vec;
+ int64 rows;
+ bool transpose_mat;
+
+ bool is_mat_vec = (n == 1);
+
+ if (is_mat_vec) {
+ mat = lhs;
+ vec = rhs;
+ rows = m;
+ transpose_mat = transpose_lhs;
+ } else {
+ mat = rhs;
+ vec = lhs;
+ rows = n;
+ transpose_mat = !transpose_rhs;
+ }
+
+ MatVec<T>(out, mat, vec, rows, cols, transpose_mat);
+}
+
+} // namespace
+
+namespace xla {
+
+void EigenMatVecF32(float* out, float* lhs, float* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs, int32 transpose_rhs) {
+ DCHECK(m == 1 || n == 1) << "not a matrix-vector multiply";
+ DispatchMatVec<float>(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+}
+
+void EigenMatVecF64(double* out, double* lhs, double* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs, int32 transpose_rhs) {
+ DCHECK(m == 1 || n == 1) << "not a matrix-vector multiply";
+ DispatchMatVec<double>(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h
new file mode 100644
index 0000000000..cb7e0a81f0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Performs a matrix-vector multiplication using Eigen. 'lhs' and 'rhs' are
+// pointers to buffers containing input matrices in column-major order. 'out' is
+// a pointer to a buffer sufficiently large to hold the result of the
+// operation. Following standard nomenclature: lhs is m x k, rhs is k x n, and
+// out is m x n.
+//
+// This requires that m = 1 or n = 1.
+//
+// TODO(b/64684907): Compare runtime performance of these functions with dot
+// simplification.
+void EigenMatVecF32(float* out, float* lhs, float* rhs, tensorflow::int64 m,
+ tensorflow::int64 n, tensorflow::int64 k,
+ tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+
+void EigenMatVecF64(double* out, double* lhs, double* rhs, tensorflow::int64 m,
+ tensorflow::int64 n, tensorflow::int64 k,
+ tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc
index 6f1c97a233..ee8eb08155 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::int32;
@@ -61,13 +62,21 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
void __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
- MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
- transpose_rhs);
+ if (m == 1 || n == 1) {
+ xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+ } else {
+ MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
+ transpose_rhs);
+ }
}
void __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
- MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
- transpose_rhs);
+ if (m == 1 || n == 1) {
+ xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+ } else {
+ MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
+ transpose_rhs);
+ }
}