aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xconfigure2
-rw-r--r--tensorflow/contrib/android/jni/tensorflow_inference_jni.cc2
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/special_math_test.py4
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake2
-rw-r--r--tensorflow/contrib/cmake/tf_python.cmake14
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py2
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py2
-rw-r--r--tensorflow/contrib/ios_examples/README.md22
-rw-r--r--tensorflow/contrib/layers/BUILD2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py2
-rw-r--r--tensorflow/core/framework/tensor.h9
-rw-r--r--tensorflow/core/kernels/BUILD32
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc4
-rw-r--r--tensorflow/core/kernels/conv_ops.cc12
-rw-r--r--tensorflow/core/kernels/linalg_ops_common.cc4
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.cc4
-rw-r--r--tensorflow/core/kernels/sparse_tensors_map_ops.cc2
-rw-r--r--tensorflow/core/kernels/stage_op.cc4
-rw-r--r--tensorflow/core/kernels/xsmm_conv2d.cc220
-rw-r--r--tensorflow/core/kernels/xsmm_conv2d.h5
-rw-r--r--tensorflow/core/kernels/xsmm_conv2d_test.cc328
-rw-r--r--tensorflow/core/lib/strings/proto_text_util.h2
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py4
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md20
-rw-r--r--tensorflow/g3doc/tutorials/deep_cnn/index.md6
-rw-r--r--tensorflow/python/ops/gradients_impl.py55
-rw-r--r--tensorflow/python/ops/math_ops_test.py2
-rw-r--r--tensorflow/python/ops/sparse_ops.py5
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py26
-rw-r--r--tensorflow/stream_executor/kernel.h4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu2
-rw-r--r--tensorflow/workspace.bzl4
-rw-r--r--third_party/libxsmm.BUILD3
35 files changed, 722 insertions, 99 deletions
diff --git a/configure b/configure
index c755ee1b75..a8e7bb7738 100755
--- a/configure
+++ b/configure
@@ -28,7 +28,7 @@ function is_macos() {
function is_windows() {
# On windows, the shell script is actually running in msys
- if [[ "${PLATFORM}" =~ msys_nt* ]]; then
+ if [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]]; then
true
else
false
diff --git a/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc b/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
index 844379232a..0a5d10e5c2 100644
--- a/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
+++ b/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
@@ -49,7 +49,7 @@ typedef std::map<std::string, std::pair<std::string, tensorflow::Tensor> >
struct SessionVariables {
std::unique_ptr<tensorflow::Session> session;
- long id = -1; // Copied from Java field for convenience.
+ int64 id = -1; // Copied from Java field for convenience.
int num_runs = 0;
int64 timing_total_us = 0;
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/special_math_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/special_math_test.py
index 615ef798dc..4c9c870894 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/special_math_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/special_math_test.py
@@ -211,8 +211,8 @@ class NdtrGradientTest(test.TestCase):
if self._use_log:
g = np.reshape(grad_eval, [-1])
half = np.ceil(len(g) / 2)
- self.assert_all_true(g[:half] > 0.)
- self.assert_all_true(g[half:] >= 0.)
+ self.assert_all_true(g[:int(half)] > 0.)
+ self.assert_all_true(g[int(half):] >= 0.)
else:
# The ndtr gradient will only be non-zero in the range [-14, 14] for
# float32 and [-38, 38] for float64.
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 911e52604e..45126cc071 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -41,7 +41,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc"
- "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest"
+ "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/best_splits_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc"
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 93c0d028ea..7717cf7b71 100644
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -649,6 +649,20 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_BINARY_DIR}/tensorboard_external
${CMAKE_CURRENT_BINARY_DIR}/tf_python/external)
+# Copy datasets for tf.contrib.learn.
+add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/datasets/data/boston_house_prices.csv
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/learn/python/learn/datasets/data/)
+add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/datasets/data/iris.csv
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/learn/python/learn/datasets/data/)
+add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/datasets/data/text_test.csv
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/learn/python/learn/datasets/data/)
+add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/datasets/data/text_train.csv
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/learn/python/learn/datasets/data/)
+
if(${tensorflow_ENABLE_GPU})
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tensorflow_gpu
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 66c13c435f..356ee14ef0 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -120,6 +120,9 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/training/*_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/*_test.py"
+ # NOTE: tensor_forest tests in tensor_forest/hybrid/... still don't pass.
+ "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/*_test.py"
+ "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/*_test.py"
)
# exclude the onces we don't want
@@ -163,6 +166,9 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Broken TensorBoard tests due to different paths in windows
"${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py"
"${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_test.py"
+ # tensor_forest tests (also note that we exclude the hybrid tests for now)
+ "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
+ "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
)
endif()
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
index ff2e575b06..e85b678ccf 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py
@@ -1301,7 +1301,7 @@ class AffineBijectorTest(test.TestCase):
def _matrix_diag(self, d):
"""Batch version of np.diag."""
orig_shape = d.shape
- d = np.reshape(d, (np.prod(d.shape[:-1]), d.shape[-1]))
+ d = np.reshape(d, (int(np.prod(d.shape[:-1])), d.shape[-1]))
diag_list = []
for i in range(d.shape[0]):
diag_list.append(np.diag(d[i, ...]))
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
index 7aa58b8021..f6d035a2c6 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
@@ -125,7 +125,7 @@ class NearestCentersLargeTest(test.TestCase):
# Tile points and expected results to reach requested size (num_points)
(self._points, self._expected_nearest_neighbor_indices,
self._expected_nearest_neighbor_squared_distances) = (
- np.tile(x, (num_points / points_per_tile, 1))
+ np.tile(x, (int(num_points / points_per_tile), 1))
for x in (points, expected_nearest_neighbor_indices,
expected_nearest_neighbor_squared_distances))
diff --git a/tensorflow/contrib/ios_examples/README.md b/tensorflow/contrib/ios_examples/README.md
index 00c13d9c7e..6bac33c0ec 100644
--- a/tensorflow/contrib/ios_examples/README.md
+++ b/tensorflow/contrib/ios_examples/README.md
@@ -31,27 +31,27 @@ cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/simple/data/
- You should see a single-screen app with a "Run Model" button. Tap that, and
you should see some debug output appear below indicating that the example
Grace Hopper image has been analyzed, with a military uniform recognized.
-
+
- Once you have success there, make sure you have a real device connected and
- open up the Xcode project in the camera subfolder. Once you build and run
+ open up the Xcode project in the `camera` subfolder. Once you build and run
that, you should get a live camera view that you can point at objects to get
real-time recognition results.
-
+
## Troubleshooting
If you're hitting problems, here's a checklist of common things to investigate:
- - Make sure that you've run the `build_all_ios.sh` script
+ - Make sure that you've run the `build_all_ios.sh` script.
This will run `download_dependencies.sh`,`compile_ios_protobuf.sh` and `compile_ios_tensorflow.sh`.
(check each one if they have run successful.)
-
+
- Check that you have version 7.3 of Xcode.
-
+
- If there's a complaint about no Sessions registered, that means that the C++
global constructors that TensorFlow relies on for registration haven't been
linked in properly. You'll have to make sure your project uses force_load, as
described below.
-
+
## Creating your Own App
You'll need to update various settings in your app to link against
@@ -62,11 +62,11 @@ rundown:
`tensorflow/contrib/makefile/gen/lib/libtensorflow-core.a`. You'll need to add
this to your linking build stage, and in Search Paths add
`tensorflow/contrib/makefile/gen/lib` to the Library Search Paths setting.
-
+
- You'll also need to add `libprotobuf.a` and `libprotobuf-lite.a` from
`tensorflow/contrib/makefile/gen/protobuf_ios/lib` to your _Build Stages_ and
_Library Search Paths_.
-
+
- The _Header Search_ paths needs to contain:
- the root folder of tensorflow,
- `tensorflow/contrib/makefile/downloads/protobuf/src`
@@ -83,10 +83,10 @@ rundown:
- You'll need to include the Accelerate framework in the "Link Binary with
Libraries" build phase of your project.
-
+
- C++11 support (or later) should be enabled by setting `C++ Language Dialect` to
`GNU++11` (or `GNU++14`), and `C++ Standard Library` to `libc++`.
-
+
- The library doesn't currently support bitcode, so you'll need to disable that
in your project settings.
diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD
index df1f0ac133..015b4afb4a 100644
--- a/tensorflow/contrib/layers/BUILD
+++ b/tensorflow/contrib/layers/BUILD
@@ -264,7 +264,7 @@ py_test(
py_test(
name = "feature_column_ops_test",
- size = "small",
+ size = "medium",
srcs = ["python/layers/feature_column_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index ffa2e17aec..58054ff96a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -407,9 +407,9 @@ class EstimatorTest(test.TestCase):
right_labels = lambda: np.ones(shape=[7, 10], dtype=np.int32)
est.fit(right_features(), right_labels(), steps=1)
# TODO(wicke): This does not fail for np.int32 because of data_feeder magic.
- wrong_type_features = np.ones(shape=[7., 8.], dtype=np.int64)
+ wrong_type_features = np.ones(shape=[7, 8], dtype=np.int64)
wrong_size_features = np.ones(shape=[7, 10])
- wrong_type_labels = np.ones(shape=[7., 10.], dtype=np.float32)
+ wrong_type_labels = np.ones(shape=[7, 10], dtype=np.float32)
wrong_size_labels = np.ones(shape=[7, 11])
est.fit(x=right_features(), y=right_labels(), steps=1)
with self.assertRaises(ValueError):
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 1533b60854..63bceae3ab 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -106,7 +106,7 @@ class RNNCellTest(test.TestCase):
[2., 2., 2., 2.],
[3., 3., 3., 3.]]),
m.name:
- 0.1 * np.ones((batch_size, state_size * (num_shifts)))
+ 0.1 * np.ones((batch_size, int(state_size * (num_shifts))))
})
self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is mostly just a
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 43e44e7a96..d71dfdab9c 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -99,11 +99,12 @@ class Tensor {
/// for details.
explicit Tensor(DataType type);
- Tensor(const Tensor& other); /// Copy constructor.
+ /// Copy constructor.
+ Tensor(const Tensor& other);
- // Move constructor. After this call, <other> is safely destructible and can
- // be assigned to, but other calls on it (e.g. shape manipulation) are not
- // valid.
+ /// \brief Move constructor. After this call, <other> is safely destructible and can
+ /// be assigned to, but other calls on it (e.g. shape manipulation) are not
+ /// valid.
Tensor(Tensor&& other);
~Tensor();
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 026d4aa7b7..b5a2d329d1 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -798,11 +798,21 @@ tf_cc_test(
tf_cc_test(
name = "xsmm_conv2d_test",
size = "small",
- srcs = ["xsmm_conv2d_test.cc"],
+ srcs = select({
+ ":xsmm": ["xsmm_conv2d_test.cc"],
+ "//conditions:default": [],
+ }),
deps = [
":conv_ops",
+ ":ops_testutil",
+ ":ops_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
],
)
@@ -2117,8 +2127,20 @@ tf_kernel_library(
tf_kernel_library(
name = "matmul_op",
+ defines = select({
+ ":xsmm": [
+ "TENSORFLOW_USE_LIBXSMM",
+ "EIGEN_USE_LIBXSMM",
+ ],
+ "//conditions:default": [],
+ }),
prefix = "matmul_op",
- deps = MATH_DEPS,
+ deps = MATH_DEPS + select({
+ ":xsmm": [
+ "@libxsmm_archive//:xsmm_avx",
+ ],
+ "//conditions:default": [],
+ }),
)
tf_kernel_library(
@@ -2367,7 +2389,10 @@ tf_kernel_library(
"//conditions:default": [],
}),
defines = select({
- ":xsmm": ["TENSORFLOW_USE_LIBXSMM"],
+ ":xsmm": [
+ "TENSORFLOW_USE_LIBXSMM",
+ "EIGEN_USE_LIBXSMM",
+ ],
"//conditions:default": [],
}) + select({
":xsmm_backward": ["TENSORFLOW_USE_LIBXSMM_BACKWARD"],
@@ -2387,7 +2412,6 @@ tf_kernel_library(
"//tensorflow/core:nn_ops_op_lib",
] + select({
":xsmm": [
- "@libxsmm_archive//:libxsmm_headers",
"@libxsmm_archive//:xsmm_avx",
],
"//conditions:default": [],
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 821bb9fe71..139fb605df 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -162,6 +162,8 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
desc.S = filter_cols;
desc.u = row_stride;
desc.v = col_stride;
+ desc.pad_h = 0;
+ desc.pad_w = 0;
desc.pad_h_in = 0; // pad_rows; // ignored by libxsmm for now.
desc.pad_w_in = 0; // pad_cols; // ignored by libxsmm for now.
desc.pad_h_out = 0;
@@ -169,7 +171,7 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
desc.threads = num_threads;
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
- desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_RSCK;
+ desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;//LIBXSMM_DNN_CONV_FORMAT_RSCK;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 91cd1c4b9a..22e48e84d8 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -196,19 +196,25 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.S = filter_cols;
desc.u = stride_rows;
desc.v = stride_cols;
- desc.pad_h_in = pad_rows; // ignored by libxsmm for now.
- desc.pad_w_in = pad_cols; // ignored by libxsmm for now.
+ desc.pad_h = pad_rows;
+ desc.pad_w = pad_cols;
+ desc.pad_h_in = pad_rows; // libxsmm supports only physical padding for now
+ desc.pad_w_in = pad_cols; // libxsmm supports only physical padding for now
desc.pad_h_out = 0;
desc.pad_w_out = 0;
desc.threads = num_threads;
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
- desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_RSCK;
+ desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
+ if (!CanUseXsmmConv2D(desc, data_format)) {
+ return false;
+ }
+
auto input_ptr = input.template flat<float>().data();
auto filter_ptr = filter.template flat<float>().data();
auto output_ptr = output->template flat<float>().data();
diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc
index 287e8901db..bb3fe2828e 100644
--- a/tensorflow/core/kernels/linalg_ops_common.cc
+++ b/tensorflow/core/kernels/linalg_ops_common.cc
@@ -202,7 +202,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
const TensorShapes& output_matrix_shapes) {
ConstMatrixMaps matrix_inputs;
- for (int i = 0; i < inputs.size(); ++i) {
+ for (size_t i = 0; i < inputs.size(); ++i) {
// TODO(kalakris): Handle alignment if possible. Eigen::Map is
// unaligned by default.
matrix_inputs.push_back(
@@ -213,7 +213,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
}
MatrixMaps matrix_outputs;
- for (int i = 0; i < output_matrix_shapes.size(); ++i) {
+ for (size_t i = 0; i < output_matrix_shapes.size(); ++i) {
// The output matrix shape may not be a matrix.
int num_output_rows = output_matrix_shapes[i].dims() >= 1
? output_matrix_shapes[i].dim_size(0)
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc
index 8a94508589..6a3f3dfc77 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_matmul_op.cc
@@ -1412,7 +1412,7 @@ class PinnedToCurrentCPU {
int ret = 0;
ret = sched_getaffinity(0, sizeof(cpu_set_t), &old_cpu_set);
if (ret != 0) {
- PLOG(WARNING) << "sched_getaffinity";
+ VLOG(WARNING) << "sched_getaffinity";
return;
}
valid = true;
@@ -1421,7 +1421,7 @@ class PinnedToCurrentCPU {
CPU_SET(sched_getcpu(), &new_cpu_set);
ret = sched_setaffinity(0, sizeof(cpu_set_t), &new_cpu_set);
if (ret != 0) {
- PLOG(WARNING) << "sched_setaffinity";
+ VLOG(WARNING) << "sched_setaffinity";
}
}
~PinnedToCurrentCPU() {
diff --git a/tensorflow/core/kernels/sparse_tensors_map_ops.cc b/tensorflow/core/kernels/sparse_tensors_map_ops.cc
index 5673ab4ee5..8101d7ca84 100644
--- a/tensorflow/core/kernels/sparse_tensors_map_ops.cc
+++ b/tensorflow/core/kernels/sparse_tensors_map_ops.cc
@@ -343,7 +343,7 @@ class TakeManySparseFromTensorsMapOp : public SparseTensorAccessingOp {
: SparseTensorAccessingOp(context) {}
void Compute(OpKernelContext* context) override {
- SparseTensorsMap* map;
+ SparseTensorsMap* map = nullptr;
OP_REQUIRES_OK(context, GetMap(context, false /* is_writing */, &map));
const Tensor& sparse_handles = context->input(0);
diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc
index 34db850013..c18b992ea1 100644
--- a/tensorflow/core/kernels/stage_op.cc
+++ b/tensorflow/core/kernels/stage_op.cc
@@ -113,10 +113,10 @@ class UnstageOp : public OpKernel {
Buffer::Tuple tuple;
buf->Get(&tuple);
OP_REQUIRES(
- ctx, tuple.size() == ctx->num_outputs(),
+ ctx, tuple.size() == (size_t)ctx->num_outputs(),
errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
" vs. ", ctx->num_outputs()));
- for (int i = 0; i < tuple.size(); ++i) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
ctx->set_output(i, tuple[i]);
}
}
diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc
index b1ebdf8410..0301ad49e7 100644
--- a/tensorflow/core/kernels/xsmm_conv2d.cc
+++ b/tensorflow/core/kernels/xsmm_conv2d.cc
@@ -32,11 +32,46 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(void);
#include "tensorflow/core/lib/core/threadpool.h"
#include "include/libxsmm_cpuid.h"
+#include "libxsmm_dnn_handle.h"
namespace tensorflow {
// Xsmm*Conv2D are wrappers for libxsmm direct convolutions.
+// Returns true if convolution can be computed efficiently by XsmmConv2D,
+// returns false otherwise.
+bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
+ TensorFormat data_format) {
+ int VECTOR_SIZE;
+ int arch = libxsmm_cpuid_x86();
+
+ if (arch == LIBXSMM_X86_AVX512_CORE) {
+ VECTOR_SIZE = 16;
+ } else if (arch == LIBXSMM_X86_AVX2) {
+ VECTOR_SIZE = 8;
+ } else {
+ VLOG(1) << "Cannot use XSMM convolutions: unsupported architecture!";
+ return false;
+ }
+
+ if (data_format != FORMAT_NHWC) {
+ VLOG(1) << "Cannot use XSMM convolutions: unsupported format!";
+ return false;
+ }
+ if (desc.pad_h_in != 0 || desc.pad_w_in != 0) {
+ VLOG(1) << "Cannot use XSMM convolutions: unsupported padding!";
+ return false;
+ }
+ if (desc.K % VECTOR_SIZE != 0) {
+ VLOG(1) << "Cannot use XSMM convolutions: output features count not"
+ " divisible by vector size!";
+ return false;
+ }
+ VLOG(2) << "Can use XSMM convolutions.";
+ return true;
+}
+
+
typedef Eigen::ThreadPoolDevice CPUDevice;
namespace functor {
@@ -47,29 +82,187 @@ static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
}
}
+LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float *kcrs, int R, int S, int C, int K,int blocksifm, int blocksofm, int ifmblock,int ofmblock, int start, int end)
+{
+ LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C,K);
+ LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm,R,S,ifmblock, ofmblock);
+ int r, s, k,c, v1,v2;
+
+ for (k = start; k < end ; k++ ) {
+ for(c = 0; c < blocksifm;c++){
+ for ( r = 0; r < R; r++ ) {
+ for ( s = 0; s < S; s++ ){
+ for ( v1 = c*ifmblock; v1 < std::min(C,(c+1)*ifmblock) ; v1++ ) {
+ for ( v2 = k*ofmblock; v2 < std::min(K, (k+1)*ofmblock); v2++ )
+ LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
+ for ( v2 = K; v2 < (k+1)*ofmblock ; v2++ )
+ LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = 0.0f;
+ }
+ for ( v1 = C; v1 < (c+1)*ifmblock ; v1++ ) {
+ for ( v2 = k*ofmblock; v2 < (k+1)*ofmblock; v2++ )
+ LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = 0.0f;
+ }
+ }
+ }
+ }
+ }
+}
+
+
+
+class libxsmm_dnn_conv_desc_wrap{
+ public:
+ const libxsmm_dnn_conv_desc d;
+
+ libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc &d_) : d(d_){
+ }
+ bool operator==(const libxsmm_dnn_conv_desc_wrap &w) const{
+ return( d.N == w.d.N &&
+ d.C == w.d.C &&
+ d.H == w.d.H &&
+ d.W == w.d.W &&
+ d.K == w.d.K &&
+ d.R == w.d.R &&
+ d.S == w.d.S &&
+ d.u == w.d.u &&
+ d.v == w.d.v &&
+ d.pad_h_in == w.d.pad_h_in &&
+ d.pad_w_in == w.d.pad_w_in
+ );
+ }
+};
+
+
+struct HashFunction{
+ std::size_t operator()(const libxsmm_dnn_conv_desc_wrap & w) const{
+ std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw;
+
+ N << w.d.N; C << w.d.C;
+ H << w.d.H; W << w.d.W;
+ K << w.d.K; R << w.d.R;
+ S << w.d.S; u << w.d.u;
+ v << w.d.v; padh << w.d.pad_h_in;
+ padw << w.d.pad_w_in;
+
+
+ std::string out_ = N.str() + C.str()\
+ + H.str() + W.str()\
+ + K.str() + R.str()\
+ + S.str() + u.str()\
+ + v.str() + padh.str()\
+ + padw.str();
+
+ return ( std::hash<std::string>()(out_));
+ }
+};
+
+class handles{
+ public:
+ libxsmm_dnn_conv_handle* find( const libxsmm_dnn_conv_desc_wrap &w) {
+ std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i = libxsmm_handles.find(w);
+ if (i == libxsmm_handles.end()){
+ libxsmm_dnn_err_t status;
+ libxsmm_dnn_conv_handle* libxsmm_handle = libxsmm_dnn_create_conv_handle_check(w.d, &status);
+ chk_libxsmm_err(status, "Create handle");
+ libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
+ return libxsmm_handle;
+ }
+ else
+ return i->second;
+ }
+ ~handles(){
+ std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction>::iterator i;
+ for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
+ chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(i->second),
+ "Destroy handle");
+ }
+ private:
+
+ std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_conv_handle*, HashFunction> libxsmm_handles;
+
+};
+
+static handles libxsmm_handles;
+
template <typename InputPtr, typename FilterPtr, typename OutputPtr>
static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
const libxsmm_dnn_conv_desc& desc,
libxsmm_dnn_conv_kind kind, InputPtr input,
FilterPtr filter, OutputPtr output) {
libxsmm_dnn_err_t status;
-
libxsmm_dnn_conv_handle* libxsmm_handle;
- libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status);
- chk_libxsmm_err(status, "Create handle");
-
+ libxsmm_dnn_conv_desc_wrap w(desc);
+
+ if(kind == LIBXSMM_DNN_CONV_KIND_FWD)
+ libxsmm_handle = libxsmm_handles.find(w);
+ else{
+ libxsmm_handle = libxsmm_dnn_create_conv_handle_check(desc, &status);
+ chk_libxsmm_err(status, "Create handle");
+ }
+
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
"Destroy handle");
return false; // Use non-libxsmm code
}
- // libxsmm_dnn_get_codegen_success can return real errors as well
chk_libxsmm_err(status, "Check codegen status");
libxsmm_dnn_buffer* libxsmm_input;
libxsmm_dnn_buffer* libxsmm_output;
libxsmm_dnn_filter* libxsmm_filter;
+
+ /*
+ const DeviceBase::CpuWorkerThreads* worker_threads =
+ ctx->device()->tensorflow_cpu_worker_threads();
+
+ int num_threads = worker_threads->num_threads;
+*/
+
+ int ifmblock = (libxsmm_handle->ifmblock);
+ int ofmblock = (libxsmm_handle->ofmblock);
+
+ int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
+ int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1;
+ float *native_filter = (float*)libxsmm_aligned_malloc( blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), 2097152);
+
+
+
+ const DeviceBase::CpuWorkerThreads* worker_threads =
+ ctx->device()->tensorflow_cpu_worker_threads();
+
+ int num_threads = worker_threads->num_threads;
+
+
+ if(blocksofm > num_threads){
+ int work = blocksofm;
+ BlockingCounter count(num_threads);
+ for (int i = 0; i < num_threads; ++i) {
+ worker_threads->workers->Schedule([=, &count]() {
+ int start = work/num_threads*i;
+ int end = (start + work/num_threads) > work ? work: start + work/num_threads;
+ copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock,start, end);
+ count.DecrementCount();
+ });
+ }
+ count.Wait();
+ }
+ else{
+
+ int work = blocksofm;
+ int num_threads = work;
+
+ BlockingCounter count(num_threads);
+ for (int i = 0; i < num_threads; ++i) {
+ worker_threads->workers->Schedule([=, &count]() {
+ int start = i;
+ int end = i+1;
+ copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock, start, end);
+ count.DecrementCount();
+ });
+ }
+ count.Wait();
+ }
libxsmm_input = libxsmm_dnn_link_input_buffer_check(
libxsmm_handle, input, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status);
@@ -78,7 +271,7 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
libxsmm_handle, output, LIBXSMM_DNN_CONV_FORMAT_NHWC_PTR, &status);
chk_libxsmm_err(status, "Link output buffer");
libxsmm_filter = libxsmm_dnn_link_filter_check(
- libxsmm_handle, filter, LIBXSMM_DNN_CONV_FORMAT_RSCK_PTR, &status);
+ libxsmm_handle, native_filter, LIBXSMM_DNN_CONV_FORMAT_LIBXSMM_PTR, &status);
chk_libxsmm_err(status, "Link filter");
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output");
@@ -95,25 +288,26 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
libxsmm_dnn_transpose_filter(libxsmm_handle);
}
- // TODO(maciejd) We would prefer raw threads instead of threadpool.
- auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
- int num_threads = worker_threads.num_threads;
BlockingCounter counter(num_threads);
+
+
+
for (int i = 0; i < num_threads; ++i) {
- worker_threads.workers->Schedule([=, &counter]() {
+ worker_threads->workers->Schedule([=, &counter]() {
chk_libxsmm_err(libxsmm_dnn_convolve_st(libxsmm_handle, kind, 0, i),
"Worker");
counter.DecrementCount();
});
}
counter.Wait();
-
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_input), "Destroy input");
chk_libxsmm_err(libxsmm_dnn_destroy_buffer(libxsmm_output), "Destroy output");
chk_libxsmm_err(libxsmm_dnn_destroy_filter(libxsmm_filter), "Destroy filter");
- chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
+
+ if(kind != LIBXSMM_DNN_CONV_KIND_FWD)
+ chk_libxsmm_err(libxsmm_dnn_destroy_conv_handle(libxsmm_handle),
"Destroy handle");
-
+ libxsmm_free(native_filter);
return true; // Succeeded
}
diff --git a/tensorflow/core/kernels/xsmm_conv2d.h b/tensorflow/core/kernels/xsmm_conv2d.h
index acc50213b0..b439511dc7 100644
--- a/tensorflow/core/kernels/xsmm_conv2d.h
+++ b/tensorflow/core/kernels/xsmm_conv2d.h
@@ -28,6 +28,11 @@ class OpKernelContext;
// XsmmConv2D is a wrapper for libxsmm direct convolutions.
+// Returns true if convolution operation specified by function arguments
+// can use XsmmConv2D implementation, and false otherwise.
+bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
+ TensorFormat data_format);
+
namespace functor {
template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/xsmm_conv2d_test.cc b/tensorflow/core/kernels/xsmm_conv2d_test.cc
index d81368314c..f4ab6896ae 100644
--- a/tensorflow/core/kernels/xsmm_conv2d_test.cc
+++ b/tensorflow/core/kernels/xsmm_conv2d_test.cc
@@ -15,13 +15,339 @@ limitations under the License.
#include "tensorflow/core/kernels/conv_ops.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "include/libxsmm.h"
+#include "tensorflow/core/framework/fake_input.h"
namespace tensorflow {
namespace {
+
+typedef struct {
+ int nImg;
+ int nIfm;
+ int nOfm;
+ int ifhp;
+ int ifwp;
+ int ifh;
+ int ifw;
+ int ofhp;
+ int ofwp;
+ int ofh;
+ int ofw;
+ int pad_h;
+ int pad_w;
+ int pad_h_in;
+ int pad_w_in;
+ int pad_h_out;
+ int pad_w_out;
+ int kh;
+ int kw;
+ int stride_h;
+ int stride_w;
+} naive_conv_t;
+
+
+LIBXSMM_INLINE void naive_copy_NCHW_to_NHWC(const float* nchw, Tensor &nhwc, int N, int H, int W, int C)
+{
+ LIBXSMM_VLA_DECL(4, const float, input, nchw, C, H, W);
+ int n, h, w, c;
+ auto output = nhwc.flat<float>();
+ for ( n = 0; n < N; n++ ) {
+ for ( h = 0; h < H; h++ ) {
+ for ( w = 0; w < W; w++ ) {
+ for ( c = 0; c < C; c++ ) {
+ output(n*H*W*C + h*W*C +w*C + c) =
+ LIBXSMM_VLA_ACCESS(4, input, n, c, h, w, C, H, W);
+ }
+ }
+ }
+ }
+}
+
+
+LIBXSMM_INLINE void naive_copy_KCRS_to_RSCK(const float* kcrs, Tensor &rsck, int R, int S, int C, int K)
+{
+ LIBXSMM_VLA_DECL(4, const float, input, kcrs, C, R, S);
+ int r, s, c, k;
+ auto output = rsck.flat<float>();
+
+ for ( r = 0; r < R; r++ ) {
+ for ( s = 0; s < S; s++ ) {
+ for ( c = 0; c < C; c++ ) {
+ for ( k = 0; k < K; k++ ) {
+ output(r*S*C*K + s*C*K + c*K + k) =
+ LIBXSMM_VLA_ACCESS(4, input, k, c, r, s, C, R, S);
+ }
+ }
+ }
+ }
+}
+
+
+
+LIBXSMM_INLINE void zero_buf(float* buf, long size) {
+ int i;
+ for (i = 0; i < size; ++i) {
+ buf[i] = 0.0f;
+ }
+}
+
+LIBXSMM_INLINE void copy_buf(Tensor &dst,float *src,long size) {
+ long i;
+ auto output = dst.flat<float>();
+ for (i = 0; i < size; ++i)
+ output(i) = src[i];
+}
+
+LIBXSMM_INLINE void init_buf(float* buf, long size, int initPos, int initOne)
+{
+ int i;
+ zero_buf(buf, size);
+ for (i = 0; i < size; ++i) {
+ buf[i] = (float)((initOne != 0) ? 1.0 : ((initPos != 0) ? drand48() : (0.05 - drand48()/10.0)));
+ }
+}
+
+
+
+LIBXSMM_INLINE void naive_conv_fp(naive_conv_t* param, const float* input, float* output, const float* filter)
+{
+ int nImg = param->nImg;
+ int nIfm = param->nIfm;
+ int nOfm = param->nOfm;
+ int ifhp = param->ifhp;
+ int ifwp = param->ifwp;
+ int ofhp = param->ofhp;
+ int ofwp = param->ofwp;
+ int ifh = param->ifh;
+ int ifw = param->ifw;
+ int ofh = param->ofh;
+ int ofw = param->ofw;
+ int pad_h = param->pad_h;
+ int pad_w = param->pad_w;
+ int pad_h_in = param->pad_h_in;
+ int pad_w_in = param->pad_w_in;
+ int pad_h_out = param->pad_h_out;
+ int pad_w_out = param->pad_w_out;
+ int kh = param->kh;
+ int kw = param->kw;
+ int stride_h = param->stride_h;
+ int stride_w = param->stride_w;
+ /* loop counters */
+ int img, ofm, ifm, oj, oi, ij, ii, kj, ki;
+
+ LIBXSMM_VLA_DECL(4, float, output_t, output + (pad_w_out * ofwp + pad_h_out), nOfm, ofhp, ofwp);
+ LIBXSMM_VLA_DECL(4, const float, input_t, input + (pad_w_in * ifwp + pad_h_in), nIfm, ifhp, ifwp);
+ LIBXSMM_VLA_DECL(4, const float, filter_t, filter, nIfm, kh, kw);
+
+ for (img = 0; img < nImg; ++img) {
+ for (ofm = 0; ofm < nOfm; ++ofm) {
+ for (ifm = 0; ifm < nIfm; ++ifm) {
+ for (oj = 0; oj < ofh; ++oj) {
+ ij = oj * stride_h - pad_h;
+ for (oi = 0; oi < ofw; ++oi) {
+ ii = oi * stride_w - pad_w;
+ for (kj = 0; kj < kh; ++kj) {
+ if(ij+kj < 0 || ij+kj >= ifh) continue;
+ for (ki = 0; ki < kw; ++ki) {
+ if(ii+ki < 0 || ii+ki >= ifw) continue;
+ LIBXSMM_VLA_ACCESS( 4, output_t, img, ofm, oj, oi, nOfm, ofhp, ofwp) +=
+ LIBXSMM_VLA_ACCESS(4, input_t, img, ifm, ij + kj, ii + ki, nIfm, ifhp, ifwp)
+ * LIBXSMM_VLA_ACCESS(4, filter_t, ofm, ifm, kj, ki, nIfm, kh, kw);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
void RunXsmmVsGeneric() {}
-TEST(XsmmConv2DTest, Basic) {}
+class XsmmConv2DTest : public OpsTestBase {
+ protected:
+ void MakeOp(int stride) {
+
+ TF_CHECK_OK(NodeDefBuilder("xsmm", "Conv2D")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("strides", {1, stride,stride, 1})
+ .Attr("padding", "VALID" )
+ .Finalize(node_def()));
+
+
+ TF_ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(XsmmConv2DTest, Basic) {
+ MakeOp(1);
+
+
+ int ifw = 14; /* input width, "W" */
+ int ifh = 14; /* input height, "H" */
+ int nImg = 32; /* mini-batch size, "N" */
+ int nIfm = 64; /* number of input feature maps, "C" */
+ int nOfm = 64; /* number of output feature maps, "K" */
+ int kh = 3; /* filter height, "R" */
+ int kw = 3; /* filter width, "S" */
+ int pad = 0; /* padding in output */
+ int stride = 1; /* stride when accessing inputs */
+
+
+ int stride_w = stride;
+ int stride_h = stride;
+ int pad_h = pad;
+ int pad_w = pad;
+
+ int pad_h_in = pad_h;
+ int pad_w_in = pad_w;
+
+ int pad_h_out = 0;
+ int pad_w_out = 0;
+
+ /* deriving some values for naive code */
+ int ofh = (ifh + 2 * pad_h - kh) / stride_h + 1;
+ int ofw = (ifw + 2 * pad_w - kw) / stride_w + 1;
+ int ifhp = ifh + 2 * pad_h_in;
+ int ifwp = ifw + 2 * pad_w_in;
+ int ofhp = ofh + 2 * pad_h_out;
+ int ofwp = ofw + 2 * pad_w_out;
+
+
+ //Initialization of Filter and Image
+
+ /* allocate data */
+ float *naive_input = (float*)libxsmm_aligned_malloc( nImg*nIfm*ifhp*ifwp*sizeof(float), 2097152);
+ float *naive_output = (float*)libxsmm_aligned_malloc( nImg*nOfm*ofhp*ofwp*sizeof(float), 2097152);
+ float *naive_filter = (float*)libxsmm_aligned_malloc( nOfm*nIfm*kh*kw* sizeof(float), 2097152);
+ /* initialize data */
+ init_buf(naive_input, nImg*nIfm*ifhp*ifwp, 0, 0);
+ zero_buf(naive_output, nImg*nOfm*ofhp*ofwp);
+ init_buf(naive_filter, nOfm*nIfm*kh*kw, 0, 0);
+
+
+ Tensor image(DT_FLOAT,
+ {nImg, ifhp, ifwp, nIfm});
+
+
+ Tensor filter(DT_FLOAT, {kh,kw,nIfm,nOfm});
+
+
+ naive_copy_NCHW_to_NHWC(naive_input, image, nImg, ifhp, ifwp, nIfm);
+ naive_copy_KCRS_to_RSCK(naive_filter, filter, kh, kw, nIfm, nOfm);
+
+
+ //Run naive convolution
+
+ naive_conv_t naive_param;
+
+ naive_param.nImg = nImg;
+ naive_param.nIfm = nIfm;
+ naive_param.nOfm = nOfm;
+ naive_param.ifhp = ifhp;
+ naive_param.ifwp = ifwp;
+ naive_param.ofhp = ofhp;
+ naive_param.ofwp = ofwp;
+ naive_param.ifh = ifh;
+ naive_param.ifw = ifw;
+ naive_param.ofh = ofh;
+ naive_param.ofw = ofw;
+ naive_param.pad_h = pad_h;
+ naive_param.pad_w = pad_w;
+ naive_param.pad_h_in = pad_h_in;
+ naive_param.pad_w_in = pad_w_in;
+ naive_param.pad_h_out = pad_h_out;
+ naive_param.pad_w_out = pad_w_out;
+ naive_param.kh = kh;
+ naive_param.kw = kw;
+ naive_param.stride_h = stride_h;
+ naive_param.stride_w = stride_w;
+
+
+ naive_conv_fp(&naive_param, naive_input, naive_output, naive_filter);
+
+
+
+ AddInputFromArray<float>(image.shape(), image.flat<float>());
+ AddInputFromArray<float>(filter.shape(), filter.flat<float>());
+
+
+
+ //Run Op (TF)
+ TF_ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(DT_FLOAT, {nImg,ofhp,ofwp, nOfm});
+ naive_copy_NCHW_to_NHWC(naive_output, expected, nImg, ofhp, ofwp, nOfm);
+
+
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+ libxsmm_free(naive_input);
+ libxsmm_free(naive_output);
+ libxsmm_free(naive_filter);
+
+
+
+}
+
+/*
+
+
+TEST(XsmmConv2DTest, Basic) {
+
+ auto num_threads =
+ ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
+ // See libxsmm_dnn.h for this struct definition.
+ libxsmm_dnn_conv_desc desc;
+ desc.N = batch;
+ desc.C = in_depth;
+ desc.H = input_rows;
+ desc.W = input_cols;
+ desc.K = out_depth;
+ desc.R = filter_rows;
+ desc.S = filter_cols;
+ desc.u = stride_rows;
+ desc.v = stride_cols;
+ desc.pad_h = pad_rows;
+ desc.pad_w = pad_cols;
+ desc.pad_h_in = pad_rows; // libxsmm supports only physical padding for now
+ desc.pad_w_in = pad_cols; // libxsmm supports only physical padding for now
+ desc.pad_h_out = 0;
+ desc.pad_w_out = 0;
+ desc.threads = num_threads;
+ desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
+ desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
+ desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;//LIBXSMM_DNN_CONV_FORMAT_RSCK;
+ desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
+ desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
+ desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
+ desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
+
+ if (!CanUseXsmmConv2D(desc, data_format)) {
+ return false;
+ }
+
+ auto input_ptr = input.template flat<float>().data();
+ auto filter_ptr = filter.template flat<float>().data();
+ auto output_ptr = output->template flat<float>().data();
+
+ bool success = functor::XsmmFwdConv2D<CPUDevice, float>()(
+ ctx, desc, input_ptr, filter_ptr, output_ptr);
+ return success;
+
+
+
+
+
+
+
+}
+*/
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/proto_text_util.h b/tensorflow/core/lib/strings/proto_text_util.h
index 5f7ee00f8c..3d0c6e4a37 100644
--- a/tensorflow/core/lib/strings/proto_text_util.h
+++ b/tensorflow/core/lib/strings/proto_text_util.h
@@ -164,7 +164,7 @@ bool ProtoParseNumericFromScanner(Scanner* scanner, T* value) {
// Special case to disallow multiple leading zeroes, to match proto parsing.
int leading_zero = 0;
- for (int i = 0; i < numeric_str.size(); ++i) {
+ for (size_t i = 0; i < numeric_str.size(); ++i) {
const char ch = numeric_str[i];
if (ch == '0') {
if (++leading_zero > 1) return false;
diff --git a/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py b/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py
index bf3f2fb015..eaff05913a 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py
@@ -80,8 +80,8 @@ def main(_):
options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
run_metadata=run_metadata)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
- trace_file = open('timeline.ctf.json', 'w')
- trace_file.write(trace.generate_chrome_trace_format())
+ with open('timeline.ctf.json', 'w') as trace_file:
+ trace_file.write(trace.generate_chrome_trace_format())
else:
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index bc07c44218..c4c9bdfcaf 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -81,7 +81,7 @@ If the above commands do not work on your system, you can follow these instructi
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
-# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
+# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl
# Mac OS X, CPU only, Python 2.7:
@@ -94,14 +94,14 @@ $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/gpu/tensorf
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
-# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
+# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
-# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
+# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl
# Mac OS X, CPU only, Python 3.4 or 3.5:
@@ -215,7 +215,7 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
-# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
+# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl
# Mac OS X, CPU only, Python 2.7:
@@ -228,14 +228,14 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
-# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
+# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
-# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
+# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl
# Mac OS X, CPU only, Python 3.4 or 3.5:
@@ -367,7 +367,7 @@ select the correct binary to install:
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp27-none-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
-# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
+# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp27-none-linux_x86_64.whl
# Mac OS X, CPU only, Python 2.7:
@@ -380,14 +380,14 @@ select the correct binary to install:
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
-# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
+# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp34-cp34m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.12.1-cp35-cp35m-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
-# Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see "Installing from sources" below.
+# Requires CUDA toolkit 8.0 and CuDNN v5.1. For other versions, see "Installing from sources" below.
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-0.12.1-cp35-cp35m-linux_x86_64.whl
# Mac OS X, CPU only, Python 3.4 or 3.5:
@@ -635,7 +635,7 @@ toolkit is installed in `/usr/local/cuda`, run the following commands (edited
to reflect the cuDNN version you downloaded):
``` bash
-tar xvzf cudnn-8.0-linux-x64-v5.1-ga.tgz
+tar xvzf cudnn-8.0-linux-x64-v5.1.tgz
sudo cp -P cuda/include/cudnn.h /usr/local/cuda/include/
sudo cp -P cuda/lib64/libcudnn* /usr/local/cuda/lib64/
sudo chmod a+r /usr/local/cuda/include/cudnn.h /usr/local/cuda/lib64/libcudnn*
diff --git a/tensorflow/g3doc/tutorials/deep_cnn/index.md b/tensorflow/g3doc/tutorials/deep_cnn/index.md
index ec9d726b3a..8c3eeb40cb 100644
--- a/tensorflow/g3doc/tutorials/deep_cnn/index.md
+++ b/tensorflow/g3doc/tutorials/deep_cnn/index.md
@@ -136,7 +136,7 @@ artificially increase the data set size:
Please see the [Images](../../api_docs/python/image.md) page for the list of
available distortions. We also attach an
-[`image_summary`](../../api_docs/python/train.md#image_summary) to the images
+[`image`](../../api_docs/python/summary.md#image) to the images
so that we may visualize them in [TensorBoard](../../how_tos/summaries_and_tensorboard/index.md).
This is a good practice to verify that inputs are built correctly.
@@ -203,7 +203,7 @@ For regularization, we also apply the usual
variables. The objective function for the model is the sum of the cross entropy
loss and all these weight decay terms, as returned by the `loss()` function.
-We visualize it in TensorBoard with a [`scalar_summary`](../../api_docs/python/train.md#scalar_summary):
+We visualize it in TensorBoard with a [`scalar`](../../api_docs/python/summary.md#scalar):
![CIFAR-10 Loss](../../images/cifar_loss.png "CIFAR-10 Total Loss")
@@ -289,7 +289,7 @@ how the model is training. We want more insight into the model during training:
[TensorBoard](../../how_tos/summaries_and_tensorboard/index.md) provides this
functionality, displaying data exported periodically from `cifar10_train.py` via
a
-[`SummaryWriter`](../../api_docs/python/train.md#SummaryWriter).
+[`FileWriter`](../../api_docs/python/summary.md#FileWriter).
For instance, we can watch how the distribution of activations and degree of
sparsity in `local3` features evolve during training:
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 20474d483d..b6f8e84e41 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -212,7 +212,8 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
A list of gradients to use, without None.
Raises:
- ValueError: If one of the grad_ys is invalid.
+ ValueError: If sizes of gradients and inputs don't match
+ TypeError: If type of any gradient is not valid for its input.
"""
if len(grad_ys) != len(ys):
raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
@@ -225,12 +226,24 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
grad_ys[i] = array_ops.fill(
array_ops.shape(y), constant_op.constant(
1, dtype=y.dtype))
+ continue
+ if y.dtype.is_floating or y.dtype.is_integer:
+ if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
+ raise TypeError("Gradient type %s generated for real or "
+ "integer-valued tensor %s with type %s must be "
+ "real or integer" %
+ (dtypes.as_dtype(grad_y.dtype).name, y,
+ dtypes.as_dtype(y.dtype).name))
+ elif y.dtype.is_complex:
+ if not grad_y.dtype.is_complex:
+ raise TypeError("Gradient type %s generated for complex-valued "
+ "tensor %s with type %s must be real" %
+ (dtypes.as_dtype(grad_y.dtype).name, y,
+ dtypes.as_dtype(y.dtype).name))
else:
- if grad_y.dtype != y.dtype:
- raise ValueError("Y and ys_grad must be of the same type, "
- "not y: %s, ys_grad: %s " %
- (dtypes.as_dtype(y.dtype).name,
- dtypes.as_dtype(grad_y.dtype).name))
+ raise TypeError("Tensor %s with type %s must be numeric "
+ "to obtain a default gradient" %
+ (y, dtypes.as_dtype(y.dtype).name))
return grad_ys
@@ -248,18 +261,32 @@ def _VerifyGeneratedGradients(grads, op):
op: Operation for which the gradients where generated.
Raises:
- ValueError: if the gradients are invalid.
+ ValueError: if sizes of gradients and inputs don't match.
+ TypeError: if type of any gradient is not valid for its input.
"""
if len(grads) != len(op.inputs):
raise ValueError("Num gradients %d generated for op %s do not match num "
"inputs %d" % (len(grads), op.node_def, len(op.inputs)))
- for i in xrange(len(grads)):
- grad = grads[i]
- inp = op.inputs[i]
- if grad is not None:
- if not grad.dtype.is_compatible_with(inp.dtype):
- raise ValueError("Gradient type %s generated for op %s does "
- "not match input type %s" %
+ for i in xrange(len(grads)):
+ grad = grads[i]
+ inp = op.inputs[i]
+ if grad is None:
+ continue
+ if grad.dtype.is_floating:
+ if not inp.dtype.is_floating:
+ raise TypeError("Gradient type %s generated for real-valued op %s "
+ "with type %s must be real" %
+ (dtypes.as_dtype(grad.dtype).name, op.node_def,
+ dtypes.as_dtype(inp.dtype).name))
+ elif grad.dtype.is_complex:
+ if not inp.dtype.is_complex:
+ raise TypeError("Gradient type %s generated for complex-valued op %s"
+ " with type %s must be complex" %
+ (dtypes.as_dtype(grad.dtype).name, op.node_def,
+ dtypes.as_dtype(inp.dtype).name))
+ else:
+ raise TypeError("Gradient type %s generated for op %s "
+ "with type %s must be either real or complex" %
(dtypes.as_dtype(grad.dtype).name, op.node_def,
dtypes.as_dtype(inp.dtype).name))
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 4ea23b431c..37787eca63 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -323,7 +323,7 @@ class DivAndModTest(test_util.TensorFlowTestCase):
a = variables.Variable(2.)
b = variables.Variable(4.)
with self.test_session() as sess:
- sess.run(variables.initialize_all_variables())
+ sess.run(variables.global_variables_initializer())
c_grad = gradients.gradients(math_ops.divide(a, b), [a, b])
self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125])
c_grad = gradients.gradients(math_ops.div(a, b), [a, b])
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 0fff2d3ba5..e2b6ac0864 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -1242,9 +1242,10 @@ def sparse_tensor_dense_matmul(sp_a,
GPU: NVidia Tesla k40c
Compiled with:
- -c opt --config=cuda --copt=-mavx
+ `-c opt --config=cuda --copt=-mavx`
- ```tensorflow/python/sparse_tensor_dense_matmul_op_test --benchmarks
+ ```
+ tensorflow/python/sparse_tensor_dense_matmul_op_test --benchmarks
A sparse [m, k] with % nonzero values between 1% and 80%
B dense [k, n]
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index 6e82c87f79..f33232b5fb 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -346,28 +346,48 @@ class CheckpointSaverHookTest(test.TestCase):
'end': 1
}, listener.get_counts())
- def test_save_secs_saves_periodically(self):
+ @test.mock.patch('time.time')
+ def test_save_secs_saves_periodically(self, mock_time):
+ # Let's have a realistic start time
+ current_time = 1484695987.209386
+
with self.graph.as_default():
+ mock_time.return_value = current_time
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_secs=2, scaffold=self.scaffold)
hook.begin()
self.scaffold.finalize()
+
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
+
+ mock_time.return_value = current_time
mon_sess.run(self.train_op) # Saved.
+
+ mock_time.return_value = current_time + 0.5
mon_sess.run(self.train_op) # Not saved.
+
self.assertEqual(1,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
- time.sleep(2.5)
+
+ # Simulate 2.5 seconds of sleep.
+ mock_time.return_value = current_time + 2.5
mon_sess.run(self.train_op) # Saved.
+
+ mock_time.return_value = current_time + 2.6
mon_sess.run(self.train_op) # Not saved.
+
+ mock_time.return_value = current_time + 2.7
mon_sess.run(self.train_op) # Not saved.
+
self.assertEqual(3,
checkpoint_utils.load_variable(self.model_dir,
self.global_step.name))
- time.sleep(2.5)
+
+ # Simulate 7.5 more seconds of sleep (10 seconds from start.
+ mock_time.return_value = current_time + 10
mon_sess.run(self.train_op) # Saved.
self.assertEqual(6,
checkpoint_utils.load_variable(self.model_dir,
diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h
index 4291a7a632..bbe02e5112 100644
--- a/tensorflow/stream_executor/kernel.h
+++ b/tensorflow/stream_executor/kernel.h
@@ -322,8 +322,8 @@ class KernelArgIterator {
}
private:
- int arg_index_;
- int number_of_arguments_;
+ size_t arg_index_;
+ size_t number_of_arguments_;
const void *const *arg_address_iter_;
const size_t *arg_size_iter_;
const size_t *shmem_bytes_iter_;
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index a9530034fe..07c9013377 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -89,7 +89,7 @@ WORKDIR /tensorflow
ENV CI_BUILD_PYTHON python
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
-ENV TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2
+ENV TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1
RUN tensorflow/tools/ci_build/builds/configured GPU \
bazel build -c opt --config=cuda tensorflow/tools/pip_package:build_pip_package && \
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 6d8b1165bb..787f7da2b3 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -76,7 +76,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
native.new_http_archive(
name = "libxsmm_archive",
urls = [
- # "http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.6.1.tar.gz",
+ "http://bazel-mirror.storage.googleapis.com/github.com/hfp/libxsmm/archive/1.6.5.tar.gz",
"https://github.com/hfp/libxsmm/archive/1.6.5.tar.gz",
],
sha256 = "5231419a8e13e7a6d286cf25d32a3aa75c443a625e5ea57024d36468bc3d5936",
@@ -139,7 +139,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
name = "nasm",
urls = [
"http://bazel-mirror.storage.googleapis.com/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2",
- "http://www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2",
+ "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.12.02.tar.bz2/d15843c3fb7db39af80571ee27ec6fad/nasm-2.12.02.tar.bz2",
],
sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324",
strip_prefix = "nasm-2.12.02",
diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD
index a0aab0f5b7..a85a2013b6 100644
--- a/third_party/libxsmm.BUILD
+++ b/third_party/libxsmm.BUILD
@@ -60,8 +60,6 @@ cc_library(
"src/libxsmm_dump.c",
"src/libxsmm_malloc.c",
"src/libxsmm_gemm.c",
- "src/libxsmm_gemm_diff.c",
- "src/libxsmm_hash.c",
"src/libxsmm_timer.c",
"src/libxsmm_trace.c",
"src/libxsmm_trans.c",
@@ -108,7 +106,6 @@ cc_library(
"src",
"src/template",
],
- linkopts = ["-ldl"],
visibility = ["//visibility:public"],
)