aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--eigen.BUILD2
-rw-r--r--tensorflow/contrib/BUILD8
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py12
-rw-r--r--tensorflow/contrib/layers/python/ops/loss_ops.py213
-rw-r--r--tensorflow/contrib/layers/python/ops/loss_ops_test.py281
-rw-r--r--tensorflow/contrib/linear_optimizer/BUILD1
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py7
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py21
-rw-r--r--tensorflow/core/BUILD3
-rw-r--r--tensorflow/core/client/tensor_c_api.cc20
-rw-r--r--tensorflow/core/common_runtime/allocator_retry.cc (renamed from tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc)6
-rw-r--r--tensorflow/core/common_runtime/allocator_retry.h (renamed from tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h)12
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc702
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h413
-rw-r--r--tensorflow/core/common_runtime/function.cc35
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc4
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc689
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h396
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc24
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator.h10
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc14
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc11
-rw-r--r--tensorflow/core/common_runtime/visitable_allocator.h (renamed from tensorflow/core/common_runtime/gpu/visitable_allocator.h)6
-rw-r--r--tensorflow/core/framework/allocator.h9
-rw-r--r--tensorflow/core/framework/allocator_test.cc20
-rw-r--r--tensorflow/core/graph/gradients.cc159
-rw-r--r--tensorflow/core/graph/gradients.h37
-rw-r--r--tensorflow/core/kernels/BUILD47
-rw-r--r--tensorflow/core/kernels/batch_matmul_op.cc49
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc8
-rw-r--r--tensorflow/core/kernels/depthtospace_op.cc61
-rw-r--r--tensorflow/core/kernels/depthtospace_op.h44
-rw-r--r--tensorflow/core/kernels/depthtospace_op_gpu.cu.cc88
-rw-r--r--tensorflow/core/kernels/image_resizer_state.h111
-rw-r--r--tensorflow/core/kernels/nn_ops_test.cc19
-rw-r--r--tensorflow/core/kernels/relu_op.cc141
-rw-r--r--tensorflow/core/kernels/relu_op.h212
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h130
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/resize_area_op.cc72
-rw-r--r--tensorflow/core/kernels/resize_bicubic_op.cc59
-rw-r--r--tensorflow/core/kernels/resize_bilinear_op.cc62
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op.cc65
-rw-r--r--tensorflow/core/kernels/softmax_op.cc23
-rw-r--r--tensorflow/core/kernels/softmax_op.h99
-rw-r--r--tensorflow/core/kernels/softmax_op_functor.h101
-rw-r--r--tensorflow/core/kernels/softmax_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/spacetodepth_op.cc50
-rw-r--r--tensorflow/core/kernels/spacetodepth_op.h44
-rw-r--r--tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc89
-rw-r--r--tensorflow/core/kernels/transpose_op.cc23
-rw-r--r--tensorflow/core/public/session.h9
-rw-r--r--tensorflow/core/public/tensor_c_api.h19
-rw-r--r--tensorflow/core/util/work_sharder.cc6
-rw-r--r--tensorflow/core/util/work_sharder_test.cc19
-rw-r--r--tensorflow/examples/android/jni/jni_utils.cc14
-rw-r--r--tensorflow/examples/android/jni/jni_utils.h3
-rw-r--r--tensorflow/examples/android/jni/tensorflow_jni.cc38
-rw-r--r--tensorflow/python/BUILD5
-rw-r--r--tensorflow/python/__init__.py7
-rw-r--r--tensorflow/python/client/session.py2
-rw-r--r--tensorflow/python/client/session_test.py28
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py1
-rw-r--r--tensorflow/python/framework/test_util.py5
-rw-r--r--tensorflow/python/kernel_tests/benchmark_test.py158
-rw-r--r--tensorflow/python/kernel_tests/depthtospace_op_test.py94
-rw-r--r--tensorflow/python/kernel_tests/rnn_cell_test.py3
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py82
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py7
-rw-r--r--tensorflow/python/kernel_tests/spacetodepth_op_test.py95
-rw-r--r--tensorflow/python/ops/array_ops.py1
-rw-r--r--tensorflow/python/ops/data_flow_ops.py2
-rw-r--r--tensorflow/python/ops/histogram_ops.py70
-rw-r--r--tensorflow/python/ops/histogram_ops_test.py171
-rw-r--r--tensorflow/python/ops/io_ops.py1
-rw-r--r--tensorflow/python/ops/rnn_cell.py46
-rw-r--r--tensorflow/python/ops/seq2seq.py12
-rw-r--r--tensorflow/python/ops/sparse_ops.py7
-rw-r--r--tensorflow/python/ops/standard_ops.py1
-rw-r--r--tensorflow/python/platform/benchmark.py213
-rw-r--r--tensorflow/python/platform/default/_app.py4
-rw-r--r--tensorflow/python/platform/googletest.py13
-rw-r--r--tensorflow/python/platform/test.py4
-rw-r--r--tensorflow/python/training/coordinator.py22
-rw-r--r--tensorflow/python/training/input.py203
-rw-r--r--tensorflow/python/training/input_test.py54
-rw-r--r--tensorflow/python/training/summary_io.py78
-rw-r--r--tensorflow/python/training/summary_writer_test.py77
-rw-r--r--tensorflow/python/training/supervisor.py4
-rw-r--r--tensorflow/stream_executor/blas.h19
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc114
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h2
-rw-r--r--tensorflow/stream_executor/stream.cc74
-rw-r--r--tensorflow/stream_executor/stream.h28
-rw-r--r--tensorflow/tensorboard/components/tf-event-dashboard/tf-event-dashboard.html2
-rw-r--r--tensorflow/tensorboard/components/tf-event-dashboard/tf-run-selector.html1
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts7
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts19
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/render.ts3
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts7
-rw-r--r--tensorflow/tensorboard/components/tf-graph-dashboard/tf-graph-dashboard.html2
-rw-r--r--tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html4
-rw-r--r--tensorflow/tensorboard/components/tf-histogram-dashboard/tf-histogram-dashboard.html2
-rw-r--r--tensorflow/tensorboard/components/tf-image-dashboard/tf-image-dashboard.html2
-rw-r--r--tensorflow/tensorboard/dist/tf-tensorboard.html10
-rw-r--r--tensorflow/tensorboard/lib/js/colorScale/colorScale.ts148
-rw-r--r--tensorflow/tensorboard/lib/js/colorScale/demo/index.html176
-rw-r--r--tensorflow/tensorboard/lib/js/colorScale/demo/style.css74
-rw-r--r--tensorflow/tensorboard/lib/js/colorScale/palettes.ts54
-rw-r--r--tensorflow/tensorboard/lib/js/colorScale/test/colorScaleTests.ts99
-rw-r--r--tensorflow/tensorboard/lib/js/colorScale/test/index.html28
-rw-r--r--tensorflow/tools/dist_test/Dockerfile28
-rw-r--r--tensorflow/tools/dist_test/Dockerfile.local20
-rw-r--r--tensorflow/tools/dist_test/README.md76
-rwxr-xr-xtensorflow/tools/dist_test/build_server.sh44
-rw-r--r--tensorflow/tools/dist_test/local/Dockerfile20
-rwxr-xr-xtensorflow/tools/dist_test/local/start_local_k8s_service.sh118
-rwxr-xr-xtensorflow/tools/dist_test/local/start_tf_cluster_container.sh91
-rwxr-xr-xtensorflow/tools/dist_test/local/test_local_tf_cluster.sh88
-rwxr-xr-xtensorflow/tools/dist_test/local_test.sh152
-rwxr-xr-xtensorflow/tools/dist_test/python/mnist_replica.py144
-rwxr-xr-xtensorflow/tools/dist_test/remote_test.sh92
-rwxr-xr-xtensorflow/tools/dist_test/scripts/create_tf_cluster.sh231
-rwxr-xr-xtensorflow/tools/dist_test/scripts/delete_tf_cluster.sh87
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_mnist_test.sh137
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_test.sh118
-rwxr-xr-xtensorflow/tools/dist_test/scripts/k8s_tensorflow.py245
-rw-r--r--tensorflow/tools/dist_test/scripts/utils.sh56
-rw-r--r--tensorflow/tools/dist_test/server/Dockerfile59
-rwxr-xr-xtensorflow/tools/dist_test/server/grpc_tensorflow_server.py122
-rw-r--r--tensorflow/user_ops/BUILD5
-rw-r--r--tensorflow/workspace.bzl4
-rw-r--r--third_party/eigen3/Eigen/Cholesky2
-rw-r--r--third_party/eigen3/Eigen/Core2
-rw-r--r--third_party/eigen3/Eigen/Eigenvalues2
-rw-r--r--third_party/eigen3/Eigen/LU2
-rw-r--r--third_party/eigen3/Eigen/QR2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/Tensor2
139 files changed, 6584 insertions, 2536 deletions
diff --git a/eigen.BUILD b/eigen.BUILD
index 1a1467a7e5..85b4f11865 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"])
-archive_dir = "eigen-eigen-db7b61411772"
+archive_dir = "eigen-eigen-0a13bf3e579d"
cc_library(
name = "eigen",
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 098dfe2752..708cfddefc 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -24,6 +24,14 @@ py_library(
],
)
+cc_library(
+ name = "contrib_kernels",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/linear_optimizer/kernels:sdca_ops",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 2c024b7bce..68200db076 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -211,6 +211,18 @@ class FullyConnectedTest(tf.test.TestCase):
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
self.assertEqual(1, cnt[0])
+ def test_empty_x_results_in_empty_output(self):
+ # Empty x is common if someone masks their input with tf.boolean_mask in
+ # order to drop missing entries, and in a particular batch all entries are
+ # missing.
+ with self.test_session():
+ x = tf.constant([[]], shape=[0, 3])
+ self.assertEqual(0, tf.size(x).eval())
+ y = tf.contrib.layers.fully_connected(x, 2, activation_fn=tf.nn.softmax)
+ tf.initialize_all_variables().run()
+ expected_y = np.array([]).reshape(0,2)
+ np.testing.assert_array_equal(expected_y, y.eval())
+
class Convolution2dTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/layers/python/ops/loss_ops.py b/tensorflow/contrib/layers/python/ops/loss_ops.py
index 276d4cc541..c451fc81d4 100644
--- a/tensorflow/contrib/layers/python/ops/loss_ops.py
+++ b/tensorflow/contrib/layers/python/ops/loss_ops.py
@@ -22,16 +22,17 @@ These loss ops are, by design, minimal, enabling flexibility in how
their output can be used.
@@reduce_batch_sum
-@@reduce_batch_mean
@@absolute_loss
@@squared_loss
+@@logistic_loss
+@@sum_absolute_loss
@@sum_squared_loss
-@@mean_absolute_loss
-@@mean_squared_loss
-@@root_mean_squared_loss
+@@sum_logistic_loss
+@@scalar_absolute_loss
+@@scalar_squared_loss
@@scalar_logistic_loss
"""
@@ -39,14 +40,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.layers.python.framework import tensor_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
-__all__ = ["reduce_batch_sum", "reduce_batch_mean", "absolute_loss",
- "squared_loss", "sum_squared_loss", "mean_absolute_loss",
- "mean_squared_loss", "root_mean_squared_loss",
+__all__ = ["reduce_batch_sum", "absolute_loss", "squared_loss", "logistic_loss",
+ "sum_absolute_loss", "sum_squared_loss", "sum_logistic_loss",
+ "scalar_absolute_loss", "scalar_squared_loss",
"scalar_logistic_loss"]
@@ -120,31 +122,11 @@ def reduce_batch_sum(x, name=None):
return _reduce_batch(x, math_ops.reduce_sum, name)
-def reduce_batch_mean(x, name=None):
- """Given a tensor `x`, returns the mean across all dimensions except dim 0.
-
- Given a tensor with the number of dimensions > 1, reduce_batch_mean
- will calculate the mean across all dimensions except for dimension
- 0. This function is useful for calculating the mean loss (error)
- across all examples in a batch when training. As an example, given a
- tensor of shape [batch_size, d1, d2], this function will calculate
- the mean across dimensions d1 and d2, returning a tensor of shape
- [batch_size].
-
- Tensors of dimension 1 are returned as-is.
-
- Args:
- x: A `Tensor` with dimension > 0.
- name: A name for the operation (optional).
-
- Returns:
- A `Tensor` with values averaged across all dimensions > 0.
-
- Raises:
- ValueError: If `x` has dimension 0.
-
- """
- return _reduce_batch(x, math_ops.reduce_mean, name)
+def _validate_predicted_and_target(predicted, target):
+ # TODO(ptucker): Optionally add assert op for shape check, for cases when
+ # shape is not fully defined at graph construction time?
+ predicted.get_shape().assert_is_compatible_with(target.get_shape())
+ tensor_util.assert_same_float_dtype([predicted, target])
def absolute_loss(predicted, target, name=None):
@@ -172,12 +154,12 @@ def absolute_loss(predicted, target, name=None):
with ops.op_scope([predicted, target], name, "absolute_loss") as scope:
predicted = ops.convert_to_tensor(predicted, name="predicted")
target = ops.convert_to_tensor(target, name="target")
- predicted.get_shape().assert_is_compatible_with(target.get_shape())
+ _validate_predicted_and_target(predicted, target)
return math_ops.abs(target - predicted, name=scope)
def squared_loss(predicted, target, name=None):
- """Computes and returns the per-example squared loss.
+ """Computes and returns the per-example squared loss, divided by 2.
Computes the per-example squared difference between the target and
predicted tensors. The tensors must have the same shape.
@@ -200,27 +182,33 @@ def squared_loss(predicted, target, name=None):
with ops.op_scope([predicted, target], name, "squared_loss") as scope:
predicted = ops.convert_to_tensor(predicted, name="predicted")
target = ops.convert_to_tensor(target, name="target")
- predicted.get_shape().assert_is_compatible_with(target.get_shape())
- return math_ops.square(target - predicted, name=scope)
+ _validate_predicted_and_target(predicted, target)
+ return math_ops.div(math_ops.square(target - predicted), 2.0, name=scope)
-def sum_squared_loss(predicted, target, name=None):
- """Calculates 1/2 the sum of the squared loss across batches.
+def logistic_loss(logit, target, name=None):
+ """Calculates the logistic cross-entropy loss.
- Computes the squared difference between the target and predicted
- tensors, sums across all dimensions except dimension 0, and divides
- by 2:
+ **WARNING:** `logit` must be unscaled, while the `target` should be a
+ normalized probability prediction. See
+ `tf.nn.sigmoid_cross_entropy_with_logits` for more details.
- losses = reduce_batch_sum(squared_loss(predicted, target)) / 2.0
+ Args:
+ logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted logit values.
+ target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+ name: A name for the operation (optional).
- where `losses` is a tensor with dimensions [batch_size].
+ Returns:
+ A `Tensor` of the logistic cross-entropy loss.
+ """
+ return nn.sigmoid_cross_entropy_with_logits(logit, target, name=name)
- The tensors must have the same shape.
- This function is equivalent to typical formulations of L2 loss, and
- similar to TensorFlow's l2_loss function. It differs from the
- l2_loss function by allowing the caller to specify both the
- predicted and target tensors.
+def _sum_loss(predicted, target, loss_fn, name="sum_loss"):
+ """Apply loss function, then sum across all non-batch dimensions.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
@@ -228,30 +216,23 @@ def sum_squared_loss(predicted, target, name=None):
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
+ loss_fn: Loss to apply, takes 2 tensors as parameters and returns a tensor.
name: A name for the operation (optional).
Returns:
- A `[batch_size]` tensor of squared losses summed across all dimensions
- except dimension 0, divided by 2.
-
- Raises:
- ValueError: If `predicted` and `target` shapes do not match.
-
+ A `[batch_size]` tensor of losses, averaged across all dimensions except
+ dimension 0.
"""
- with ops.op_scope([predicted, target], name, "sum_squared_loss") as scope:
- return math_ops.div(
- reduce_batch_sum(squared_loss(predicted, target)),
- 2.0,
- name=scope)
+ return reduce_batch_sum(loss_fn(predicted, target), name=name)
-def mean_absolute_loss(predicted, target, name=None):
- """Calculates the mean absolute loss across batches.
+def sum_absolute_loss(predicted, target, name="sum_absolute_loss"):
+ """Calculates the sum of absolute losses across batches.
Computes the absolute difference between the target and predicted
tensors, averaged across all dimensions except dimension 0:
- losses = reduce_batch_mean(absolute_loss(predicted, target))
+ losses = reduce_batch_sum(absolute_loss(predicted, target))
where `losses` is a tensor with dimensions [batch_size].
@@ -275,22 +256,26 @@ def mean_absolute_loss(predicted, target, name=None):
ValueError: If `predicted` and `target` shapes do not match.
"""
- with ops.op_scope([predicted, target], name, "mean_absolute_loss") as scope:
- return reduce_batch_mean(absolute_loss(predicted, target), name=scope)
+ return _sum_loss(predicted, target, absolute_loss, name=name)
-def mean_squared_loss(predicted, target, name=None):
- """Calculates the mean squared loss across batches.
+def sum_squared_loss(predicted, target, name="sum_squared_loss"):
+ """Calculates the sum of the squared loss across batches.
Computes the squared difference between the target and predicted
- tensors, and averages across all dimensions except dimension 0:
+ tensors, sums across all dimensions except dimension 0.
- losses = reduce_batch_mean(squared_loss(predicted, target))
+ losses = reduce_batch_sum(squared_loss(predicted, target))
where `losses` is a tensor with dimensions [batch_size].
The tensors must have the same shape.
+ This function is equivalent to typical formulations of L2 loss, and
+ similar to TensorFlow's l2_loss function. It differs from the
+ l2_loss function by allowing the caller to specify both the
+ predicted and target tensors.
+
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
of predicted values.
@@ -300,29 +285,43 @@ def mean_squared_loss(predicted, target, name=None):
name: A name for the operation (optional).
Returns:
- A `[batch_size]` tensor of squared differences, averaged across
- all dimensions except dimension 0.
+ A `[batch_size]` tensor of squared losses summed across all dimensions
+ except dimension 0.
Raises:
ValueError: If `predicted` and `target` shapes do not match.
"""
- with ops.op_scope([predicted, target], name, "mean_squared_loss") as scope:
- return reduce_batch_mean(squared_loss(predicted, target), name=scope)
+ return _sum_loss(predicted, target, squared_loss, name=name)
-def root_mean_squared_loss(predicted, target, name=None):
- """Calculates the root mean squared loss across batches.
+def sum_logistic_loss(logit, target, name="sum_logistic_loss"):
+ """Calculates the sum of the logistic loss across batches.
- Computes the root mean squared loss between the target and predicted
- tensors, which is the square root of the mean squared differences
- between the predicted and target tensors:
+ Computes the logistic between logit and predicted tensors, summed across all
+ dimensions except dimension 0.
- losses = sqrt(mean_squared_loss(predicted, target))
+ **WARNING:** `logit` must be unscaled, while the `target` should be a
+ normalized probability prediction. See
+ `tf.nn.sigmoid_cross_entropy_with_logits` for more details.
- where `losses` is a tensor with dimensions [batch_size].
+ Args:
+ logit: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted logit values.
+ target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+ name: A name for the operation (optional).
- The tensors must have the same shape.
+ Returns:
+ A `[batch_size]` tensor of logistic losses summed across all dimensions
+ except dimension 0.
+ """
+ return _sum_loss(logit, target, logistic_loss, name=name)
+
+
+def _scalar_loss(predicted, target, loss_fn, name=None):
+ """Reduces losses to a scalar.
Args:
predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
@@ -330,23 +329,52 @@ def root_mean_squared_loss(predicted, target, name=None):
target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
target values. The shape of the target tensor should match the
`predicted` tensor.
+ loss_fn: Loss to apply, takes 2 tensors as parameters and returns a tensor.
name: A name for the operation (optional).
Returns:
- A `[batch_size]` tensor of the root mean squared differences.
+ Caculate sum of losses per example, then average across batch.
+ """
+ with ops.op_scope([predicted, target], name, "scalar_loss") as scope:
+ return math_ops.reduce_mean(
+ _sum_loss(predicted, target, loss_fn), name=scope)
- Raises:
- ValueError: If `predicted` and `target` shapes do not match.
+def scalar_absolute_loss(predicted, target, name="scalar_absolute_loss"):
+ """Reduces absolute losses to a scalar.
+
+ Args:
+ predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted values.
+ target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ Caculate sum of absolute losses per example, then average across batch.
"""
- with ops.op_scope([predicted, target],
- name,
- "root_mean_squared_loss") as scope:
- return math_ops.sqrt(mean_squared_loss(predicted, target),
- name=scope)
+ return _scalar_loss(predicted, target, loss_fn=absolute_loss, name=name)
+
+def scalar_squared_loss(predicted, target, name="scalar_squared_loss"):
+ """Reduces squared losses to a scalar.
-def scalar_logistic_loss(logit, target, name=None):
+ Args:
+ predicted: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]`
+ of predicted values.
+ target: A `Tensor` of shape `[batch_size, dim_1, ..., dim_n]` of
+ target values. The shape of the target tensor should match the
+ `predicted` tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ Caculate sum of squared losses per example, then average across batch.
+ """
+ return _scalar_loss(predicted, target, loss_fn=squared_loss, name=name)
+
+
+def scalar_logistic_loss(logit, target, name="scalar_logistic_loss"):
"""Calculates the logistic cross-entropy loss, averaged across batches.
**WARNING:** `logit` must be unscaled, while the `target` should be a
@@ -368,8 +396,5 @@ def scalar_logistic_loss(logit, target, name=None):
Raises:
ValueError: If `logit` and `target` shapes do not match.
"""
- with ops.op_scope([logit, target], name,
- "scalar_logistic_loss") as scope:
- batch_loss = reduce_batch_sum(nn.sigmoid_cross_entropy_with_logits(logit,
- target))
- return math_ops.reduce_mean(batch_loss, [0], name=scope)
+ return _scalar_loss(logit, target, loss_fn=logistic_loss, name=name)
+
diff --git a/tensorflow/contrib/layers/python/ops/loss_ops_test.py b/tensorflow/contrib/layers/python/ops/loss_ops_test.py
index 48f49989cf..1453af5331 100644
--- a/tensorflow/contrib/layers/python/ops/loss_ops_test.py
+++ b/tensorflow/contrib/layers/python/ops/loss_ops_test.py
@@ -21,6 +21,10 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
+from tensorflow.contrib.layers.python.framework import tensor_util
+
+pi = 3.14
+indiana_pi = 3.2 # https://en.wikipedia.org/wiki/Indiana_Pi_Bill
class ReduceBatchSumTest(tf.test.TestCase):
@@ -89,72 +93,6 @@ class ReduceBatchSumTest(tf.test.TestCase):
self.assertAllClose(expected_result, actual_result.eval())
-class ReduceBatchMeanTest(tf.test.TestCase):
-
- def testDimensionNone(self):
- with self.test_session():
- input_array = np.array([
- [1.0, 2.0],
- [-1.0, -2.0]
- ], dtype=np.float32)
- placeholder_vec = tf.placeholder(tf.float32, name="placeholder_vec")
- expected_result = np.array([1.5, -1.5])
- actual_result = tf.contrib.layers.reduce_batch_mean(placeholder_vec)
- self.assertEqual(actual_result.get_shape().as_list(), [None])
- self.assertAllClose(expected_result, actual_result.eval(feed_dict={
- placeholder_vec: input_array
- }))
-
- def testDimension0(self):
- with self.test_session():
- input_vec = tf.constant(2.0)
- with self.assertRaises(ValueError):
- tf.contrib.layers.reduce_batch_mean(input_vec)
-
- def testDimension1(self):
- with self.test_session():
- input_vec = tf.constant([1.0, 2.0])
- expected_result = np.array([1.0, 2.0])
- actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
- self.assertAllClose(expected_result, actual_result.eval())
-
- def testDimension2(self):
- with self.test_session():
- input_vec = tf.constant([
- [1.0, 2.0],
- [-1.0, -2.0]
- ])
- expected_result = np.array([1.5, -1.5])
- actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
- self.assertAllClose(expected_result, actual_result.eval())
-
- def testReturnShape(self):
- with self.test_session():
- input_vec = tf.constant([
- [1.0, 2.0],
- [-1.0, -2.0]
- ])
- expected_result = np.array([3.0, -3.0])
- actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
- self.assertShapeEqual(expected_result, actual_result)
-
- def testDimensionN(self):
- with self.test_session():
- input_vec = tf.constant([
- [
- [1.0, 2.0],
- [3.0, 4.0]
- ],
- [
- [5.0, 6.0],
- [7.0, 8.0]
- ]
- ])
- expected_result = np.array([2.5, 6.5])
- actual_result = tf.contrib.layers.reduce_batch_mean(input_vec)
- self.assertAllClose(expected_result, actual_result.eval())
-
-
class AbsoluteLossTest(tf.test.TestCase):
def _getTestVectors(self):
@@ -191,7 +129,7 @@ class SquaredLossTest(tf.test.TestCase):
target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
predicted = tf.constant([1.1, -0.2, 3.3, 1.6], shape=[2, 2],
name="predicted")
- expected_loss = np.array([0.01, 0.04, 0.09, 0.16]).reshape(2, 2)
+ expected_loss = np.array([0.005, 0.02, 0.045, 0.08]).reshape(2, 2)
return target, predicted, expected_loss
def testSquaredLoss(self):
@@ -250,114 +188,108 @@ class SumSquaredLossTest(tf.test.TestCase):
tf.contrib.layers.sum_squared_loss(incompatible_shape, target)
-class MeanAbsoluteLossTest(tf.test.TestCase):
-
- def _getTestVectors(self):
- target = tf.constant([[0.0, 1.0, 2.0],
- [3.0, 2.0, 4.0]],
- shape=[2, 3],
- name="target")
- predicted = tf.constant([[3.0, -3.0, 0.0],
- [1.0, 2.0, 0.0]],
- shape=[2, 3],
- name="predicted")
- expected_loss = np.array([3.0, 2.0])
- return target, predicted, expected_loss
-
- def testMeanAbsoluteLoss(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.mean_absolute_loss(predicted, target)
- self.assertAllClose(expected_loss, result.eval())
-
- def testMeanAbsoluteLossReturnShape(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.mean_absolute_loss(predicted, target)
- self.assertShapeEqual(expected_loss, result)
+class ScalarAbsoluteLossTest(tf.test.TestCase):
- def testInvalidShapesValueError(self):
+ def testScalarAbsoluteLoss(self):
with self.test_session():
- target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
- incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
- name="incompatible_shape")
- with self.assertRaises(ValueError):
- tf.contrib.layers.mean_absolute_loss(incompatible_shape, target)
-
-
-class MeanSquaredLossTest(tf.test.TestCase):
-
- def _getTestVectors(self):
- target = tf.constant([[0.0, 1.0, 2.0],
- [3.0, 2.0, 4.0]],
- shape=[2, 3],
- name="target")
- predicted = tf.constant([[3.0, -3.0, 0.0],
- [1.0, 2.0, 0.0]],
- shape=[2, 3],
- name="predicted")
- expected_loss = np.array([9.666667, 6.666667])
- return target, predicted, expected_loss
-
- def testMeanSquaredLoss(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.mean_squared_loss(predicted, target)
- self.assertAllClose(expected_loss, result.eval())
-
- def testMeanSquaredLossReturnShape(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.mean_squared_loss(predicted, target)
- self.assertShapeEqual(expected_loss, result)
-
- def testInvalidShapesValueError(self):
+ actual = tf.constant([pi], name="pi")
+ actual_placeholder = tf.placeholder(tf.float32)
+ label = tf.constant([indiana_pi], name="lbl")
+ label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
+ expected_loss = abs(indiana_pi - pi)
+
+ # Both shapes are set.
+ both_shapes_loss = tf.contrib.layers.scalar_absolute_loss(actual, label)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ both_shapes_loss.eval(), expected_loss, decimal=6)
+
+ # No shape for 'actual' - check that the loss layer can be created.
+ no_actual_shape_loss = tf.contrib.layers.scalar_absolute_loss(
+ actual_placeholder, label)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_actual_shape_loss.eval({actual_placeholder: [pi]}),
+ expected_loss, decimal=6)
+
+ # No shape for 'label' - check that the loss layer can be created.
+ no_label_shape_loss = tf.contrib.layers.scalar_absolute_loss(
+ actual, label_placeholder)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
+ expected_loss, decimal=6)
+
+ # No shapes.
+ no_shape_loss = tf.contrib.layers.scalar_absolute_loss(
+ actual_placeholder, label_placeholder)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_shape_loss.eval({label_placeholder: [indiana_pi],
+ actual_placeholder: [pi]}),
+ expected_loss, decimal=6)
+
+ # Evaluate the previous one again, but this time with different
+ # (matching) shapes. This should still work.
+ np.testing.assert_almost_equal(
+ no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
+ actual_placeholder: [pi, pi]}),
+ expected_loss, decimal=6)
+
+
+class ScalarSquaredLossTest(tf.test.TestCase):
+
+ def testScalarSquaredLoss(self):
with self.test_session():
- target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
- incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
- name="incompatible_shape")
- with self.assertRaises(ValueError):
- tf.contrib.layers.mean_squared_loss(incompatible_shape, target)
-
-
-class RootMeanSquaredLossTest(tf.test.TestCase):
-
- def _getTestVectors(self):
- target = tf.constant([[0.0, 1.0, 2.0],
- [3.0, 2.0, 4.0]],
- shape=[2, 3],
- name="target")
- predicted = tf.constant([[3.0, -3.0, 0.0],
- [1.0, 2.0, 0.0]],
- shape=[2, 3],
- name="predicted")
- expected_loss = np.array([3.109126, 2.5819889])
- return target, predicted, expected_loss
-
- def testRootMeanSquaredLoss(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.root_mean_squared_loss(predicted, target)
- self.assertAllClose(expected_loss, result.eval())
-
- def testRootMeanSquaredLossReturnShape(self):
- with self.test_session():
- target, predicted, expected_loss = self._getTestVectors()
- result = tf.contrib.layers.root_mean_squared_loss(predicted, target)
- self.assertShapeEqual(expected_loss, result)
-
- def testInvalidShapesValueError(self):
- with self.test_session():
- target = tf.constant([1.0, 0.0, 3.0, 2.0], shape=[2, 2], name="target")
- incompatible_shape = tf.constant([0.0, 1.1], shape=[2],
- name="incompatible_shape")
- with self.assertRaises(ValueError):
- tf.contrib.layers.root_mean_squared_loss(incompatible_shape, target)
-
-
-class MeanScalarLogisticLossTest(tf.test.TestCase):
-
- def _get_mean_sigmoid_logistic_loss(self, logit, target):
+ actual = tf.constant([pi], name="pi")
+ actual_placeholder = tf.placeholder(tf.float32)
+ label = tf.constant([indiana_pi], name="lbl")
+ label_placeholder = tf.placeholder(tf.float32, name="lbl_ph")
+ expected_loss = (indiana_pi - pi) * (indiana_pi - pi) / 2
+
+ # Both shapes are set.
+ both_shapes_loss = tf.contrib.layers.scalar_squared_loss(actual, label)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ both_shapes_loss.eval(), expected_loss, decimal=6)
+
+ # No shape for 'actual' - check that the loss layer can be created.
+ no_actual_shape_loss = tf.contrib.layers.scalar_squared_loss(
+ actual_placeholder, label)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_actual_shape_loss.eval({actual_placeholder: [pi]}),
+ expected_loss, decimal=6)
+
+ # No shape for 'label' - check that the loss layer can be created.
+ no_label_shape_loss = tf.contrib.layers.scalar_squared_loss(
+ actual, label_placeholder)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_label_shape_loss.eval({label_placeholder: [indiana_pi]}),
+ expected_loss,
+ decimal=6)
+
+ # No shapes.
+ no_shape_loss = tf.contrib.layers.scalar_squared_loss(
+ actual_placeholder, label_placeholder)
+ tf.initialize_all_variables().run()
+ np.testing.assert_almost_equal(
+ no_shape_loss.eval({label_placeholder: [indiana_pi],
+ actual_placeholder: [pi]}),
+ expected_loss, decimal=6)
+
+ # Evaluate the previous one again, but this time with different
+ # (matching) shapes. This should still work.
+ np.testing.assert_almost_equal(
+ no_shape_loss.eval({label_placeholder: [indiana_pi, indiana_pi],
+ actual_placeholder: [pi, pi]}),
+ expected_loss, decimal=6)
+
+
+class ScalarLogisticLossTest(tf.test.TestCase):
+
+ def _expected_loss(self, logit, target):
sigmoid = 1.0 / (1.0 + np.exp(-logit))
logistic_loss = (target * -np.log(sigmoid)) - (
(1.0 - target) * np.log(1.0 - sigmoid))
@@ -365,14 +297,13 @@ class MeanScalarLogisticLossTest(tf.test.TestCase):
return np.sum(batch_losses) / len(batch_losses)
- def test_mean__scalar_logistic_loss(self):
+ def test_scalar_logistic_loss(self):
logit = np.array([[9.45, -42], [4.2, 1], [-0.6, 20]])
target = np.array([[0.8, 0.9], [0.45, 0.99999], [0.1, 0.0006]])
- expected_loss = self._get_mean_sigmoid_logistic_loss(logit, target)
with self.test_session():
result = tf.contrib.layers.scalar_logistic_loss(
tf.constant(logit), tf.constant(target))
- self.assertAllClose(expected_loss, result.eval())
+ self.assertAllClose(self._expected_loss(logit, target), result.eval())
if __name__ == "__main__":
diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD
index 234142757f..3d92123c28 100644
--- a/tensorflow/contrib/linear_optimizer/BUILD
+++ b/tensorflow/contrib/linear_optimizer/BUILD
@@ -36,6 +36,7 @@ py_test(
name = "sdca_ops_test",
srcs = ["python/kernel_tests/sdca_ops_test.py"],
srcs_version = "PY2AND3",
+ tags = ["noasan"], # doesn't pass ASAN for some reason
deps = [
":sdca_ops_py",
"//tensorflow:tensorflow_py",
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index e973a88bb7..a6da0ce5e9 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -112,12 +112,13 @@ def make_dense_variable_dict(num_dense_features, num_examples):
def get_binary_predictions_for_logistic(predictions, cutoff=0.5):
return tf.cast(
tf.greater_equal(predictions, tf.ones_like(predictions) * cutoff),
- tf.float32)
+ dtype=tf.float32)
def get_binary_predictions_for_hinge(predictions):
- all_ones = tf.ones_like(predictions)
- return tf.add(tf.sign(predictions), all_ones) / 2
+ return tf.cast(
+ tf.greater_equal(predictions, tf.zeros_like(predictions)),
+ dtype=tf.float32)
# Setup the single container shared across all tests. This is testing proper
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 8116ad00b0..5820794f35 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -28,9 +28,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework.load_library import load_op_library
from tensorflow.python.framework.ops import convert_to_tensor
from tensorflow.python.framework.ops import name_scope
+from tensorflow.python.framework.ops import op_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as var_ops
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
from tensorflow.python.platform import resource_loader
@@ -55,6 +57,7 @@ def _maybe_load_sdca_ops():
assert _sdca_ops, 'Could not load _sdca_ops.so'
+# TODO(rohananil): add op_scope to appropriate methods.
class SdcaModel(object):
"""Stochastic dual coordinate ascent solver for linear models.
@@ -255,13 +258,20 @@ class SdcaModel(object):
predictions = math_ops.sigmoid(predictions)
return predictions
- def minimize(self):
+ def minimize(self, global_step=None, name=None):
"""Add operations to train a linear model by minimizing the loss function.
+ Args:
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation.
+
Returns:
An Operation that updates the variables passed in the constructor.
"""
- with name_scope('sdca/minimize'):
+ # Technically, the op depends on a lot more than the variables,
+ # but we'll keep the list short.
+ with op_scope([], name, 'sdca/minimize'):
sparse_features_indices = []
sparse_features_values = []
for sf in self._examples['sparse_features']:
@@ -301,7 +311,7 @@ class SdcaModel(object):
assign_ops.append(var.assign(slot_var))
assign_group = control_flow_ops.group(*assign_ops)
with ops.control_dependencies([assign_group]):
- return _sdca_ops.sdca_shrink_l1(
+ shrink_l1 = _sdca_ops.sdca_shrink_l1(
self._convert_n_to_tensor(
self._variables['sparse_features_weights'],
as_ref=True),
@@ -310,6 +320,11 @@ class SdcaModel(object):
as_ref=True),
l1=self._options['symmetric_l1_regularization'],
l2=self._symmetric_l2_regularization())
+ if not global_step:
+ return shrink_l1
+ with ops.control_dependencies([shrink_l1]):
+ with ops.colocate_with(global_step):
+ return state_ops.assign_add(global_step, 1, name=name).op
def approximate_duality_gap(self):
"""Add operations to compute the approximate duality gap.
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index cea3b723f8..ef76ebdd8d 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -968,7 +968,6 @@ tf_cuda_library(
tf_cuda_library(
name = "gpu_runtime",
srcs = [
- "common_runtime/gpu/gpu_allocator_retry.cc",
"common_runtime/gpu/gpu_bfc_allocator.cc",
"common_runtime/gpu/gpu_debug_allocator.cc",
"common_runtime/gpu/gpu_device.cc",
@@ -982,7 +981,6 @@ tf_cuda_library(
"common_runtime/gpu_device_context.h",
],
hdrs = [
- "common_runtime/gpu/gpu_allocator_retry.h",
"common_runtime/gpu/gpu_bfc_allocator.h",
"common_runtime/gpu/gpu_debug_allocator.h",
"common_runtime/gpu/gpu_device.h",
@@ -991,7 +989,6 @@ tf_cuda_library(
"common_runtime/gpu/gpu_util.h",
"common_runtime/gpu/pool_allocator.h",
"common_runtime/gpu/process_state.h",
- "common_runtime/gpu/visitable_allocator.h",
],
copts = tf_copts(),
linkstatic = 1,
diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc
index fa2f1417d2..8abebcd811 100644
--- a/tensorflow/core/client/tensor_c_api.cc
+++ b/tensorflow/core/client/tensor_c_api.cc
@@ -420,18 +420,26 @@ void TF_Run_Helper(TF_Session* s, const char* handle,
run_options->length)) {
status->status =
tensorflow::errors::InvalidArgument("Unparseable RunOptions proto");
+ return;
+ }
+ if (run_outputs != nullptr && run_outputs->data != nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Passing non-empty run_outputs is invalid.");
+ return;
}
- RunOutputs run_outputs_proto;
+ RunOutputs run_outputs_proto;
result = s->session->Run(run_options_proto, inputs, output_tensor_names,
target_node_names, &outputs, &run_outputs_proto);
// Serialize back to upstream client, who now owns the new buffer
- int proto_size = run_outputs_proto.ByteSize();
- void* str_buf = reinterpret_cast<void*>(operator new(proto_size));
- run_outputs_proto.SerializeToArray(str_buf, proto_size);
- run_outputs->data = str_buf;
- run_outputs->length = proto_size;
+ if (run_outputs != nullptr) {
+ int proto_size = run_outputs_proto.ByteSize();
+ void* str_buf = reinterpret_cast<void*>(operator new(proto_size));
+ run_outputs_proto.SerializeToArray(str_buf, proto_size);
+ run_outputs->data = str_buf;
+ run_outputs->length = proto_size;
+ }
}
} else {
// NOTE(zongheng): PRun does not support RunOptions yet.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc b/tensorflow/core/common_runtime/allocator_retry.cc
index 4d97491f2e..8c3c45706f 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc
+++ b/tensorflow/core/common_runtime/allocator_retry.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/allocator_retry.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
@@ -21,9 +21,9 @@ limitations under the License.
namespace tensorflow {
-GPUAllocatorRetry::GPUAllocatorRetry() : env_(Env::Default()) {}
+AllocatorRetry::AllocatorRetry() : env_(Env::Default()) {}
-void* GPUAllocatorRetry::AllocateRaw(
+void* AllocatorRetry::AllocateRaw(
std::function<void*(size_t alignment, size_t num_bytes,
bool verbose_failure)>
alloc_func,
diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h b/tensorflow/core/common_runtime/allocator_retry.h
index aa4ac81998..613f19d41b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h
+++ b/tensorflow/core/common_runtime/allocator_retry.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ALLOCATOR_RETRY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_ALLOCATOR_RETRY_H_
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
@@ -23,9 +23,9 @@ limitations under the License.
namespace tensorflow {
// A retrying wrapper for a memory allocator.
-class GPUAllocatorRetry {
+class AllocatorRetry {
public:
- GPUAllocatorRetry();
+ AllocatorRetry();
// Call 'alloc_func' to obtain memory. On first call,
// 'verbose_failure' will be false. If return value is nullptr,
@@ -50,11 +50,11 @@ class GPUAllocatorRetry {
};
// Implementation details below
-inline void GPUAllocatorRetry::NotifyDealloc() {
+inline void AllocatorRetry::NotifyDealloc() {
mutex_lock l(mu_);
memory_returned_.notify_all();
}
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ALLOCATOR_RETRY_H_
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
new file mode 100644
index 0000000000..7a2ea91c9b
--- /dev/null
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -0,0 +1,702 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/bfc_allocator.h"
+
+#include "tensorflow/core/common_runtime/allocator_retry.h"
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
+ bool allow_growth, const string& name)
+ : suballocator_(sub_allocator),
+ name_(name),
+ free_chunks_list_(kInvalidChunkHandle),
+ next_allocation_id_(1) {
+ if (allow_growth) {
+ // 1MiB smallest initial allocation, unless total memory available
+ // is less.
+ curr_region_allocation_bytes_ =
+ RoundedBytes(std::min(total_memory, size_t{1048576}));
+ } else {
+ curr_region_allocation_bytes_ = RoundedBytes(total_memory);
+ }
+
+ // Allocate the requested amount of memory.
+ memory_limit_ = total_memory;
+ stats_.bytes_limit = static_cast<int64>(total_memory);
+
+ // Create a bunch of bins of various good sizes.
+
+ // We create bins to fit all possible ranges that cover the
+ // memory_limit_ starting from allocations up to 256 bytes to
+ // allocations up to (and including) the memory limit.
+ for (BinNum b = 0; b < kNumBins; b++) {
+ size_t bin_size = BinNumToSize(b);
+ VLOG(1) << "Creating bin of max chunk size "
+ << strings::HumanReadableNumBytes(bin_size);
+ new (BinFromIndex(b)) Bin(this, bin_size);
+ CHECK_EQ(BinForSize(bin_size), BinFromIndex(b));
+ CHECK_EQ(BinForSize(bin_size + 255), BinFromIndex(b));
+ CHECK_EQ(BinForSize(bin_size * 2 - 1), BinFromIndex(b));
+ if (b + 1 < kNumBins) {
+ CHECK_NE(BinForSize(bin_size * 2), BinFromIndex(b));
+ }
+ }
+}
+
+BFCAllocator::~BFCAllocator() {
+ // Return memory back.
+ VLOG(2) << "Number of regions allocated: "
+ << region_manager_.regions().size();
+ for (const auto& region : region_manager_.regions()) {
+ suballocator_->Free(region.ptr(), region.memory_size());
+ }
+
+ for (BinNum b = 0; b < kNumBins; b++) {
+ BinFromIndex(b)->~Bin();
+ }
+}
+
+BFCAllocator::Chunk* BFCAllocator::ChunkFromHandle(ChunkHandle h) {
+ DCHECK_GE(h, 0);
+ DCHECK_LT(h, static_cast<int>(chunks_.size()));
+ return &(chunks_[h]);
+}
+
+bool BFCAllocator::Extend(size_t rounded_bytes) {
+ // Do we have enough space to handle the client's request?
+ // If not, fail immediately.
+ if (total_region_allocated_bytes_ + rounded_bytes > memory_limit_) {
+ return false;
+ }
+
+ // If curr_region_allocation_bytes_ is not enough to satisfy the
+ // allocation, keep multiplying by a power of two until that is
+ // sufficient.
+ bool increased_allocation = false;
+ while (rounded_bytes > curr_region_allocation_bytes_) {
+ curr_region_allocation_bytes_ *= 2;
+ increased_allocation = true;
+ }
+
+ // Try allocating.
+ size_t bytes = curr_region_allocation_bytes_;
+ void* mem_addr = suballocator_->Alloc(32, bytes);
+ if (mem_addr == nullptr && !started_backpedal_) {
+ // Only backpedal once.
+ started_backpedal_ = true;
+
+ static constexpr float kBackpedalFactor = 0.9;
+
+ // Try allocating less memory.
+ bytes = RoundedBytes(bytes * kBackpedalFactor);
+ while (mem_addr == nullptr && bytes > rounded_bytes) {
+ mem_addr = suballocator_->Alloc(32, bytes);
+ bytes = RoundedBytes(bytes * kBackpedalFactor);
+ }
+ }
+
+ if (mem_addr == nullptr) {
+ return false;
+ }
+
+ if (!increased_allocation) {
+ // Increase the region size of the next required allocation.
+ curr_region_allocation_bytes_ *= 2;
+ }
+
+ VLOG(1) << "Extending allocation by " << strings::HumanReadableNumBytes(bytes)
+ << " bytes.";
+
+ total_region_allocated_bytes_ += bytes;
+ VLOG(1) << "Total allocated bytes: "
+ << strings::HumanReadableNumBytes(total_region_allocated_bytes_);
+
+ VLOG(1) << "Allocated memory at " << mem_addr << " to "
+ << static_cast<void*>(static_cast<char*>(mem_addr) + bytes);
+ region_manager_.AddAllocationRegion(mem_addr, bytes);
+
+ // Create one large chunk for the whole memory space that will
+ // be chunked later.
+ ChunkHandle h = AllocateChunk();
+ BFCAllocator::Chunk* c = ChunkFromHandle(h);
+ c->ptr = mem_addr;
+ c->size = bytes;
+ c->allocation_id = -1;
+ c->prev = kInvalidChunkHandle;
+ c->next = kInvalidChunkHandle;
+
+ region_manager_.set_handle(c->ptr, h);
+
+ // TODO(vrv): Try to merge this new region with an existing region,
+ // if the address space is contiguous, to avoid fragmentation
+ // across regions.
+
+ // Insert the chunk into the right bin.
+ InsertFreeChunkIntoBin(h);
+
+ // Invoke visitors on newly allocated region.
+ for (auto visitor : region_visitors_) {
+ visitor(mem_addr, bytes);
+ }
+ return true;
+}
+
+BFCAllocator::ChunkHandle BFCAllocator::AllocateChunk() {
+ if (free_chunks_list_ != kInvalidChunkHandle) {
+ ChunkHandle h = free_chunks_list_;
+ Chunk* c = ChunkFromHandle(h);
+ free_chunks_list_ = c->next;
+ return h;
+ } else {
+ ChunkHandle h = chunks_.size();
+ chunks_.resize(h + 1);
+ return h;
+ }
+}
+
+void BFCAllocator::DeallocateChunk(ChunkHandle h) {
+ Chunk* c = ChunkFromHandle(h);
+ c->next = free_chunks_list_;
+ free_chunks_list_ = h;
+}
+
+void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) {
+ // Fast path: Try once to allocate without getting the retry_helper_ involved
+ void* r = AllocateRawInternal(unused_alignment, num_bytes, false);
+ if (r != nullptr) {
+ return r;
+ } else {
+ static const int64 kMaxMillisToWait = 10000; // 10 seconds
+ return retry_helper_.AllocateRaw(
+ [this](size_t a, size_t nb, bool v) {
+ return AllocateRawInternal(a, nb, v);
+ },
+ kMaxMillisToWait, unused_alignment, num_bytes);
+ }
+}
+
+void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes,
+ const AllocationAttributes& allocation_attr) {
+ if (allocation_attr.no_retry_on_failure) {
+ // Return immediately upon the first failure if this is for allocating an
+ // optional scratch space.
+ void* result = AllocateRawInternal(unused_alignment, num_bytes, false);
+ if (result == nullptr) {
+ // The counter incrementing is not thread-safe. But we don't really care.
+ // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
+ // more general usage.
+ static int log_counter = 0;
+ if (log_counter < 10) {
+ log_counter++;
+ LOG(WARNING)
+ << "Ran out of memory trying to allocate "
+ << strings::HumanReadableNumBytes(num_bytes)
+ << ". The caller indicates that this is not a failure, but"
+ << " may mean that there could be performance gains if more"
+ << " memory is available.";
+ }
+ }
+ return result;
+ } else {
+ return AllocateRaw(unused_alignment, num_bytes);
+ }
+}
+
+// static
+size_t BFCAllocator::RoundedBytes(size_t bytes) {
+ size_t rounded_bytes =
+ (kMinAllocationSize *
+ ((bytes + kMinAllocationSize - 1) / kMinAllocationSize));
+ DCHECK_EQ(size_t{0}, rounded_bytes % kMinAllocationSize);
+ return rounded_bytes;
+}
+
+void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
+ size_t num_bytes,
+ bool dump_log_on_failure) {
+ if (num_bytes == 0) {
+ LOG(ERROR) << "tried to allocate 0 bytes";
+ return nullptr;
+ }
+ // First, always allocate memory of at least kMinAllocationSize
+ // bytes, and always allocate multiples of kMinAllocationSize bytes
+ // so all memory addresses are nicely byte aligned.
+ size_t rounded_bytes = RoundedBytes(num_bytes);
+
+ // The BFC allocator tries to find the best fit first.
+ BinNum bin_num = BinNumForSize(rounded_bytes);
+
+ mutex_lock l(lock_);
+ void* ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes);
+ if (ptr != nullptr) {
+ return ptr;
+ }
+
+ // Try to extend
+ if (Extend(rounded_bytes)) {
+ ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes);
+ if (ptr != nullptr) {
+ return ptr;
+ }
+ }
+
+ // We searched all bins for an existing free chunk to use and
+ // couldn't find one. This means we must have run out of memory,
+ // Dump the memory log for analysis.
+ if (dump_log_on_failure) {
+ DumpMemoryLog(rounded_bytes);
+ LOG(WARNING) << RenderOccupancy();
+ LOG(WARNING) << "Ran out of memory trying to allocate "
+ << strings::HumanReadableNumBytes(num_bytes)
+ << ". See logs for memory state.";
+ }
+ return nullptr;
+}
+
+void* BFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes,
+ size_t num_bytes) {
+ // First identify the first bin that could satisfy rounded_bytes.
+ for (; bin_num < kNumBins; bin_num++) {
+ // Start searching from the first bin for the smallest chunk that fits
+ // rounded_bytes.
+ Bin* b = BinFromIndex(bin_num);
+ for (auto citer = b->free_chunks.begin(); citer != b->free_chunks.end();
+ ++citer) {
+ const BFCAllocator::ChunkHandle h = (*citer);
+ BFCAllocator::Chunk* chunk = ChunkFromHandle(h);
+ DCHECK(!chunk->in_use());
+ if (chunk->size >= rounded_bytes) {
+ // We found an existing chunk that fits us that wasn't in use, so remove
+ // it from the free bin structure prior to using.
+ RemoveFreeChunkIterFromBin(&b->free_chunks, citer);
+
+ // If we can break the size of the chunk into two reasonably
+ // large pieces, do so.
+ //
+ // TODO(vrv): What should be the criteria when deciding when
+ // to split?
+ if (chunk->size >= rounded_bytes * 2) {
+ SplitChunk(h, rounded_bytes);
+ chunk = ChunkFromHandle(h); // Update chunk pointer in case it moved
+ }
+
+ // The requested size of the returned chunk is what the user
+ // has allocated.
+ chunk->requested_size = num_bytes;
+ // Assign a unique id and increment the id counter, marking the
+ // chunk as being in use.
+ chunk->allocation_id = next_allocation_id_++;
+
+ // Update stats.
+ ++stats_.num_allocs;
+ stats_.bytes_in_use += chunk->size;
+ stats_.max_bytes_in_use =
+ std::max(stats_.max_bytes_in_use, stats_.bytes_in_use);
+ stats_.max_alloc_size =
+ std::max<std::size_t>(stats_.max_alloc_size, chunk->size);
+
+ VLOG(4) << "Returning: " << chunk->ptr;
+ if (VLOG_IS_ON(4)) {
+ LOG(INFO) << "A: " << RenderOccupancy();
+ }
+ return chunk->ptr;
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+void BFCAllocator::SplitChunk(BFCAllocator::ChunkHandle h, size_t num_bytes) {
+ // Allocate the new chunk before we do any ChunkFromHandle
+ ChunkHandle h_new_chunk = AllocateChunk();
+
+ Chunk* c = ChunkFromHandle(h);
+ CHECK(!c->in_use() && (c->bin_num == kInvalidBinNum));
+
+ // Create a new chunk starting num_bytes after c
+ BFCAllocator::Chunk* new_chunk = ChunkFromHandle(h_new_chunk);
+ new_chunk->ptr = static_cast<void*>(static_cast<char*>(c->ptr) + num_bytes);
+ region_manager_.set_handle(new_chunk->ptr, h_new_chunk);
+
+ // Set the new sizes of the chunks.
+ new_chunk->size = c->size - num_bytes;
+ c->size = num_bytes;
+
+ // The new chunk is not in use.
+ new_chunk->allocation_id = -1;
+
+ // Maintain the pointers.
+ // c <-> c_neighbor becomes
+ // c <-> new_chunk <-> c_neighbor
+ BFCAllocator::ChunkHandle h_neighbor = c->next;
+ new_chunk->prev = h;
+ new_chunk->next = h_neighbor;
+ c->next = h_new_chunk;
+ if (h_neighbor != kInvalidChunkHandle) {
+ Chunk* c_neighbor = ChunkFromHandle(h_neighbor);
+ c_neighbor->prev = h_new_chunk;
+ }
+
+ // Add the newly free chunk to the free bin.
+ InsertFreeChunkIntoBin(h_new_chunk);
+}
+
+void BFCAllocator::DeallocateRaw(void* ptr) {
+ DeallocateRawInternal(ptr);
+ retry_helper_.NotifyDealloc();
+}
+
+void BFCAllocator::DeallocateRawInternal(void* ptr) {
+ if (ptr == nullptr) {
+ LOG(ERROR) << "tried to deallocate nullptr";
+ return;
+ }
+ mutex_lock l(lock_);
+
+ // Find the chunk from the ptr.
+ BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
+ CHECK(h != kInvalidChunkHandle);
+
+ // Consider coalescing it.
+ FreeAndMaybeCoalesce(h);
+
+ if (VLOG_IS_ON(4)) {
+ LOG(INFO) << "F: " << RenderOccupancy();
+ }
+}
+
+// Merges h1 and h2 when Chunk(h1)->next is h2 and Chunk(h2)->prev is c1.
+// We merge Chunk(h2) into Chunk(h1).
+void BFCAllocator::Merge(BFCAllocator::ChunkHandle h1,
+ BFCAllocator::ChunkHandle h2) {
+ Chunk* c1 = ChunkFromHandle(h1);
+ Chunk* c2 = ChunkFromHandle(h2);
+ // We can only merge chunks that are not in use.
+ CHECK(!c1->in_use() && !c2->in_use());
+
+ // c1's prev doesn't change, still points to the same ptr, and is
+ // still not in use.
+
+ // Fix up neighbor pointers
+ //
+ // c1 <-> c2 <-> c3 should become
+ // c1 <-> c3
+
+ BFCAllocator::ChunkHandle h3 = c2->next;
+ c1->next = h3;
+ CHECK(c2->prev == h1);
+ if (h3 != kInvalidChunkHandle) {
+ BFCAllocator::Chunk* c3 = ChunkFromHandle(h3);
+ c3->prev = h1;
+ }
+
+ // Set the new size
+ c1->size += c2->size;
+
+ DeleteChunk(h2);
+}
+
+void BFCAllocator::DeleteChunk(ChunkHandle h) {
+ // Delete h and cleanup all state
+ Chunk* c = ChunkFromHandle(h);
+ // VLOG(4) << "Removing: " << c->ptr;
+ region_manager_.erase(c->ptr);
+ DeallocateChunk(h);
+}
+
+void BFCAllocator::InsertFreeChunkIntoBin(BFCAllocator::ChunkHandle h) {
+ Chunk* c = ChunkFromHandle(h);
+ CHECK(!c->in_use() && (c->bin_num == kInvalidBinNum));
+ BinNum bin_num = BinNumForSize(c->size);
+ Bin* new_bin = BinFromIndex(bin_num);
+ c->bin_num = bin_num;
+ new_bin->free_chunks.insert(h);
+}
+
+void BFCAllocator::RemoveFreeChunkIterFromBin(
+ BFCAllocator::Bin::FreeChunkSet* free_chunks,
+ const BFCAllocator::Bin::FreeChunkSet::iterator& citer) {
+ ChunkHandle h = *citer;
+ Chunk* c = ChunkFromHandle(h);
+ CHECK(!c->in_use() && (c->bin_num != kInvalidBinNum));
+ free_chunks->erase(citer);
+ c->bin_num = kInvalidBinNum;
+}
+
+void BFCAllocator::RemoveFreeChunkFromBin(BFCAllocator::ChunkHandle h) {
+ Chunk* c = ChunkFromHandle(h);
+ CHECK(!c->in_use() && (c->bin_num != kInvalidBinNum));
+ int count = BinFromIndex(c->bin_num)->free_chunks.erase(h);
+ CHECK(count > 0) << "Could not find chunk in bin";
+ c->bin_num = kInvalidBinNum;
+}
+
+void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) {
+ Chunk* c = ChunkFromHandle(h);
+ CHECK(c->in_use() && (c->bin_num == kInvalidBinNum));
+
+ // Mark the chunk as no longer in use
+ c->allocation_id = -1;
+
+ // Updates the stats.
+ stats_.bytes_in_use -= c->size;
+
+ // This chunk is no longer in-use, consider coalescing the chunk
+ // with adjacent chunks.
+ ChunkHandle chunk_to_reassign = h;
+
+ // If the next chunk is free, coalesce the two
+ if (c->next != kInvalidChunkHandle) {
+ Chunk* cnext = ChunkFromHandle(c->next);
+ if (!cnext->in_use()) {
+ // VLOG(8) << "Chunk at " << cnext->ptr << " merging with c " <<
+ // c->ptr;
+
+ chunk_to_reassign = h;
+
+ // Deletes c->next
+ RemoveFreeChunkFromBin(c->next);
+ Merge(h, ChunkFromHandle(h)->next);
+ }
+ }
+
+ // If the previous chunk is free, coalesce the two
+ c = ChunkFromHandle(h);
+ if (c->prev != kInvalidChunkHandle) {
+ Chunk* cprev = ChunkFromHandle(c->prev);
+ if (!cprev->in_use()) {
+ // VLOG(8) << "Chunk at " << c->ptr << " merging into c->prev "
+ // << cprev->ptr;
+
+ chunk_to_reassign = c->prev;
+
+ // Deletes c
+ RemoveFreeChunkFromBin(c->prev);
+ Merge(ChunkFromHandle(h)->prev, h);
+ c = ChunkFromHandle(h);
+ }
+ }
+
+ InsertFreeChunkIntoBin(chunk_to_reassign);
+}
+
+void BFCAllocator::AddAllocVisitor(Visitor visitor) {
+ VLOG(1) << "AddVisitor";
+ mutex_lock l(lock_);
+ region_visitors_.push_back(visitor);
+ for (const auto& region : region_manager_.regions()) {
+ visitor(region.ptr(), region.memory_size());
+ }
+}
+
+bool BFCAllocator::TracksAllocationSizes() { return true; }
+
+size_t BFCAllocator::RequestedSize(void* ptr) {
+ mutex_lock l(lock_);
+ BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
+ CHECK(h != kInvalidChunkHandle)
+ << "Asked for requested size of pointer we never allocated: " << ptr;
+ BFCAllocator::Chunk* c = ChunkFromHandle(h);
+ return c->requested_size;
+}
+
+size_t BFCAllocator::AllocatedSize(void* ptr) {
+ mutex_lock l(lock_);
+ BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
+ CHECK(h != kInvalidChunkHandle)
+ << "Asked for allocated size of pointer we never allocated: " << ptr;
+ BFCAllocator::Chunk* c = ChunkFromHandle(h);
+ return c->size;
+}
+
+int64 BFCAllocator::AllocationId(void* ptr) {
+ mutex_lock l(lock_);
+ BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
+ CHECK(h != kInvalidChunkHandle)
+ << "Asked for allocation id of pointer we never allocated: " << ptr;
+ BFCAllocator::Chunk* c = ChunkFromHandle(h);
+ return c->allocation_id;
+}
+
+namespace {
+
+void RenderRegion(char* rendered, const size_t resolution,
+ const size_t total_render_size, const size_t offset,
+ const void* base_ptr, const void* ptr, const size_t size,
+ const char c) {
+ const char* base_ptr_c = static_cast<const char*>(base_ptr);
+ const char* ptr_c = static_cast<const char*>(ptr);
+
+ size_t start_location =
+ ((ptr_c - base_ptr_c + offset) * resolution) / total_render_size;
+ CHECK_GE(start_location, 0);
+ CHECK_LT(start_location, resolution);
+ size_t end_location =
+ ((ptr_c + size - 1 - base_ptr_c + offset) * resolution) /
+ total_render_size;
+ CHECK_GE(end_location, 0);
+ CHECK_LT(end_location, resolution);
+
+ for (size_t i = start_location; i <= end_location; ++i) {
+ rendered[i] = c;
+ }
+}
+
+} // namespace
+
+string BFCAllocator::RenderOccupancy() {
+ // Make a buffer for the ASCII-art representation.
+ const size_t resolution = 100;
+ char rendered[resolution];
+
+ // Compute the total region size to render over
+ size_t total_region_size = 0;
+ for (const auto& region : region_manager_.regions()) {
+ total_region_size += region.memory_size();
+ }
+
+ // Start out with everything empty
+ RenderRegion(rendered, resolution, total_region_size, 0, nullptr, nullptr,
+ total_region_size, '_');
+
+ size_t region_offset = 0;
+ for (const auto& region : region_manager_.regions()) {
+ ChunkHandle h = region_manager_.get_handle(region.ptr());
+ // Then render each chunk left to right.
+ while (h != kInvalidChunkHandle) {
+ Chunk* c = ChunkFromHandle(h);
+ if (c->in_use()) {
+ // Render the wasted space
+ size_t wasted = c->size - c->requested_size;
+ if (wasted > 0) {
+ RenderRegion(rendered, resolution, total_region_size,
+ region_offset + c->requested_size, region.ptr(), c->ptr,
+ wasted, 'x');
+ }
+ // Then the occupied space
+ RenderRegion(rendered, resolution, total_region_size, region_offset,
+ region.ptr(), c->ptr, c->requested_size, '*');
+ }
+ h = c->next;
+ }
+ region_offset += region.memory_size();
+ }
+
+ return StringPiece(rendered, resolution).ToString();
+}
+
+void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
+ // For each bin: tally up the total number of chunks and bytes.
+ // Note that bins hold only free chunks.
+ for (BinNum bin_num = 0; bin_num < kNumBins; bin_num++) {
+ Bin* b = BinFromIndex(bin_num);
+
+ size_t total_bytes_in_use = 0;
+ size_t total_bytes_in_bin = 0;
+ size_t total_requested_bytes_in_use = 0;
+ size_t total_requested_bytes_in_bin = 0;
+ size_t total_chunks_in_use = 0;
+ size_t total_chunks_in_bin = 0;
+ for (ChunkHandle h : b->free_chunks) {
+ Chunk* c = ChunkFromHandle(h);
+ total_bytes_in_bin += c->size;
+ total_requested_bytes_in_bin += c->requested_size;
+ ++total_chunks_in_bin;
+ if (c->in_use()) {
+ total_bytes_in_use += c->size;
+ total_requested_bytes_in_use += c->requested_size;
+ ++total_chunks_in_use;
+ }
+ }
+
+ LOG(INFO) << "Bin (" << b->bin_size
+ << "): \tTotal Chunks: " << total_chunks_in_bin
+ << ", Chunks in use: " << total_chunks_in_use << " "
+ << strings::HumanReadableNumBytes(total_bytes_in_bin)
+ << " allocated for chunks. "
+ << strings::HumanReadableNumBytes(total_requested_bytes_in_bin)
+ << " client-requested for chunks. "
+ << strings::HumanReadableNumBytes(total_bytes_in_use)
+ << " in use in bin. "
+ << strings::HumanReadableNumBytes(total_requested_bytes_in_use)
+ << " client-requested in use in bin.";
+ }
+
+ // Find the bin that we would have liked to allocate in, so we
+ // can get some further analysis about fragmentation.
+ Bin* b = BinForSize(num_bytes);
+
+ LOG(INFO) << "Bin for " << strings::HumanReadableNumBytes(num_bytes)
+ << " was " << strings::HumanReadableNumBytes(b->bin_size)
+ << ", Chunk State: ";
+
+ for (ChunkHandle h : b->free_chunks) {
+ Chunk* c = ChunkFromHandle(h);
+ LOG(INFO) << c->DebugString(this, true);
+ }
+
+ // Next show the chunks that are in use, and also summarize their
+ // number by size.
+ std::map<size_t, int> in_use_by_size;
+ for (const auto& region : region_manager_.regions()) {
+ ChunkHandle h = region_manager_.get_handle(region.ptr());
+ while (h != kInvalidChunkHandle) {
+ const Chunk* c = ChunkFromHandle(h);
+ if (c->in_use()) {
+ in_use_by_size[c->size]++;
+ LOG(INFO) << "Chunk at " << c->ptr << " of size " << c->size;
+ }
+ h = c->next;
+ }
+
+ h = region_manager_.get_handle(region.ptr());
+ while (h != kInvalidChunkHandle) {
+ const Chunk* c = ChunkFromHandle(h);
+ if (!c->in_use()) {
+ LOG(INFO) << "Free at " << c->ptr << " of size " << c->size;
+ }
+ h = c->next;
+ }
+ }
+
+ LOG(INFO) << " Summary of in-use Chunks by size: ";
+ size_t total_bytes = 0;
+ for (auto& it : in_use_by_size) {
+ LOG(INFO) << it.second << " Chunks of size " << it.first << " totalling "
+ << strings::HumanReadableNumBytes(it.first * it.second);
+ total_bytes += (it.first * it.second);
+ }
+ LOG(INFO) << "Sum Total of in-use chunks: "
+ << strings::HumanReadableNumBytes(total_bytes);
+ LOG(INFO) << "Stats: \n" << stats_.DebugString();
+}
+
+void BFCAllocator::GetStats(AllocatorStats* stats) {
+ mutex_lock l(lock_);
+ *stats = stats_;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
new file mode 100644
index 0000000000..1be804090a
--- /dev/null
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -0,0 +1,413 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMMON_RUNTIME_BFC_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_BFC_ALLOCATOR_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/allocator_retry.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+
+// A memory allocator that implements a 'best-fit with coalescing'
+// algorithm. This is essentially a very simple version of Doug Lea's
+// malloc (dlmalloc).
+//
+// The goal of this allocator is to support defragmentation via
+// coalescing. One assumption we make is that the process using this
+// allocator owns pretty much all of the memory, and that nearly
+// all requests to allocate memory go through this interface.
+class BFCAllocator : public VisitableAllocator {
+ public:
+ // Takes ownership of sub_allocator.
+ BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
+ bool allow_growth, const string& name);
+ ~BFCAllocator() override;
+
+ string Name() override { return name_; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void* AllocateRaw(size_t alignment, size_t num_bytes,
+ const AllocationAttributes& allocation_attr) override;
+ void DeallocateRaw(void* ptr) override;
+
+ void AddAllocVisitor(Visitor visitor) override;
+
+ // Does nothing, because memory is never freed.
+ void AddFreeVisitor(Visitor visitor) override {}
+
+ bool TracksAllocationSizes() override;
+
+ size_t RequestedSize(void* ptr) override;
+
+ size_t AllocatedSize(void* ptr) override;
+
+ int64 AllocationId(void* ptr) override;
+
+ void GetStats(AllocatorStats* stats) override;
+
+ private:
+ struct Bin;
+
+ void* AllocateRawInternal(size_t alignment, size_t num_bytes,
+ bool dump_log_on_failure);
+ void DeallocateRawInternal(void* ptr);
+
+ // A ChunkHandle is an index into the chunks_ vector in BFCAllocator
+ // kInvalidChunkHandle means an invalid chunk
+ typedef int ChunkHandle;
+ static const int kInvalidChunkHandle = -1;
+
+ typedef int BinNum;
+ static const int kInvalidBinNum = -1;
+ static const int kNumBins = 21;
+
+ // Chunks point to memory. Their prev/next pointers form a
+ // doubly-linked list of addresses sorted by base address that
+ // must be contiguous. Chunks contain information about whether
+ // they are in use or whether they are free, and contain a pointer
+ // to the bin they are in.
+ struct Chunk {
+ size_t size = 0; // Full size of buffer.
+
+ // We sometimes give chunks that are larger than needed to reduce
+ // fragmentation. requested_size keeps track of what the client
+ // actually wanted so we can understand whether our splitting
+ // strategy is efficient.
+ size_t requested_size = 0;
+
+ // allocation_id is set to -1 when the chunk is not in use. It is assigned a
+ // value greater than zero before the chunk is returned from
+ // AllocateRaw, and this value is unique among values assigned by
+ // the parent allocator.
+ int64 allocation_id = -1;
+ void* ptr = nullptr; // pointer to granted subbuffer.
+
+ // If not kInvalidChunkHandle, the memory referred to by 'prev' is directly
+ // preceding the memory used by this chunk. E.g., It should start
+ // at 'ptr - prev->size'
+ ChunkHandle prev = kInvalidChunkHandle;
+
+ // If not kInvalidChunkHandle, the memory referred to by 'next' is directly
+ // following the memory used by this chunk. E.g., It should be at
+ // 'ptr + size'
+ ChunkHandle next = kInvalidChunkHandle;
+
+ // What bin are we in?
+ BinNum bin_num = kInvalidBinNum;
+
+ bool in_use() const { return allocation_id != -1; }
+
+ string DebugString(BFCAllocator* a, bool recurse) {
+ string dbg;
+ strings::StrAppend(&dbg, " Size: ", strings::HumanReadableNumBytes(size),
+ " | Requested Size: ",
+ strings::HumanReadableNumBytes(requested_size),
+ " | in_use: ", in_use());
+ if (recurse && prev != BFCAllocator::kInvalidChunkHandle) {
+ Chunk* p = a->ChunkFromHandle(prev);
+ strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false));
+ }
+ if (recurse && next != BFCAllocator::kInvalidChunkHandle) {
+ Chunk* n = a->ChunkFromHandle(next);
+ strings::StrAppend(&dbg, ", next: ", n->DebugString(a, false));
+ }
+ return dbg;
+ }
+ };
+
+ // A Bin is a collection of similar-sized free chunks.
+ struct Bin {
+ // All chunks in this bin have >= bin_size memory.
+ size_t bin_size = 0;
+
+ struct ChunkComparator {
+ explicit ChunkComparator(BFCAllocator* allocator)
+ : allocator_(allocator) {}
+ // Sort first by size and then use pointer address as a tie breaker.
+ bool operator()(const ChunkHandle ha, const ChunkHandle hb) const {
+ const Chunk* a = allocator_->ChunkFromHandle(ha);
+ const Chunk* b = allocator_->ChunkFromHandle(hb);
+ if (a->size != b->size) {
+ return a->size < b->size;
+ }
+ return a->ptr < b->ptr;
+ }
+
+ private:
+ BFCAllocator* allocator_; // The parent allocator
+ };
+
+ typedef std::set<ChunkHandle, ChunkComparator> FreeChunkSet;
+ // List of free chunks within the bin, sorted by chunk size.
+ // Chunk * not owned.
+ FreeChunkSet free_chunks;
+ Bin(BFCAllocator* allocator, size_t bs)
+ : bin_size(bs), free_chunks(ChunkComparator(allocator)) {}
+ };
+
+ static const size_t kMinAllocationBits = 8;
+ static const size_t kMinAllocationSize = 1 << kMinAllocationBits;
+
+ // AllocationRegion maps pointers to ChunkHandles for a single
+ // contiguous memory region.
+ //
+ // This class is thread-compatible.
+ class AllocationRegion {
+ public:
+ AllocationRegion(void* ptr, size_t memory_size)
+ : ptr_(ptr),
+ memory_size_(memory_size),
+ end_ptr_(
+ static_cast<void*>(static_cast<char*>(ptr_) + memory_size_)) {
+ DCHECK_EQ(0, memory_size % kMinAllocationSize);
+ const size_t n_handles =
+ (memory_size + kMinAllocationSize - 1) / kMinAllocationSize;
+ handles_ = new ChunkHandle[n_handles];
+ for (size_t i = 0; i < n_handles; i++) {
+ handles_[i] = kInvalidChunkHandle;
+ }
+ }
+
+ AllocationRegion() {}
+
+ ~AllocationRegion() { delete[] handles_; }
+
+ AllocationRegion(AllocationRegion&& other) { Swap(other); }
+
+ AllocationRegion& operator=(AllocationRegion&& other) {
+ Swap(other);
+ return *this;
+ }
+
+ void* ptr() const { return ptr_; }
+ void* end_ptr() const { return end_ptr_; }
+ size_t memory_size() const { return memory_size_; }
+ ChunkHandle get_handle(const void* p) const {
+ return handles_[IndexFor(p)];
+ }
+ void set_handle(const void* p, ChunkHandle h) { handles_[IndexFor(p)] = h; }
+ void erase(const void* p) { set_handle(p, kInvalidChunkHandle); }
+
+ private:
+ void Swap(AllocationRegion& other) {
+ std::swap(ptr_, other.ptr_);
+ std::swap(memory_size_, other.memory_size_);
+ std::swap(end_ptr_, other.end_ptr_);
+ std::swap(handles_, other.handles_);
+ }
+
+ int IndexFor(const void* p) const {
+ std::uintptr_t p_int = reinterpret_cast<std::uintptr_t>(p);
+ std::uintptr_t base_int = reinterpret_cast<std::uintptr_t>(ptr_);
+ DCHECK_GE(p_int, base_int);
+ DCHECK_LT(p_int, base_int + memory_size_);
+ return static_cast<int>(((p_int - base_int) >> kMinAllocationBits));
+ }
+
+ // Metadata about the allocation region.
+ void* ptr_ = nullptr;
+ size_t memory_size_ = 0;
+ void* end_ptr_ = nullptr;
+
+ // Array of size "memory_size / kMinAllocationSize". It is
+ // indexed by (p-base) / kMinAllocationSize, contains ChunkHandle
+ // for the memory allocation represented by "p"
+ ChunkHandle* handles_ = nullptr;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(AllocationRegion);
+ };
+
+ // RegionManager aggregates one or more "AllocationRegions" and provides
+ // a layer of indirection from pointers to the underlying ChunkHandle,
+ // allowing allocation across multiple discontiguous memory regions.
+ //
+ // This class is thread-compatible.
+ class RegionManager {
+ public:
+ RegionManager() {}
+ ~RegionManager() {}
+
+ void AddAllocationRegion(void* ptr, size_t memory_size) {
+ // Insert sorted by end_ptr
+ auto entry =
+ std::upper_bound(regions_.begin(), regions_.end(), ptr, &Comparator);
+ regions_.insert(entry, AllocationRegion(ptr, memory_size));
+ }
+
+ ChunkHandle get_handle(const void* p) const {
+ return RegionFor(p)->get_handle(p);
+ }
+
+ void set_handle(const void* p, ChunkHandle h) {
+ return MutableRegionFor(p)->set_handle(p, h);
+ }
+ void erase(const void* p) { return MutableRegionFor(p)->erase(p); }
+
+ const std::vector<AllocationRegion>& regions() const { return regions_; }
+
+ private:
+ static bool Comparator(const void* ptr, const AllocationRegion& other) {
+ return ptr < other.end_ptr();
+ }
+
+ AllocationRegion* MutableRegionFor(const void* p) {
+ return const_cast<AllocationRegion*>(RegionFor(p));
+ }
+
+ const AllocationRegion* RegionFor(const void* p) const {
+ auto entry =
+ std::upper_bound(regions_.begin(), regions_.end(), p, &Comparator);
+
+ if (entry != regions_.end()) {
+ return &(*entry);
+ }
+
+ LOG(FATAL) << "Could not find Region for " << p;
+ return nullptr;
+ }
+
+ private:
+ std::vector<AllocationRegion> regions_;
+ };
+
+ // Returns 'bytes' rounded up to the next highest kMinAllocationSize.
+ size_t RoundedBytes(size_t bytes);
+
+ // Try to add a new memory region that can satisfy an allocation of
+ // 'rounded_bytes' bytes. Returns true on success and false on
+ // failure.
+ bool Extend(size_t rounded_bytes) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Returns a pointer to an underlying allocated chunk of size
+ // 'rounded_bytes'.
+ void* FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes)
+ EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Splits the chunk specified by 'h' into two chunks, one at least
+ // of size 'num_bytes'.
+ void SplitChunk(ChunkHandle h, size_t num_bytes)
+ EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Merges the two chunk handles. Requires that the chunks are
+ // contiguous in their allocation.
+ void Merge(ChunkHandle h, ChunkHandle h2) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Frees the memory represented by 'h', coalescing the chunk if
+ // possible.
+ void FreeAndMaybeCoalesce(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Adds the chunk 'h' to the proper free bin.
+ void InsertFreeChunkIntoBin(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Removes the free chunk pointed to by 'c' from the set free_chunks.
+ void RemoveFreeChunkIterFromBin(Bin::FreeChunkSet* free_chunks,
+ const Bin::FreeChunkSet::iterator& c)
+ EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Removes a free chunk from the bin.
+ void RemoveFreeChunkFromBin(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Removes the chunk metadata represented by 'h'.
+ void DeleteChunk(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ string RenderOccupancy() EXCLUSIVE_LOCKS_REQUIRED(lock_);
+ void DumpMemoryLog(size_t num_bytes) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ ChunkHandle AllocateChunk() EXCLUSIVE_LOCKS_REQUIRED(lock_);
+ void DeallocateChunk(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ Chunk* ChunkFromHandle(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ AllocatorRetry retry_helper_;
+
+ // Structures immutable after construction
+ size_t memory_limit_ = 0;
+ inline int Log2FloorNonZero(uint64 n) {
+#if defined(__GNUC__)
+ return 63 ^ __builtin_clzll(n);
+#else
+ int r = 0;
+ while (n > 0) {
+ r++;
+ n >>= 1;
+ }
+ return r;
+#endif
+ }
+
+ // Map from bin size to Bin
+ Bin* BinFromIndex(BinNum index) {
+ return reinterpret_cast<Bin*>(&(bins_space_[index * sizeof(Bin)]));
+ }
+ size_t BinNumToSize(BinNum index) {
+ return static_cast<size_t>(256) << index;
+ }
+ BinNum BinNumForSize(size_t bytes) {
+ uint64 v = std::max<size_t>(bytes, 256) >> kMinAllocationBits;
+ int b = std::min(kNumBins - 1, Log2FloorNonZero(v));
+ return b;
+ }
+ Bin* BinForSize(size_t bytes) { return BinFromIndex(BinNumForSize(bytes)); }
+
+ char bins_space_[sizeof(Bin) * kNumBins];
+
+ // The size of the current region allocation.
+ size_t curr_region_allocation_bytes_;
+
+ // The total number of allocated bytes by the allocator.
+ size_t total_region_allocated_bytes_ = 0;
+
+ // An indicator that expansion of a region has hit the limits
+ // of the available memory.
+ bool started_backpedal_ = false;
+
+ std::unique_ptr<SubAllocator> suballocator_;
+ string name_;
+
+ // Structures mutable after construction
+ mutable mutex lock_;
+ RegionManager region_manager_ GUARDED_BY(lock_);
+
+ std::vector<Chunk> chunks_;
+ ChunkHandle free_chunks_list_; // Ptr to head of linked list of free Chunks
+
+ // Called once on each region, ASAP.
+ std::vector<Visitor> region_visitors_;
+
+ // Counter containing the next unique identifier to assign to a
+ // newly-created chunk.
+ int64 next_allocation_id_ GUARDED_BY(lock_);
+
+ // Stats.
+ AllocatorStats stats_ GUARDED_BY(lock_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BFCAllocator);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 403cece230..47bd6c56ec 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1170,37 +1170,44 @@ FunctionBody* SymbolicGradientHelper::Compute() {
Copy();
Graph* g = gbody_->graph;
+
+ const int num_y = gbody_->ret_nodes.size();
+
+ // Populate 'y_node_outputs_' with node function body outputs.
// Populate 'y_grad_nodes' with initial gradient nodes for each return node of
// the original function body (these will be 'arg' nodes in the function
// gradient body).
- const int num_y = gbody_->ret_nodes.size();
- std::vector<Node*> y_grad_nodes;
- y_grad_nodes.reserve(num_y);
+ std::vector<NodeOut> y_node_outputs;
+ y_node_outputs.reserve(num_y);
+ std::vector<NodeOut> y_grad_node_outputs;
+ y_grad_node_outputs.reserve(num_y);
for (int i = 0; i < num_y; ++i) {
Node* y = gbody_->ret_nodes[i];
+ y_node_outputs.push_back({y, 0});
DCHECK_EQ(y->type_string(), kRetOp);
const DataType dtype = y->input_type(0);
const int index = gbody_->arg_nodes.size();
Node* dy = AddArg(g, dtype, index);
gbody_->arg_types.push_back(dtype);
gbody_->arg_nodes.push_back(dy);
- y_grad_nodes.push_back(dy);
+ y_grad_node_outputs.push_back({dy, 0});
}
- // Populate 'x_nodes' with function args (not including 'y_grad_nodes').
+ // Populate 'x_nodes' with function args (excluding 'y_grad_node_outputs').
const int num_x = fbody_->arg_nodes.size();
- std::vector<Node*> x_nodes;
- x_nodes.reserve(num_x);
+ std::vector<NodeOut> x_node_outputs;
+ x_node_outputs.reserve(num_x);
for (size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
- x_nodes.push_back(gbody_->arg_nodes[i]);
+ x_node_outputs.push_back({gbody_->arg_nodes[i], 0});
}
// Call AddSymbolicGradients which will add nodes to graph 'g' that
- // compute the function gradient (adding an entry in 'x_grad_nodes' for
- // each node in 'x_nodes').
- std::vector<GradNodeOutput> x_grad_nodes(x_nodes.size());
- TF_CHECK_OK(AddSymbolicGradients(gbody_->ret_nodes, x_nodes, y_grad_nodes,
- &x_grad_nodes, g));
+ // compute the function gradient (adding an entry in 'x_grad_node_outputs' for
+ // each node in 'x_node_outputs').
+ std::vector<NodeOut> x_grad_node_outputs;
+ TF_CHECK_OK(AddSymbolicGradients(y_node_outputs, x_node_outputs,
+ y_grad_node_outputs, &x_grad_node_outputs,
+ g));
// Remove the old return nodes from the function body.
for (Node* n : gbody_->ret_nodes) {
@@ -1211,7 +1218,7 @@ FunctionBody* SymbolicGradientHelper::Compute() {
// Add new return nodes to the function gradient body for each node
// in 'x_grad_nodes'.
for (size_t i = 0; i < fbody_->arg_types.size(); ++i) {
- Endpoint grad = {x_grad_nodes[i].node, x_grad_nodes[i].index};
+ Endpoint grad = {x_grad_node_outputs[i].node, x_grad_node_outputs[i].index};
Node* ret = AddRet(g, grad, i);
gbody_->ret_nodes.push_back(ret);
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc
index c701f80cec..a3ac2e1d67 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/allocator_retry.h"
#include <vector>
#include "tensorflow/core/lib/core/notification.h"
@@ -55,7 +55,7 @@ class FakeAllocator {
}
private:
- GPUAllocatorRetry retry_;
+ AllocatorRetry retry_;
void* good_ptr_ = reinterpret_cast<void*>(0xdeadbeef);
mutex mu_;
size_t memory_capacity_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 03507cd948..33496154ec 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -15,17 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
-#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
-#include "tensorflow/core/lib/core/bits.h"
-#include "tensorflow/core/lib/gtl/stl_util.h"
-#include "tensorflow/core/lib/strings/numbers.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/stream_executor.h"
-#include "tensorflow/core/platform/types.h"
namespace gpu = ::perftools::gputools;
@@ -36,680 +26,9 @@ GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory)
GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory,
const GPUOptions& gpu_options)
- : device_id_(device_id),
- free_chunks_list_(kInvalidChunkHandle),
- next_allocation_id_(1) {
- // Get a pointer to the stream_executor for this device
- stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
-
- if (gpu_options.allow_growth()) {
- // 1MiB smallest initial allocation, unless total memory available
- // is less.
- curr_region_allocation_bytes_ =
- RoundedBytes(std::min(total_memory, size_t{1048576}));
- } else {
- curr_region_allocation_bytes_ = RoundedBytes(total_memory);
- }
-
- // Allocate the requested amount of memory.
- gpu_memory_size_ = total_memory;
- stats_.bytes_limit = static_cast<int64>(total_memory);
-
- // Create a bunch of bins of various good sizes.
-
- // We create bins to fit all possible ranges that cover the
- // gpu_memory_size_ starting from allocations up to 256 bytes to
- // allocations up to (and including) the memory limit.
- for (BinNum b = 0; b < kNumBins; b++) {
- size_t bin_size = BinNumToSize(b);
- VLOG(1) << "Creating bin of max chunk size "
- << strings::HumanReadableNumBytes(bin_size);
- new (BinFromIndex(b)) Bin(this, bin_size);
- CHECK_EQ(BinForSize(bin_size), BinFromIndex(b));
- CHECK_EQ(BinForSize(bin_size + 255), BinFromIndex(b));
- CHECK_EQ(BinForSize(bin_size * 2 - 1), BinFromIndex(b));
- if (b + 1 < kNumBins) {
- CHECK_NE(BinForSize(bin_size * 2), BinFromIndex(b));
- }
- }
-}
-
-GPUBFCAllocator::~GPUBFCAllocator() {
- // Return memory back.
- VLOG(2) << "Number of regions allocated: "
- << region_manager_.regions().size();
- for (const auto& region : region_manager_.regions()) {
- gpu::DeviceMemoryBase gpu_ptr{region.ptr()};
- stream_exec_->Deallocate(&gpu_ptr);
- }
-
- for (BinNum b = 0; b < kNumBins; b++) {
- BinFromIndex(b)->~Bin();
- }
-}
-
-GPUBFCAllocator::Chunk* GPUBFCAllocator::ChunkFromHandle(ChunkHandle h) {
- DCHECK_GE(h, 0);
- DCHECK_LT(h, static_cast<int>(chunks_.size()));
- return &(chunks_[h]);
-}
-
-bool GPUBFCAllocator::Extend(size_t rounded_bytes) {
- // Do we have enough space to handle the client's request?
- // If not, fail immediately.
- if (total_region_allocated_bytes_ + rounded_bytes > gpu_memory_size_) {
- return false;
- }
-
- // If curr_region_allocation_bytes_ is not enough to satisfy the
- // allocation, keep multiplying by a power of two until that is
- // sufficient.
- bool increased_allocation = false;
- while (rounded_bytes > curr_region_allocation_bytes_) {
- curr_region_allocation_bytes_ *= 2;
- increased_allocation = true;
- }
-
- // Try allocating.
- size_t bytes = curr_region_allocation_bytes_;
- gpu::DeviceMemory<char> gpu_mem = stream_exec_->AllocateArray<char>(bytes);
- if (gpu_mem == nullptr && !started_backpedal_) {
- // Only backpedal once.
- started_backpedal_ = true;
-
- static constexpr float kBackpedalFactor = 0.9;
-
- // Try allocating less memory.
- bytes = RoundedBytes(bytes * kBackpedalFactor);
- while (gpu_mem == nullptr && bytes > rounded_bytes) {
- gpu_mem = stream_exec_->AllocateArray<char>(bytes);
- bytes = RoundedBytes(bytes * kBackpedalFactor);
- }
- }
-
- if (gpu_mem == nullptr) {
- return false;
- }
-
- if (!increased_allocation) {
- // Increase the region size of the next required allocation.
- curr_region_allocation_bytes_ *= 2;
- }
-
- VLOG(1) << "Extending allocation by " << strings::HumanReadableNumBytes(bytes)
- << " bytes.";
-
- total_region_allocated_bytes_ += bytes;
- VLOG(1) << "Total allocated bytes: "
- << strings::HumanReadableNumBytes(total_region_allocated_bytes_);
-
- void* gpu_mem_base = gpu_mem.opaque();
- VLOG(1) << "Allocated memory at " << gpu_mem_base << " to "
- << static_cast<void*>(static_cast<char*>(gpu_mem_base) + bytes);
- region_manager_.AddAllocationRegion(gpu_mem_base, bytes);
-
- // Create one large chunk for the whole memory space that will
- // be chunked later.
- ChunkHandle h = AllocateChunk();
- GPUBFCAllocator::Chunk* c = ChunkFromHandle(h);
- c->ptr = gpu_mem_base;
- c->size = bytes;
- c->allocation_id = -1;
- c->prev = kInvalidChunkHandle;
- c->next = kInvalidChunkHandle;
-
- region_manager_.set_handle(c->ptr, h);
-
- // TODO(vrv): Try to merge this new region with an existing region,
- // if the address space is contiguous, to avoid fragmentation
- // across regions.
-
- // Insert the chunk into the right bin.
- InsertFreeChunkIntoBin(h);
-
- // Invoke visitors on newly allocated region.
- for (auto visitor : region_visitors_) {
- visitor(gpu_mem_base, bytes);
- }
- return true;
-}
-
-GPUBFCAllocator::ChunkHandle GPUBFCAllocator::AllocateChunk() {
- if (free_chunks_list_ != kInvalidChunkHandle) {
- ChunkHandle h = free_chunks_list_;
- Chunk* c = ChunkFromHandle(h);
- free_chunks_list_ = c->next;
- return h;
- } else {
- ChunkHandle h = chunks_.size();
- chunks_.resize(h + 1);
- return h;
- }
-}
-
-void GPUBFCAllocator::DeallocateChunk(ChunkHandle h) {
- Chunk* c = ChunkFromHandle(h);
- c->next = free_chunks_list_;
- free_chunks_list_ = h;
-}
-
-void* GPUBFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) {
- // Fast path: Try once to allocate without getting the retry_helper_ involved
- void* r = AllocateRawInternal(unused_alignment, num_bytes, false);
- if (r != nullptr) {
- return r;
- } else {
- static const int64 kMaxMillisToWait = 10000; // 10 seconds
- return retry_helper_.AllocateRaw(
- [this](size_t a, size_t nb, bool v) {
- return AllocateRawInternal(a, nb, v);
- },
- kMaxMillisToWait, unused_alignment, num_bytes);
- }
-}
-
-void* GPUBFCAllocator::AllocateRaw(
- size_t unused_alignment, size_t num_bytes,
- const AllocationAttributes& allocation_attr) {
- if (allocation_attr.no_retry_on_failure) {
- // Return immediately upon the first failure if this is for allocating an
- // optional scratch space.
- void* result = AllocateRawInternal(unused_alignment, num_bytes, false);
- if (result == nullptr) {
- // The counter incrementing is not thread-safe. But we don't really care.
- // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
- // more general usage.
- static int log_counter = 0;
- if (log_counter < 10) {
- log_counter++;
- LOG(WARNING)
- << "Ran out of memory trying to allocate "
- << strings::HumanReadableNumBytes(num_bytes)
- << ". The caller indicates that this is not a failure, but"
- << " may mean that there could be performance gains if more"
- << " memory is available.";
- }
- }
- return result;
- } else {
- return AllocateRaw(unused_alignment, num_bytes);
- }
-}
-
-// static
-size_t GPUBFCAllocator::RoundedBytes(size_t bytes) {
- size_t rounded_bytes =
- (kMinAllocationSize *
- ((bytes + kMinAllocationSize - 1) / kMinAllocationSize));
- DCHECK_EQ(size_t{0}, rounded_bytes % kMinAllocationSize);
- return rounded_bytes;
-}
-
-void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
- size_t num_bytes,
- bool dump_log_on_failure) {
- if (num_bytes == 0) {
- LOG(ERROR) << "tried to allocate 0 bytes";
- return nullptr;
- }
- // First, always allocate memory of at least kMinAllocationSize
- // bytes, and always allocate multiples of kMinAllocationSize bytes
- // so all memory addresses are nicely byte aligned.
- size_t rounded_bytes = RoundedBytes(num_bytes);
-
- // The BFC allocator tries to find the best fit first.
- BinNum bin_num = BinNumForSize(rounded_bytes);
-
- mutex_lock l(lock_);
- void* ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes);
- if (ptr != nullptr) {
- return ptr;
- }
-
- // Try to extend
- if (Extend(rounded_bytes)) {
- ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes);
- if (ptr != nullptr) {
- return ptr;
- }
- }
-
- // We searched all bins for an existing free chunk to use and
- // couldn't find one. This means we must have run out of memory,
- // Dump the memory log for analysis.
- if (dump_log_on_failure) {
- DumpMemoryLog(rounded_bytes);
- LOG(WARNING) << RenderOccupancy();
- LOG(WARNING) << "Ran out of memory trying to allocate "
- << strings::HumanReadableNumBytes(num_bytes)
- << ". See logs for memory state.";
- }
- return nullptr;
-}
-
-void* GPUBFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes,
- size_t num_bytes) {
- // First identify the first bin that could satisfy rounded_bytes.
- for (; bin_num < kNumBins; bin_num++) {
- // Start searching from the first bin for the smallest chunk that fits
- // rounded_bytes.
- Bin* b = BinFromIndex(bin_num);
- for (auto citer = b->free_chunks.begin(); citer != b->free_chunks.end();
- ++citer) {
- const GPUBFCAllocator::ChunkHandle h = (*citer);
- GPUBFCAllocator::Chunk* chunk = ChunkFromHandle(h);
- DCHECK(!chunk->in_use());
- if (chunk->size >= rounded_bytes) {
- // We found an existing chunk that fits us that wasn't in use, so remove
- // it from the free bin structure prior to using.
- RemoveFreeChunkIterFromBin(&b->free_chunks, citer);
-
- // If we can break the size of the chunk into two reasonably
- // large pieces, do so.
- //
- // TODO(vrv): What should be the criteria when deciding when
- // to split?
- if (chunk->size >= rounded_bytes * 2) {
- SplitChunk(h, rounded_bytes);
- chunk = ChunkFromHandle(h); // Update chunk pointer in case it moved
- }
-
- // The requested size of the returned chunk is what the user
- // has allocated.
- chunk->requested_size = num_bytes;
- // Assign a unique id and increment the id counter, marking the
- // chunk as being in use.
- chunk->allocation_id = next_allocation_id_++;
-
- // Update stats.
- ++stats_.num_allocs;
- stats_.bytes_in_use += chunk->size;
- stats_.max_bytes_in_use =
- std::max(stats_.max_bytes_in_use, stats_.bytes_in_use);
- stats_.max_alloc_size =
- std::max<std::size_t>(stats_.max_alloc_size, chunk->size);
-
- VLOG(4) << "Returning: " << chunk->ptr;
- if (VLOG_IS_ON(4)) {
- LOG(INFO) << "A: " << RenderOccupancy();
- }
- return chunk->ptr;
- }
- }
- }
-
- return nullptr;
-}
-
-void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::ChunkHandle h,
- size_t num_bytes) {
- // Allocate the new chunk before we do any ChunkFromHandle
- ChunkHandle h_new_chunk = AllocateChunk();
-
- Chunk* c = ChunkFromHandle(h);
- CHECK(!c->in_use() && (c->bin_num == kInvalidBinNum));
-
- // Create a new chunk starting num_bytes after c
- GPUBFCAllocator::Chunk* new_chunk = ChunkFromHandle(h_new_chunk);
- new_chunk->ptr = static_cast<void*>(static_cast<char*>(c->ptr) + num_bytes);
- region_manager_.set_handle(new_chunk->ptr, h_new_chunk);
-
- // Set the new sizes of the chunks.
- new_chunk->size = c->size - num_bytes;
- c->size = num_bytes;
-
- // The new chunk is not in use.
- new_chunk->allocation_id = -1;
-
- // Maintain the pointers.
- // c <-> c_neighbor becomes
- // c <-> new_chunk <-> c_neighbor
- GPUBFCAllocator::ChunkHandle h_neighbor = c->next;
- new_chunk->prev = h;
- new_chunk->next = h_neighbor;
- c->next = h_new_chunk;
- if (h_neighbor != kInvalidChunkHandle) {
- Chunk* c_neighbor = ChunkFromHandle(h_neighbor);
- c_neighbor->prev = h_new_chunk;
- }
-
- // Add the newly free chunk to the free bin.
- InsertFreeChunkIntoBin(h_new_chunk);
-}
-
-void GPUBFCAllocator::DeallocateRaw(void* ptr) {
- DeallocateRawInternal(ptr);
- retry_helper_.NotifyDealloc();
-}
-
-void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {
- if (ptr == nullptr) {
- LOG(ERROR) << "tried to deallocate nullptr";
- return;
- }
- mutex_lock l(lock_);
-
- // Find the chunk from the ptr.
- GPUBFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
- CHECK(h != kInvalidChunkHandle);
-
- // Consider coalescing it.
- FreeAndMaybeCoalesce(h);
-
- if (VLOG_IS_ON(4)) {
- LOG(INFO) << "F: " << RenderOccupancy();
- }
-}
-
-// Merges h1 and h2 when Chunk(h1)->next is h2 and Chunk(h2)->prev is c1.
-// We merge Chunk(h2) into Chunk(h1).
-void GPUBFCAllocator::Merge(GPUBFCAllocator::ChunkHandle h1,
- GPUBFCAllocator::ChunkHandle h2) {
- Chunk* c1 = ChunkFromHandle(h1);
- Chunk* c2 = ChunkFromHandle(h2);
- // We can only merge chunks that are not in use.
- CHECK(!c1->in_use() && !c2->in_use());
-
- // c1's prev doesn't change, still points to the same ptr, and is
- // still not in use.
-
- // Fix up neighbor pointers
- //
- // c1 <-> c2 <-> c3 should become
- // c1 <-> c3
-
- GPUBFCAllocator::ChunkHandle h3 = c2->next;
- c1->next = h3;
- CHECK(c2->prev == h1);
- if (h3 != kInvalidChunkHandle) {
- GPUBFCAllocator::Chunk* c3 = ChunkFromHandle(h3);
- c3->prev = h1;
- }
-
- // Set the new size
- c1->size += c2->size;
-
- DeleteChunk(h2);
-}
-
-void GPUBFCAllocator::DeleteChunk(ChunkHandle h) {
- // Delete h and cleanup all state
- Chunk* c = ChunkFromHandle(h);
- // VLOG(4) << "Removing: " << c->ptr;
- region_manager_.erase(c->ptr);
- DeallocateChunk(h);
-}
-
-void GPUBFCAllocator::InsertFreeChunkIntoBin(GPUBFCAllocator::ChunkHandle h) {
- Chunk* c = ChunkFromHandle(h);
- CHECK(!c->in_use() && (c->bin_num == kInvalidBinNum));
- BinNum bin_num = BinNumForSize(c->size);
- Bin* new_bin = BinFromIndex(bin_num);
- c->bin_num = bin_num;
- new_bin->free_chunks.insert(h);
-}
-
-void GPUBFCAllocator::RemoveFreeChunkIterFromBin(
- GPUBFCAllocator::Bin::FreeChunkSet* free_chunks,
- const GPUBFCAllocator::Bin::FreeChunkSet::iterator& citer) {
- ChunkHandle h = *citer;
- Chunk* c = ChunkFromHandle(h);
- CHECK(!c->in_use() && (c->bin_num != kInvalidBinNum));
- free_chunks->erase(citer);
- c->bin_num = kInvalidBinNum;
-}
-
-void GPUBFCAllocator::RemoveFreeChunkFromBin(GPUBFCAllocator::ChunkHandle h) {
- Chunk* c = ChunkFromHandle(h);
- CHECK(!c->in_use() && (c->bin_num != kInvalidBinNum));
- int count = BinFromIndex(c->bin_num)->free_chunks.erase(h);
- CHECK(count > 0) << "Could not find chunk in bin";
- c->bin_num = kInvalidBinNum;
-}
-
-void GPUBFCAllocator::FreeAndMaybeCoalesce(GPUBFCAllocator::ChunkHandle h) {
- Chunk* c = ChunkFromHandle(h);
- CHECK(c->in_use() && (c->bin_num == kInvalidBinNum));
-
- // Mark the chunk as no longer in use
- c->allocation_id = -1;
-
- // Updates the stats.
- stats_.bytes_in_use -= c->size;
-
- // This chunk is no longer in-use, consider coalescing the chunk
- // with adjacent chunks.
- ChunkHandle chunk_to_reassign = h;
-
- // If the next chunk is free, coalesce the two
- if (c->next != kInvalidChunkHandle) {
- Chunk* cnext = ChunkFromHandle(c->next);
- if (!cnext->in_use()) {
- // VLOG(8) << "Chunk at " << cnext->ptr << " merging with c " <<
- // c->ptr;
-
- chunk_to_reassign = h;
-
- // Deletes c->next
- RemoveFreeChunkFromBin(c->next);
- Merge(h, ChunkFromHandle(h)->next);
- }
- }
-
- // If the previous chunk is free, coalesce the two
- c = ChunkFromHandle(h);
- if (c->prev != kInvalidChunkHandle) {
- Chunk* cprev = ChunkFromHandle(c->prev);
- if (!cprev->in_use()) {
- // VLOG(8) << "Chunk at " << c->ptr << " merging into c->prev "
- // << cprev->ptr;
-
- chunk_to_reassign = c->prev;
-
- // Deletes c
- RemoveFreeChunkFromBin(c->prev);
- Merge(ChunkFromHandle(h)->prev, h);
- c = ChunkFromHandle(h);
- }
- }
-
- InsertFreeChunkIntoBin(chunk_to_reassign);
-}
-
-void GPUBFCAllocator::AddAllocVisitor(Visitor visitor) {
- VLOG(1) << "AddVisitor";
- mutex_lock l(lock_);
- region_visitors_.push_back(visitor);
- for (const auto& region : region_manager_.regions()) {
- visitor(region.ptr(), region.memory_size());
- }
-}
-
-bool GPUBFCAllocator::TracksAllocationSizes() { return true; }
-
-size_t GPUBFCAllocator::RequestedSize(void* ptr) {
- mutex_lock l(lock_);
- GPUBFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
- CHECK(h != kInvalidChunkHandle)
- << "Asked for requested size of pointer we never allocated: " << ptr;
- GPUBFCAllocator::Chunk* c = ChunkFromHandle(h);
- return c->requested_size;
-}
-
-size_t GPUBFCAllocator::AllocatedSize(void* ptr) {
- mutex_lock l(lock_);
- GPUBFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
- CHECK(h != kInvalidChunkHandle)
- << "Asked for allocated size of pointer we never allocated: " << ptr;
- GPUBFCAllocator::Chunk* c = ChunkFromHandle(h);
- return c->size;
-}
-
-int64 GPUBFCAllocator::AllocationId(void* ptr) {
- mutex_lock l(lock_);
- GPUBFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
- CHECK(h != kInvalidChunkHandle)
- << "Asked for allocation id of pointer we never allocated: " << ptr;
- GPUBFCAllocator::Chunk* c = ChunkFromHandle(h);
- return c->allocation_id;
-}
-
-namespace {
-
-void RenderRegion(char* rendered, const size_t resolution,
- const size_t total_render_size, const size_t offset,
- const void* base_ptr, const void* ptr, const size_t size,
- const char c) {
- const char* base_ptr_c = static_cast<const char*>(base_ptr);
- const char* ptr_c = static_cast<const char*>(ptr);
-
- size_t start_location =
- ((ptr_c - base_ptr_c + offset) * resolution) / total_render_size;
- CHECK_GE(start_location, 0);
- CHECK_LT(start_location, resolution);
- size_t end_location =
- ((ptr_c + size - 1 - base_ptr_c + offset) * resolution) /
- total_render_size;
- CHECK_GE(end_location, 0);
- CHECK_LT(end_location, resolution);
-
- for (size_t i = start_location; i <= end_location; ++i) {
- rendered[i] = c;
- }
-}
-
-} // namespace
-
-string GPUBFCAllocator::RenderOccupancy() {
- // Make a buffer for the ASCII-art representation.
- const size_t resolution = 100;
- char rendered[resolution];
-
- // Compute the total region size to render over
- size_t total_region_size = 0;
- for (const auto& region : region_manager_.regions()) {
- total_region_size += region.memory_size();
- }
-
- // Start out with everything empty
- RenderRegion(rendered, resolution, total_region_size, 0, nullptr, nullptr,
- total_region_size, '_');
-
- size_t region_offset = 0;
- for (const auto& region : region_manager_.regions()) {
- ChunkHandle h = region_manager_.get_handle(region.ptr());
- // Then render each chunk left to right.
- while (h != kInvalidChunkHandle) {
- Chunk* c = ChunkFromHandle(h);
- if (c->in_use()) {
- // Render the wasted space
- size_t wasted = c->size - c->requested_size;
- if (wasted > 0) {
- RenderRegion(rendered, resolution, total_region_size,
- region_offset + c->requested_size, region.ptr(), c->ptr,
- wasted, 'x');
- }
- // Then the occupied space
- RenderRegion(rendered, resolution, total_region_size, region_offset,
- region.ptr(), c->ptr, c->requested_size, '*');
- }
- h = c->next;
- }
- region_offset += region.memory_size();
- }
-
- return StringPiece(rendered, resolution).ToString();
-}
-
-void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
- // For each bin: tally up the total number of chunks and bytes.
- // Note that bins hold only free chunks.
- for (BinNum bin_num = 0; bin_num < kNumBins; bin_num++) {
- Bin* b = BinFromIndex(bin_num);
-
- size_t total_bytes_in_use = 0;
- size_t total_bytes_in_bin = 0;
- size_t total_requested_bytes_in_use = 0;
- size_t total_requested_bytes_in_bin = 0;
- size_t total_chunks_in_use = 0;
- size_t total_chunks_in_bin = 0;
- for (ChunkHandle h : b->free_chunks) {
- Chunk* c = ChunkFromHandle(h);
- total_bytes_in_bin += c->size;
- total_requested_bytes_in_bin += c->requested_size;
- ++total_chunks_in_bin;
- if (c->in_use()) {
- total_bytes_in_use += c->size;
- total_requested_bytes_in_use += c->requested_size;
- ++total_chunks_in_use;
- }
- }
-
- LOG(INFO) << "Bin (" << b->bin_size
- << "): \tTotal Chunks: " << total_chunks_in_bin
- << ", Chunks in use: " << total_chunks_in_use << " "
- << strings::HumanReadableNumBytes(total_bytes_in_bin)
- << " allocated for chunks. "
- << strings::HumanReadableNumBytes(total_requested_bytes_in_bin)
- << " client-requested for chunks. "
- << strings::HumanReadableNumBytes(total_bytes_in_use)
- << " in use in bin. "
- << strings::HumanReadableNumBytes(total_requested_bytes_in_use)
- << " client-requested in use in bin.";
- }
-
- // Find the bin that we would have liked to allocate in, so we
- // can get some further analysis about fragmentation.
- Bin* b = BinForSize(num_bytes);
-
- LOG(INFO) << "Bin for " << strings::HumanReadableNumBytes(num_bytes)
- << " was " << strings::HumanReadableNumBytes(b->bin_size)
- << ", Chunk State: ";
-
- for (ChunkHandle h : b->free_chunks) {
- Chunk* c = ChunkFromHandle(h);
- LOG(INFO) << c->DebugString(this, true);
- }
-
- // Next show the chunks that are in use, and also summarize their
- // number by size.
- std::map<size_t, int> in_use_by_size;
- for (const auto& region : region_manager_.regions()) {
- ChunkHandle h = region_manager_.get_handle(region.ptr());
- while (h != kInvalidChunkHandle) {
- const Chunk* c = ChunkFromHandle(h);
- if (c->in_use()) {
- in_use_by_size[c->size]++;
- LOG(INFO) << "Chunk at " << c->ptr << " of size " << c->size;
- }
- h = c->next;
- }
-
- h = region_manager_.get_handle(region.ptr());
- while (h != kInvalidChunkHandle) {
- const Chunk* c = ChunkFromHandle(h);
- if (!c->in_use()) {
- LOG(INFO) << "Free at " << c->ptr << " of size " << c->size;
- }
- h = c->next;
- }
- }
-
- LOG(INFO) << " Summary of in-use Chunks by size: ";
- size_t total_bytes = 0;
- for (auto& it : in_use_by_size) {
- LOG(INFO) << it.second << " Chunks of size " << it.first << " totalling "
- << strings::HumanReadableNumBytes(it.first * it.second);
- total_bytes += (it.first * it.second);
- }
- LOG(INFO) << "Sum Total of in-use chunks: "
- << strings::HumanReadableNumBytes(total_bytes);
- LOG(INFO) << "Stats: \n" << stats_.DebugString();
-}
-
-void GPUBFCAllocator::GetStats(AllocatorStats* stats) {
- mutex_lock l(lock_);
- *stats = stats_;
-}
+ : BFCAllocator(
+ new GPUMemAllocator(
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie()),
+ total_memory, gpu_options.allow_growth(), "gpu_bfc") {}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index 2714fd3487..f94367cc98 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -21,394 +21,60 @@ limitations under the License.
#include <unordered_map>
#include <vector>
-#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
-#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
-#include "tensorflow/core/lib/gtl/stl_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/common_runtime/allocator_retry.h"
+#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
+namespace gpu = ::perftools::gputools;
+
namespace tensorflow {
// A GPU memory allocator that implements a 'best-fit with coalescing'
-// algorithm. This is essentially a very simple version of Doug Lea's
-// malloc (dlmalloc).
-//
-// The goal of this allocator is to support defragmentation via
-// coalescing. One assumption we make is that the process using this
-// allocator owns pretty much all of the GPU memory, and that nearly
-// all requests to allocate GPU memory go through this interface.
-class GPUBFCAllocator : public VisitableAllocator {
+// algorithm.
+class GPUBFCAllocator : public BFCAllocator {
public:
// 'device_id' refers to the StreamExecutor ID of the device within
// the process and must reference a valid ID in the process.
GPUBFCAllocator(int device_id, size_t total_memory);
GPUBFCAllocator(int device_id, size_t total_memory,
const GPUOptions& gpu_options);
- ~GPUBFCAllocator() override;
-
- string Name() override { return "gpu_bfc"; }
- void* AllocateRaw(size_t alignment, size_t num_bytes) override;
- void* AllocateRaw(size_t alignment, size_t num_bytes,
- const AllocationAttributes& allocation_attr) override;
- void DeallocateRaw(void* ptr) override;
-
- void AddAllocVisitor(Visitor visitor) override;
-
- // Does nothing, because gpu memory is never freed.
- void AddFreeVisitor(Visitor visitor) override {}
-
- bool TracksAllocationSizes() override;
-
- size_t RequestedSize(void* ptr) override;
-
- size_t AllocatedSize(void* ptr) override;
-
- int64 AllocationId(void* ptr) override;
-
- void GetStats(AllocatorStats* stats) override;
-
- private:
- struct Bin;
-
- void* AllocateRawInternal(size_t alignment, size_t num_bytes,
- bool dump_log_on_failure);
- void DeallocateRawInternal(void* ptr);
-
- // A ChunkHandle is an index into the chunks_ vector in GPUBFCAllocator
- // kInvalidChunkHandle means an invalid chunk
- typedef int ChunkHandle;
- static const int kInvalidChunkHandle = -1;
-
- typedef int BinNum;
- static const int kInvalidBinNum = -1;
- static const int kNumBins = 21;
-
- // Chunks point to GPU memory. Their prev/next pointers form a
- // doubly-linked list of addresses sorted by GPU base address that
- // must be contiguous. Chunks contain information about whether
- // they are in use or whether they are free, and contain a pointer
- // to the bin they are in.
- struct Chunk {
- size_t size = 0; // Full size of GPU buffer.
-
- // We sometimes give chunks that are larger than needed to reduce
- // fragmentation. requested_size keeps track of what the client
- // actually wanted so we can understand whether our splitting
- // strategy is efficient.
- size_t requested_size = 0;
-
- // allocation_id is set to -1 when the chunk is not in use. It is assigned a
- // value greater than zero before the chunk is returned from
- // AllocateRaw, and this value is unique among values assigned by
- // the parent allocator.
- int64 allocation_id = -1;
- void* ptr = nullptr; // pointer to granted GPU subbuffer.
-
- // If not kInvalidChunkHandle, the memory referred to by 'prev' is directly
- // preceding the memory used by this chunk. E.g., It should start
- // at 'ptr - prev->size'
- ChunkHandle prev = kInvalidChunkHandle;
-
- // If not kInvalidChunkHandle, the memory referred to by 'next' is directly
- // following the memory used by this chunk. E.g., It should be at
- // 'ptr + size'
- ChunkHandle next = kInvalidChunkHandle;
-
- // What bin are we in?
- BinNum bin_num = kInvalidBinNum;
-
- bool in_use() const { return allocation_id != -1; }
-
- string DebugString(GPUBFCAllocator* a, bool recurse) {
- string dbg;
- strings::StrAppend(&dbg, " Size: ", strings::HumanReadableNumBytes(size),
- " | Requested Size: ",
- strings::HumanReadableNumBytes(requested_size),
- " | in_use: ", in_use());
- if (recurse && prev != GPUBFCAllocator::kInvalidChunkHandle) {
- Chunk* p = a->ChunkFromHandle(prev);
- strings::StrAppend(&dbg, ", prev: ", p->DebugString(a, false));
- }
- if (recurse && next != GPUBFCAllocator::kInvalidChunkHandle) {
- Chunk* n = a->ChunkFromHandle(next);
- strings::StrAppend(&dbg, ", next: ", n->DebugString(a, false));
- }
- return dbg;
- }
- };
+ virtual ~GPUBFCAllocator() {}
- // A Bin is a collection of similar-sized free chunks.
- struct Bin {
- // All chunks in this bin have >= bin_size memory.
- size_t bin_size = 0;
-
- struct ChunkComparator {
- explicit ChunkComparator(GPUBFCAllocator* allocator)
- : allocator_(allocator) {}
- // Sort first by size and then use pointer address as a tie breaker.
- bool operator()(const ChunkHandle ha, const ChunkHandle hb) const {
- const Chunk* a = allocator_->ChunkFromHandle(ha);
- const Chunk* b = allocator_->ChunkFromHandle(hb);
- if (a->size != b->size) {
- return a->size < b->size;
- }
- return a->ptr < b->ptr;
- }
-
- private:
- GPUBFCAllocator* allocator_; // The parent allocator
- };
-
- typedef std::set<ChunkHandle, ChunkComparator> FreeChunkSet;
- // List of free chunks within the bin, sorted by chunk size.
- // Chunk * not owned.
- FreeChunkSet free_chunks;
- Bin(GPUBFCAllocator* allocator, size_t bs)
- : bin_size(bs), free_chunks(ChunkComparator(allocator)) {}
- };
-
- static const size_t kMinAllocationBits = 8;
- static const size_t kMinAllocationSize = 1 << kMinAllocationBits;
-
- // AllocationRegion maps pointers to ChunkHandles for a single
- // contiguous memory region.
- //
- // This class is thread-compatible.
- class AllocationRegion {
- public:
- AllocationRegion(void* ptr, size_t memory_size)
- : ptr_(ptr),
- memory_size_(memory_size),
- end_ptr_(
- static_cast<void*>(static_cast<char*>(ptr_) + memory_size_)) {
- DCHECK_EQ(0, memory_size % kMinAllocationSize);
- const size_t n_handles =
- (memory_size + kMinAllocationSize - 1) / kMinAllocationSize;
- handles_ = new ChunkHandle[n_handles];
- for (size_t i = 0; i < n_handles; i++) {
- handles_[i] = kInvalidChunkHandle;
- }
- }
-
- AllocationRegion() {}
-
- ~AllocationRegion() { delete[] handles_; }
-
- AllocationRegion(AllocationRegion&& other) { Swap(other); }
-
- AllocationRegion& operator=(AllocationRegion&& other) {
- Swap(other);
- return *this;
- }
-
- void* ptr() const { return ptr_; }
- void* end_ptr() const { return end_ptr_; }
- size_t memory_size() const { return memory_size_; }
- ChunkHandle get_handle(const void* p) const {
- return handles_[IndexFor(p)];
- }
- void set_handle(const void* p, ChunkHandle h) { handles_[IndexFor(p)] = h; }
- void erase(const void* p) { set_handle(p, kInvalidChunkHandle); }
-
- private:
- void Swap(AllocationRegion& other) {
- std::swap(ptr_, other.ptr_);
- std::swap(memory_size_, other.memory_size_);
- std::swap(end_ptr_, other.end_ptr_);
- std::swap(handles_, other.handles_);
- }
-
- int IndexFor(const void* p) const {
- std::uintptr_t p_int = reinterpret_cast<std::uintptr_t>(p);
- std::uintptr_t base_int = reinterpret_cast<std::uintptr_t>(ptr_);
- DCHECK_GE(p_int, base_int);
- DCHECK_LT(p_int, base_int + memory_size_);
- return static_cast<int>(((p_int - base_int) >> kMinAllocationBits));
- }
-
- // Metadata about the allocation region.
- void* ptr_ = nullptr;
- size_t memory_size_ = 0;
- void* end_ptr_ = nullptr;
-
- // Array of size "memory_size / kMinAllocationSize". It is
- // indexed by (p-base) / kMinAllocationSize, contains ChunkHandle
- // for the memory allocation represented by "p"
- ChunkHandle* handles_ = nullptr;
-
- TF_DISALLOW_COPY_AND_ASSIGN(AllocationRegion);
- };
-
- // RegionManager aggregates one or more "AllocationRegions" and provides
- // a layer of indirection from pointers to the underlying ChunkHandle,
- // allowing allocation across multiple discontiguous memory regions.
- //
- // This class is thread-compatible.
- class RegionManager {
- public:
- RegionManager() {}
- ~RegionManager() {}
-
- void AddAllocationRegion(void* ptr, size_t memory_size) {
- // Insert sorted by end_ptr
- auto entry =
- std::upper_bound(regions_.begin(), regions_.end(), ptr, &Comparator);
- regions_.insert(entry, AllocationRegion(ptr, memory_size));
- }
-
- ChunkHandle get_handle(const void* p) const {
- return RegionFor(p)->get_handle(p);
- }
-
- void set_handle(const void* p, ChunkHandle h) {
- return MutableRegionFor(p)->set_handle(p, h);
- }
- void erase(const void* p) { return MutableRegionFor(p)->erase(p); }
-
- const std::vector<AllocationRegion>& regions() const { return regions_; }
-
- private:
- static bool Comparator(const void* ptr, const AllocationRegion& other) {
- return ptr < other.end_ptr();
- }
-
- AllocationRegion* MutableRegionFor(const void* p) {
- return const_cast<AllocationRegion*>(RegionFor(p));
- }
-
- const AllocationRegion* RegionFor(const void* p) const {
- auto entry =
- std::upper_bound(regions_.begin(), regions_.end(), p, &Comparator);
-
- if (entry != regions_.end()) {
- return &(*entry);
- }
-
- LOG(FATAL) << "Could not find Region for " << p;
- return nullptr;
- }
-
- private:
- std::vector<AllocationRegion> regions_;
- };
-
- // Returns 'bytes' rounded up to the next highest kMinAllocationSize.
- size_t RoundedBytes(size_t bytes);
-
- // Try to add a new memory region that can satisfy an allocation of
- // 'rounded_bytes' bytes. Returns true on success and false on
- // failure.
- bool Extend(size_t rounded_bytes) EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- // Returns a pointer to an underlying allocated chunk of size
- // 'rounded_bytes'.
- void* FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes)
- EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- // Splits the chunk specified by 'h' into two chunks, one at least
- // of size 'num_bytes'.
- void SplitChunk(ChunkHandle h, size_t num_bytes)
- EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- // Merges the two chunk handles. Requires that the chunks are
- // contiguous in their allocation.
- void Merge(ChunkHandle h, ChunkHandle h2) EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- // Frees the memory represented by 'h', coalescing the chunk if
- // possible.
- void FreeAndMaybeCoalesce(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- // Adds the chunk 'h' to the proper free bin.
- void InsertFreeChunkIntoBin(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- // Removes the free chunk pointed to by 'c' from the set free_chunks.
- void RemoveFreeChunkIterFromBin(Bin::FreeChunkSet* free_chunks,
- const Bin::FreeChunkSet::iterator& c)
- EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- // Removes a free chunk from the bin.
- void RemoveFreeChunkFromBin(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- // Removes the chunk metadata represented by 'h'.
- void DeleteChunk(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- string RenderOccupancy() EXCLUSIVE_LOCKS_REQUIRED(lock_);
- void DumpMemoryLog(size_t num_bytes) EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- ChunkHandle AllocateChunk() EXCLUSIVE_LOCKS_REQUIRED(lock_);
- void DeallocateChunk(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
-
- Chunk* ChunkFromHandle(ChunkHandle h) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+};
- GPUAllocatorRetry retry_helper_;
+// Suballocator for GPU memory.
+class GPUMemAllocator : public SubAllocator {
+ public:
+ // Note: stream_exec cannot be null.
+ explicit GPUMemAllocator(perftools::gputools::StreamExecutor* stream_exec)
+ : stream_exec_(stream_exec) {
+ CHECK(stream_exec_ != nullptr);
+ }
+ ~GPUMemAllocator() override {}
- // Structures immutable after construction
- const int device_id_;
- size_t gpu_memory_size_ = 0;
- inline int Log2FloorNonZero(uint64 n) {
-#if defined(__GNUC__)
- return 63 ^ __builtin_clzll(n);
-#else
- int r = 0;
- while (n > 0) {
- r++;
- n >>= 1;
+ void* Alloc(size_t alignment, size_t num_bytes) override {
+ void* ptr = nullptr;
+ if (num_bytes > 0) {
+ ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque();
}
- return r;
-#endif
+ return ptr;
}
- // Map from bin size to Bin
- Bin* BinFromIndex(BinNum index) {
- return reinterpret_cast<Bin*>(&(bins_space_[index * sizeof(Bin)]));
- }
- size_t BinNumToSize(BinNum index) {
- return static_cast<size_t>(256) << index;
- }
- BinNum BinNumForSize(size_t bytes) {
- uint64 v = std::max<size_t>(bytes, 256) >> kMinAllocationBits;
- int b = std::min(kNumBins - 1, Log2FloorNonZero(v));
- return b;
+ void Free(void* ptr, size_t num_bytes) override {
+ if (ptr != nullptr) {
+ gpu::DeviceMemoryBase gpu_ptr(ptr);
+ stream_exec_->Deallocate(&gpu_ptr);
+ }
}
- Bin* BinForSize(size_t bytes) { return BinFromIndex(BinNumForSize(bytes)); }
- char bins_space_[sizeof(Bin) * kNumBins];
-
- perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
-
- // The size of the current region allocation.
- size_t curr_region_allocation_bytes_;
-
- // The total number of allocated bytes by the allocator.
- size_t total_region_allocated_bytes_ = 0;
-
- // An indicator that expansion of a region has hit the limits
- // of the available GPU memory.
- bool started_backpedal_ = false;
-
- // Structures mutable after construction
- mutable mutex lock_;
- RegionManager region_manager_ GUARDED_BY(lock_);
-
- std::vector<Chunk> chunks_;
- ChunkHandle free_chunks_list_; // Ptr to head of linked list of free Chunks
-
- // Called once on each region, ASAP.
- std::vector<Visitor> region_visitors_;
-
- // Counter containing the next unique identifier to assign to a
- // newly-created chunk.
- int64 next_allocation_id_ GUARDED_BY(lock_);
-
- // Stats.
- AllocatorStats stats_ GUARDED_BY(lock_);
+ private:
+ perftools::gputools::StreamExecutor* stream_exec_; // not owned, non-null
- TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index 0dfa97bb63..58ea42ea1b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -20,7 +20,7 @@ limitations under the License.
#include <string>
#include <unordered_map>
-#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
index f52f0078b0..4e102e823f 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
@@ -226,30 +226,6 @@ TEST(EventMgr, ManySmallTensorsSeparateCallsFlushed) {
}
}
-// Running the polling loop should clear the queue, without an explict
-// poll call here, given a moderate delay.
-TEST(EventMgr, LongDelayedPolling) {
- auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
- EventMgr em(stream_exec, GPUOptions());
- TEST_EventMgrHelper th(&em);
- EXPECT_EQ(0, th.queue_size());
- EXPECT_EQ(0, th.free_size());
- std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
- CHECK(stream.get());
- stream->Init();
- for (int i = 0; i < 5; ++i) {
- TensorReferenceVector* v = new TensorReferenceVector;
- AddTensorReference(v, 100 * 1048576);
- th.QueueTensors(stream.get(), v);
- EXPECT_EQ(1 + i, th.queue_size());
- EXPECT_EQ(0, th.free_size());
- }
- th.StartPollingLoop();
- sleep(1);
- EXPECT_EQ(0, th.queue_size());
- EXPECT_EQ(5, th.free_size());
-}
-
// Deleting the EventMgr when events are still pending should shut
// down gracefully.
TEST(EventMgr, NonEmptyShutdown) {
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.h b/tensorflow/core/common_runtime/gpu/pool_allocator.h
index c3fc53ea62..d8838ab7f4 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator.h
@@ -24,7 +24,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <vector>
-#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -35,14 +35,6 @@ limitations under the License.
namespace tensorflow {
-// Interface of an object that does the underlying alloc/free of memory.
-class SubAllocator {
- public:
- virtual ~SubAllocator() {}
- virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
- virtual void Free(void* ptr, size_t num_bytes) = 0;
-};
-
// Interface of an object that rounds up integers.
class RoundUpInterface {
public:
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
index e4f66c3b4b..67e10f7c05 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -187,9 +187,17 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
gpu::Platform* gpu_platform = GPUMachineManager();
gpu::StreamExecutor* se = gpu_platform->ExecutorForDevice(0).ValueOrDie();
CHECK(se);
- Allocator* allocator = new PoolAllocator(
- 100 /*pool_size_limit*/, true /*auto_resize*/,
- new CUDAHostAllocator(se), new Pow2Rounder, "cuda_host");
+ Allocator* allocator = nullptr;
+ static constexpr bool kCudaHostMemoryUseBFC = true;
+ if (kCudaHostMemoryUseBFC) {
+ allocator =
+ new BFCAllocator(new CUDAHostAllocator(se), 1LL << 36 /*64GB max*/,
+ true /*allow_growth*/, "cuda_host_bfc" /*name*/);
+ } else {
+ allocator = new PoolAllocator(
+ 100 /*pool_size_limit*/, true /*auto_resize*/,
+ new CUDAHostAllocator(se), new Pow2Rounder, "cuda_host");
+ }
if (LogMemory::IsEnabled()) {
// Wrap the allocator to track allocation ids for better logging
// at the cost of performance.
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
index 5414b75fff..1f4ccf7096 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -315,11 +315,20 @@ class ColocationGraph {
device_set_->FindMatchingDevices(specified_device_name,
&devices_matching_nodedef);
if (devices_matching_nodedef.empty()) {
+ // Sometimes it is almost impossible to understand the problem
+ // without a list of available devices.
+ std::vector<string> device_names;
+ for (const Device* device : device_set_->devices()) {
+ device_names.push_back(device->name());
+ }
+ std::sort(device_names.begin(), device_names.end());
+
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
node->def().device(),
"' because no devices matching that specification "
- "are registered in this process");
+ "are registered in this process; available devices: ",
+ str_util::Join(device_names, ", "));
} else if (specified_device_name.has_type) {
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
diff --git a/tensorflow/core/common_runtime/gpu/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h
index b0b5ec3bd9..17582a2915 100644
--- a/tensorflow/core/common_runtime/gpu/visitable_allocator.h
+++ b/tensorflow/core/common_runtime/visitable_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_
+#ifndef TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
#include <functional>
#include "tensorflow/core/framework/allocator.h"
@@ -42,4 +42,4 @@ class VisitableAllocator : public Allocator {
virtual void AddFreeVisitor(Visitor visitor) = 0;
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_
+#endif // TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 30c7c19102..9eddf53e3d 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -292,6 +292,15 @@ Allocator* cpu_allocator();
// AllocatorStats. By default, it's disabled.
void EnableCPUAllocatorStats(bool enable);
+// Abstract interface of an object that does the underlying suballoc/free of
+// memory for a higher-level allocator.
+class SubAllocator {
+ public:
+ virtual ~SubAllocator() {}
+ virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
+ virtual void Free(void* ptr, size_t num_bytes) = 0;
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc
index 2abe47cafd..f97a9b6dcf 100644
--- a/tensorflow/core/framework/allocator_test.cc
+++ b/tensorflow/core/framework/allocator_test.cc
@@ -38,6 +38,26 @@ static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use,
#endif
}
+TEST(AllocatorAttributesTest, AllCombos) {
+ for (bool on_host : {false, true}) {
+ for (bool nic_compatible : {false, true}) {
+ for (bool gpu_compatible : {false, true}) {
+ for (bool track_sizes : {false, true}) {
+ AllocatorAttributes aa;
+ aa.set_on_host(on_host);
+ aa.set_nic_compatible(nic_compatible);
+ aa.set_gpu_compatible(gpu_compatible);
+ aa.set_track_sizes(track_sizes);
+ EXPECT_EQ(on_host, aa.on_host());
+ EXPECT_EQ(nic_compatible, aa.nic_compatible());
+ EXPECT_EQ(gpu_compatible, aa.gpu_compatible());
+ EXPECT_EQ(track_sizes, aa.track_sizes());
+ }
+ }
+ }
+ }
+}
+
TEST(CPUAllocatorTest, Simple) {
EnableCPUAllocatorStats(true);
Allocator* a = cpu_allocator();
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc
index 2646370a76..1c902d29a0 100644
--- a/tensorflow/core/graph/gradients.cc
+++ b/tensorflow/core/graph/gradients.cc
@@ -40,37 +40,30 @@ static const char* const kRetOp = "_Retval";
static const char* const kGradientOp = "SymbolicGradient";
static const char* const kNodeLabel = "Func";
-// Represents the index-th output of a node.
-struct Endpoint {
- Node* node;
- int index;
-
- // Returns the string name represents this endpoint.
- string name() const {
- if (index == 0) {
- return node->name();
- } else {
- return strings::StrCat(node->name(), ":", index);
- }
+string NodeOut::name() const {
+ if (index == 0) {
+ return node->name();
+ } else {
+ return strings::StrCat(node->name(), ":", index);
}
+}
- DataType dtype() const { return node->output_type(index); }
-};
+DataType NodeOut::dtype() const { return node->output_type(index); }
-struct EndpointHash {
- uint64 operator()(const Endpoint& x) const {
+struct NodeOutHash {
+ uint64 operator()(const NodeOut& x) const {
return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
x.index);
}
};
-struct EndpointEq {
- bool operator()(const Endpoint& x, const Endpoint& y) const {
+struct NodeOutEq {
+ bool operator()(const NodeOut& x, const NodeOut& y) const {
return (x.node == y.node) && (x.index == y.index);
}
};
-static Node* AddZerosLike(Graph* g, Endpoint input) {
+static Node* AddZerosLike(Graph* g, NodeOut input) {
DCHECK_LT(0, input.dtype());
DCHECK_LT(input.dtype(), DT_FLOAT_REF);
NodeDef ndef;
@@ -85,7 +78,7 @@ static Node* AddZerosLike(Graph* g, Endpoint input) {
return ret;
}
-static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<Endpoint> grads) {
+static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {
const int num_x = n->num_inputs();
const int num_y = n->num_outputs();
CHECK_EQ(num_y, grads.size());
@@ -95,19 +88,19 @@ static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<Endpoint> grads) {
ndef.set_op(kGradientOp);
// The gradient node should have num_x + num_y inputs.
- std::vector<Endpoint> n_inputs(num_x);
+ std::vector<NodeOut> n_inputs(num_x);
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) continue;
n_inputs[e->dst_input()] = {e->src(), e->src_output()};
}
DataTypeVector in_types;
- for (const Endpoint& ep : n_inputs) {
- ndef.add_input(ep.name());
- in_types.push_back(ep.dtype());
+ for (const NodeOut& nout : n_inputs) {
+ ndef.add_input(nout.name());
+ in_types.push_back(nout.dtype());
}
- for (const Endpoint& ep : grads) {
- ndef.add_input(ep.name());
- in_types.push_back(ep.dtype());
+ for (const NodeOut& nout : grads) {
+ ndef.add_input(nout.name());
+ in_types.push_back(nout.dtype());
}
CHECK_EQ(ndef.input_size(), num_x + num_y);
@@ -128,34 +121,34 @@ static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<Endpoint> grads) {
class SymbolicGradientBuilder {
public:
- SymbolicGradientBuilder(gtl::ArraySlice<Node*> y_nodes,
- gtl::ArraySlice<Node*> x_nodes,
- gtl::ArraySlice<Node*> y_grad_nodes,
- std::vector<GradNodeOutput>* x_grad_nodes,
+ SymbolicGradientBuilder(gtl::ArraySlice<NodeOut> y_node_outputs,
+ gtl::ArraySlice<NodeOut> x_node_outputs,
+ gtl::ArraySlice<NodeOut> y_grad_node_outputs,
+ std::vector<NodeOut>* x_grad_node_outputs,
Graph* graph);
Status Compute();
private:
- gtl::ArraySlice<Node*> y_nodes_;
- gtl::ArraySlice<Node*> x_nodes_;
- gtl::ArraySlice<Node*> y_grad_nodes_;
- std::vector<GradNodeOutput>* x_grad_nodes_;
+ gtl::ArraySlice<NodeOut> y_node_outputs_;
+ gtl::ArraySlice<NodeOut> x_node_outputs_;
+ gtl::ArraySlice<NodeOut> y_grad_node_outputs_;
+ std::vector<NodeOut>* x_grad_node_outputs_;
Graph* graph_; // Not owned.
// A vector of output endpoints which represents backpropagated
// gradients
- typedef std::vector<Endpoint> BackpropedGradients;
+ typedef std::vector<NodeOut> BackpropedGradients;
- // backprops_ is a map from an output endpoint to its accumulated
- // gradients. When an output endpoint has accumulated all its
+ // backprops_ is a map from a node output to its accumulated
+ // gradients. When a node output has accumulated all its
// gradients, we add a node which sums them up.
- std::unordered_map<Endpoint, BackpropedGradients, EndpointHash, EndpointEq>
+ std::unordered_map<NodeOut, BackpropedGradients, NodeOutHash, NodeOutEq>
backprops_;
// pending[i] is count-down counter for i-th node's expected
// backprops. When pending[i] becomes zero, we collected all
- // backprop gradients for all output endpoint of the ith-node.
+ // backprop gradients for all outputs of the ith-node.
std::vector<int> pending_;
// 'ready' keeps track of nodes that have been completely
@@ -163,7 +156,8 @@ class SymbolicGradientBuilder {
// add dy as an input of the gradient function.
std::deque<Node*> ready_;
- // The set of nodes at which to stop backprop (and populate 'x_grad_nodes_').
+ // The set of nodes at which to stop backprop.
+ // Maps from node.id -> index of 'x_node_outputs_'
std::unordered_map<int, int> stop_nodes_;
// Initialize pending_ and ready_.
@@ -173,33 +167,35 @@ class SymbolicGradientBuilder {
// to 'dst', when the backprop algorithm constructs the node
// 'dst_grad' which computes the gradient, we need to propagate it
// to 'src'.
- void BackpropAlongEdge(const Endpoint& dst_grad, const Endpoint& src);
- void BackpropZerosAlongEdge(const Endpoint& src);
+ void BackpropAlongEdge(const NodeOut& dst_grad, const NodeOut& src);
+ void BackpropZerosAlongEdge(const NodeOut& src);
- Endpoint SumGradients(const Endpoint& src);
+ NodeOut SumGradients(const NodeOut& src);
TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder);
};
SymbolicGradientBuilder::SymbolicGradientBuilder(
- gtl::ArraySlice<Node*> y_nodes,
- gtl::ArraySlice<Node*> x_nodes,
- gtl::ArraySlice<Node*> y_grad_nodes,
- std::vector<GradNodeOutput>* x_grad_nodes,
- Graph* graph) : y_nodes_(y_nodes), x_nodes_(x_nodes),
- y_grad_nodes_(y_grad_nodes), x_grad_nodes_(x_grad_nodes),
- graph_(graph) {
- CHECK_EQ(y_nodes_.size(), y_grad_nodes.size());
- x_grad_nodes_->clear();
- x_grad_nodes_->resize(x_nodes_.size());
- stop_nodes_.reserve(x_nodes_.size());
- for (int i = 0; i < x_nodes_.size(); ++i) {
- stop_nodes_.insert(std::make_pair(x_nodes_[i]->id(), i));
+ gtl::ArraySlice<NodeOut> y_node_outputs,
+ gtl::ArraySlice<NodeOut> x_node_outputs,
+ gtl::ArraySlice<NodeOut> y_grad_node_outputs,
+ std::vector<NodeOut>* x_grad_node_outputs, Graph* graph)
+ : y_node_outputs_(y_node_outputs),
+ x_node_outputs_(x_node_outputs),
+ y_grad_node_outputs_(y_grad_node_outputs),
+ x_grad_node_outputs_(x_grad_node_outputs),
+ graph_(graph) {
+ CHECK_EQ(y_node_outputs_.size(), y_grad_node_outputs.size());
+ x_grad_node_outputs_->clear();
+ x_grad_node_outputs_->resize(x_node_outputs_.size());
+ stop_nodes_.reserve(x_node_outputs_.size());
+ for (int i = 0; i < x_node_outputs_.size(); ++i) {
+ stop_nodes_.insert(std::make_pair(x_node_outputs_[i].node->id(), i));
}
}
-void SymbolicGradientBuilder::BackpropAlongEdge(const Endpoint& dst_grad,
- const Endpoint& src) {
+void SymbolicGradientBuilder::BackpropAlongEdge(const NodeOut& dst_grad,
+ const NodeOut& src) {
CHECK_NOTNULL(src.node);
auto iter = backprops_.find(src);
if (iter != backprops_.end()) {
@@ -211,7 +207,7 @@ void SymbolicGradientBuilder::BackpropAlongEdge(const Endpoint& dst_grad,
}
}
-void SymbolicGradientBuilder::BackpropZerosAlongEdge(const Endpoint& src) {
+void SymbolicGradientBuilder::BackpropZerosAlongEdge(const NodeOut& src) {
CHECK_NOTNULL(src.node);
auto iter = backprops_.find(src);
if (iter != backprops_.end()) {
@@ -227,9 +223,9 @@ void SymbolicGradientBuilder::InitBackprop() {
backprops_.clear();
std::unordered_set<Node*> visited;
std::deque<Node*> queue;
- for (Node* n : x_nodes_) {
- queue.push_back(n);
- visited.insert(n);
+ for (const NodeOut& nout : x_node_outputs_) {
+ queue.push_back(nout.node);
+ visited.insert(nout.node);
}
// Going forward to figure out which endpoints need backprop-ed.
@@ -255,20 +251,19 @@ void SymbolicGradientBuilder::InitBackprop() {
}
{
- const int num_y = y_grad_nodes_.size();
+ const int num_y = y_grad_node_outputs_.size();
for (int i = 0; i < num_y; ++i) {
- Node* y = y_nodes_[i];
- Node* dy = y_grad_nodes_[i];
+ Node* y = y_node_outputs_[i].node;
for (const Edge* e : y->in_edges()) {
if (e->IsControlEdge()) continue;
- BackpropAlongEdge({dy, e->dst_input()}, {e->src(), e->src_output()});
+ BackpropAlongEdge(y_grad_node_outputs_[i], {e->src(), e->src_output()});
}
}
}
CHECK(!ready_.empty());
}
-Endpoint SymbolicGradientBuilder::SumGradients(const Endpoint& src) {
+NodeOut SymbolicGradientBuilder::SumGradients(const NodeOut& src) {
const DataType dtype = src.dtype();
auto iter = backprops_.find(src);
CHECK(iter != backprops_.end());
@@ -286,8 +281,8 @@ Endpoint SymbolicGradientBuilder::SumGradients(const Endpoint& src) {
NodeDef ndef;
ndef.set_name(graph_->NewName(kNodeLabel));
ndef.set_op("AddN"); // N-way Add
- for (const Endpoint& ep : grads) {
- ndef.add_input(ep.name());
+ for (const NodeOut& nout : grads) {
+ ndef.add_input(nout.name());
}
AddNodeAttr("N", static_cast<int64>(grads.size()), &ndef);
AddNodeAttr("T", dtype, &ndef);
@@ -295,8 +290,8 @@ Endpoint SymbolicGradientBuilder::SumGradients(const Endpoint& src) {
Node* add = graph_->AddNode(ndef, &s);
TF_CHECK_OK(s);
for (size_t i = 0; i < grads.size(); ++i) {
- const Endpoint& ep = grads[i];
- graph_->AddEdge(ep.node, ep.index, add, i);
+ const NodeOut& nout = grads[i];
+ graph_->AddEdge(nout.node, nout.index, add, i);
}
return {add, 0};
}
@@ -312,7 +307,7 @@ Status SymbolicGradientBuilder::Compute() {
InitBackprop();
// Backward propagation.
- gtl::InlinedVector<Endpoint, 8> dy;
+ gtl::InlinedVector<NodeOut, 8> dy;
while (!ready_.empty()) {
// n has collected all gradients.
Node* n = ready_.front();
@@ -324,11 +319,11 @@ Status SymbolicGradientBuilder::Compute() {
auto iter = stop_nodes_.find(n->id());
if (iter != stop_nodes_.end()) {
- // Stop backprop and add gradient sum to 'x_grad_nodes'.
+ // Stop backprop and add gradient sum to 'x_grad_node_outputs_'.
// TODO(andydavis) Support stop nodes with more than one output.
CHECK_EQ(1, num_y);
- Endpoint grad = SumGradients({n, 0});
- (*x_grad_nodes_)[iter->second] = {grad.node, grad.index};
+ const int index = iter->second;
+ (*x_grad_node_outputs_)[index] = SumGradients(x_node_outputs_[index]);
continue;
}
@@ -350,6 +345,7 @@ Status SymbolicGradientBuilder::Compute() {
// Adds a gradient node with num_x + num_y inputs and num_x
// outputs.
+ // TODO(andydavis) Support primitive gradient ops.
Node* grad = AddSymGrad(graph_, n, dy);
for (const Edge* e : n->in_edges()) {
if (e->IsControlEdge()) continue;
@@ -369,12 +365,13 @@ Status SymbolicGradientBuilder::Compute() {
return Status::OK();
}
-Status AddSymbolicGradients(gtl::ArraySlice<Node*> y_nodes,
- gtl::ArraySlice<Node*> x_nodes,
- gtl::ArraySlice<Node*> y_grad_nodes,
- std::vector<GradNodeOutput>* x_grad_nodes,
+Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
+ gtl::ArraySlice<NodeOut> x_node_outputs,
+ gtl::ArraySlice<NodeOut> y_grad_node_outputs,
+ std::vector<NodeOut>* x_grad_node_outputs,
Graph* graph) {
- SymbolicGradientBuilder builder(y_nodes, x_nodes, y_grad_nodes, x_grad_nodes,
+ SymbolicGradientBuilder builder(y_node_outputs, x_node_outputs,
+ y_grad_node_outputs, x_grad_node_outputs,
graph);
return builder.Compute();
}
diff --git a/tensorflow/core/graph/gradients.h b/tensorflow/core/graph/gradients.h
index bc18fd7cf2..a7d9613d79 100644
--- a/tensorflow/core/graph/gradients.h
+++ b/tensorflow/core/graph/gradients.h
@@ -16,40 +16,41 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
#define THIRD_PARTY_TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
-// GradNodeOutput represents a single gradient node output.
-struct GradNodeOutput {
+// Represents the output of 'node' at 'index'.
+struct NodeOut {
Node* node;
int index;
+
+ // Returns the string name that represents the output of this node.
+ string name() const;
+ // Returns the data type of the output of this node.
+ DataType dtype() const;
};
// NOTE: This API is a work in progress and will likely be changing frequently.
//
-// Given initial gradient nodes 'y_grad_nodes' (which compute the symbolic
-// partial derivatives of some loss function 'L' w.r.t the inputs of each
-// node in 'y_nodes'), adds gradient nodes to 'graph' that compute the sum
-// of all gradients flowing into the single output of each node in 'x_nodes'.
-// Note that gradient nodes will not be added to 'graph' which compute
-// the symbolic partial derivative of 'L' w.r.t. each node in 'x_nodes' (i.e.
-// backprop will stop at these nodes). This restriction will be lifted in
-// a subsequent CL.
+// Given initial gradient-node outputs 'y_grad_node_outputs' (which compute the
+// symbolic partial derivatives of some loss function 'L' w.r.t the node outputs
+// 'y_node_outputs'), adds gradient nodes to 'graph' that compute the symbolic
+// partial derivatives of 'L' w.r.t the node outputs 'x_node_outputs'.
//
-// REQUIRES: Each node in 'x_nodes' must have a single output (this
-// restriction will be removed in a subsequent change).
+// REQUIRES: Each node in 'x_node_outputs' to be unique, and so to have a single
+// output (this restriction will be removed in a subsequent change).
-// TODO(andydavis) Add support for returning 'x_node' gradients by endpoint
-// (i.e. {node, index}).
// TODO(andydavis) Add symbolic gradient support for general graphs (the current
// implementation only supports gradients for functions). In particular,
// the nodes in 'x_nodes' are currently restricted to have one output.
-Status AddSymbolicGradients(gtl::ArraySlice<Node*> y_nodes,
- gtl::ArraySlice<Node*> x_nodes,
- gtl::ArraySlice<Node*> y_grad_nodes,
- std::vector<GradNodeOutput>* x_grad_nodes,
+
+Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
+ gtl::ArraySlice<NodeOut> x_node_outputs,
+ gtl::ArraySlice<NodeOut> y_grad_node_outputs,
+ std::vector<NodeOut>* x_grad_node_outputs,
Graph* graph);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ebda2a2a6d..6b9e093baf 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -214,6 +214,21 @@ cc_header_only_library(
deps = [":bounds_check"],
)
+cc_library(
+ name = "image_resizer_state",
+ hdrs = ["image_resizer_state.h"],
+ visibility = ["//visibility:private"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_header_only_library(
+ name = "image_resizer_state_lib",
+ deps = [":image_resizer_state"],
+)
+
# OpKernel libraries ----------------------------------------------------------
tf_kernel_libraries(
@@ -221,7 +236,6 @@ tf_kernel_libraries(
prefixes = [
"bcast_ops",
"bitcast_op",
- "depthtospace_op",
"concat_op",
"constant_op",
"diag_op",
@@ -239,7 +253,6 @@ tf_kernel_libraries(
"reverse_sequence_op",
"shape_ops",
"slice_op",
- "spacetodepth_op",
"split_op",
"tile_ops",
"transpose_op",
@@ -250,6 +263,7 @@ tf_kernel_libraries(
deps = [
":bounds_check",
":concat_lib",
+ ":depth_space_ops",
":fill_functor",
":ops_util",
":split_lib",
@@ -545,6 +559,7 @@ tf_kernel_libraries(
"sample_distorted_bounding_box_op",
],
deps = [
+ ":image_resizer_state",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
@@ -830,6 +845,31 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "depth_space_ops",
+ srcs = [
+ "depthtospace_op.cc",
+ "spacetodepth_op.cc",
+ ],
+ hdrs = [
+ "depthtospace_op.h",
+ "spacetodepth_op.h",
+ ],
+ gpu_srcs = [
+ "depthtospace_op.h",
+ "depthtospace_op_gpu.cu.cc",
+ "spacetodepth_op.h",
+ "spacetodepth_op_gpu.cu.cc",
+ ],
+ visibility = ["//visibility:private"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 0,
+)
+
tf_kernel_libraries(
name = "parsing",
prefixes = [
@@ -1062,6 +1102,7 @@ filegroup(
"slice_op.h",
"softmax_op.cc",
"softmax_op.h",
+ "softmax_op_functor.h",
"split_lib.h",
"split_lib_cpu.cc",
"split_op.cc",
@@ -1095,10 +1136,12 @@ filegroup(
"batch_norm_op.h",
"control_flow_ops.h",
"conv_2d.h",
+ "image_resizer_state.h",
"maxpooling_op.h",
"reduction_ops.h",
"reduction_ops_common.h",
"relu_op.h",
+ "relu_op_functor.h",
"save_restore_tensor.h",
"softplus_op.h",
"softsign_op.h",
diff --git a/tensorflow/core/kernels/batch_matmul_op.cc b/tensorflow/core/kernels/batch_matmul_op.cc
index 306ae6d38c..f5a64e1f46 100644
--- a/tensorflow/core/kernels/batch_matmul_op.cc
+++ b/tensorflow/core/kernels/batch_matmul_op.cc
@@ -113,6 +113,39 @@ perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
perftools::gputools::DeviceMemory<T> typed(wrapped);
return typed;
}
+
+class CublasScratchAllocator : public perftools::gputools::ScratchAllocator {
+ public:
+ using Stream = ::perftools::gputools::Stream;
+ using DeviceMemoryBytes = ::perftools::gputools::DeviceMemory<uint8>;
+
+ CublasScratchAllocator(OpKernelContext* context) : context_(context) {}
+
+ int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
+
+ perftools::gputools::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
+ Stream* stream, int64 byte_size) override {
+ Tensor temporary_memory;
+
+ Status allocation_status(context_->allocate_temp(
+ DT_UINT8, TensorShape({byte_size}), &temporary_memory));
+ if (!allocation_status.ok()) {
+ return perftools::gputools::port::StatusOr<DeviceMemoryBytes>(
+ DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
+ }
+ // Hold the reference of the allocated tensors until the end of the
+ // allocator.
+ allocated_tensors_.push_back(temporary_memory);
+ return perftools::gputools::port::StatusOr<DeviceMemoryBytes>(
+ DeviceMemoryBytes::MakeFromByteSize(
+ temporary_memory.flat<uint8>().data(),
+ temporary_memory.flat<uint8>().size()));
+ }
+
+ private:
+ OpKernelContext* context_;
+ std::vector<Tensor> allocated_tensors_;
+};
} // namespace
template <typename Scalar>
@@ -162,12 +195,14 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
// where A, B and C are assumed to be in column major.
// We want the output to be in row-major, so we can compute
// C' = B' x A' (' stands for transpose)
+ CublasScratchAllocator scratch_allocator(context);
bool blas_launch_status =
- stream->ThenBlasGemmBatched(blas_transpose_b, blas_transpose_a, n, m, k,
- static_cast<Scalar>(1.0), b_ptrs,
- adj_y ? k : n, a_ptrs, adj_x ? m : k,
- static_cast<Scalar>(0.0), c_ptrs, n,
- batch_size)
+ stream
+ ->ThenBlasGemmBatchedWithScratch(
+ blas_transpose_b, blas_transpose_a, n, m, k,
+ static_cast<Scalar>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
+ adj_x ? m : k, static_cast<Scalar>(0.0), c_ptrs, n, batch_size,
+ &scratch_allocator)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal(
@@ -265,9 +300,7 @@ REGISTER_CPU(int32);
REGISTER_CPU(complex64);
#ifdef GOOGLE_CUDA
-// TODO(kalakris): The GPU implementation is currently disabled due to issues
-// encountered in practice. See b/24534272.
-// REGISTER_GPU(float);
+REGISTER_GPU(float);
#endif // GOOGLE_CUDA
#undef REGISTER_CPU
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 60f0474103..0e70bc31e8 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -45,7 +45,7 @@ class DecodeCSVOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("records", &records));
OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
- for (int i = 0; i < record_defaults.size(); ++i) {
+ for (int64 i = 0; i < record_defaults.size(); ++i) {
OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
errors::InvalidArgument(
"There should only be 1 default per field but field ", i,
@@ -53,7 +53,7 @@ class DecodeCSVOp : public OpKernel {
}
auto records_t = records->flat<string>();
- int records_size = records_t.size();
+ int64 records_size = records_t.size();
OpOutputList output;
OP_REQUIRES_OK(ctx, ctx->output_list("output", &output));
@@ -63,7 +63,7 @@ class DecodeCSVOp : public OpKernel {
output.allocate(i, records->shape(), &out);
}
- for (int i = 0; i < records_size; ++i) {
+ for (int64 i = 0; i < records_size; ++i) {
const StringPiece record(records_t(i));
std::vector<string> fields;
ExtractFields(ctx, record, &fields);
@@ -165,7 +165,7 @@ class DecodeCSVOp : public OpKernel {
void ExtractFields(OpKernelContext* ctx, StringPiece input,
std::vector<string>* result) {
- int current_idx = 0;
+ int64 current_idx = 0;
if (!input.empty()) {
while (static_cast<size_t>(current_idx) < input.size()) {
if (input[current_idx] == '\n' || input[current_idx] == '\r') {
diff --git a/tensorflow/core/kernels/depthtospace_op.cc b/tensorflow/core/kernels/depthtospace_op.cc
index 01d5c479ae..4355bda960 100644
--- a/tensorflow/core/kernels/depthtospace_op.cc
+++ b/tensorflow/core/kernels/depthtospace_op.cc
@@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/core/kernels/depthtospace_op.h"
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -60,8 +62,8 @@ class DepthToSpaceOp : public OpKernel {
"instead of: ", dims));
const int batch_size = input.dim_size(0);
- const int height = input.dim_size(1);
- const int width = input.dim_size(2);
+ const int input_height = input.dim_size(1);
+ const int input_width = input.dim_size(2);
const int input_depth = input.dim_size(3);
const int block_size_sq = block_size_ * block_size_;
@@ -73,40 +75,57 @@ class DepthToSpaceOp : public OpKernel {
"should be divisible by: ", block_size_sq));
const int output_depth = input_depth / block_size_sq;
- const int output_width = width * block_size_;
- const int output_height = height * block_size_;
+ const int output_width = input_width * block_size_;
+ const int output_height = input_height * block_size_;
// Allocate output tensor.
- Tensor* outputs_tensor = nullptr;
+ Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0, TensorShape({batch_size, output_height,
output_width, output_depth}),
- &outputs_tensor));
+ &output));
+
+ typename TTypes<T, 4>::ConstTensor Tinput = input.tensor<T, 4>();
+ typename TTypes<T, 4>::Tensor Toutput = output->tensor<T, 4>();
+
+ functor::DepthToSpaceOpFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), Tinput, block_size_, Toutput);
+ };
+
+ private:
+ int block_size_;
+};
- auto Toutput = outputs_tensor->tensor<T, 4>();
- auto Tinput = input.tensor<T, 4>();
+// Partial specialization of DepthToSpaceOpFunctor for a CPUDevice.
+namespace functor {
+template <typename T>
+struct DepthToSpaceOpFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
+ int block_size, typename TTypes<T, 4>::Tensor output) {
+ const int batch_size = output.dimension(0);
+ const int output_height = output.dimension(1);
+ const int output_width = output.dimension(2);
+ const int output_depth = output.dimension(3);
for (int b = 0; b < batch_size; ++b) {
for (int h = 0; h < output_height; ++h) {
- const int in_h = h / block_size_;
- const int offset_h = (h % block_size_);
+ const int in_h = h / block_size;
+ const int offset_h = (h % block_size);
for (int w = 0; w < output_width; ++w) {
- const int in_w = w / block_size_;
- const int offset_w = (w % block_size_);
+ const int in_w = w / block_size;
+ const int offset_w = (w % block_size);
const int offset_d =
- (offset_h * block_size_ + offset_w) * output_depth;
+ (offset_h * block_size + offset_w) * output_depth;
for (int d = 0; d < output_depth; ++d) {
const int in_d = d + offset_d;
- Toutput(b, h, w, d) = Tinput(b, in_h, in_w, in_d);
+ output(b, h, w, d) = input(b, in_h, in_w, in_d);
}
}
}
}
- };
-
- private:
- int block_size_;
+ }
};
+} // namespace functor
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
@@ -116,4 +135,10 @@ class DepthToSpaceOp : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER);
#undef REGISTER
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("DepthToSpace").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ DepthToSpaceOp<GPUDevice, float>);
+#endif // GOOGLE_CUDA
+
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/depthtospace_op.h b/tensorflow/core/kernels/depthtospace_op.h
new file mode 100644
index 0000000000..60c347d985
--- /dev/null
+++ b/tensorflow/core/kernels/depthtospace_op.h
@@ -0,0 +1,44 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DEPTHTOSPACE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_DEPTHTOSPACE_OP_H_
+// Functor definition for XentOp, must be compilable by nvcc.
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by DepthToSpaceOp to do the computations.
+template <typename Device, typename T>
+struct DepthToSpaceOpFunctor {
+ // Implements the depth to space conversion.
+ //
+ // input: 4-D input tensor.
+ // block_size: block size for the conversion.
+ // output: 4-D output tensor.
+ //
+ // The dimensions of the tensors are guaranteed to be correct when the
+ // functor is called.
+ void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
+ int block_size, typename TTypes<T, 4>::Tensor output);
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DEPTHTOSPACE_OP_H_
diff --git a/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc b/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc
new file mode 100644
index 0000000000..229222c305
--- /dev/null
+++ b/tensorflow/core/kernels/depthtospace_op_gpu.cu.cc
@@ -0,0 +1,88 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/depthtospace_op.h"
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename dtype>
+__global__ void D2S(const int32 nthreads, const dtype* input_ptr,
+ const int block_size, const int batch_size,
+ const int input_height, const int input_width,
+ const int input_depth, const int output_height,
+ const int output_width, const int output_depth,
+ dtype* output_ptr) {
+ CUDA_1D_KERNEL_LOOP(out_idx, nthreads) {
+ // out_idx = d + output_depth * (w + output_width * (h + output_height * b))
+ const int d = out_idx % output_depth;
+ const int out_idx2 = out_idx / output_depth;
+ const int w = out_idx2 % output_width;
+ const int out_idx3 = out_idx2 / output_width;
+ const int h = out_idx3 % output_height;
+ const int b = out_idx3 / output_height;
+
+ const int in_h = h / block_size;
+ const int offset_h = h % block_size;
+ const int in_w = w / block_size;
+ const int offset_w = w % block_size;
+ const int offset_d = (offset_h * block_size + offset_w) * output_depth;
+ const int in_d = d + offset_d;
+ const int inp_idx =
+ in_d + input_depth * (in_w + input_width * (in_h + input_height * b));
+ *(output_ptr + out_idx) = ldg(input_ptr + inp_idx);
+ }
+}
+
+// Specialization of DepthToSpaceOpFunctor for a GPUDevice.
+namespace functor {
+template <typename T>
+struct DepthToSpaceOpFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
+ int block_size, typename TTypes<T, 4>::Tensor output) {
+ const int batch_size = output.dimension(0);
+ const int input_height = input.dimension(1);
+ const int input_width = input.dimension(2);
+ const int input_depth = input.dimension(3);
+ const int output_height = output.dimension(1);
+ const int output_width = output.dimension(2);
+ const int output_depth = output.dimension(3);
+
+ const int total_count =
+ batch_size * output_height * output_width * output_depth;
+ CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
+ D2S<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ config.virtual_thread_count, input.data(), block_size, batch_size,
+ input_height, input_width, input_depth, output_height, output_width,
+ output_depth, output.data());
+ }
+};
+} // end namespace functor
+
+// Instantiate the GPU implementation for float.
+template struct functor::DepthToSpaceOpFunctor<GPUDevice, float>;
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/image_resizer_state.h b/tensorflow/core/kernels/image_resizer_state.h
new file mode 100644
index 0000000000..776d4d56e1
--- /dev/null
+++ b/tensorflow/core/kernels/image_resizer_state.h
@@ -0,0 +1,111 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is a helper struct to package up the input and ouput
+// parameters of an image resizer (the height, widths, etc.). To
+// reduce code duplication and ensure consistency across the different
+// resizers, it performs the input validation.
+
+#ifndef TENSORFLOW_KERNELS_IMAGE_RESIZER_STATE_H_
+#define TENSORFLOW_KERNELS_IMAGE_RESIZER_STATE_H_
+
+#define EIGEN_USE_THREADS
+
+#include <math.h>
+#include <algorithm>
+#include <array>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+
+namespace tensorflow {
+
+struct ImageResizerState {
+ explicit ImageResizerState(bool align_corners)
+ : align_corners_(align_corners) {}
+
+ // ValidateAndCreateOutput checks the bounds on the input tensors
+ // and requested size, sets up some of the resizing state such as the
+ // height_scale and width_scale, and allocates the output.
+ // If any of these operations fails, it sets an error status in
+ // the context, which the caller must check.
+ void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ const Tensor& shape_t = context->input(1);
+ OP_REQUIRES(context, shape_t.dims() == 1,
+ errors::InvalidArgument("shape_t must be 1-dimensional",
+ shape_t.shape().DebugString()));
+ OP_REQUIRES(context, shape_t.NumElements() == 2,
+ errors::InvalidArgument("shape_t must have two elements",
+ shape_t.shape().DebugString()));
+ auto Svec = shape_t.vec<int32>();
+ batch_size = input.dim_size(0);
+ out_height = internal::SubtleMustCopy(Svec(0));
+ out_width = internal::SubtleMustCopy(Svec(1));
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(input.dim_size(1), std::numeric_limits<int32>::max()) &&
+ FastBoundsCheck(input.dim_size(2),
+ std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("input sizes must be between 0 and max int32"));
+
+ in_height = static_cast<int32>(input.dim_size(1));
+ in_width = static_cast<int32>(input.dim_size(2));
+ channels = input.dim_size(3);
+ OP_REQUIRES(context, out_height > 0 && out_width > 0,
+ errors::InvalidArgument("output dimensions must be positive"));
+ OP_REQUIRES(
+ context, channels > 0,
+ errors::InvalidArgument("image must have at least one channel"));
+ OP_REQUIRES(
+ context, input.dim_size(1) > 0 && input.dim_size(2) > 0,
+ errors::InvalidArgument("input image must be of non-zero size"));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({input.dim_size(0), out_height,
+ out_width, input.dim_size(3)}),
+ &output));
+
+ height_scale = (align_corners_ && out_height > 1)
+ ? (in_height - 1) / static_cast<float>(out_height - 1)
+ : in_height / static_cast<float>(out_height);
+ width_scale = (align_corners_ && out_width > 1)
+ ? (in_width - 1) / static_cast<float>(out_width - 1)
+ : in_width / static_cast<float>(out_width);
+ }
+
+ int64 batch_size;
+ int64 out_height;
+ int64 out_width;
+ int64 in_height;
+ int64 in_width;
+ int64 channels;
+ float height_scale;
+ float width_scale;
+ Tensor* output;
+
+ private:
+ bool align_corners_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_IMAGE_RESIZER_STATE_H_
diff --git a/tensorflow/core/kernels/nn_ops_test.cc b/tensorflow/core/kernels/nn_ops_test.cc
index 05808840f9..dddb8bbb4b 100644
--- a/tensorflow/core/kernels/nn_ops_test.cc
+++ b/tensorflow/core/kernels/nn_ops_test.cc
@@ -492,6 +492,8 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
// OD: output_depth
// KR: kernel_rows
// KC: kernel_cols
+// STR: stride
+// PAD: padding
#define BM_ConvFloatDepthwiseFwd(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, \
LABEL) \
@@ -509,12 +511,25 @@ static void BM_ConvFloatDepthwise(int iters, int batch, int rows, int cols,
strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
} \
+ static void BM_ConvFloatDepthwiseFwdGPU_##LABEL(int iters) { \
+ BM_ConvFloatDepthwise( \
+ iters, BS, R, C, ID, DM, OD, KR, KC, DEPTHWISE_CONV_OP_FWD, 1, STR, \
+ PAD, true, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", DM, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
+ } \
BENCHMARK(BM_ConvFloatDepthwiseFwdCPU1_##LABEL); \
- BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL)
+ BENCHMARK(BM_ConvFloatDepthwiseFwdCPU4_##LABEL); \
+ BENCHMARK(BM_ConvFloatDepthwiseFwdGPU_##LABEL);
-// TODO(andydavis,jmchen) Add more benchmarks.
+// The configurations below are mostly from mobilenet models.
BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 1, SAME, conv0);
BM_ConvFloatDepthwiseFwd(32, 112, 112, 64, 1, 64, 3, 3, 1, SAME, conv1);
+BM_ConvFloatDepthwiseFwd(32, 56, 56, 128, 1, 128, 3, 3, 1, SAME, conv2);
+BM_ConvFloatDepthwiseFwd(32, 56, 56, 128, 1, 128, 3, 3, 2, SAME, conv3);
+BM_ConvFloatDepthwiseFwd(32, 28, 28, 128, 1, 128, 3, 3, 1, SAME, conv4);
+BM_ConvFloatDepthwiseFwd(32, 14, 14, 512, 1, 512, 3, 3, 1, SAME, conv5);
+BM_ConvFloatDepthwiseFwd(32, 7, 7, 1024, 1, 1024, 3, 3, 1, SAME, conv6);
static void BM_LRNFloat(int iters, int depth, int cols, int rows,
int batch_size, int range, int num_threads,
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index b70c9657b2..899011417f 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -30,147 +30,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-template <typename Device, typename T>
-class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> {
- public:
- using UnaryElementWiseOp<T, ReluOp<Device, T>>::UnaryElementWiseOp;
-
- void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
- functor::Relu<Device, T> functor;
- functor(context->eigen_device<Device>(), input.flat<T>(),
- output->flat<T>());
- }
-};
-
-// Out of line check to save code space (we have this code once, rather
-// than once for every NDIMS * NumTypes * Num_different_relu_variants
-// functions.
-static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
- const Tensor& a) {
- OP_REQUIRES(context, a.IsSameSize(g),
- errors::InvalidArgument("g and a must be the same size"));
-}
-static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
- const Tensor& a) {
- ValidateSameSizeHelper(context, g, a);
- return context->status().ok();
-}
-
-template <typename Device, typename T>
-class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
- public:
- using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
-
- void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
- const Tensor& a, Tensor* output);
-
- // INPUTS:
- // g (gradients): backpropagated gradients
- // a (inputs): either the inputs that were passed to ReluOp(), or its
- // outputs (using either one yields the same result here).
- // OUTPUT:
- // gradients to backprop
- template <int NDIMS>
- void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
- Tensor* output) {
- OperateNoTemplate(context, g, a, output);
- }
-};
-
-template <typename Device, typename T>
-void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
- const Tensor& g, const Tensor& a,
- Tensor* output) {
- if (!ValidateSameSize(context, g, a)) return;
- functor::ReluGrad<Device, T> functor;
- functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
- output->flat<T>());
-}
-
-template <typename Device, typename T>
-class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
- public:
- using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp;
-
- void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
- functor::Relu6<Device, T> functor;
- functor(context->eigen_device<Device>(), input.flat<T>(),
- output->flat<T>());
- }
-};
-
-template <typename Device, typename T>
-class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
- public:
- using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
-
- void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
- const Tensor& a, Tensor* output);
-
- // INPUTS:
- // g (gradients): backpropagated gradients
- // a (inputs): inputs that were passed to Relu6Op()
- // OUTPUT:
- // gradients to backprop
- template <int NDIMS>
- void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
- Tensor* output) {
- OperateNoTemplate(context, g, a, output);
- }
-};
-
-template <typename Device, typename T>
-void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
- const Tensor& g, const Tensor& a,
- Tensor* output) {
- if (!ValidateSameSize(context, g, a)) return;
- functor::Relu6Grad<Device, T> functor;
- functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
- output->flat<T>());
-}
-
-template <typename Device, typename T>
-class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
- public:
- using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
-
- void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
- functor::Elu<Device, T> functor;
- functor(context->eigen_device<Device>(), input.flat<T>(),
- output->flat<T>());
- }
-};
-
-template <typename Device, typename T>
-class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
- public:
- using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
-
- void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
- const Tensor& a, Tensor* output);
-
- // INPUTS:
- // g (gradients): backpropagated gradients
- // a (outputs): outputs of the EluOp()
- // OUTPUT:
- // gradients to backprop
- template <int NDIMS>
- void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
- Tensor* output) {
- OperateNoTemplate(context, g, a, output);
- }
-};
-
-template <typename Device, typename T>
-void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
- const Tensor& g, const Tensor& a,
- Tensor* output) {
- if (!ValidateSameSize(context, g, a)) return;
- functor::EluGrad<Device, T> functor;
- functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
- output->flat<T>());
-}
-
#define REGISTER_RELU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index b8431aeded..b41be2dfe3 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -13,118 +13,168 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// See docs in ../ops/nn_ops.cc.
+
#ifndef TENSORFLOW_KERNELS_RELU_OP_H_
#define TENSORFLOW_KERNELS_RELU_OP_H_
-// Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc.
+
+#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/relu_op_functor.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
-namespace functor {
-// Functor used by ReluOp to do the computations.
template <typename Device, typename T>
-struct Relu {
- // Computes Relu activation.
- //
- // features: any shape.
- // activations: same shape as "features".
- void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
- typename TTypes<T>::Tensor activations) {
- activations.device(d) = features.cwiseMax(static_cast<T>(0));
+class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> {
+ public:
+ using UnaryElementWiseOp<T, ReluOp<Device, T>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::Relu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+// Out of line check to save code space (we have this code once, rather
+// than once for every NDIMS * NumTypes * Num_different_relu_variants
+// functions.
+struct ReluHelpers {
+ static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
+ const Tensor& a) {
+ OP_REQUIRES(context, a.IsSameSize(g),
+ errors::InvalidArgument("g and a must be the same size"));
+ }
+ static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
+ const Tensor& a) {
+ ValidateSameSizeHelper(context, g, a);
+ return context->status().ok();
}
};
-// Functor used by ReluGradOp to do the computations.
template <typename Device, typename T>
-struct ReluGrad {
- // Computes ReluGrad backprops.
- //
- // gradients: gradients backpropagated to the Relu op.
- // features: either the inputs that were passed to the Relu or, or its
- // outputs (using either one yields the same result here).
- // backprops: gradients to backpropagate to the Relu inputs.
- void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
- typename TTypes<T>::ConstTensor features,
- typename TTypes<T>::Tensor backprops) {
- // NOTE: When the activation is exactly zero, we do not propagate the
- // associated gradient value. This allows the output of the Relu to be used,
- // as well as its input.
- backprops.device(d) =
- gradients * (features > features.constant(static_cast<T>(0)));
+class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
+ public:
+ using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): either the inputs that were passed to ReluOp(), or its
+ // outputs (using either one yields the same result here).
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, output);
}
};
-// Functor used by Relu6Op to do the computations.
template <typename Device, typename T>
-struct Relu6 {
- // Computes Relu6 activation.
- //
- // features: any shape.
- // activations: same shape as "features".
- void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
- typename TTypes<T>::Tensor activations) {
- activations.device(d) =
- features.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(6));
+void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::ReluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
+
+template <typename Device, typename T>
+class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
+ public:
+ using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::Relu6<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(),
+ output->flat<T>());
}
};
-// Functor used by ReluGradOp to do the computations.
template <typename Device, typename T>
-struct Relu6Grad {
- // Computes Relu6Grad backprops.
- //
- // gradients: gradients backpropagated to the Relu6 op.
- // features: inputs that where passed to the Relu6 op.
- // backprops: gradients to backpropagate to the Relu6 inputs.
- void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
- typename TTypes<T>::ConstTensor features,
- typename TTypes<T>::Tensor backprops) {
- // NOTE: When the activation is exactly zero or six, we
- // arbitrarily choose to not propagate the associated gradient
- // value.
- backprops.device(d) = gradients *
- (features > features.constant(static_cast<T>(0))) *
- (features < features.constant(static_cast<T>(6)));
+class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
+ public:
+ using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): inputs that were passed to Relu6Op()
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, output);
}
};
-// Functor used by EluOp to do the computations.
template <typename Device, typename T>
-struct Elu {
- // Computes Elu activation.
- //
- // features: any shape.
- // activations: same shape as "features".
- void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
- typename TTypes<T>::Tensor activations) {
- // features.constant(?)
- activations.device(d) =
- (features < static_cast<T>(0))
- .select(features.exp() - features.constant(static_cast<T>(1)),
- features);
+void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::Relu6Grad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
+
+template <typename Device, typename T>
+class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
+ public:
+ using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::Elu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(),
+ output->flat<T>());
}
};
-// Functor used by EluGradOp to do the computations.
template <typename Device, typename T>
-struct EluGrad {
- // Computes EluGrad backprops.
- //
- // gradients: gradients backpropagated to the Elu op.
- // activations: outputs of the Elu op.
- // backprops: gradients to backpropagate to the Elu inputs.
- void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
- typename TTypes<T>::ConstTensor activations,
- typename TTypes<T>::Tensor backprops) {
- backprops.device(d) =
- (activations < static_cast<T>(0))
- .select((activations + static_cast<T>(1)) * gradients, gradients);
+class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
+ public:
+ using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (outputs): outputs of the EluOp()
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, output);
}
};
-} // namespace functor
+template <typename Device, typename T>
+void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::EluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+}
+
} // namespace tensorflow
+#undef EIGEN_USE_THREADS
+
#endif // TENSORFLOW_KERNELS_RELU_OP_H_
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
new file mode 100644
index 0000000000..5d732a6141
--- /dev/null
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -0,0 +1,130 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
+#define TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
+// Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc.
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by ReluOp to do the computations.
+template <typename Device, typename T>
+struct Relu {
+ // Computes Relu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor activations) {
+ activations.device(d) = features.cwiseMax(static_cast<T>(0));
+ }
+};
+
+// Functor used by ReluGradOp to do the computations.
+template <typename Device, typename T>
+struct ReluGrad {
+ // Computes ReluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the Relu op.
+ // features: either the inputs that were passed to the Relu or, or its
+ // outputs (using either one yields the same result here).
+ // backprops: gradients to backpropagate to the Relu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor backprops) {
+ // NOTE: When the activation is exactly zero, we do not propagate the
+ // associated gradient value. This allows the output of the Relu to be used,
+ // as well as its input.
+ backprops.device(d) =
+ gradients * (features > features.constant(static_cast<T>(0)));
+ }
+};
+
+// Functor used by Relu6Op to do the computations.
+template <typename Device, typename T>
+struct Relu6 {
+ // Computes Relu6 activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor activations) {
+ activations.device(d) =
+ features.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(6));
+ }
+};
+
+// Functor used by ReluGradOp to do the computations.
+template <typename Device, typename T>
+struct Relu6Grad {
+ // Computes Relu6Grad backprops.
+ //
+ // gradients: gradients backpropagated to the Relu6 op.
+ // features: inputs that where passed to the Relu6 op.
+ // backprops: gradients to backpropagate to the Relu6 inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor backprops) {
+ // NOTE: When the activation is exactly zero or six, we
+ // arbitrarily choose to not propagate the associated gradient
+ // value.
+ backprops.device(d) = gradients *
+ (features > features.constant(static_cast<T>(0))) *
+ (features < features.constant(static_cast<T>(6)));
+ }
+};
+
+// Functor used by EluOp to do the computations.
+template <typename Device, typename T>
+struct Elu {
+ // Computes Elu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor activations) {
+ // features.constant(?)
+ activations.device(d) =
+ (features < static_cast<T>(0))
+ .select(features.exp() - features.constant(static_cast<T>(1)),
+ features);
+ }
+};
+
+// Functor used by EluGradOp to do the computations.
+template <typename Device, typename T>
+struct EluGrad {
+ // Computes EluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the Elu op.
+ // activations: outputs of the Elu op.
+ // backprops: gradients to backpropagate to the Elu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor activations,
+ typename TTypes<T>::Tensor backprops) {
+ backprops.device(d) =
+ (activations < static_cast<T>(0))
+ .select((activations + static_cast<T>(1)) * gradients, gradients);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index 6451619768..0a12c854b8 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <stdio.h>
-#include "tensorflow/core/kernels/relu_op.h"
+#include "tensorflow/core/kernels/relu_op_functor.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc
index 5c4b1cdb12..5bf064f159 100644
--- a/tensorflow/core/kernels/resize_area_op.cc
+++ b/tensorflow/core/kernels/resize_area_op.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/image_resizer_state.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
@@ -40,49 +41,22 @@ class ResizeAreaOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- const Tensor& shape_t = context->input(1);
- OP_REQUIRES(context, shape_t.dims() == 1,
- errors::InvalidArgument("shape_t must be 1-dimensional",
- shape_t.shape().DebugString()));
- OP_REQUIRES(context, shape_t.NumElements() == 2,
- errors::InvalidArgument("shape_t must have two elements",
- shape_t.shape().DebugString()));
-
- auto Svec = shape_t.vec<int32>();
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- 0, TensorShape({input.dim_size(0), Svec(0),
- Svec(1), input.dim_size(3)}),
- &output));
- const int64 batch_size = input.dim_size(0);
- const int64 in_height = input.dim_size(1);
- const int64 in_width = input.dim_size(2);
- const int64 channels = input.dim_size(3);
- const int64 out_height = output->dim_size(1);
- const int64 out_width = output->dim_size(2);
+ ImageResizerState st(align_corners_);
+ st.ValidateAndCreateOutput(context, input);
+
+ if (!context->status().ok()) return;
typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
- typename TTypes<float, 4>::Tensor output_data = output->tensor<float, 4>();
+ typename TTypes<float, 4>::Tensor output_data =
+ st.output->tensor<float, 4>();
// A temporary tensor for computing the sum.
Tensor sum_tensor;
- OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<float>::value,
- TensorShape({channels}), &sum_tensor));
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::value,
+ TensorShape({st.channels}),
+ &sum_tensor));
typename TTypes<float, 1>::Tensor sum_data = sum_tensor.vec<float>();
- const float height_scale =
- (align_corners_ && out_height > 1)
- ? (in_height - 1) / static_cast<float>(out_height - 1)
- : in_height / static_cast<float>(out_height);
- const float width_scale =
- (align_corners_ && out_width > 1)
- ? (in_width - 1) / static_cast<float>(out_width - 1)
- : in_width / static_cast<float>(out_width);
-
// When using this algorithm for downsizing, the target pixel value is the
// weighted average of all the source pixels. The weight is determined by
// the contribution percentage of the source pixel.
@@ -102,19 +76,19 @@ class ResizeAreaOp : public OpKernel {
// out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale
// out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale
// out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale
- float scale = 1.0 / (height_scale * width_scale);
- for (int64 b = 0; b < batch_size; ++b) {
- for (int64 y = 0; y < out_height; ++y) {
- const float in_y = y * height_scale;
- const float in_y1 = (y + 1) * height_scale;
+ float scale = 1.0 / (st.height_scale * st.width_scale);
+ for (int64 b = 0; b < st.batch_size; ++b) {
+ for (int64 y = 0; y < st.out_height; ++y) {
+ const float in_y = y * st.height_scale;
+ const float in_y1 = (y + 1) * st.height_scale;
// The start and end height indices of all the cells that could
// contribute to the target cell.
int64 y_start = floor(in_y);
int64 y_end = ceil(in_y1);
- for (int64 x = 0; x < out_width; ++x) {
- const float in_x = x * width_scale;
- const float in_x1 = (x + 1) * width_scale;
+ for (int64 x = 0; x < st.out_width; ++x) {
+ const float in_x = x * st.width_scale;
+ const float in_x1 = (x + 1) * st.width_scale;
// The start and end width indices of all the cells that could
// contribute to the target cell.
int64 x_start = floor(in_x);
@@ -127,16 +101,16 @@ class ResizeAreaOp : public OpKernel {
for (int64 j = x_start; j < x_end; ++j) {
float scale_x =
j < in_x ? j + 1 - in_x : (j + 1 > in_x1 ? in_x1 - j : 1.0);
- for (int64 c = 0; c < channels; ++c) {
+ for (int64 c = 0; c < st.channels; ++c) {
#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val))))
- sum_data(c) +=
- input_data(b, BOUND(i, in_height), BOUND(j, in_width), c) *
- scale_y * scale_x * scale;
+ sum_data(c) += input_data(b, BOUND(i, st.in_height),
+ BOUND(j, st.in_width), c) *
+ scale_y * scale_x * scale;
#undef BOUND
}
}
}
- for (int64 c = 0; c < channels; ++c) {
+ for (int64 c = 0; c < st.channels; ++c) {
output_data(b, y, x, c) = sum_data(c);
}
}
diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc
index f81383984b..ce6c920bd1 100644
--- a/tensorflow/core/kernels/resize_bicubic_op.cc
+++ b/tensorflow/core/kernels/resize_bicubic_op.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/image_resizer_state.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
@@ -92,62 +93,28 @@ class ResizeBicubicOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- const Tensor& shape_t = context->input(1);
- OP_REQUIRES(context, shape_t.dims() == 1,
- errors::InvalidArgument("shape_t must be 1-dimensional",
- shape_t.shape().DebugString()));
- OP_REQUIRES(context, shape_t.NumElements() == 2,
- errors::InvalidArgument("shape_t must have two elements",
- shape_t.shape().DebugString()));
-
- auto Svec = shape_t.vec<int32>();
- // Initialize shape to the batch size of the input, then add
- // the rest of the dimensions
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- 0, TensorShape({input.dim_size(0), Svec(0),
- Svec(1), input.dim_size(3)}),
- &output));
- const int64 batch_size = input.dim_size(0);
- const int64 in_height = input.dim_size(1);
- const int64 in_width = input.dim_size(2);
- const int64 channels = input.dim_size(3);
- const int64 out_height = output->dim_size(1);
- const int64 out_width = output->dim_size(2);
- CHECK_GT(in_height, 0);
- CHECK_GT(in_width, 0);
- CHECK_GT(channels, 0);
- CHECK_GT(out_height, 0);
- CHECK_GT(out_width, 0);
+ ImageResizerState st(align_corners_);
+ st.ValidateAndCreateOutput(context, input);
- typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
- typename TTypes<float, 4>::Tensor output_data = output->tensor<float, 4>();
+ if (!context->status().ok()) return;
- const float height_scale =
- (align_corners_ && out_height > 1)
- ? (in_height - 1) / static_cast<float>(out_height - 1)
- : in_height / static_cast<float>(out_height);
- const float width_scale =
- (align_corners_ && out_width > 1)
- ? (in_width - 1) / static_cast<float>(out_width - 1)
- : in_width / static_cast<float>(out_width);
+ typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
+ typename TTypes<float, 4>::Tensor output_data =
+ st.output->tensor<float, 4>();
std::array<float, 4> coeff = {{0.0, 0.0, 0.0, 0.0}};
- for (int64 b = 0; b < batch_size; ++b) {
- for (int64 y = 0; y < out_height; ++y) {
+ for (int64 b = 0; b < st.batch_size; ++b) {
+ for (int64 y = 0; y < st.out_height; ++y) {
std::array<float, 4> y_weights;
std::array<int64, 4> y_indices;
- GetWeightsAndIndices(height_scale, y, in_height, &y_weights,
+ GetWeightsAndIndices(st.height_scale, y, st.in_height, &y_weights,
&y_indices);
- for (int64 x = 0; x < out_width; ++x) {
+ for (int64 x = 0; x < st.out_width; ++x) {
std::array<float, 4> x_weights;
std::array<int64, 4> x_indices;
- GetWeightsAndIndices(width_scale, x, in_width, &x_weights,
+ GetWeightsAndIndices(st.width_scale, x, st.in_width, &x_weights,
&x_indices);
- for (int64 c = 0; c < channels; ++c) {
+ for (int64 c = 0; c < st.channels; ++c) {
// Use a 4x4 patch to compute the interpolated output value at
// (b, y, x, c).
for (int64 i = 0; i < 4; ++i) {
diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc
index ebf9532d7b..bdf60f4c4f 100644
--- a/tensorflow/core/kernels/resize_bilinear_op.cc
+++ b/tensorflow/core/kernels/resize_bilinear_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/image_resizer_state.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
@@ -39,64 +40,29 @@ class ResizeBilinearOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- const Tensor& shape_t = context->input(1);
- OP_REQUIRES(context, shape_t.dims() == 1,
- errors::InvalidArgument("shape_t must be 1-dimensional",
- shape_t.shape().DebugString()));
- OP_REQUIRES(context, shape_t.NumElements() == 2,
- errors::InvalidArgument("shape_t must have two elements",
- shape_t.shape().DebugString()));
-
- auto Svec = shape_t.vec<int32>();
- // Initialize shape to the batch size of the input, then add
- // the rest of the dimensions
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- 0, TensorShape({input.dim_size(0), Svec(0),
- Svec(1), input.dim_size(3)}),
- &output));
+ ImageResizerState st(align_corners_);
+ st.ValidateAndCreateOutput(context, input);
- const int64 batch_size = input.dim_size(0);
- const int64 in_height = input.dim_size(1);
- const int64 in_width = input.dim_size(2);
- const int64 channels = input.dim_size(3);
- const int64 out_height = output->dim_size(1);
- const int64 out_width = output->dim_size(2);
- CHECK_GT(in_height, 0);
- CHECK_GT(in_width, 0);
- CHECK_GT(channels, 0);
- CHECK_GT(out_height, 0);
- CHECK_GT(out_width, 0);
+ if (!context->status().ok()) return;
typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
- typename TTypes<float, 4>::Tensor output_data = output->tensor<float, 4>();
+ typename TTypes<float, 4>::Tensor output_data =
+ st.output->tensor<float, 4>();
- const float height_scale =
- (align_corners_ && out_height > 1)
- ? (in_height - 1) / static_cast<float>(out_height - 1)
- : in_height / static_cast<float>(out_height);
- const float width_scale =
- (align_corners_ && out_width > 1)
- ? (in_width - 1) / static_cast<float>(out_width - 1)
- : in_width / static_cast<float>(out_width);
-
- for (int b = 0; b < batch_size; ++b) {
- for (int y = 0; y < out_height; ++y) {
- const float in_y = y * height_scale;
+ for (int b = 0; b < st.batch_size; ++b) {
+ for (int y = 0; y < st.out_height; ++y) {
+ const float in_y = y * st.height_scale;
const int top_y_index = static_cast<int>(floorf(in_y));
const int bottom_y_index =
- std::min(static_cast<int64>(ceilf(in_y)), (in_height - 1));
+ std::min(static_cast<int64>(ceilf(in_y)), (st.in_height - 1));
const float y_lerp = in_y - top_y_index;
- for (int x = 0; x < out_width; ++x) {
- const float in_x = x * width_scale;
+ for (int x = 0; x < st.out_width; ++x) {
+ const float in_x = x * st.width_scale;
const int left_x_index = static_cast<int>(floorf(in_x));
const int right_x_index =
- std::min(static_cast<int64>(ceilf(in_x)), (in_width - 1));
+ std::min(static_cast<int64>(ceilf(in_x)), (st.in_width - 1));
const float x_lerp = in_x - left_x_index;
- for (int c = 0; c < channels; ++c) {
+ for (int c = 0; c < st.channels; ++c) {
const float top_left = input_data(b, top_y_index, left_x_index, c);
const float top_right =
input_data(b, top_y_index, right_x_index, c);
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
index 281e0feb39..61b89fb9a5 100644
--- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/image_resizer_state.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
@@ -44,56 +45,28 @@ class ResizeNearestNeighborOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- const Tensor& shape_t = context->input(1);
- OP_REQUIRES(context, shape_t.dims() == 1,
- errors::InvalidArgument("shape_t must be 1-dimensional",
- shape_t.shape().DebugString()));
- OP_REQUIRES(context, shape_t.NumElements() == 2,
- errors::InvalidArgument("shape_t must have two elements",
- shape_t.shape().DebugString()));
+ ImageResizerState st(align_corners_);
+ st.ValidateAndCreateOutput(context, input);
- auto sizes = shape_t.vec<int32>();
- OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0,
- errors::InvalidArgument("shape_t's elements must be positive"));
-
- // Initialize shape to the batch size of the input, then add
- // the rest of the dimensions
- Tensor* output = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output(0, TensorShape({input.dim_size(0), sizes(0),
- sizes(1), input.dim_size(3)}),
- &output));
+ if (!context->status().ok()) return;
- const int64 batch_size = input.dim_size(0);
- const int64 in_height = input.dim_size(1);
- const int64 in_width = input.dim_size(2);
- const int64 channels = input.dim_size(3);
- const int64 out_height = output->dim_size(1);
- const int64 out_width = output->dim_size(2);
+ OP_REQUIRES(context, st.in_height < (1 << 24) && st.in_width < (1 << 24),
+ errors::InvalidArgument("nearest neighbor requires max height "
+ "& width of 2^24"));
typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
- typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
-
- const float height_scale =
- (align_corners_ && out_height > 1)
- ? (in_height - 1) / static_cast<float>(out_height - 1)
- : in_height / static_cast<float>(out_height);
- const float width_scale =
- (align_corners_ && out_width > 1)
- ? (in_width - 1) / static_cast<float>(out_width - 1)
- : in_width / static_cast<float>(out_width);
-
- for (int b = 0; b < batch_size; ++b) {
- for (int y = 0; y < out_height; ++y) {
- const int in_y = std::min(static_cast<int64>(floorf(y * height_scale)),
- (in_height - 1));
- for (int x = 0; x < out_width; ++x) {
- const int in_x = std::min(static_cast<int64>(floorf(x * width_scale)),
- (in_width - 1));
- for (int c = 0; c < channels; ++c) {
+ typename TTypes<T, 4>::Tensor output_data = st.output->tensor<T, 4>();
+
+ for (int b = 0; b < st.batch_size; ++b) {
+ for (int y = 0; y < st.out_height; ++y) {
+ const int in_y =
+ std::min(static_cast<int64>(floorf(y * st.height_scale)),
+ (st.in_height - 1));
+ for (int x = 0; x < st.out_width; ++x) {
+ const int in_x =
+ std::min(static_cast<int64>(floorf(x * st.width_scale)),
+ (st.in_width - 1));
+ for (int c = 0; c < st.channels; ++c) {
output_data(b, y, x, c) = input_data(b, in_y, in_x, c);
}
}
diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc
index 038efbe31a..305a91fecf 100644
--- a/tensorflow/core/kernels/softmax_op.cc
+++ b/tensorflow/core/kernels/softmax_op.cc
@@ -28,29 +28,6 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-template <typename Device, typename T>
-class SoftmaxOp : public OpKernel {
- public:
- explicit SoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {
- log_ = StringPiece(name()).starts_with("Log");
- }
-
- void Compute(OpKernelContext* context) override {
- const Tensor& logits_in = context->input(0);
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
- errors::InvalidArgument("logits must be 2-dimensional"));
- Tensor* softmax_out = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output(0, logits_in.shape(), &softmax_out));
- functor::SoftmaxFunctor<Device, T> functor;
- functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
- softmax_out->matrix<T>(), log_);
- }
-
- private:
- bool log_;
-};
-
// Partial specialization for a CPUDevice, that uses the Eigen implementation
// from SoftmaxEigenImpl.
namespace functor {
diff --git a/tensorflow/core/kernels/softmax_op.h b/tensorflow/core/kernels/softmax_op.h
index 6e0064bd5b..df78f85cc2 100644
--- a/tensorflow/core/kernels/softmax_op.h
+++ b/tensorflow/core/kernels/softmax_op.h
@@ -13,89 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// See docs in ../ops/nn_ops.cc.
+
#ifndef TENSORFLOW_KERNELS_SOFTMAX_OP_H_
#define TENSORFLOW_KERNELS_SOFTMAX_OP_H_
-// Functor definition for SoftmaxOp, must be compilable by nvcc.
+
+#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/softmax_op_functor.h"
namespace tensorflow {
-namespace functor {
-
-// Functor used by SoftmaxOp to do the computations.
-template <typename Device, typename T>
-struct SoftmaxFunctor {
- // Computes Softmax or LogSoftmax activation.
- //
- // logits: dim: batch_size, num_classes.
- // softmax: dims: batch_size, num_classes.
- // log: boolean
- void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
- typename TTypes<T>::Matrix softmax, const bool log);
-};
-// Eigen code implementing SoftmaxFunctor::operator() or
-// LogSoftmaxFunctor::operator().
-// This code works for both CPU and GPU and is used by the functor
-// specializations for both device types.
template <typename Device, typename T>
-struct SoftmaxEigenImpl {
- static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits,
- typename TTypes<T>::Matrix softmax, const bool log) {
- const int kBatchDim = 0;
- const int kClassDim = 1;
-
- const int batch_size = logits.dimension(kBatchDim);
- const int num_classes = logits.dimension(kClassDim);
+class SoftmaxOp : public OpKernel {
+ public:
+ explicit SoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {
+ log_ = StringPiece(name()).starts_with("Log");
+ }
-// These arrays are used to reduce along the class dimension, and broadcast
-// the resulting value to all classes.
-#if !defined(EIGEN_HAS_INDEX_LIST)
- Eigen::DSizes<int, 1> along_class(kClassDim);
- Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
- Eigen::DSizes<int, 2> one_by_class(1, num_classes);
-#else
- Eigen::IndexList<Eigen::type2index<kClassDim> > along_class;
- Eigen::IndexList<Eigen::type2index<1> > depth_dim;
- Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one;
- batch_by_one.set(0, batch_size);
- Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
- one_by_class.set(1, num_classes);
-#endif
- //shifted_logits = logits - max(logits along classes);
- auto shifted_logits = (logits - logits.maximum(along_class)
- .eval()
- .reshape(batch_by_one)
- .broadcast(one_by_class));
- if (log) {
- // Calculate the log of the softmax
- // softmax = logits - max(logits along classes);
- softmax.device(d) = shifted_logits;
- // softmax = softmax - log(sum(exp(softmax along classes)));
- softmax.device(d) = (softmax -
- softmax.exp().sum(along_class)
- .eval()
- .reshape(batch_by_one)
- .broadcast(one_by_class)
- .log());
- } else {
- // NOTE(touts): If you modify this implementation please run
- // the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc.
- //
- // softmax = exp(logits - max(logits along classes));
- softmax.device(d) = shifted_logits.exp();
- // softmax = softmax / sum(softmax along classes);
- softmax.device(d) = (softmax /
- softmax.sum(along_class)
- .eval()
- .reshape(batch_by_one)
- .broadcast(one_by_class));
+ void Compute(OpKernelContext* context) override {
+ const Tensor& logits_in = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
+ errors::InvalidArgument("logits must be 2-dimensional"));
+ Tensor* softmax_out = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, logits_in.shape(), &softmax_out));
+ if (logits_in.NumElements()) {
+ functor::SoftmaxFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
+ softmax_out->matrix<T>(), log_);
}
}
+
+ private:
+ bool log_;
};
-} // namespace functor
} // namespace tensorflow
+#undef EIGEN_USE_THREADS
+
#endif // TENSORFLOW_KERNELS_SOFTMAX_OP_H_
diff --git a/tensorflow/core/kernels/softmax_op_functor.h b/tensorflow/core/kernels/softmax_op_functor.h
new file mode 100644
index 0000000000..47bb9de411
--- /dev/null
+++ b/tensorflow/core/kernels/softmax_op_functor.h
@@ -0,0 +1,101 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_SOFTMAX_OP_FUNCTOR_H_
+#define TENSORFLOW_KERNELS_SOFTMAX_OP_FUNCTOR_H_
+// Functor definition for SoftmaxOp, must be compilable by nvcc.
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by SoftmaxOp to do the computations.
+template <typename Device, typename T>
+struct SoftmaxFunctor {
+ // Computes Softmax or LogSoftmax activation.
+ //
+ // logits: dim: batch_size, num_classes.
+ // softmax: dims: batch_size, num_classes.
+ // log: boolean
+ void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::Matrix softmax, const bool log);
+};
+
+// Eigen code implementing SoftmaxFunctor::operator() or
+// LogSoftmaxFunctor::operator().
+// This code works for both CPU and GPU and is used by the functor
+// specializations for both device types.
+template <typename Device, typename T>
+struct SoftmaxEigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::Matrix softmax, const bool log) {
+ const int kBatchDim = 0;
+ const int kClassDim = 1;
+
+ const int batch_size = logits.dimension(kBatchDim);
+ const int num_classes = logits.dimension(kClassDim);
+
+// These arrays are used to reduce along the class dimension, and broadcast
+// the resulting value to all classes.
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::DSizes<int, 1> along_class(kClassDim);
+ Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
+ Eigen::DSizes<int, 2> one_by_class(1, num_classes);
+#else
+ Eigen::IndexList<Eigen::type2index<kClassDim> > along_class;
+ Eigen::IndexList<Eigen::type2index<1> > depth_dim;
+ Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one;
+ batch_by_one.set(0, batch_size);
+ Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
+ one_by_class.set(1, num_classes);
+#endif
+ //shifted_logits = logits - max(logits along classes);
+ auto shifted_logits = (logits - logits.maximum(along_class)
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
+ if (log) {
+ // Calculate the log of the softmax
+ // softmax = logits - max(logits along classes);
+ softmax.device(d) = shifted_logits;
+ // softmax = softmax - log(sum(exp(softmax along classes)));
+ softmax.device(d) = (softmax -
+ softmax.exp().sum(along_class)
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class)
+ .log());
+ } else {
+ // NOTE(touts): If you modify this implementation please run
+ // the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc.
+ //
+ // softmax = exp(logits - max(logits along classes));
+ softmax.device(d) = shifted_logits.exp();
+ // softmax = softmax / sum(softmax along classes);
+ softmax.device(d) = (softmax /
+ softmax.sum(along_class)
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
+ }
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_SOFTMAX_OP_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
index 0bfc27d32b..e27fff9b92 100644
--- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/core/kernels/softmax_op.h"
+#include "tensorflow/core/kernels/softmax_op_functor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/spacetodepth_op.cc b/tensorflow/core/kernels/spacetodepth_op.cc
index 4f9a71ce90..9b6bb19ee8 100644
--- a/tensorflow/core/kernels/spacetodepth_op.cc
+++ b/tensorflow/core/kernels/spacetodepth_op.cc
@@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/core/kernels/spacetodepth_op.h"
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -89,27 +91,43 @@ class SpaceToDepthOp : public OpKernel {
auto Toutput = outputs_tensor->tensor<T, 4>();
auto Tinput = input.tensor<T, 4>();
+ functor::SpaceToDepthOpFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), Tinput, block_size_, Toutput);
+ };
+
+ private:
+ int block_size_;
+};
+
+// Partial specialization of SpaceToDepthOpFunctor for a CPUDevice.
+namespace functor {
+template <typename T>
+struct SpaceToDepthOpFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
+ int block_size, typename TTypes<T, 4>::Tensor output) {
+ const int batch_size = output.dimension(0);
+ const int input_height = input.dimension(1);
+ const int input_width = input.dimension(2);
+ const int input_depth = input.dimension(3);
+
for (int b = 0; b < batch_size; ++b) {
- for (int h = 0; h < height; ++h) {
- const int out_h = h / block_size_;
- const int offset_h = (h % block_size_);
- for (int w = 0; w < width; ++w) {
- const int out_w = w / block_size_;
- const int offset_w = (w % block_size_);
- const int offset_d =
- (offset_h * block_size_ + offset_w) * input_depth;
+ for (int h = 0; h < input_height; ++h) {
+ const int out_h = h / block_size;
+ const int offset_h = (h % block_size);
+ for (int w = 0; w < input_width; ++w) {
+ const int out_w = w / block_size;
+ const int offset_w = (w % block_size);
+ const int offset_d = (offset_h * block_size + offset_w) * input_depth;
for (int d = 0; d < input_depth; ++d) {
const int out_d = d + offset_d;
- Toutput(b, out_h, out_w, out_d) = Tinput(b, h, w, d);
+ output(b, out_h, out_w, out_d) = input(b, h, w, d);
}
}
}
}
- };
-
- private:
- int block_size_;
+ }
};
+} // namespace functor
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
@@ -119,4 +137,10 @@ class SpaceToDepthOp : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER);
#undef REGISTER
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("SpaceToDepth").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ SpaceToDepthOp<GPUDevice, float>);
+#endif // GOOGLE_CUDA
+
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/spacetodepth_op.h b/tensorflow/core/kernels/spacetodepth_op.h
new file mode 100644
index 0000000000..8d225c6cdb
--- /dev/null
+++ b/tensorflow/core/kernels/spacetodepth_op.h
@@ -0,0 +1,44 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SPACETODEPTH_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPACETODEPTH_OP_H_
+// Functor definition for XentOp, must be compilable by nvcc.
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by SpaceToDepthOp to do the computations.
+template <typename Device, typename T>
+struct SpaceToDepthOpFunctor {
+ // Implements the space to depth conversion.
+ //
+ // input: 4-D input tensor.
+ // block_size: block size for the conversion.
+ // output: 4-D output tensor.
+ //
+ // The dimensions of the tensors are guaranteed to be right when the
+ // functor is called.
+ void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
+ int block_size, typename TTypes<T, 4>::Tensor output);
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_SPACETODEPTH_OP_H_
diff --git a/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc b/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc
new file mode 100644
index 0000000000..d6678a22ed
--- /dev/null
+++ b/tensorflow/core/kernels/spacetodepth_op_gpu.cu.cc
@@ -0,0 +1,89 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/spacetodepth_op.h"
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename dtype>
+__global__ void S2D(const int32 nthreads, const dtype* input_ptr,
+ const int block_size, const int batch_size,
+ const int input_height, const int input_width,
+ const int input_depth, const int output_height,
+ const int output_width, const int output_depth,
+ dtype* output_ptr) {
+ CUDA_1D_KERNEL_LOOP(inp_idx, nthreads) {
+ // inp_idx = d + input_depth * (w + input_width * (h + input_height * b))
+ const int d = inp_idx % input_depth;
+ const int inp_idx2 = inp_idx / input_depth;
+ const int w = inp_idx2 % input_width;
+ const int inp_idx3 = inp_idx2 / input_width;
+ const int h = inp_idx3 % input_height;
+ const int b = inp_idx3 / input_height;
+
+ const int out_h = h / block_size;
+ const int offset_h = h % block_size;
+ const int out_w = w / block_size;
+ const int offset_w = w % block_size;
+ const int offset_d = (offset_h * block_size + offset_w) * input_depth;
+ const int out_d = d + offset_d;
+ const int out_idx =
+ out_d +
+ output_depth * (out_w + output_width * (out_h + output_height * b));
+ *(output_ptr + out_idx) = ldg(input_ptr + inp_idx);
+ }
+}
+
+// Specialization of SpaceToDepthOpFunctor for a CPUDevice.
+namespace functor {
+template <typename T>
+struct SpaceToDepthOpFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
+ int block_size, typename TTypes<T, 4>::Tensor output) {
+ const int batch_size = output.dimension(0);
+ const int input_height = input.dimension(1);
+ const int input_width = input.dimension(2);
+ const int input_depth = input.dimension(3);
+ const int output_height = output.dimension(1);
+ const int output_width = output.dimension(2);
+ const int output_depth = output.dimension(3);
+
+ const int total_count =
+ batch_size * input_height * input_width * input_depth;
+ CudaLaunchConfig config = GetCudaLaunchConfig(total_count, d);
+ S2D<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ config.virtual_thread_count, input.data(), block_size, batch_size,
+ input_height, input_width, input_depth, output_height, output_width,
+ output_depth, output.data());
+ }
+};
+} // end namespace functor
+
+// Instantiate the GPU implementation for float.
+template struct functor::SpaceToDepthOpFunctor<GPUDevice, float>;
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 5ecef9c6f9..52e792a399 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -55,8 +56,8 @@ class InvertPermutationOp : public OpKernel {
auto Tout = output->vec<int32>();
std::fill_n(Tout.data(), N, -1);
for (int i = 0; i < N; ++i) {
- const int32 d = Tin(i);
- OP_REQUIRES(context, 0 <= d && d < N,
+ const int32 d = internal::SubtleMustCopy(Tin(i));
+ OP_REQUIRES(context, FastBoundsCheck(d, N),
errors::InvalidArgument(d, " is not between 0 and ", N));
OP_REQUIRES(context, Tout(d) == -1,
errors::InvalidArgument(d, " is duplicated in the input."));
@@ -107,18 +108,26 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
errors::InvalidArgument(
"transpose expects a vector of size ", input.dims(),
". But input(1) is a vector of size ", Vperm.size()));
- gtl::ArraySlice<int32> permutation(
- reinterpret_cast<const int32*>(Vperm.data()), dims);
+ // using volatile instead of SubtleMustCopy here so that the
+ // asynchrony boundary is permutation.
+ const volatile int32* perm_begin =
+ reinterpret_cast<const volatile int32*>(Vperm.data());
+ const std::vector<int32> permutation(perm_begin, perm_begin + dims);
TensorShape shape;
// Check whether permutation is a permutation of integers of [0 .. dims).
gtl::InlinedVector<bool, 8> bits(dims);
- for (const int32 d : permutation) {
+ bool is_identity = true;
+ for (int i = 0; i < dims; ++i) {
+ const int32 d = permutation[i];
OP_REQUIRES(
ctx, 0 <= d && d < dims,
errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
bits[d] = true;
shape.AddDim(input.dim_size(d));
+ if (d != i) {
+ is_identity = false;
+ }
}
for (int i = 0; i < dims; ++i) {
OP_REQUIRES(ctx, bits[i], errors::InvalidArgument(
@@ -126,8 +135,8 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
str_util::Join(permutation, ","), "}."));
}
- // 0-D and 1-D transposes do nothing
- if (dims <= 1) {
+ // 0-D, 1-D, and identity transposes do nothing.
+ if (dims <= 1 || is_identity) {
ctx->set_output(0, input);
return;
}
diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h
index cf71dc7ecb..6fa02fd729 100644
--- a/tensorflow/core/public/session.h
+++ b/tensorflow/core/public/session.h
@@ -139,7 +139,8 @@ class Session {
/// \brief Like `Run`, but allows users to pass in a `RunOptions` proto and
/// to retrieve non-Tensor metadata output via a `RunOutputs` proto for this
- /// step.
+ /// step. `run_outputs` may be nullptr, in which case any metadata output is
+ /// discarded.
/// NOTE: This API is still experimental and may change.
virtual Status Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor> >& inputs,
@@ -148,8 +149,8 @@ class Session {
std::vector<Tensor>* outputs, RunOutputs* run_outputs);
/// \brief Sets up a graph for partial execution. All future feeds and
- /// fetches are specified by 'input_names' and 'output_names'. Returns
- /// 'handle' that can be used to perform a sequence of partial feeds and
+ /// fetches are specified by `input_names` and `output_names`. Returns
+ /// `handle` that can be used to perform a sequence of partial feeds and
/// fetches.
/// NOTE: This API is still experimental and may change.
virtual Status PRunSetup(const std::vector<string>& input_names,
@@ -157,7 +158,7 @@ class Session {
const std::vector<string>& target_nodes,
string* handle);
- /// \brief Continues the pending execution specified by 'handle' with the
+ /// \brief Continues the pending execution specified by `handle` with the
/// provided input tensors and fills `outputs` for the endpoints specified
/// in `output_names`.
/// NOTE: This API is still experimental and may change.
diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h
index b7ac96b6b9..e846f5d0b6 100644
--- a/tensorflow/core/public/tensor_c_api.h
+++ b/tensorflow/core/public/tensor_c_api.h
@@ -268,15 +268,26 @@ extern void TF_ExtendGraph(TF_Session*, const void* proto, size_t proto_len,
// failure, inputs[] become the property of the implementation (the
// implementation will eventually call TF_DeleteTensor on each input).
//
-// The caller retains the ownership of both `run_options` and `run_outputs`, and
-// should manually call TF_DeleteBuffer on them.
+// Any NULL and non-NULL value combinations for (`run_options`,
+// `run_outputs`) are valid.
+//
+// - `run_options` may be NULL, in which case it will be ignored; or
+// non-NULL, in which case it must point to a `TF_Buffer` containing the
+// serialized representation of a `RunOptions` protocol buffer.
+// - `run_output` may be NULL, in which case it will be ignored; or non-NULL,
+// in which case it must point to an empty, freshly allocated `TF_Buffer`
+// that may be updated to contain the serialized representation of a
+// `RunOutput` protocol buffer.
+//
+// The caller retains the ownership of `run_options` and/or `run_outputs` (when
+// not NULL) and should manually call TF_DeleteBuffer on them.
//
// On success, the tensors corresponding to output_names[0,noutputs-1]
// are placed in outputs[], and these outputs[] become the property
// of the caller (the caller must eventually call TF_DeleteTensor on
// them).
//
-// On failure, outputs[] contains nulls.
+// On failure, outputs[] contains NULLs.
extern void TF_Run(TF_Session*,
// RunOptions
const TF_Buffer* run_options,
@@ -341,7 +352,7 @@ extern void TF_PRun(TF_Session*, const char* handle,
// On success, place OK in status and return the newly created library handle.
// The caller owns the library handle.
//
-// On failure, place an error status in status and return nullptr.
+// On failure, place an error status in status and return NULL.
extern TF_Library* TF_LoadLibrary(const char* library_filename,
TF_Status* status);
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index dfc2c04baa..046d69a939 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -39,8 +39,10 @@ void Shard(int num_workers, thread::ThreadPool* workers, int64 total,
// much. Let us assume each cost unit is 1ns, kMinCostPerShard=10000
// is 10us.
static const int64 kMinCostPerShard = 10000;
- const int num_shards = std::max(
- 1, std::min<int>(num_workers, total * cost_per_unit / kMinCostPerShard));
+ const int num_shards =
+ std::max<int>(1, std::min(static_cast<int64>(num_workers),
+ total * cost_per_unit / kMinCostPerShard));
+
// Each shard contains up to "block_size" units. [0, total) is sharded
// into:
// [0, block_size), [block_size, 2*block_size), ...
diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc
index 3772bf9bca..c0d7267da9 100644
--- a/tensorflow/core/util/work_sharder_test.cc
+++ b/tensorflow/core/util/work_sharder_test.cc
@@ -59,6 +59,25 @@ TEST(Shard, Basic) {
}
}
+TEST(Shard, OverflowTest) {
+ thread::ThreadPool threads(Env::Default(), "test", 3);
+ mutex mu;
+ for (auto workers : {1, 2, 3}) {
+ const int64 total_elements = 1LL << 32;
+ const int64 cost_per_unit = 10000;
+ int num_shards = 0;
+ int64 num_elements = 0;
+ Shard(workers, &threads, total_elements, cost_per_unit,
+ [&mu, &num_shards, &num_elements](int64 start, int64 limit) {
+ mutex_lock l(mu);
+ ++num_shards;
+ num_elements += limit - start;
+ });
+ EXPECT_EQ(num_shards, workers);
+ EXPECT_EQ(num_elements, total_elements);
+ }
+}
+
void BM_Sharding(int iters, int arg) {
thread::ThreadPool threads(Env::Default(), "test", 16);
const int64 total = 1LL << 30;
diff --git a/tensorflow/examples/android/jni/jni_utils.cc b/tensorflow/examples/android/jni/jni_utils.cc
index 0a1f8adbd0..db0eedeb16 100644
--- a/tensorflow/examples/android/jni/jni_utils.cc
+++ b/tensorflow/examples/android/jni/jni_utils.cc
@@ -157,3 +157,17 @@ void ReadFileToVector(AAssetManager* const asset_manager,
VLOG(0) << "Read " << str_vector->size() << " values from " << filename;
}
+void WriteProtoToFile(const char* const filename,
+ const google::protobuf::MessageLite& message) {
+ std::fstream outfile;
+ outfile.open(filename, std::fstream::binary | std::fstream::out);
+ if (outfile.fail()) {
+ LOG(WARNING) << "Failed to write proto to " << filename;
+ return;
+ } else {
+ google::protobuf::io::OstreamOutputStream raw_out(&outfile);
+ google::protobuf::io::CodedOutputStream coded_out(&raw_out);
+ message.SerializeToCodedStream(&coded_out);
+ }
+ VLOG(0) << "Wrote proto to " << filename;
+}
diff --git a/tensorflow/examples/android/jni/jni_utils.h b/tensorflow/examples/android/jni/jni_utils.h
index 4c1b140abf..c296744061 100644
--- a/tensorflow/examples/android/jni/jni_utils.h
+++ b/tensorflow/examples/android/jni/jni_utils.h
@@ -42,4 +42,7 @@ void ReadFileToString(AAssetManager* const asset_manager,
void ReadFileToVector(AAssetManager* const asset_manager,
const char* const filename, std::vector<std::string>* str_vector);
+void WriteProtoToFile(const char* const filename,
+ const google::protobuf::MessageLite& message);
+
#endif // ORG_TENSORFLOW_JNI_JNI_UTILS_H_
diff --git a/tensorflow/examples/android/jni/tensorflow_jni.cc b/tensorflow/examples/android/jni/tensorflow_jni.cc
index 2b0aa82777..e1060ab666 100644
--- a/tensorflow/examples/android/jni/tensorflow_jni.cc
+++ b/tensorflow/examples/android/jni/tensorflow_jni.cc
@@ -21,13 +21,16 @@ limitations under the License.
#include <jni.h>
#include <pthread.h>
+#include <sys/stat.h>
#include <unistd.h>
#include <queue>
#include <sstream>
#include <string>
+#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
@@ -51,6 +54,12 @@ static int g_image_mean; // The image mean.
static int g_num_runs = 0;
static int64 g_timing_total_us = 0;
+#ifdef SAVE_STEP_STATS
+static const bool kSaveStepStats = true;
+#else
+static const bool kSaveStepStats = false;
+#endif
+
inline static int64 CurrentThreadTimeUs() {
struct timeval tv;
gettimeofday(&tv, NULL);
@@ -199,11 +208,30 @@ static std::string ClassifyImage(const RGBA* const bitmap_src,
std::vector<tensorflow::Tensor> output_tensors;
std::vector<std::string> output_names({"output:0"});
- const int64 start_time = CurrentThreadTimeUs();
- tensorflow::Status s =
- session->Run(input_tensors, output_names, {}, &output_tensors);
- const int64 end_time = CurrentThreadTimeUs();
-
+ tensorflow::Status s;
+ int64 start_time, end_time;
+
+ if (kSaveStepStats) {
+ RunOptions run_options;
+ run_options.set_trace_level(RunOptions::FULL_TRACE);
+ RunOutputs run_outputs;
+ start_time = CurrentThreadTimeUs();
+ s = session->Run(run_options, input_tensors, output_names, {},
+ &output_tensors, &run_outputs);
+ end_time = CurrentThreadTimeUs();
+ assert(run_outputs.has_step_stats());
+
+ const StepStats& stats = run_outputs.step_stats();
+
+ mkdir("/sdcard/tf/", 0755);
+ const string filename =
+ strings::Printf("/sdcard/tf/stepstats%05d.pb", g_num_runs);
+ WriteProtoToFile(filename.c_str(), stats);
+ } else {
+ start_time = CurrentThreadTimeUs();
+ s = session->Run(input_tensors, output_names, {}, &output_tensors);
+ end_time = CurrentThreadTimeUs();
+ }
const int64 elapsed_time_inf = end_time - start_time;
g_timing_total_us += elapsed_time_inf;
VLOG(0) << "End computing. Ran in " << elapsed_time_inf / 1000 << "ms ("
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 3d8ddf639f..ddd50985db 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -40,6 +40,7 @@ py_library(
name = "platform",
srcs = glob(["platform/**/*.py"]),
srcs_version = "PY2AND3",
+ deps = ["//tensorflow/core:protos_all_py"],
)
py_library(
@@ -1006,6 +1007,7 @@ py_test(
name = "session_test",
srcs = ["client/session_test.py"],
srcs_version = "PY2AND3",
+ tags = ["noasan"],
deps = [
":framework",
":framework_test_lib",
@@ -1034,12 +1036,12 @@ cpu_only_kernel_test_list = glob([
"kernel_tests/attention_ops_test.py",
"kernel_tests/barrier_ops_test.py",
"kernel_tests/bcast_ops_test.py",
+ "kernel_tests/benchmark_test.py",
"kernel_tests/candidate_sampler_ops_test.py",
"kernel_tests/cholesky_op_test.py",
"kernel_tests/clip_ops_test.py",
"kernel_tests/decode_csv_op_test.py",
"kernel_tests/decode_raw_op_test.py",
- "kernel_tests/depthtospace_op_test.py",
"kernel_tests/determinant_op_test.py",
"kernel_tests/diag_op_test.py",
"kernel_tests/edit_distance_op_test.py",
@@ -1069,7 +1071,6 @@ cpu_only_kernel_test_list = glob([
"kernel_tests/sparse_reorder_op_test.py",
"kernel_tests/sparse_to_dense_op_test.py",
"kernel_tests/sparsemask_op_test.py",
- "kernel_tests/spacetodepth_op_test.py",
"kernel_tests/summary_ops_test.py",
"kernel_tests/template_test.py",
"kernel_tests/topk_op_test.py",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index c36cdfe30f..84adaca8da 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -59,7 +59,7 @@ from tensorflow.core.framework.attr_value_pb2 import *
from tensorflow.core.protobuf.config_pb2 import *
from tensorflow.core.util.event_pb2 import *
# Import things out of contrib
-from tensorflow import contrib
+import tensorflow.contrib as contrib
# Framework
from tensorflow.python.framework.framework_lib import *
@@ -101,6 +101,7 @@ from tensorflow.python.framework import framework_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import histogram_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
@@ -117,8 +118,8 @@ _whitelist = set([app, compat, contrib, errors, flags, gfile, image,
# strings of other modules.
__all__ = make_all(__name__,
[framework_lib, array_ops, client_lib, constant_op,
- control_flow_ops, io_ops, math_ops, nn, script_ops,
- sparse_ops, state_ops, train])
+ control_flow_ops, histogram_ops, io_ops, math_ops, nn,
+ script_ops, sparse_ops, state_ops, train])
# Symbols whitelisted for export without documentation.
# TODO(cwhipkey): review these and move to contrib, expose through
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index a77cdffda5..d739eb9ee5 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -294,7 +294,7 @@ class BaseSession(SessionInterface):
[`SparseTensorValue`](../../api_docs/python/sparse_ops.md#SparseTensorValue).
The optional `options` argument expects a [`RunOptions`] proto. The options
- allow controling the behavior of this particular step (e.g. turning tracing
+ allow controlling the behavior of this particular step (e.g. turning tracing
on).
The optional `run_outputs` argument expects a [`RunOutputs`] proto. When
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 491b293125..c82e9a96d0 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -25,7 +25,6 @@ import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.core.framework import step_stats_pb2
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
@@ -927,13 +926,32 @@ class SessionTest(test_util.TensorFlowTestCase):
sess.run(constant_op.constant(1.0),
options=run_options,
run_outputs=run_outputs)
+
self.assertTrue(run_outputs.HasField('step_stats'))
+ self.assertEquals(len(run_outputs.step_stats.dev_stats), 1)
+
+ def testRunOptionsRunOutputs(self):
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_outputs = config_pb2.RunOutputs()
+
+ with ops.device('/cpu:0'):
+ with session.Session() as sess:
+ # all combinations are valid
+ sess.run(constant_op.constant(1.0), options=None, run_outputs=None)
+ sess.run(constant_op.constant(1.0), options=None,
+ run_outputs=run_outputs)
+ self.assertTrue(not run_outputs.HasField('step_stats'))
- step_stats = step_stats_pb2.StepStats()
- self.assertEquals(len(step_stats.dev_stats), 0)
+ sess.run(constant_op.constant(1.0), options=run_options,
+ run_outputs=None)
+ self.assertTrue(not run_outputs.HasField('step_stats'))
- step_stats.CopyFrom(run_outputs.step_stats)
- self.assertEquals(len(step_stats.dev_stats), 1)
+ sess.run(constant_op.constant(1.0), options=run_options,
+ run_outputs=run_outputs)
+
+ self.assertTrue(run_outputs.HasField('step_stats'))
+ self.assertEquals(len(run_outputs.step_stats.dev_stats), 1)
def testFeedShapeCompatibility(self):
with session.Session() as sess:
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index 7180f7d77c..d19482a154 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -81,6 +81,7 @@ def all_libraries(module_to_name, members, documented):
exclude_symbols=["sparse_matmul", "arg_min", "arg_max",
"lin_space", "sparse_segment_mean_grad"],
prefix=PREFIX_TEXT),
+ library("histogram_ops", "Histograms"),
library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT),
library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"],
prefix=PREFIX_TEXT),
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 284db94d45..17f21d56af 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -165,9 +165,8 @@ class TensorFlowTestCase(googletest.TestCase):
text_format.Merge(expected_message_maybe_ascii, expected_message)
self._AssertProtoEquals(expected_message, message)
else:
- assert False, ("Can't compare protos of type " +
- type(expected_message_maybe_ascii) + " and " +
- type(message))
+ assert False, ("Can't compare protos of type %s and %s" %
+ (type(expected_message_maybe_ascii), type(message)))
def assertProtoEqualsVersion(
self, expected, actual, producer=versions.GRAPH_DEF_VERSION,
diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py
new file mode 100644
index 0000000000..4a5d55fbff
--- /dev/null
+++ b/tensorflow/python/kernel_tests/benchmark_test.py
@@ -0,0 +1,158 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for tensorflow.python.framework.importer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import random
+
+import tensorflow as tf
+
+from google.protobuf import text_format
+from tensorflow.core.util import test_log_pb2
+from tensorflow.python.platform import benchmark
+
+
+# Used by SomeRandomBenchmark class below.
+_ran_somebenchmark_1 = [False]
+_ran_somebenchmark_2 = [False]
+_ran_somebenchmark_but_shouldnt = [False]
+
+
+class SomeRandomBenchmark(tf.test.Benchmark):
+ """This Benchmark should automatically be registered in the registry."""
+
+ def _dontRunThisBenchmark(self):
+ _ran_somebenchmark_but_shouldnt[0] = True
+
+ def notBenchmarkMethod(self):
+ _ran_somebenchmark_but_shouldnt[0] = True
+
+ def benchmark1(self):
+ _ran_somebenchmark_1[0] = True
+
+ def benchmark2(self):
+ _ran_somebenchmark_2[0] = True
+
+
+class TestReportingBenchmark(tf.test.Benchmark):
+ """This benchmark (maybe) reports some stuff."""
+
+ def benchmarkReport1(self):
+ self.report_benchmark(iters=1)
+
+ def benchmarkReport2(self):
+ self.report_benchmark(
+ iters=2, name="custom_benchmark_name",
+ extras={"number_key": 3, "other_key": "string"})
+
+
+class BenchmarkTest(tf.test.TestCase):
+
+ def testGlobalBenchmarkRegistry(self):
+ registry = list(benchmark.GLOBAL_BENCHMARK_REGISTRY)
+ self.assertEqual(len(registry), 2)
+ self.assertTrue(SomeRandomBenchmark in registry)
+ self.assertTrue(TestReportingBenchmark in registry)
+
+ def testRunSomeRandomBenchmark(self):
+ # Validate that SomeBenchmark has not run yet
+ self.assertFalse(_ran_somebenchmark_1[0])
+ self.assertFalse(_ran_somebenchmark_2[0])
+ self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
+
+ # Run other benchmarks, but this wont run the one we care about
+ benchmark._run_benchmarks("unrelated")
+
+ # Validate that SomeBenchmark has not run yet
+ self.assertFalse(_ran_somebenchmark_1[0])
+ self.assertFalse(_ran_somebenchmark_2[0])
+ self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
+
+ # Run all the benchmarks, avoid generating any reports
+ if benchmark.TEST_REPORTER_TEST_ENV in os.environ:
+ del os.environ[benchmark.TEST_REPORTER_TEST_ENV]
+ benchmark._run_benchmarks("SomeRandom")
+
+ # Validate that SomeRandomBenchmark ran correctly
+ self.assertTrue(_ran_somebenchmark_1[0])
+ self.assertTrue(_ran_somebenchmark_2[0])
+ self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
+
+ def testReportingBenchmark(self):
+ tempdir = tf.test.get_temp_dir()
+ try:
+ tf.gfile.MakeDirs(tempdir)
+ except OSError as e:
+ # It's OK if the directory already exists.
+ if " exists:" not in str(e):
+ raise e
+
+ prefix = os.path.join(
+ tempdir, "reporting_bench_%016x_" % random.getrandbits(64))
+ expected_output_file = "%s%s" % (
+ prefix, "TestReportingBenchmark.benchmarkReport1")
+ expected_output_file_2 = "%s%s" % (
+ prefix, "TestReportingBenchmark.custom_benchmark_name")
+ try:
+ self.assertFalse(tf.gfile.Exists(expected_output_file))
+ # Run benchmark but without env, shouldn't write anything
+ if benchmark.TEST_REPORTER_TEST_ENV in os.environ:
+ del os.environ[benchmark.TEST_REPORTER_TEST_ENV]
+ reporting = TestReportingBenchmark()
+ reporting.benchmarkReport1() # This should run without writing anything
+ self.assertFalse(tf.gfile.Exists(expected_output_file))
+
+ # Runbenchmark with env, should write
+ os.environ[benchmark.TEST_REPORTER_TEST_ENV] = prefix
+
+ reporting = TestReportingBenchmark()
+ reporting.benchmarkReport1() # This should write
+ reporting.benchmarkReport2() # This should write
+
+ # Check the files were written
+ self.assertTrue(tf.gfile.Exists(expected_output_file))
+ self.assertTrue(tf.gfile.Exists(expected_output_file_2))
+
+ # Check the contents are correct
+ expected_1 = test_log_pb2.BenchmarkEntry()
+ expected_1.name = "TestReportingBenchmark.benchmarkReport1"
+ expected_1.iters = 1
+
+ expected_2 = test_log_pb2.BenchmarkEntry()
+ expected_2.name = "TestReportingBenchmark.custom_benchmark_name"
+ expected_2.iters = 2
+ expected_2.extras["number_key"].double_value = 3
+ expected_2.extras["other_key"].string_value = "string"
+
+ read_benchmark_1 = tf.gfile.GFile(expected_output_file, "r").read()
+ read_benchmark_1 = text_format.Merge(
+ read_benchmark_1, test_log_pb2.BenchmarkEntry())
+ self.assertProtoEquals(expected_1, read_benchmark_1)
+
+ read_benchmark_2 = tf.gfile.GFile(expected_output_file_2, "r").read()
+ read_benchmark_2 = text_format.Merge(
+ read_benchmark_2, test_log_pb2.BenchmarkEntry())
+ self.assertProtoEquals(expected_2, read_benchmark_2)
+
+ finally:
+ tf.gfile.DeleteRecursively(tempdir)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py
index 8dda8832b3..bace61b40f 100644
--- a/tensorflow/python/kernel_tests/depthtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py
@@ -25,12 +25,17 @@ import tensorflow as tf
class DepthToSpaceTest(tf.test.TestCase):
+ def _testOne(self, inputs, block_size, outputs):
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = tf.depth_to_space(tf.to_float(inputs), block_size)
+ self.assertAllEqual(x_tf.eval(), outputs)
+
def testBasic(self):
x_np = [[[[1, 2, 3, 4]]]]
- with self.test_session(use_gpu=False):
- block_size = 2
- x_tf = tf.depth_to_space(x_np, block_size)
- self.assertAllEqual(x_tf.eval(), [[[[1], [2]], [[3], [4]]]])
+ block_size = 2
+ x_out = [[[[1], [2]], [[3], [4]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input dimensions. To make sure elements are
# correctly ordered spatially.
@@ -40,12 +45,28 @@ class DepthToSpaceTest(tf.test.TestCase):
[[9, 10, 11, 12],
[13, 14, 15, 16]]]]
block_size = 2
- with self.test_session(use_gpu=False):
- x_tf = tf.depth_to_space(x_np, block_size)
- self.assertAllEqual(x_tf.eval(), [[[[1], [2], [5], [6]],
- [[3], [4], [7], [8]],
- [[9], [10], [13], [14]],
- [[11], [12], [15], [16]]]])
+ x_out = [[[[1], [2], [5], [6]],
+ [[3], [4], [7], [8]],
+ [[9], [10], [13], [14]],
+ [[11], [12], [15], [16]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ def testBlockSize2Batch10(self):
+ block_size = 2
+ def batch_input_elt(i):
+ return [[[1 * i, 2 * i, 3 * i, 4 * i],
+ [5 * i, 6 * i, 7 * i, 8 * i]],
+ [[9 * i, 10 * i, 11 * i, 12 * i],
+ [13 * i, 14 * i, 15 * i, 16 * i]]]
+ def batch_output_elt(i):
+ return [[[1 * i], [2 * i], [5 * i], [6 * i]],
+ [[3 * i], [4 * i], [7 * i], [8 * i]],
+ [[9 * i], [10 * i], [13 * i], [14 * i]],
+ [[11 * i], [12 * i], [15 * i], [16 * i]]]
+ batch_size = 10
+ x_np = [batch_input_elt(i) for i in xrange(batch_size)]
+ x_out = [batch_output_elt(i) for i in xrange(batch_size)]
+ self._testOne(x_np, block_size, x_out)
# Tests for different width and height.
def testNonSquare(self):
@@ -53,46 +74,42 @@ class DepthToSpaceTest(tf.test.TestCase):
[[5, 50, 6, 60, 7, 70, 8, 80]],
[[9, 90, 10, 100, 11, 110, 12, 120]]]]
block_size = 2
- with self.test_session(use_gpu=False):
- x_tf = tf.depth_to_space(x_np, block_size)
- self.assertAllEqual(x_tf.eval(), [[[[1, 10], [2, 20]],
- [[3, 30], [4, 40]],
- [[5, 50], [6, 60]],
- [[7, 70], [8, 80]],
- [[9, 90], [10, 100]],
- [[11, 110], [12, 120]]]])
+ x_out = [[[[1, 10], [2, 20]],
+ [[3, 30], [4, 40]],
+ [[5, 50], [6, 60]],
+ [[7, 70], [8, 80]],
+ [[9, 90], [10, 100]],
+ [[11, 110], [12, 120]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input dimensions. To make sure elements are
# correctly ordered spatially.
def testBlockSize4FlatInput(self):
x_np = [[[[1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16]]]]
block_size = 4
- with self.test_session(use_gpu=False):
- x_tf = tf.depth_to_space(x_np, block_size)
- self.assertAllEqual(x_tf.eval(), [[[[1], [2], [5], [6]],
- [[3], [4], [7], [8]],
- [[9], [10], [13], [14]],
- [[11], [12], [15], [16]]]])
+ x_out = [[[[1], [2], [5], [6]],
+ [[3], [4], [7], [8]],
+ [[9], [10], [13], [14]],
+ [[11], [12], [15], [16]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input depths.
# To make sure elements are properly interleaved in depth.
def testDepthInterleaved(self):
x_np = [[[[1, 10, 2, 20, 3, 30, 4, 40]]]]
block_size = 2
- with self.test_session(use_gpu=False):
- x_tf = tf.depth_to_space(x_np, block_size)
- self.assertAllEqual(x_tf.eval(), [[[[1, 10], [2, 20]],
- [[3, 30], [4, 40]]]])
+ x_out = [[[[1, 10], [2, 20]],
+ [[3, 30], [4, 40]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input depths. Here an odd depth.
# To make sure elements are properly interleaved in depth.
def testDepthInterleavedDepth3(self):
x_np = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]]
block_size = 2
- with self.test_session(use_gpu=False):
- x_tf = tf.depth_to_space(x_np, block_size)
- self.assertAllEqual(x_tf.eval(), [[[[1, 2, 3], [4, 5, 6]],
- [[7, 8, 9], [10, 11, 12]]]])
+ x_out = [[[[1, 2, 3], [4, 5, 6]],
+ [[7, 8, 9], [10, 11, 12]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input depths.
# To make sure elements are properly interleaved in depth.
@@ -102,13 +119,11 @@ class DepthToSpaceTest(tf.test.TestCase):
[[9, 90, 10, 100, 11, 110, 12, 120],
[13, 130, 14, 140, 15, 150, 16, 160]]]]
block_size = 2
- with self.test_session(use_gpu=False):
- x_tf = tf.depth_to_space(x_np, block_size)
- self.assertAllEqual(x_tf.eval(),
- [[[[1, 10], [2, 20], [5, 50], [6, 60]],
- [[3, 30], [4, 40], [7, 70], [8, 80]],
- [[9, 90], [10, 100], [13, 130], [14, 140]],
- [[11, 110], [12, 120], [15, 150], [16, 160]]]])
+ x_out = [[[[1, 10], [2, 20], [5, 50], [6, 60]],
+ [[3, 30], [4, 40], [7, 70], [8, 80]],
+ [[9, 90], [10, 100], [13, 130], [14, 140]],
+ [[11, 110], [12, 120], [15, 150], [16, 160]]]]
+ self._testOne(x_np, block_size, x_out)
# Error handling:
@@ -205,5 +220,6 @@ class DepthToSpaceGradientTest(tf.test.TestCase):
block_size = 3
self._compare(1, 2, 3, 2, block_size)
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index 88048cfa7a..5261af4aab 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -184,7 +184,8 @@ class RNNCellTest(tf.test.TestCase):
x = tf.zeros([1, 1], dtype=tf.int32)
m = tf.zeros([1, 2])
g, new_m = tf.nn.rnn_cell.EmbeddingWrapper(
- tf.nn.rnn_cell.GRUCell(2), 3)(x, m)
+ tf.nn.rnn_cell.GRUCell(2),
+ embedding_classes=3, embedding_size=2)(x, m)
sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1]]),
m.name: np.array([[0.1, 0.1]])})
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 82c432922c..41dae10210 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import sys
import time
import timeit
@@ -953,6 +952,7 @@ def graph_creation_static_vs_dynamic_rnn_benchmark(max_time):
print("%d \t %f \t %f \t %f" %
(max_time, delta_static, delta_dynamic, delta_dynamic/delta_static))
+ return delta_static, delta_dynamic
def _timer(sess, ops):
@@ -1013,6 +1013,8 @@ def static_vs_dynamic_rnn_benchmark(batch_size, max_time, num_units, use_gpu):
(batch_size, max_time, num_units, use_gpu, delta_static,
delta_dynamic, delta_dynamic/delta_static))
+ return delta_static, delta_dynamic
+
def _dynamic_rnn_swap_memory_benchmark(inputs_t, sequence_length,
swap_memory):
@@ -1061,6 +1063,7 @@ def dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units):
print("%d \t %d \t %d \t %f \t %f \t %f" %
(batch_size, max_time, num_units, no_swap, swap, swap/no_swap))
+ return no_swap, swap
def rnn_long_sequence_benchmark(batch_size, seqlen, num_units,
@@ -1097,34 +1100,55 @@ def rnn_long_sequence_benchmark(batch_size, seqlen, num_units,
elapsed/seqlen))
-def main(_):
- print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM")
- print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)")
- for max_time in (1, 25, 50):
- graph_creation_static_vs_dynamic_rnn_benchmark(max_time)
-
- print("Calculation: Static Unroll with Dynamic Flow LSTM "
- "vs. Dynamic Unroll LSTM")
- print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) "
- "\t dt(dynamic)/dt(static)")
- for batch_size in (256,):
- for max_time in (50,):
- for num_units in (512, 256, 128):
- for use_gpu in (False, True):
- static_vs_dynamic_rnn_benchmark(
- batch_size, max_time, num_units, use_gpu)
-
- print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap")
- print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap")
- for batch_size in (256, 512):
- for max_time in (100,):
- for num_units in (512, 256, 128):
- dynamic_rnn_swap_memory_benchmark(batch_size, max_time, num_units)
+class BenchmarkRNN(tf.test.Benchmark):
+
+ def benchmarkGraphCreationStaticVsDynamicLSTM(self):
+ print("Graph Creation: Static Unroll vs. Dynamic Unroll LSTM")
+ print("max_t \t dt(static) \t dt(dynamic) \t dt(dynamic)/dt(static)")
+ for max_time in (1, 25, 50):
+ s_dt, d_dt = graph_creation_static_vs_dynamic_rnn_benchmark(max_time)
+ self.report_benchmark(name="graph_creation_time_static_T%02d" % max_time,
+ iters=5, wall_time=s_dt)
+ self.report_benchmark(name="graph_creation_time_dynamic_T%02d" % max_time,
+ iters=5, wall_time=d_dt)
+
+ def benchmarkStaticUnrollVsDynamicFlowLSTM(self):
+ print("Calculation: Static Unroll with Dynamic Flow LSTM "
+ "vs. Dynamic Unroll LSTM")
+ print("batch \t max_t \t units \t gpu \t dt(static) \t dt(dynamic) "
+ "\t dt(dynamic)/dt(static)")
+ for batch_size in (256,):
+ for max_time in (50,):
+ for num_units in (512, 256, 128):
+ for use_gpu in (False, True):
+ s_dt, d_dt = static_vs_dynamic_rnn_benchmark(
+ batch_size, max_time, num_units, use_gpu)
+ self.report_benchmark(
+ name="static_unroll_time_T%02d_B%03d_N%03d_gpu_%s"
+ % (max_time, batch_size, num_units, use_gpu),
+ iters=10, wall_time=s_dt)
+ self.report_benchmark(
+ name="dynamic_unroll_time_T%02d_B%03d_N%03d_gpu_%s"
+ % (max_time, batch_size, num_units, use_gpu),
+ iters=10, wall_time=d_dt)
+
+ def benchmarkDynamicLSTMNoMemorySwapVsMemorySwap(self):
+ print("Calculation: Dynamic LSTM No Memory Swap vs. Memory Swap")
+ print("batch \t max_t \t units \t no_swap \t swap \t swap/no_swap")
+ for batch_size in (256, 512):
+ for max_time in (100,):
+ for num_units in (512, 256, 128):
+ no_swap, swap = dynamic_rnn_swap_memory_benchmark(
+ batch_size, max_time, num_units)
+ self.report_benchmark(
+ name="dynamic_lstm_no_memory_swap_T%02d_B%03d_N%03d"
+ % (max_time, batch_size, num_units),
+ iters=10, wall_time=no_swap)
+ self.report_benchmark(
+ name="dynamic_lstm_with_memory_swap_T%02d_B%03d_N%03d"
+ % (max_time, batch_size, num_units),
+ iters=10, wall_time=swap)
if __name__ == "__main__":
- if "--benchmarks" in sys.argv:
- sys.argv.remove("--benchmarks")
- tf.app.run()
- else:
- tf.test.main()
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index 44d5d30fb3..91c389a2a2 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -121,6 +121,13 @@ class SoftmaxTest(tf.test.TestCase):
self._testOverflow(use_gpu=False)
+ def testEmpty(self):
+ with self.test_session():
+ x = tf.constant([[]], shape=[0, 3])
+ self.assertEqual(0, tf.size(x).eval())
+ expected_y = np.array([]).reshape(0, 3)
+ np.testing.assert_array_equal(expected_y, tf.nn.softmax(x).eval())
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
index 8b8ef6158a..02ebdce768 100644
--- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
@@ -25,13 +25,18 @@ import tensorflow as tf
class SpaceToDepthTest(tf.test.TestCase):
+ def _testOne(self, inputs, block_size, outputs):
+ for use_gpu in [False, True]:
+ with self.test_session(use_gpu=use_gpu):
+ x_tf = tf.space_to_depth(tf.to_float(inputs), block_size)
+ self.assertAllEqual(x_tf.eval(), outputs)
+
def testBasic(self):
x_np = [[[[1], [2]],
[[3], [4]]]]
- with self.test_session(use_gpu=False):
- block_size = 2
- out_tf = tf.space_to_depth(x_np, block_size)
- self.assertAllEqual(out_tf.eval(), [[[[1, 2, 3, 4]]]])
+ block_size = 2
+ x_out = [[[[1, 2, 3, 4]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input dimensions. To make sure elements are
# correctly ordered spatially.
@@ -40,14 +45,12 @@ class SpaceToDepthTest(tf.test.TestCase):
[[3], [4], [7], [8]],
[[9], [10], [13], [14]],
[[11], [12], [15], [16]]]]
-
- with self.test_session(use_gpu=False):
- block_size = 2
- out_tf = tf.space_to_depth(x_np, block_size)
- self.assertAllEqual(out_tf.eval(), [[[[1, 2, 3, 4],
- [5, 6, 7, 8]],
- [[9, 10, 11, 12],
- [13, 14, 15, 16]]]])
+ block_size = 2
+ x_out = [[[[1, 2, 3, 4],
+ [5, 6, 7, 8]],
+ [[9, 10, 11, 12],
+ [13, 14, 15, 16]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input dimensions. To make sure elements are
# correctly ordered in depth. Here, larger block size.
@@ -56,34 +59,27 @@ class SpaceToDepthTest(tf.test.TestCase):
[[3], [4], [7], [8]],
[[9], [10], [13], [14]],
[[11], [12], [15], [16]]]]
-
- with self.test_session(use_gpu=False):
- block_size = 4
- out_tf = tf.space_to_depth(x_np, block_size)
- self.assertAllEqual(
- out_tf.eval(),
- [[[[1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16]]]])
+ block_size = 4
+ x_out = [[[[1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input depths.
# To make sure elements are properly interleaved in depth.
def testDepthInterleaved(self):
x_np = [[[[1, 10], [2, 20]],
[[3, 30], [4, 40]]]]
- with self.test_session(use_gpu=False):
- block_size = 2
- out_tf = tf.space_to_depth(x_np, block_size)
- self.assertAllEqual(out_tf.eval(), [[[[1, 10, 2, 20, 3, 30, 4, 40]]]])
+ block_size = 2
+ x_out = [[[[1, 10, 2, 20, 3, 30, 4, 40]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input depths. Here an odd depth.
# To make sure elements are properly interleaved in depth.
def testDepthInterleavedDepth3(self):
x_np = [[[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]]]
- with self.test_session(use_gpu=False):
- block_size = 2
- out_tf = tf.space_to_depth(x_np, block_size)
- self.assertAllEqual(out_tf.eval(),
- [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]])
+ block_size = 2
+ x_out = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]]
+ self._testOne(x_np, block_size, x_out)
# Tests for larger input dimensions AND for larger input depths.
# To make sure elements are properly interleaved in depth and ordered
@@ -93,14 +89,29 @@ class SpaceToDepthTest(tf.test.TestCase):
[[3, 30], [4, 40], [7, 70], [8, 80]],
[[9, 90], [10, 100], [13, 130], [14, 140]],
[[11, 110], [12, 120], [15, 150], [16, 160]]]]
- with self.test_session(use_gpu=False):
- block_size = 2
- out_tf = tf.space_to_depth(x_np, block_size)
- self.assertAllEqual(out_tf.eval(),
- [[[[1, 10, 2, 20, 3, 30, 4, 40],
- [5, 50, 6, 60, 7, 70, 8, 80]],
- [[9, 90, 10, 100, 11, 110, 12, 120],
- [13, 130, 14, 140, 15, 150, 16, 160]]]])
+ block_size = 2
+ x_out = [[[[1, 10, 2, 20, 3, 30, 4, 40],
+ [5, 50, 6, 60, 7, 70, 8, 80]],
+ [[9, 90, 10, 100, 11, 110, 12, 120],
+ [13, 130, 14, 140, 15, 150, 16, 160]]]]
+ self._testOne(x_np, block_size, x_out)
+
+ def testBlockSize2Batch10(self):
+ block_size = 2
+ def batch_input_elt(i):
+ return [[[1 * i], [2 * i], [5 * i], [6 * i]],
+ [[3 * i], [4 * i], [7 * i], [8 * i]],
+ [[9 * i], [10 * i], [13 * i], [14 * i]],
+ [[11 * i], [12 * i], [15 * i], [16 * i]]]
+ def batch_output_elt(i):
+ return [[[1 * i, 2 * i, 3 * i, 4 * i],
+ [5 * i, 6 * i, 7 * i, 8 * i]],
+ [[9 * i, 10 * i, 11 * i, 12 * i],
+ [13 * i, 14 * i, 15 * i, 16 * i]]]
+ batch_size = 10
+ x_np = [batch_input_elt(i) for i in xrange(batch_size)]
+ x_out = [batch_output_elt(i) for i in xrange(batch_size)]
+ self._testOne(x_np, block_size, x_out)
# Tests for different width and height.
def testNonSquare(self):
@@ -110,13 +121,11 @@ class SpaceToDepthTest(tf.test.TestCase):
[[7, 70], [8, 80]],
[[9, 90], [10, 100]],
[[11, 110], [12, 120]]]]
- with self.test_session(use_gpu=False):
- block_size = 2
- out_tf = tf.space_to_depth(x_np, block_size)
- self.assertAllEqual(out_tf.eval(),
- [[[[1, 10, 2, 20, 3, 30, 4, 40]],
- [[5, 50, 6, 60, 7, 70, 8, 80]],
- [[9, 90, 10, 100, 11, 110, 12, 120]]]])
+ block_size = 2
+ x_out = [[[[1, 10, 2, 20, 3, 30, 4, 40]],
+ [[5, 50, 6, 60, 7, 70, 8, 80]],
+ [[9, 90, 10, 100, 11, 110, 12, 120]]]]
+ self._testOne(x_np, block_size, x_out)
# Error handling:
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index ccf38d5be1..f759a0a1a0 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -405,6 +405,7 @@ def boolean_mask(tensor, mask, name="boolean_mask"):
ValueError: If shapes do not conform.
Examples:
+
```python
# 2-D example
a = [[1, 2], [3, 4], [5, 6]]
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index cba8daa368..0b6125dfd6 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -218,7 +218,7 @@ class QueueBase(object):
return gen_data_flow_ops._queue_enqueue(self._queue_ref, vals, name=scope)
def enqueue_many(self, vals, name=None):
- """Enqueues zero or elements to this queue.
+ """Enqueues zero or more elements to this queue.
This operation slices each component tensor along the 0th dimension to
make multiple queue elements. All of the tensors in `vals` must have the
diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py
index a24fb39eb5..12e2e4eb8b 100644
--- a/tensorflow/python/ops/histogram_ops.py
+++ b/tensorflow/python/ops/histogram_ops.py
@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Operations for histograms."""
+# pylint: disable=g-short-docstring-punctuation
+"""## Histograms
+
+@@histogram_fixed_width
+"""
from __future__ import absolute_import
from __future__ import division
@@ -24,30 +28,34 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
-def histogram_fixed_width(hist,
- new_values,
+def histogram_fixed_width(values,
value_range,
- use_locking=False,
- name='histogram_fixed_width'):
- """Update histogram Variable with new values.
+ nbins=100,
+ use_locking=True,
+ dtype=dtypes.int32,
+ name=None):
+ """Return histogram of values.
- This Op fills histogram with counts of values falling within fixed-width,
- half-open bins.
+ Given the tensor `values`, this operation returns a rank 1 histogram counting
+ the number of entries in `values` that fell into every bin. The bins are
+ equal width and determined by the arguments `value_range` and `nbins`.
Args:
- hist: 1-D mutable `Tensor`, e.g. a `Variable`.
- new_values: Numeric `Tensor`.
+ values: Numeric `Tensor`.
value_range: Shape [2] `Tensor`. new_values <= value_range[0] will be
mapped to hist[0], values >= value_range[1] will be mapped to hist[-1].
Must be same dtype as new_values.
+ nbins: Integer number of bins in this histogram.
use_locking: Boolean.
If `True`, use locking during the operation (optional).
- name: A name for this operation (optional).
+ dtype: dtype for returned histogram.
+ name: A name for this operation (defaults to 'histogram_fixed_width').
Returns:
- An op that updates `hist` with `new_values` when evaluated.
+ A `Variable` holding histogram of values.
Examples:
```python
@@ -57,24 +65,21 @@ def histogram_fixed_width(hist,
new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
with tf.default_session() as sess:
- hist = variables.Variable(array_ops.zeros(nbins, dtype=tf.int32))
- hist_update = histogram_ops.histogram_fixed_width(hist, new_values,
- value_range)
+ hist = tf.histogram_fixed_width(new_values, value_range, nbins=5)
variables.initialize_all_variables().run()
- sess.run(hist_update) => [2, 1, 1, 0, 2]
+ sess.run(hist) => [2, 1, 1, 0, 2]
```
"""
- with ops.op_scope([hist, new_values, value_range], name) as scope:
- new_values = ops.convert_to_tensor(new_values, name='new_values')
- new_values = array_ops.reshape(new_values, [-1])
+ with variable_scope.variable_op_scope(
+ [values, value_range], name, 'histogram_fixed_width') as scope:
+ values = ops.convert_to_tensor(values, name='values')
+ values = array_ops.reshape(values, [-1])
value_range = ops.convert_to_tensor(value_range, name='value_range')
- dtype = hist.dtype
# Map tensor values that fall within value_range to [0, 1].
- scaled_values = math_ops.truediv(new_values - value_range[0],
+ scaled_values = math_ops.truediv(values - value_range[0],
value_range[1] - value_range[0],
name='scaled_values')
- nbins = math_ops.cast(hist.get_shape()[0], scaled_values.dtype)
# map tensor values within the open interval value_range to {0,.., nbins-1},
# values outside the open interval will be zero or less, or nbins or more.
@@ -87,9 +92,18 @@ def histogram_fixed_width(hist,
# Dummy vector to scatter.
# TODO(langmore) Replace non-ideal creation of large dummy vector once an
# alternative to scatter is available.
- updates = array_ops.ones([indices.get_shape()[0]], dtype=dtype)
- return state_ops.scatter_add(hist,
- indices,
- updates,
- use_locking=use_locking,
- name=scope)
+ updates = array_ops.ones_like(indices, dtype=dtype)
+
+ hist = variable_scope.get_variable('hist',
+ initializer=array_ops.zeros_initializer(
+ [nbins],
+ dtype=dtype),
+ trainable=False)
+ hist_assign_zero = hist.assign(array_ops.zeros_like(hist))
+
+ with ops.control_dependencies([hist_assign_zero]):
+ return state_ops.scatter_add(hist,
+ indices,
+ updates,
+ use_locking=use_locking,
+ name=scope.name)
diff --git a/tensorflow/python/ops/histogram_ops_test.py b/tensorflow/python/ops/histogram_ops_test.py
index 8358c2f1ea..514ba14e16 100644
--- a/tensorflow/python/ops/histogram_ops_test.py
+++ b/tensorflow/python/ops/histogram_ops_test.py
@@ -17,149 +17,132 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import histogram_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import googletest
-
import numpy as np
+import tensorflow as tf
-class HistogramFixedWidthTest(test_util.TensorFlowTestCase):
+class HistogramFixedWidthTest(tf.test.TestCase):
def setUp(self):
self.rng = np.random.RandomState(0)
+ def test_empty_input_gives_all_zero_counts(self):
+ # Bins will be:
+ # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ value_range = [0.0, 5.0]
+ values = []
+ expected_bin_counts = [0, 0, 0, 0, 0]
+ with self.test_session():
+ hist = tf.histogram_fixed_width(values, value_range, nbins=5)
+ tf.initialize_all_variables().run()
+
+ # Hist should start "fresh" with every eval.
+ self.assertAllClose(expected_bin_counts, hist.eval())
+ self.assertAllClose(expected_bin_counts, hist.eval())
+
def test_one_update_on_constant_input(self):
# Bins will be:
# (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
- nbins = [5]
value_range = [0.0, 5.0]
- new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
+ values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
expected_bin_counts = [2, 1, 1, 0, 2]
- with self.test_session() as sess:
- hist = variables.Variable(array_ops.zeros(nbins, dtype=dtypes.int32))
- hist_update = histogram_ops.histogram_fixed_width(hist, new_values,
- value_range)
- variables.initialize_all_variables().run()
- self.assertTrue(hist.dtype.is_compatible_with(hist_update.dtype))
- updated_hist_array = sess.run(hist_update)
-
- # The new updated_hist_array is returned by the updating op.
- self.assertAllClose(expected_bin_counts, updated_hist_array)
+ with self.test_session():
+ hist = tf.histogram_fixed_width(values, value_range, nbins=5)
+ tf.initialize_all_variables().run()
- # hist should contain updated values, but eval() should not change it.
+ # Hist should start "fresh" with every eval.
self.assertAllClose(expected_bin_counts, hist.eval())
self.assertAllClose(expected_bin_counts, hist.eval())
def test_one_update_on_constant_2d_input(self):
# Bins will be:
# (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
- nbins = [5]
value_range = [0.0, 5.0]
- new_values = [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]]
+ values = [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]]
expected_bin_counts = [2, 1, 1, 0, 2]
- with self.test_session() as sess:
- hist = variables.Variable(array_ops.zeros(nbins, dtype=dtypes.int32))
- hist_update = histogram_ops.histogram_fixed_width(hist, new_values,
- value_range)
- variables.initialize_all_variables().run()
- self.assertTrue(hist.dtype.is_compatible_with(hist_update.dtype))
- updated_hist_array = sess.run(hist_update)
-
- # The new updated_hist_array is returned by the updating op.
- self.assertAllClose(expected_bin_counts, updated_hist_array)
+ with self.test_session():
+ hist = tf.histogram_fixed_width(values, value_range, nbins=5)
+ tf.initialize_all_variables().run()
- # hist should contain updated values, but eval() should not change it.
+ # Hist should start "fresh" with every eval.
self.assertAllClose(expected_bin_counts, hist.eval())
self.assertAllClose(expected_bin_counts, hist.eval())
def test_two_updates_on_constant_input(self):
# Bins will be:
# (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
- nbins = [5]
value_range = [0.0, 5.0]
- new_values_1 = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
- new_values_2 = [1.5, 4.5, 4.5, 4.5, 0.0, 0.0]
+ values_1 = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
+ values_2 = [1.5, 4.5, 4.5, 4.5, 0.0, 0.0]
expected_bin_counts_1 = [2, 1, 1, 0, 2]
- expected_bin_counts_2 = [4, 2, 1, 0, 5]
- with self.test_session() as sess:
- hist = variables.Variable(array_ops.zeros(nbins, dtype=dtypes.int32))
- new_values = array_ops.placeholder(dtypes.float32, shape=[6])
- hist_update = histogram_ops.histogram_fixed_width(hist, new_values,
- value_range)
- variables.initialize_all_variables().run()
- updated_hist_array = sess.run(hist_update,
- feed_dict={new_values: new_values_1})
-
- # The new updated_hist_array is returned by the updating op.
- # hist should contain the updated values.
- self.assertAllClose(expected_bin_counts_1, updated_hist_array)
- self.assertAllClose(expected_bin_counts_1, hist.eval())
-
- updated_hist_array = sess.run(hist_update,
- feed_dict={new_values: new_values_2})
- self.assertAllClose(expected_bin_counts_2, updated_hist_array)
- self.assertAllClose(expected_bin_counts_2, hist.eval())
+ expected_bin_counts_2 = [2, 1, 0, 0, 3]
+ with self.test_session():
+ values = tf.placeholder(tf.float32, shape=[6])
+ hist = tf.histogram_fixed_width(values, value_range, nbins=5)
+ tf.initialize_all_variables().run()
+
+ # The values in hist should depend on the current feed and nothing else.
+ self.assertAllClose(expected_bin_counts_1,
+ hist.eval(feed_dict={values: values_1}))
+ self.assertAllClose(expected_bin_counts_2,
+ hist.eval(feed_dict={values: values_2}))
+ self.assertAllClose(expected_bin_counts_1,
+ hist.eval(feed_dict={values: values_1}))
+ self.assertAllClose(expected_bin_counts_1,
+ hist.eval(feed_dict={values: values_1}))
def test_two_updates_on_scalar_input(self):
# Bins will be:
# (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
- nbins = [5]
value_range = [0.0, 5.0]
- new_values_1 = 1.5
- new_values_2 = 2.5
+ values_1 = 1.5
+ values_2 = 2.5
expected_bin_counts_1 = [0, 1, 0, 0, 0]
- expected_bin_counts_2 = [0, 1, 1, 0, 0]
- with self.test_session() as sess:
- hist = variables.Variable(array_ops.zeros(nbins, dtype=dtypes.int32))
- new_values = array_ops.placeholder(dtypes.float32, shape=[])
- hist_update = histogram_ops.histogram_fixed_width(hist, new_values,
- value_range)
- variables.initialize_all_variables().run()
-
- # The new updated_hist_array is returned by the updating op.
- # hist should contain the updated values.
- updated_hist_array = sess.run(hist_update,
- feed_dict={new_values: new_values_1})
- self.assertAllClose(expected_bin_counts_1, updated_hist_array)
- self.assertAllClose(expected_bin_counts_1, hist.eval())
-
- updated_hist_array = sess.run(hist_update,
- feed_dict={new_values: new_values_2})
- self.assertAllClose(expected_bin_counts_2, updated_hist_array)
- self.assertAllClose(expected_bin_counts_2, hist.eval())
-
- def test_multiple_random_3d_updates_results_in_right_dist(self):
- # Update with uniform 3-D rvs. Resultant
+ expected_bin_counts_2 = [0, 0, 1, 0, 0]
+ with self.test_session():
+ values = tf.placeholder(tf.float32, shape=[])
+ hist = tf.histogram_fixed_width(values, value_range, nbins=5)
+ tf.initialize_all_variables().run()
+
+ # The values in hist should depend on the current feed and nothing else.
+ self.assertAllClose(expected_bin_counts_2,
+ hist.eval(feed_dict={values: values_2}))
+ self.assertAllClose(expected_bin_counts_1,
+ hist.eval(feed_dict={values: values_1}))
+ self.assertAllClose(expected_bin_counts_1,
+ hist.eval(feed_dict={values: values_1}))
+ self.assertAllClose(expected_bin_counts_2,
+ hist.eval(feed_dict={values: values_2}))
+
+ def test_multiple_random_accumulating_updates_results_in_right_dist(self):
+ # Accumulate the updates in a new variable. Resultant
# histogram should be uniform. Use only 3 bins because with many bins it
# would be unlikely that all would be close to 1/n. If someone ever wants
# to test that, it would be better to check that the cdf was linear.
- nbins = [3]
value_range = [1.0, 4.14159]
with self.test_session() as sess:
- hist = variables.Variable(array_ops.zeros(nbins, dtype=dtypes.int32))
- new_values = array_ops.placeholder(dtypes.float32, shape=[4, 4, 4])
- hist_update = histogram_ops.histogram_fixed_width(hist, new_values,
- value_range)
- variables.initialize_all_variables().run()
+ values = tf.placeholder(tf.float32, shape=[4, 4, 4])
+ hist = tf.histogram_fixed_width(values,
+ value_range,
+ nbins=3,
+ dtype=tf.int64)
+
+ hist_accum = tf.Variable(tf.zeros_initializer([3], dtype=tf.int64))
+ hist_accum = hist_accum.assign_add(hist)
+
+ tf.initialize_all_variables().run()
for _ in range(100):
# Map the rv: U[0, 1] --> U[value_range[0], value_range[1]].
- new_values_arr = (
+ values_arr = (
value_range[0] +
(value_range[1] - value_range[0]) * self.rng.rand(4, 4, 4))
- # The new updated_hist_array is returned by the updating op.
- # hist should contain the updated values.
- updated_hist_array = sess.run(hist_update,
- feed_dict={new_values: new_values_arr})
+ hist_accum_arr = sess.run(hist_accum, feed_dict={values: values_arr})
- pmf = updated_hist_array / float(updated_hist_array.sum())
+ pmf = hist_accum_arr / float(hist_accum_arr.sum())
np.testing.assert_allclose(1 / 3, pmf, atol=0.02)
if __name__ == '__main__':
- googletest.main()
+ tf.test.main()
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 15cf4736a3..d622d5dc21 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -92,6 +92,7 @@ The "producer" functions add a queue to the graph and a corresponding
@@match_filenames_once
@@limit_epochs
+@@input_producer
@@range_input_producer
@@slice_input_producer
@@string_input_producer
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index 7d922b3ed7..ebdfdc113b 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -556,15 +556,13 @@ class EmbeddingWrapper(RNNCell):
feed into your RNN.
"""
- def __init__(self, cell, embedding_classes=0, embedding=None,
- initializer=None):
+ def __init__(self, cell, embedding_classes, embedding_size, initializer=None):
"""Create a cell with an added input embedding.
Args:
cell: an RNNCell, an embedding will be put before its inputs.
embedding_classes: integer, how many symbols will be embedded.
- embedding: Variable, the embedding to use; if None, a new embedding
- will be created; if set, then embedding_classes is not required.
+ embedding_size: integer, the size of the vectors we embed into.
initializer: an initializer to use when creating the embedding;
if None, the initializer from variable scope or a default one is used.
@@ -574,21 +572,12 @@ class EmbeddingWrapper(RNNCell):
"""
if not isinstance(cell, RNNCell):
raise TypeError("The parameter cell is not RNNCell.")
- if embedding_classes < 1 and embedding is None:
- raise ValueError("Pass embedding or embedding_classes must be > 0: %d."
- % embedding_classes)
- if embedding_classes > 0 and embedding is not None:
- if embedding.size[0] != embedding_classes:
- raise ValueError("You declared embedding_classes=%d but passed an "
- "embedding for %d classes." % (embedding.size[0],
- embedding_classes))
- if embedding.size[1] != cell.input_size:
- raise ValueError("You passed embedding with output size %d and a cell"
- " that accepts size %d." % (embedding.size[1],
- cell.input_size))
+ if embedding_classes <= 0 or embedding_size <= 0:
+ raise ValueError("Both embedding_classes and embedding_size must be > 0: "
+ "%d, %d." % (embedding_classes, embedding_size))
self._cell = cell
self._embedding_classes = embedding_classes
- self._embedding = embedding
+ self._embedding_size = embedding_size
self._initializer = initializer
@property
@@ -607,20 +596,17 @@ class EmbeddingWrapper(RNNCell):
"""Run the cell on embedded inputs."""
with vs.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper"
with ops.device("/cpu:0"):
- if self._embedding:
- embedding = self._embedding
+ if self._initializer:
+ initializer = self._initializer
+ elif vs.get_variable_scope().initializer:
+ initializer = vs.get_variable_scope().initializer
else:
- if self._initializer:
- initializer = self._initializer
- elif vs.get_variable_scope().initializer:
- initializer = vs.get_variable_scope().initializer
- else:
- # Default initializer for embeddings should have variance=1.
- sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
- initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
- embedding = vs.get_variable("embedding", [self._embedding_classes,
- self._cell.input_size],
- initializer=initializer)
+ # Default initializer for embeddings should have variance=1.
+ sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
+ initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
+ embedding = vs.get_variable("embedding", [self._embedding_classes,
+ self._embedding_size],
+ initializer=initializer)
embedded = embedding_ops.embedding_lookup(
embedding, array_ops.reshape(inputs, [-1]))
return self._cell(embedded, state)
diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py
index 7df123ef70..6cbf70437c 100644
--- a/tensorflow/python/ops/seq2seq.py
+++ b/tensorflow/python/ops/seq2seq.py
@@ -311,7 +311,9 @@ def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
"""
with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq"):
# Encoder.
- encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
+ encoder_cell = rnn_cell.EmbeddingWrapper(
+ cell, embedding_classes=num_encoder_symbols,
+ embedding_size=cell.input_size)
_, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
@@ -686,7 +688,9 @@ def embedding_attention_seq2seq(encoder_inputs, decoder_inputs, cell,
"""
with variable_scope.variable_scope(scope or "embedding_attention_seq2seq"):
# Encoder.
- encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
+ encoder_cell = rnn_cell.EmbeddingWrapper(
+ cell, embedding_classes=num_encoder_symbols,
+ embedding_size=cell.input_size)
encoder_outputs, encoder_state = rnn.rnn(
encoder_cell, encoder_inputs, dtype=dtype)
@@ -772,7 +776,9 @@ def one2many_rnn_seq2seq(encoder_inputs, decoder_inputs_dict, cell,
with variable_scope.variable_scope(scope or "one2many_rnn_seq2seq"):
# Encoder.
- encoder_cell = rnn_cell.EmbeddingWrapper(cell, num_encoder_symbols)
+ encoder_cell = rnn_cell.EmbeddingWrapper(
+ cell, embedding_classes=num_encoder_symbols,
+ embedding_size=cell.input_size)
_, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
# Decoder.
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index c7c4ceb083..e1fd5d0143 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -774,7 +774,7 @@ def _SerializeManySparseShape(op): # pylint: disable=invalid-name
return [tensor_shape.matrix(None, 3)]
-def deserialize_many_sparse(serialized_sparse, dtype, name=None):
+def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
"""Deserialize and concatenate `SparseTensors` from a serialized minibatch.
The input `serialized_sparse` must be a string matrix of shape `[N x 3]` where
@@ -823,6 +823,7 @@ def deserialize_many_sparse(serialized_sparse, dtype, name=None):
serialized_sparse: 2-D `Tensor` of type `string` of shape `[N, 3]`.
The serialized and packed `SparseTensor' objects.
dtype: The `dtype` of the serialized `SparseTensor` objects.
+ rank: (optional) Python int, the rank of the `SparseTensor` objects.
name: A name prefix for the returned tensors (optional)
Returns:
@@ -835,6 +836,10 @@ def deserialize_many_sparse(serialized_sparse, dtype, name=None):
gen_sparse_ops._deserialize_many_sparse(
serialized_sparse, dtype, name=name))
+ # Feed rank data back in, if available
+ output_indices.set_shape([None, rank])
+ output_shape.set_shape([rank])
+
return ops.SparseTensor(output_indices, output_values, output_shape)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 13cb40cf31..90ac0a057b 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -42,6 +42,7 @@ from tensorflow.python.ops.control_flow_ops import foldr
from tensorflow.python.ops.control_flow_ops import map_fn
from tensorflow.python.ops.data_flow_ops import *
from tensorflow.python.ops.gradients import *
+from tensorflow.python.ops.histogram_ops import *
from tensorflow.python.ops.init_ops import *
from tensorflow.python.ops.io_ops import *
from tensorflow.python.ops.linalg_ops import *
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
new file mode 100644
index 0000000000..87f95bb2c7
--- /dev/null
+++ b/tensorflow/python/platform/benchmark.py
@@ -0,0 +1,213 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities to run benchmarks."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import inspect
+import numbers
+import os
+import re
+import sys
+
+import six # pylint: disable=unused-import
+
+from google.protobuf import text_format
+from tensorflow.core.util import test_log_pb2
+from tensorflow.python.platform import app
+from tensorflow.python.platform import gfile
+
+# When a subclass of the Benchmark class is created, it is added to
+# the registry automatically
+GLOBAL_BENCHMARK_REGISTRY = set()
+
+# Environment variable that determines whether benchmarks are written.
+# See also tensorflow/core/util/reporter.h TestReporter::kTestReporterEnv.
+TEST_REPORTER_TEST_ENV = "TEST_REPORT_FILE_PREFIX"
+
+
+def _global_report_benchmark(
+ name, iters=None, cpu_time=None, wall_time=None,
+ throughput=None, extras=None):
+ """Method for recording a benchmark directly.
+
+ Args:
+ name: The BenchmarkEntry name.
+ iters: (optional) How many iterations were run
+ cpu_time: (optional) Total cpu time in seconds
+ wall_time: (optional) Total wall time in seconds
+ throughput: (optional) Throughput (in MB/s)
+ extras: (optional) Dict mapping string keys to additional benchmark info.
+
+ Raises:
+ TypeError: if extras is not a dict.
+ IOError: if the benchmark output file already exists.
+ """
+ if extras is not None:
+ if not isinstance(extras, dict):
+ raise TypeError("extras must be a dict")
+
+ test_env = os.environ.get(TEST_REPORTER_TEST_ENV, None)
+ if test_env is None:
+ # Reporting was not requested
+ return
+
+ entry = test_log_pb2.BenchmarkEntry()
+ entry.name = name
+ if iters is not None:
+ entry.iters = iters
+ if cpu_time is not None:
+ entry.cpu_time = cpu_time
+ if wall_time is not None:
+ entry.wall_time = wall_time
+ if throughput is not None:
+ entry.throughput = throughput
+ if extras is not None:
+ for (k, v) in extras.items():
+ if isinstance(v, numbers.Number):
+ entry.extras[k].double_value = v
+ else:
+ entry.extras[k].string_value = str(v)
+
+ serialized_entry = text_format.MessageToString(entry)
+
+ mangled_name = name.replace("/", "__")
+ output_path = "%s%s" % (test_env, mangled_name)
+ if gfile.Exists(output_path):
+ raise IOError("File already exists: %s" % output_path)
+ with gfile.GFile(output_path, "w") as out:
+ out.write(serialized_entry)
+
+
+class _BenchmarkRegistrar(type):
+ """The Benchmark class registrar. Used by abstract Benchmark class."""
+
+ def __new__(mcs, clsname, base, attrs):
+ newclass = super(mcs, _BenchmarkRegistrar).__new__(
+ mcs, clsname, base, attrs)
+ if len(newclass.mro()) > 2:
+ # Only the base Benchmark abstract class has mro length 2.
+ # The rest subclass from it and are therefore registered.
+ GLOBAL_BENCHMARK_REGISTRY.add(newclass)
+ return newclass
+
+
+class Benchmark(object):
+ """Abstract class that provides helper functions for running benchmarks.
+
+ Any class subclassing this one is immediately registered in the global
+ benchmark registry.
+
+ Only methods whose names start with the word "benchmark" will be run during
+ benchmarking.
+ """
+ __metaclass__ = _BenchmarkRegistrar
+
+ def _get_name(self, overwrite_name):
+ """Returns full name of class and method calling report_benchmark."""
+
+ # Expect that the caller called report_benchmark, which called _get_name.
+ caller = inspect.stack()[2]
+ calling_class = caller[0].f_locals.get("self", None)
+ # Use the method name, or overwrite_name is provided.
+ name = overwrite_name if overwrite_name is not None else caller[3]
+ if calling_class is not None:
+ # Prefix the name with the class name.
+ class_name = type(calling_class).__name__
+ name = "%s.%s" % (class_name, name)
+ return name
+
+ def report_benchmark(
+ self,
+ iters=None,
+ cpu_time=None,
+ wall_time=None,
+ throughput=None,
+ extras=None,
+ name=None):
+ """Report a benchmark.
+
+ Args:
+ iters: (optional) How many iterations were run
+ cpu_time: (optional) Total cpu time in seconds
+ wall_time: (optional) Total wall time in seconds
+ throughput: (optional) Throughput (in MB/s)
+ extras: (optional) Dict mapping string keys to additional benchmark info.
+ name: (optional) Override the BenchmarkEntry name with `name`.
+ Otherwise it is inferred from the calling class and top-level
+ method name.
+ """
+ name = self._get_name(overwrite_name=name)
+ _global_report_benchmark(
+ name=name, iters=iters, cpu_time=cpu_time, wall_time=wall_time,
+ throughput=throughput, extras=extras)
+
+
+def _run_specific_benchmark(benchmark_class):
+ benchmark = benchmark_class()
+ attrs = dir(benchmark)
+ # Only run methods of this class whose names start with "benchmark"
+ for attr in attrs:
+ if not attr.startswith("benchmark"):
+ continue
+ benchmark_fn = getattr(benchmark, attr)
+ if not callable(benchmark_fn):
+ continue
+ # Call this benchmark method
+ benchmark_fn()
+
+
+def _run_benchmarks(regex):
+ """Run benchmarks that match regex `regex`.
+
+ This function goes through the global benchmark registry, and matches
+ benchmark **classe names** of the form "module.name.BenchmarkClass" to
+ the given regex. If a class matches, all of its benchmark methods
+ are run.
+
+ Args:
+ regex: The string regular expression to match Benchmark classes against.
+ """
+ registry = list(GLOBAL_BENCHMARK_REGISTRY)
+
+ # Match benchmarks in registry against regex
+ for benchmark in registry:
+ benchmark_name = "%s.%s" % (benchmark.__module__, benchmark.__name__)
+ if re.search(regex, benchmark_name):
+ # Found a match
+
+ _run_specific_benchmark(benchmark)
+
+
+def benchmarks_main(true_main=None):
+ """Run benchmarks as declared in args.
+
+ Args:
+ true_main: True main function to run if benchmarks are not requested.
+ """
+ argv = sys.argv
+ found_arg = [arg for arg in argv
+ if arg.startswith("--benchmarks=")
+ or arg.startswith("-benchmarks=")]
+ if found_arg:
+ # Remove --benchmarks arg from sys.argv
+ argv.remove(found_arg[0])
+
+ regex = found_arg[0].split("=")[1]
+ app.run(lambda _: _run_benchmarks(regex))
+ else:
+ true_main()
diff --git a/tensorflow/python/platform/default/_app.py b/tensorflow/python/platform/default/_app.py
index e700956f17..74fecfe7ef 100644
--- a/tensorflow/python/platform/default/_app.py
+++ b/tensorflow/python/platform/default/_app.py
@@ -23,8 +23,8 @@ import sys
from tensorflow.python.platform import flags
-def run():
+def run(main=None):
f = flags.FLAGS
f._parse_flags()
- main = sys.modules['__main__'].main
+ main = main or sys.modules['__main__'].main
sys.exit(main(sys.argv))
diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py
index 2049bd2b1d..76e15d7872 100644
--- a/tensorflow/python/platform/googletest.py
+++ b/tensorflow/python/platform/googletest.py
@@ -21,7 +21,20 @@ from __future__ import print_function
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
from . import control_imports
+from tensorflow.python.platform import benchmark
+
+# Import the Benchmark class
+Benchmark = benchmark.Benchmark # pylint: disable=invalid-name
+
if control_imports.USE_OSS and control_imports.OSS_GOOGLETEST:
from tensorflow.python.platform.default._googletest import *
+ from tensorflow.python.platform.default._googletest import main as g_main
else:
from tensorflow.python.platform.google._googletest import *
+ from tensorflow.python.platform.google._googletest import main as g_main
+
+
+# Redefine main to allow running benchmarks
+def main():
+ # Benchmarks determine whether to run tests or not, by calling g_main
+ benchmark.benchmarks_main(true_main=g_main)
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index d2b9d1f974..6d78193233 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -72,6 +72,10 @@ from tensorflow.python.kernel_tests.gradient_checker import compute_gradient
# pylint: enable=unused-import
+# Import Benchmark class
+Benchmark = googletest.Benchmark # pylint: disable=invalid-name
+
+
def main():
"""Runs all unit tests."""
return googletest.main()
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index 661bae7bc1..f018126bc8 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -131,6 +131,8 @@ class Coordinator(object):
# Event set when threads must stop.
self._stop_event = threading.Event()
# Python exc_info to report.
+ # If not None, it should hold the returned value of sys.exc_info(), which is
+ # a tuple containing exception (type, value, traceback).
self._exc_info_to_raise = None
def request_stop(self, ex=None):
@@ -138,6 +140,10 @@ class Coordinator(object):
After this is called, calls to `should_stop()` will return `True`.
+ Note: If an exception is being passed in, in must be in the context of
+ handling the exception (i.e. `try: ... except Exception as ex: ...`) and not
+ a newly created one.
+
Args:
ex: Optional `Exception`, or Python `exc_info` tuple as returned by
`sys.exc_info()`. If this is the first call to `request_stop()` the
@@ -154,6 +160,22 @@ class Coordinator(object):
logging.info("Error reported to Coordinator: %s",
compat.as_str_any(ex))
self._exc_info_to_raise = sys.exc_info()
+ # self._exc_info_to_raise should contain a tuple containing exception
+ # (type, value, traceback)
+ if (len(self._exc_info_to_raise) != 3 or
+ not self._exc_info_to_raise[0] or
+ not self._exc_info_to_raise[1]):
+ # Raise, catch and record the exception here so that error happens
+ # where expected.
+ try:
+ raise ValueError(
+ "ex must be a tuple or sys.exc_info must return the current "
+ "exception: %s"
+ % self._exc_info_to_raise)
+ except ValueError:
+ # Record this error so it kills the coordinator properly.
+ self._exc_info_to_raise = sys.exc_info()
+
self._stop_event.set()
def clear_stop(self):
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index ace7b49d97..ae2df782a4 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -84,20 +84,63 @@ def limit_epochs(tensor, num_epochs=None, name=None):
return array_ops.identity(tensor, name=name)
-def _input_producer(input_tensor, dtype, num_epochs, shuffle, seed, capacity,
- shared_name, name, summary_name):
- if shuffle:
- input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)
- input_tensor = limit_epochs(input_tensor, num_epochs)
-
- q = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=[dtype], shapes=[[]],
- shared_name=shared_name, name=name)
- enq = q.enqueue_many([input_tensor])
- queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq]))
- logging_ops.scalar_summary("queue/%s/%s" % (q.name, summary_name),
- math_ops.cast(q.size(), dtypes.float32) *
- (1. / capacity))
- return q
+def input_producer(input_tensor, element_shape=None, num_epochs=None,
+ shuffle=True, seed=None, capacity=32, shared_name=None,
+ summary_name=None, name=None):
+ """Output the rows of `input_tensor` to a queue for an input pipeline.
+
+ Args:
+ input_tensor: A tensor with the rows to produce. Must be at
+ one-dimensional. Must either have a fully-defined shape, or
+ `element_shape` must be defined.
+ element_shape: (Optional.) A `TensorShape` representing the shape of a
+ row of `input_tensor`, if it cannot be inferred.
+ num_epochs: (Optional.) An integer. If specified `input_producer` produces
+ each row of `input_tensor` `num_epochs` times before generating an
+ `OutOfRange` error. If not specified, `input_producer` can cycle through
+ the rows of `input_tensor` an unlimited number of times.
+ shuffle: (Optional.) A boolean. If true, the rows are randomly shuffled
+ within each eopch.
+ seed: (Optional.) An integer. The seed to use if `shuffle` is true.
+ capacity: (Optional.) The capacity of the queue to be used for buffering
+ the input.
+ shared_name: (Optional.) If set, this queue will be shared under the given
+ name across multiple sessions.
+ summary_name: (Optional.) If set, a scalar summary for the current queue
+ size will be generated, using this name as part of the tag.
+ name: (Optional.) A name for queue.
+
+ Returns:
+ A queue with the output rows. A `QueueRunner` for the queue is
+ added to the current `QUEUE_RUNNER` collection of the current
+ graph.
+
+ Raises:
+ ValueError: If the shape of the input cannot be inferred from the arguments.
+ """
+ with ops.op_scope([input_tensor], name, "input_producer"):
+ input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor")
+ element_shape = input_tensor.get_shape()[1:].merge_with(element_shape)
+ if not element_shape.is_fully_defined():
+ raise ValueError("Either `input_tensor` must have a fully defined shape "
+ "or `element_shape` must be specified")
+
+ if shuffle:
+ input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)
+
+ input_tensor = limit_epochs(input_tensor, num_epochs)
+
+ q = data_flow_ops.FIFOQueue(capacity=capacity,
+ dtypes=[input_tensor.dtype.base_dtype],
+ shapes=[element_shape],
+ shared_name=shared_name, name=name)
+ enq = q.enqueue_many([input_tensor])
+ queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq]))
+ if summary_name is not None:
+ logging_ops.scalar_summary("queue/%s/%s" % (q.name, summary_name),
+ math_ops.cast(q.size(), dtypes.float32) *
+ (1. / capacity))
+ return q
def string_input_producer(string_tensor, num_epochs=None, shuffle=True,
@@ -108,9 +151,9 @@ def string_input_producer(string_tensor, num_epochs=None, shuffle=True,
string_tensor: A 1-D string tensor with the strings to produce.
num_epochs: An integer (optional). If specified, `string_input_producer`
produces each string from `string_tensor` `num_epochs` times before
- generating an OutOfRange error. If not specified, `string_input_producer`
- can cycle through the strings in `string_tensor` an unlimited number of
- times.
+ generating an `OutOfRange` error. If not specified,
+ `string_input_producer` can cycle through the strings in `string_tensor`
+ an unlimited number of times.
shuffle: Boolean. If true, the strings are randomly shuffled within each
epoch.
seed: An integer (optional). Seed used if shuffle == True.
@@ -137,9 +180,9 @@ def string_input_producer(string_tensor, num_epochs=None, shuffle=True,
logging_ops.Assert(math_ops.greater(array_ops.size(string_tensor), 0),
[not_null_err])]):
string_tensor = array_ops.identity(string_tensor)
- return _input_producer(
+ return input_producer(
input_tensor=string_tensor,
- dtype=dtypes.string,
+ element_shape=[],
num_epochs=num_epochs,
shuffle=shuffle,
seed=seed,
@@ -173,8 +216,8 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
"""
with ops.op_scope([limit], name, "input_producer") as name:
range_tensor = math_ops.range(limit)
- return _input_producer(
- range_tensor, dtypes.int32, num_epochs, shuffle, seed, capacity,
+ return input_producer(
+ range_tensor, [], num_epochs, shuffle, seed, capacity,
shared_name, name, "fraction_of_%d_full" % capacity)
@@ -231,51 +274,104 @@ def _flatten(tensor_list_list):
return [tensor for tensor_list in tensor_list_list for tensor in tensor_list]
+class _SparseMetaData(object):
+ """Store information about the Tensor: Is it sparse?, dtype, and rank."""
+
+ def __init__(self, sparse, dtype, rank):
+ self._sparse = sparse
+ self._dtype = dtype
+ self._rank = rank
+
+ def __eq__(self, other):
+ if self.sparse != other.sparse:
+ return False
+ if not self.sparse:
+ return True
+ if self.dtype != other.dtype:
+ return False
+ if not self.rank.is_compatible_with(other.rank):
+ return False
+ return True
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __str__(self):
+ return "[SparseMetaData(%s, %s, %s)]" % (self.sparse, self.dtype, self.rank)
+
+ def merge_with(self, other):
+ if self != other:
+ raise ValueError("SparseMetaData objects are incompatible: %s vs. %s"
+ % (self, other))
+ if self.sparse:
+ self.rank.merge_with(other.rank)
+ return self
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @property
+ def sparse(self):
+ return self._sparse
+
+ @property
+ def rank(self):
+ return self._rank
+
+
def _serialize_sparse_tensors(tensor_list, enqueue_many):
"""Serialize SparseTensors for feeding into batch, etc."""
- is_sparse_list = [isinstance(t, ops.SparseTensor) for t in tensor_list]
- sparse_dtypes_list = [
- t.dtype if isinstance(t, ops.SparseTensor) else None
+ sparse_info_list = [
+ _SparseMetaData(sparse=True,
+ dtype=t.dtype,
+ rank=t.shape.get_shape().with_rank(1)[0])
+ if isinstance(t, ops.SparseTensor)
+ else _SparseMetaData(False, None, None)
for t in tensor_list]
- def _maybe_serialize(t, is_sparse):
- if not is_sparse:
+ def _maybe_serialize(t, sparse):
+ if not sparse:
return t
return (sparse_ops.serialize_many_sparse(t) if enqueue_many
else sparse_ops.serialize_sparse(t))
+
serialized_list = [
- _maybe_serialize(t, is_sparse)
- for (t, is_sparse) in zip(tensor_list, is_sparse_list)]
- return serialized_list, is_sparse_list, sparse_dtypes_list
+ _maybe_serialize(t, info.sparse) for (t, info)
+ in zip(tensor_list, sparse_info_list)]
+
+ return serialized_list, sparse_info_list
def _serialize_sparse_tensors_join(tensor_list_list, enqueue_many):
"""Serialize SparseTensors for feeding into batch_join, etc."""
- (s0, is_sparse_list, sparse_dtypes_list) = _serialize_sparse_tensors(
+ (s0, sparse_info_list) = _serialize_sparse_tensors(
tensor_list_list[0], enqueue_many)
serialized_list_list = [s0]
for tensor_list in tensor_list_list[1:]:
- (s, is_sparse_candidate, sparse_dtypes_candidate) = (
- _serialize_sparse_tensors(tensor_list, enqueue_many))
- if is_sparse_candidate != is_sparse_list:
+ s, sparse_info_candidate = _serialize_sparse_tensors(
+ tensor_list, enqueue_many)
+ if sparse_info_list != sparse_info_candidate:
raise ValueError("Inconsistent SparseTensors list: %s vs. %s"
% (tensor_list_list[0], tensor_list))
- if sparse_dtypes_candidate != sparse_dtypes_list:
- raise ValueError("Inconsistent SparseTensor dtypes in list: %s vs. %s"
- % (tensor_list_list[0], tensor_list))
+ sparse_info_list = [
+ info.merge_with(candidate)
+ for (info, candidate) in zip(sparse_info_list, sparse_info_candidate)]
serialized_list_list.append(s)
- return (serialized_list_list, is_sparse_list, sparse_dtypes_list)
+
+ return (serialized_list_list, sparse_info_list)
-def _deserialize_sparse_tensors(serialized_list, is_sparse_list, sparse_dtypes):
+def _deserialize_sparse_tensors(serialized_list, sparse_info_list):
"""Deserialize SparseTensors after dequeue in batch, batch_join, etc."""
received_sequence = isinstance(serialized_list, collections.Sequence)
if not received_sequence:
serialized_list = (serialized_list,)
- tensors = [sparse_ops.deserialize_many_sparse(s, sparse_dtype) if is_sparse
- else s
- for (s, is_sparse, sparse_dtype)
- in zip(serialized_list, is_sparse_list, sparse_dtypes)]
+ tensors = [
+ sparse_ops.deserialize_many_sparse(s, info.dtype, info.rank.value)
+ if info.sparse else s
+ for (s, info)
+ in zip(serialized_list, sparse_info_list)]
return tensors if received_sequence else tensors[0]
@@ -345,7 +441,8 @@ def _enqueue(queue, tensor_list, threads, enqueue_many):
def batch(tensor_list, batch_size, num_threads=1, capacity=32,
- enqueue_many=False, shapes=None, shared_name=None, name=None):
+ enqueue_many=False, shapes=None,
+ shared_name=None, name=None):
"""Creates batches of tensors in `tensor_list`.
This function is implemented using a queue. A `QueueRunner` for the
@@ -394,7 +491,7 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
"""
with ops.op_scope(tensor_list, name, "batch") as name:
tensor_list = _validate(tensor_list)
- tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors(
+ (tensor_list, sparse_info) = _serialize_sparse_tensors(
tensor_list, enqueue_many)
types = _dtypes([tensor_list])
shapes = _shapes([tensor_list], shapes, enqueue_many)
@@ -407,7 +504,7 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
dequeued = queue.dequeue_many(batch_size, name=name)
- dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
+ dequeued = _deserialize_sparse_tensors(dequeued, sparse_info)
return dequeued
@@ -478,8 +575,8 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
"""
with ops.op_scope(_flatten(tensor_list_list), name, "batch_join") as name:
tensor_list_list = _validate_join(tensor_list_list)
- tensor_list_list, is_sparse, sparse_dtypes = (
- _serialize_sparse_tensors_join(tensor_list_list, enqueue_many))
+ tensor_list_list, sparse_info = _serialize_sparse_tensors_join(
+ tensor_list_list, enqueue_many)
types = _dtypes(tensor_list_list)
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
@@ -491,7 +588,7 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
dequeued = queue.dequeue_many(batch_size, name=name)
- dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
+ dequeued = _deserialize_sparse_tensors(dequeued, sparse_info)
return dequeued
@@ -567,7 +664,7 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
"""
with ops.op_scope(tensor_list, name, "shuffle_batch") as name:
tensor_list = _validate(tensor_list)
- tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors(
+ tensor_list, sparse_info = _serialize_sparse_tensors(
tensor_list, enqueue_many)
types = _dtypes([tensor_list])
shapes = _shapes([tensor_list], shapes, enqueue_many)
@@ -586,7 +683,7 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
logging_ops.scalar_summary(summary_name, full)
dequeued = queue.dequeue_many(batch_size, name=name)
- dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
+ dequeued = _deserialize_sparse_tensors(dequeued, sparse_info)
return dequeued
@@ -652,8 +749,8 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity,
with ops.op_scope(
_flatten(tensor_list_list), name, "shuffle_batch_join") as name:
tensor_list_list = _validate_join(tensor_list_list)
- tensor_list_list, is_sparse, sparse_dtypes = (
- _serialize_sparse_tensors_join(tensor_list_list, enqueue_many))
+ tensor_list_list, sparse_info = _serialize_sparse_tensors_join(
+ tensor_list_list, enqueue_many)
types = _dtypes(tensor_list_list)
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
queue = data_flow_ops.RandomShuffleQueue(
@@ -671,5 +768,5 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity,
logging_ops.scalar_summary(summary_name, full)
dequeued = queue.dequeue_many(batch_size, name=name)
- dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
+ dequeued = _deserialize_sparse_tensors(dequeued, sparse_info)
return dequeued
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index 14c31442dd..b265c6e3c4 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -69,6 +69,60 @@ class LimitEpochsTest(tf.test.TestCase):
love_me_two_times.eval()
+class InputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session():
+ input_tensor = [[1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [9, 10, 11, 12]]
+ num_epochs = 2
+ queue = tf.train.input_producer(
+ input_tensor, num_epochs=num_epochs, shuffle=False)
+ dequeue_many = queue.dequeue_many(len(input_tensor) * num_epochs)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ self.assertAllEqual(input_tensor * num_epochs, dequeue_many.eval())
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+ def testNoShapeInference(self):
+ with self.test_session():
+ # Disable shape inference for the input.
+ input_value = [[1, 2, 3, 4],
+ [5, 6, 7, 8],
+ [9, 10, 11, 12]]
+ input_tensor = tf.placeholder_with_default(input_value, shape=None)
+ num_epochs = 2
+ queue = tf.train.input_producer(
+ input_tensor, element_shape=[4], num_epochs=num_epochs, shuffle=False)
+ dequeue_many = queue.dequeue_many(len(input_value) * num_epochs)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ self.assertAllEqual(input_value * num_epochs, dequeue_many.eval())
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+ def testShapeError(self):
+ input_tensor = tf.placeholder(tf.float32, None)
+ with self.assertRaisesRegexp(ValueError, "fully defined shape"):
+ _ = tf.train.input_producer(input_tensor)
+
+
class StringInputProducerTest(tf.test.TestCase):
def testNoShuffle(self):
diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py
index 1257230df9..ff92008872 100644
--- a/tensorflow/python/training/summary_io.py
+++ b/tensorflow/python/training/summary_io.py
@@ -25,11 +25,14 @@ import time
import six
+from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import summary_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import ops
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
from tensorflow.python.util import compat
@@ -53,7 +56,8 @@ class SummaryWriter(object):
@@close
"""
- def __init__(self, logdir, graph_def=None, max_queue=10, flush_secs=120):
+ def __init__(self, logdir, graph=None, max_queue=10, flush_secs=120,
+ graph_def=None):
"""Creates a `SummaryWriter` and an event file.
On construction the summary writer creates a new event file in `logdir`.
@@ -61,7 +65,7 @@ class SummaryWriter(object):
call one of the following functions: `add_summary()`, `add_session_log()`,
`add_event()`, or `add_graph()`.
- If you pass a `graph_def` protocol buffer to the constructor it is added to
+ If you pass a `Graph` to the constructor it is added to
the event file. (This is equivalent to calling `add_graph()` later).
TensorBoard will pick the graph from the file and display it graphically so
@@ -72,8 +76,8 @@ class SummaryWriter(object):
...create a graph...
# Launch the graph in a session.
sess = tf.Session()
- # Create a summary writer, add the 'graph_def' to the event file.
- writer = tf.train.SummaryWriter(<some-directory>, sess.graph_def)
+ # Create a summary writer, add the 'graph' to the event file.
+ writer = tf.train.SummaryWriter(<some-directory>, sess.graph)
```
The other arguments to the constructor control the asynchronous writes to
@@ -86,10 +90,11 @@ class SummaryWriter(object):
Args:
logdir: A string. Directory where event file will be written.
- graph_def: A `GraphDef` protocol buffer.
+ graph: A `Graph` object, such as `sess.graph`.
max_queue: Integer. Size of the queue for pending events and summaries.
flush_secs: Number. How often, in seconds, to flush the
pending events and summaries to disk.
+ graph_def: DEPRECATED: Use the `graph` argument instead.
"""
self._logdir = logdir
if not gfile.IsDirectory(self._logdir):
@@ -100,8 +105,9 @@ class SummaryWriter(object):
self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
flush_secs)
self._worker.start()
- if graph_def is not None:
- self.add_graph(graph_def)
+ if graph is not None or graph_def is not None:
+ # Calling it with both graph and graph_def for backward compatibility.
+ self.add_graph(graph=graph, graph_def=graph_def)
def add_summary(self, summary, global_step=None):
"""Adds a `Summary` protocol buffer to the event file.
@@ -154,22 +160,64 @@ class SummaryWriter(object):
"""
self._event_queue.put(event)
- def add_graph(self, graph_def, global_step=None):
- """Adds a `GraphDef` protocol buffer to the event file.
+ def _add_graph_def(self, graph_def, global_step=None):
+ graph_bytes = graph_def.SerializeToString()
+ event = event_pb2.Event(wall_time=time.time(), graph_def=graph_bytes)
+ if global_step is not None:
+ event.step = int(global_step)
+ self._event_queue.put(event)
+
+ def add_graph(self, graph, global_step=None, graph_def=None):
+ """Adds a `Graph` to the event file.
The graph described by the protocol buffer will be displayed by
TensorBoard. Most users pass a graph in the constructor instead.
Args:
- graph_def: A `GraphDef` protocol buffer.
+ graph: A `Graph` object, such as `sess.graph`.
global_step: Number. Optional global step counter to record with the
graph.
+ graph_def: DEPRECATED. Use the `graph` parameter instead.
+
+ Raises:
+ ValueError: If both graph and graph_def are passed to the method.
"""
- graph_bytes = graph_def.SerializeToString()
- event = event_pb2.Event(wall_time=time.time(), graph_def=graph_bytes)
- if global_step is not None:
- event.step = int(global_step)
- self._event_queue.put(event)
+
+ if graph is not None and graph_def is not None:
+ raise ValueError("Please pass only graph, or graph_def (deprecated), "
+ "but not both.")
+
+ if isinstance(graph, ops.Graph) or isinstance(graph_def, ops.Graph):
+ # The user passed a `Graph`.
+
+ # Check if the user passed it via the graph or the graph_def argument and
+ # correct for that.
+ if not isinstance(graph, ops.Graph):
+ logging.warning("When passing a `Graph` object, please use the `graph`"
+ " named argument instead of `graph_def`.")
+ graph = graph_def
+
+ # Serialize the graph with additional info.
+ true_graph_def = graph.as_graph_def(add_shapes=True)
+ elif (isinstance(graph, graph_pb2.GraphDef)
+ or isinstance(graph_def, graph_pb2.GraphDef)):
+ # The user passed a `GraphDef`.
+ logging.warning("Passing a `GraphDef` to the SummaryWriter is deprecated."
+ " Pass a `Graph` object instead, such as `sess.graph`.")
+
+ # Check if the user passed it via the graph or the graph_def argument and
+ # correct for that.
+ if isinstance(graph, graph_pb2.GraphDef):
+ true_graph_def = graph
+ else:
+ true_graph_def = graph_def
+
+ else:
+ # The user passed neither `Graph`, nor `GraphDef`.
+ raise TypeError("The passed graph must be an instance of `Graph` "
+ "or the deprecated `GraphDef`")
+ # Finally, add the graph_def to the summary writer.
+ self._add_graph_def(true_graph_def, global_step)
def flush(self):
"""Flushes the event file to disk.
diff --git a/tensorflow/python/training/summary_writer_test.py b/tensorflow/python/training/summary_writer_test.py
index 3307c2da12..d1ff95f902 100644
--- a/tensorflow/python/training/summary_writer_test.py
+++ b/tensorflow/python/training/summary_writer_test.py
@@ -49,6 +49,25 @@ class SummaryWriterTestCase(tf.test.TestCase):
def _assertRecent(self, t):
self.assertTrue(abs(t - time.time()) < 5)
+ def _assertEventsWithGraph(self, test_dir, g, has_shapes):
+ rr = self._EventsReader(test_dir)
+
+ # The first event should list the file_version.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:2", ev.file_version)
+
+ # The next event should have the graph.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(0, ev.step)
+ ev_graph = tf.GraphDef()
+ ev_graph.ParseFromString(ev.graph_def)
+ self.assertProtoEquals(g.as_graph_def(add_shapes=has_shapes), ev_graph)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
def testAddingSummaryAndGraph(self):
test_dir = self._CleanTestDir("basics")
sw = tf.train.SummaryWriter(test_dir)
@@ -105,30 +124,54 @@ class SummaryWriterTestCase(tf.test.TestCase):
# We should be done.
self.assertRaises(StopIteration, lambda: next(rr))
- def testInitializingWithGraphDef(self):
- test_dir = self._CleanTestDir("basics_with_graph")
+ def testGraphAsNamed(self):
+ test_dir = self._CleanTestDir("basics_named_graph")
+ with tf.Graph().as_default() as g:
+ tf.constant([12], name="douze")
+ sw = tf.train.SummaryWriter(test_dir, graph=g)
+ sw.close()
+ self._assertEventsWithGraph(test_dir, g, True)
+
+ def testGraphAsPositional(self):
+ test_dir = self._CleanTestDir("basics_positional_graph")
+ with tf.Graph().as_default() as g:
+ tf.constant([12], name="douze")
+ sw = tf.train.SummaryWriter(test_dir, g)
+ sw.close()
+ self._assertEventsWithGraph(test_dir, g, True)
+
+ def testGraphDefAsNamed(self):
+ test_dir = self._CleanTestDir("basics_named_graph_def")
with tf.Graph().as_default() as g:
tf.constant([12], name="douze")
gd = g.as_graph_def()
sw = tf.train.SummaryWriter(test_dir, graph_def=gd)
sw.close()
- rr = self._EventsReader(test_dir)
-
- # The first event should list the file_version.
- ev = next(rr)
- self._assertRecent(ev.wall_time)
- self.assertEquals("brain.Event:2", ev.file_version)
+ self._assertEventsWithGraph(test_dir, g, False)
- # The next event should have the graph.
- ev = next(rr)
- self._assertRecent(ev.wall_time)
- self.assertEquals(0, ev.step)
- ev_graph = tf.GraphDef()
- ev_graph.ParseFromString(ev.graph_def)
- self.assertProtoEquals(gd, ev_graph)
+ def testGraphDefAsPositional(self):
+ test_dir = self._CleanTestDir("basics_positional_graph_def")
+ with tf.Graph().as_default() as g:
+ tf.constant([12], name="douze")
+ gd = g.as_graph_def()
+ sw = tf.train.SummaryWriter(test_dir, gd)
+ sw.close()
+ self._assertEventsWithGraph(test_dir, g, False)
+
+ def testGraphAndGraphDef(self):
+ with self.assertRaises(ValueError):
+ test_dir = self._CleanTestDir("basics_graph_and_graph_def")
+ with tf.Graph().as_default() as g:
+ tf.constant([12], name="douze")
+ gd = g.as_graph_def()
+ sw = tf.train.SummaryWriter(test_dir, graph=g, graph_def=gd)
+ sw.close()
- # We should be done.
- self.assertRaises(StopIteration, lambda: next(rr))
+ def testNeitherGraphNorGraphDef(self):
+ with self.assertRaises(TypeError):
+ test_dir = self._CleanTestDir("basics_string_instead_of_graph")
+ sw = tf.train.SummaryWriter(test_dir, "string instead of graph object")
+ sw.close()
# Checks that values returned from session Run() calls are added correctly to
# summaries. These are numpy types so we need to check they fit in the
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index 7df435fc65..7396627f7b 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -844,7 +844,7 @@ class SVSummaryThread(coordinator.LooperThread):
self._sess = sess
def run_loop(self):
- if self._sv.global_step:
+ if self._sv.global_step is not None:
summary_strs, global_step = self._sess.run([self._sv.summary_op,
self._sv.global_step])
else:
@@ -912,7 +912,7 @@ class SVTimerCheckpointThread(coordinator.LooperThread):
def run_loop(self):
self._sv.saver.save(self._sess, self._sv.save_path,
global_step=self._sv.global_step)
- if self._sv.summary_writer and self._sv.global_step:
+ if self._sv.summary_writer and self._sv.global_step is not None:
current_step = training_util.global_step(self._sess, self._sv.global_step)
self._sv.summary_writer.add_session_log(
SessionLog(status=SessionLog.CHECKPOINT,
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 94475817e0..1f1d427c45 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -50,6 +50,7 @@ namespace perftools {
namespace gputools {
class Stream;
+class ScratchAllocator;
template <typename ElemT>
class DeviceMemory;
@@ -880,14 +881,14 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<float> *> &a, int lda,
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta,
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
- int batch_count) = 0;
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, double alpha,
const port::ArraySlice<DeviceMemory<double> *> &a, int lda,
const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta,
const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
- int batch_count) = 0;
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<float> alpha,
@@ -895,7 +896,7 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
std::complex<float> beta,
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
- int batch_count) = 0;
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
virtual bool DoBlasGemmBatched(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, std::complex<double> alpha,
@@ -903,7 +904,7 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
- int batch_count) = 0;
+ int batch_count, ScratchAllocator *scratch_allocator) = 0;
// Computes a matrix-matrix product where one input matrix is Hermitian:
//
@@ -1140,7 +1141,7 @@ class BlasSupport {
// Macro used to quickly declare overrides for abstract virtuals in the
// BlasSupport base class.
-#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
+#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
bool DoBlasAsum(Stream *stream, uint64 elem_count, \
const DeviceMemory<float> &x, int incx, \
DeviceMemory<float> *result) override; \
@@ -1626,14 +1627,14 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \
const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \
const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \
- int batch_count) override; \
+ int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, double alpha, \
const port::ArraySlice<DeviceMemory<double> *> &a, int lda, \
const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \
const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, \
- int batch_count) override; \
+ int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
@@ -1641,7 +1642,7 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \
std::complex<float> beta, \
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \
- int batch_count) override; \
+ int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
@@ -1650,7 +1651,7 @@ class BlasSupport {
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, \
int ldb, std::complex<double> beta, \
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \
- int ldc, int batch_count) override; \
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \
bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \
uint64 m, uint64 n, std::complex<float> alpha, \
const DeviceMemory<std::complex<float>> &a, int lda, \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 19ad12d28b..fb21baf9bf 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <complex>
+#include "third_party/gpus/cuda/include/cublas_v2.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
@@ -34,8 +35,8 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/scratch_allocator.h"
#include "tensorflow/stream_executor/stream_executor.h"
-#include "third_party/gpus/cuda/include/cublas_v2.h"
namespace perftools {
namespace gputools {
@@ -1707,37 +1708,64 @@ template <typename T, typename FuncT>
port::Status CUDABlas::DoBlasGemmBatchedInternal(
FuncT cublas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
- const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
- const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
- const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
- int batch_count) {
- std::vector<T *> a_ptr_vec, b_ptr_vec, c_ptr_vec;
+ const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
+ const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
+ T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
+ std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
for (int i = 0; i < batch_count; ++i) {
- a_ptr_vec.push_back(static_cast<T *>(a_array[i]->opaque()));
- b_ptr_vec.push_back(static_cast<T *>(b_array[i]->opaque()));
- c_ptr_vec.push_back(static_cast<T *>(c_array[i]->opaque()));
+ a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque()));
+ b_raw_ptrs.push_back(static_cast<T *>(b_ptrs_to_wrappers[i]->opaque()));
+ c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
}
typedef typename CUDAComplexT<T>::type CUDA_T;
- SE_ASSIGN_OR_RETURN(
- std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_ptr_array,
- stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
- SE_ASSIGN_OR_RETURN(
- std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_ptr_array,
- stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
- SE_ASSIGN_OR_RETURN(
- std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_ptr_array,
- stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
-
- if (!stream->ThenMemcpy(a_ptr_array->mutable_device_memory(),
- a_ptr_vec.data(), batch_count * sizeof(T *))
- .ok() ||
- !stream->ThenMemcpy(b_ptr_array->mutable_device_memory(),
- b_ptr_vec.data(), batch_count * sizeof(T *))
- .ok() ||
- !stream->ThenMemcpy(c_ptr_array->mutable_device_memory(),
- c_ptr_vec.data(), batch_count * sizeof(T *))
- .ok()) {
+
+ const size_t size = batch_count * sizeof(CUDA_T *);
+
+ // Device-side copy of pointers to matrices.
+ DeviceMemory<CUDA_T *> a;
+ DeviceMemory<CUDA_T *> b;
+ DeviceMemory<CUDA_T *> c;
+
+ // If temporary space is allocated for device-side copies of pointers to
+ // matrices, that temporary space should not be freed until this function
+ // returns. Although the values for these unique_ptrs are not set here, they
+ // are declared at this scope so they will be destroyed when the function
+ // returns.
+ //
+ // If a scratch allocator is provided, these pointers will not be used at all.
+ std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_temporary;
+ std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_temporary;
+ std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_temporary;
+
+ // Decide how to allocate device-side copy of pointers to matrices based on
+ // whether a scratch allocator was passed.
+ if (scratch_allocator != nullptr) {
+ SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
+ scratch_allocator->AllocateBytes(stream, size));
+ SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
+ scratch_allocator->AllocateBytes(stream, size));
+ SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
+ scratch_allocator->AllocateBytes(stream, size));
+ a = DeviceMemory<CUDA_T *>(a_bytes);
+ b = DeviceMemory<CUDA_T *>(b_bytes);
+ c = DeviceMemory<CUDA_T *>(c_bytes);
+ } else {
+ SE_ASSIGN_OR_RETURN(a_temporary,
+ stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
+ SE_ASSIGN_OR_RETURN(b_temporary,
+ stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
+ SE_ASSIGN_OR_RETURN(c_temporary,
+ stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
+ a = DeviceMemory<CUDA_T *>(*a_temporary->mutable_device_memory());
+ b = DeviceMemory<CUDA_T *>(*b_temporary->mutable_device_memory());
+ c = DeviceMemory<CUDA_T *>(*c_temporary->mutable_device_memory());
+ }
+
+ if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() ||
+ !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() ||
+ !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) {
return port::Status(port::error::INTERNAL,
"failed to copy memory from host to device in "
"CUDABlas::DoBlasGemmBatched");
@@ -1746,13 +1774,9 @@ port::Status CUDABlas::DoBlasGemmBatchedInternal(
bool ok = DoBlasInternal(
cublas_func, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
- CUDAComplex(&alpha),
- const_cast<const CUDA_T **>(CUDAMemory(a_ptr_array->device_memory())),
- lda,
- const_cast<const CUDA_T **>(CUDAMemory(b_ptr_array->device_memory())),
- ldb, CUDAComplex(&beta),
- const_cast<CUDA_T **>(CUDAMemory(c_ptr_array->device_memory())), ldc,
- batch_count);
+ CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda,
+ const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta),
+ const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count);
if (ok) {
return port::Status::OK();
@@ -1767,10 +1791,11 @@ bool CUDABlas::DoBlasGemmBatched(
const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
- int batch_count) {
+ int batch_count, ScratchAllocator *scratch_allocator) {
SE_RETURN_STATUS_AS_BOOL(DoBlasGemmBatchedInternal(
dynload::cublasSgemmBatched, stream, transa, transb, m, n, k, alpha,
- a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count));
+ a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
+ scratch_allocator));
}
bool CUDABlas::DoBlasGemmBatched(
@@ -1779,10 +1804,11 @@ bool CUDABlas::DoBlasGemmBatched(
const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
- int ldc, int batch_count) {
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
SE_RETURN_STATUS_AS_BOOL(DoBlasGemmBatchedInternal(
dynload::cublasDgemmBatched, stream, transa, transb, m, n, k, alpha,
- a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count));
+ a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
+ scratch_allocator));
}
bool CUDABlas::DoBlasGemmBatched(
@@ -1793,10 +1819,11 @@ bool CUDABlas::DoBlasGemmBatched(
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
int ldb, std::complex<float> beta,
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
- int ldc, int batch_count) {
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
SE_RETURN_STATUS_AS_BOOL(DoBlasGemmBatchedInternal(
dynload::cublasCgemmBatched, stream, transa, transb, m, n, k, alpha,
- a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count));
+ a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
+ scratch_allocator));
}
bool CUDABlas::DoBlasGemmBatched(
@@ -1807,10 +1834,11 @@ bool CUDABlas::DoBlasGemmBatched(
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
int ldb, std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
- int ldc, int batch_count) {
+ int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
SE_RETURN_STATUS_AS_BOOL(DoBlasGemmBatchedInternal(
dynload::cublasZgemmBatched, stream, transa, transb, m, n, k, alpha,
- a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count));
+ a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count,
+ scratch_allocator));
}
bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
index 046b7253e4..d5b949f7d1 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -93,7 +93,7 @@ class CUDABlas : public blas::BlasSupport {
const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
- int batch_count);
+ int batch_count, ScratchAllocator *scratch_allocator);
// mutex that guards the cuBLAS handle for this device.
mutex mu_;
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 587896a2ab..cee781f77b 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -2986,6 +2986,17 @@ Stream &Stream::ThenBlasGemmBatched(
int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
int batch_count) {
+ return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
+ b, ldb, beta, c, ldc, batch_count,
+ nullptr);
+}
+
+Stream &Stream::ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
+ int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -2993,9 +3004,12 @@ Stream &Stream::ThenBlasGemmBatched(
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
const port::ArraySlice<DeviceMemory<float> *> &, int,
const port::ArraySlice<DeviceMemory<float> *> &, int, float,
- const port::ArraySlice<DeviceMemory<float> *> &, int, int> impl;
+ const port::ArraySlice<DeviceMemory<float> *> &, int, int,
+ ScratchAllocator *>
+ impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
- k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
+ scratch_allocator);
}
Stream &Stream::ThenBlasGemmBatched(
@@ -3004,6 +3018,17 @@ Stream &Stream::ThenBlasGemmBatched(
int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
int batch_count) {
+ return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
+ b, ldb, beta, c, ldc, batch_count,
+ nullptr);
+}
+
+Stream &Stream::ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
+ int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
+ double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -3011,9 +3036,12 @@ Stream &Stream::ThenBlasGemmBatched(
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
const port::ArraySlice<DeviceMemory<double> *> &, int,
const port::ArraySlice<DeviceMemory<double> *> &, int, double,
- const port::ArraySlice<DeviceMemory<double> *> &, int, int> impl;
+ const port::ArraySlice<DeviceMemory<double> *> &, int, int,
+ ScratchAllocator *>
+ impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
- k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
+ scratch_allocator);
}
Stream &Stream::ThenBlasGemmBatched(
@@ -3024,6 +3052,19 @@ Stream &Stream::ThenBlasGemmBatched(
std::complex<float> beta,
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
int batch_count) {
+ return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
+ b, ldb, beta, c, ldc, batch_count,
+ nullptr);
+}
+
+Stream &Stream::ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
+ std::complex<float> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -3035,9 +3076,11 @@ Stream &Stream::ThenBlasGemmBatched(
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
int, std::complex<float>,
const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
- int, int> impl;
+ int, int, ScratchAllocator *>
+ impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
- k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
+ scratch_allocator);
}
Stream &Stream::ThenBlasGemmBatched(
@@ -3048,6 +3091,19 @@ Stream &Stream::ThenBlasGemmBatched(
std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
int batch_count) {
+ return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
+ b, ldb, beta, c, ldc, batch_count,
+ nullptr);
+}
+
+Stream &Stream::ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
+ std::complex<double> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
@@ -3059,9 +3115,11 @@ Stream &Stream::ThenBlasGemmBatched(
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
int, std::complex<double>,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
- int, int> impl;
+ int, int, ScratchAllocator *>
+ impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
- k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count);
+ k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
+ scratch_allocator);
}
Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index d91c62ca26..599146f49b 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -944,6 +944,34 @@ class Stream {
std::complex<double> beta,
const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
int batch_count);
+ Stream &ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
+ int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
+ float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator);
+ Stream &ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
+ int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
+ double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator);
+ Stream &ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<float> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
+ std::complex<float> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator);
+ Stream &ThenBlasGemmBatchedWithScratch(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, std::complex<double> alpha,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
+ std::complex<double> beta,
+ const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
+ int batch_count, ScratchAllocator *scratch_allocator);
// See BlasSupport::DoBlasHemm.
Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/tf-event-dashboard.html b/tensorflow/tensorboard/components/tf-event-dashboard/tf-event-dashboard.html
index d77af121d7..6bf32a4ccd 100644
--- a/tensorflow/tensorboard/components/tf-event-dashboard/tf-event-dashboard.html
+++ b/tensorflow/tensorboard/components/tf-event-dashboard/tf-event-dashboard.html
@@ -115,7 +115,7 @@ The #center div contains tf-charts embedded inside tf-collapsable-panes.
<p>
Maybe data hasn't loaded yet, or maybe you need
to add some <code>tf.scalar_summary</code> ops to your graph, and
- serialize them using the <code>tf.training.summary_io.SummaryWriter</code>.
+ serialize them using the <code>tf.train.SummaryWriter</code>.
</p>
</div>
</template>
diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/tf-run-selector.html b/tensorflow/tensorboard/components/tf-event-dashboard/tf-run-selector.html
index 563b2dd194..c69446edc4 100644
--- a/tensorflow/tensorboard/components/tf-event-dashboard/tf-run-selector.html
+++ b/tensorflow/tensorboard/components/tf-event-dashboard/tf-run-selector.html
@@ -75,7 +75,6 @@ Properties out:
display: flex;
flex-grow: 1;
flex-shrink: 1;
- height: 0px; /* hackhack So the flex-grow takes over and gives it space */
}
.x-button {
font-size: 13px;
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
index 1dba760aae..8dca63c9ab 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
@@ -515,6 +515,13 @@ function addEdges(h: Hierarchy, graph: SlimGraph,
let sourceAncestorIndex = getPath(graph.nodes[baseEdge.v], sourcePath);
let destAncestorIndex = getPath(graph.nodes[baseEdge.w], destPath);
+ // If the hierarchical path cannot be found for either endpoint, then we
+ // cannot create the edge. This happens for example when a node has a
+ // control dependency on a summary node, which are embedded.
+ if (sourceAncestorIndex === -1 || destAncestorIndex === -1) {
+ return;
+ }
+
// Find the lowest shared ancestor between source and dest by looking for
// the highest nodes that differ between their ancestor paths.
while (sourcePath[sourceAncestorIndex] === destPath[destAncestorIndex]) {
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts
index 0d9e5b53bf..b2f4fd1d7f 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts
@@ -87,7 +87,7 @@ export const PARAMS = {
*/
labelHeight: 20,
/** X-space between each extracted node and the core graph. */
- extractXOffset: 50,
+ extractXOffset: 15,
/** Y-space between each extracted node. */
extractYOffset: 20
},
@@ -486,9 +486,24 @@ function layoutMetanode(renderNodeInfo: render.RenderGroupNodeInfo): void {
return height + yOffset + child.height;
}, 0);
+ // Compute the total padding between the core graph, in-extract and
+ // out-extract boxes.
+ let numParts = 0;
+ if (renderNodeInfo.isolatedInExtract.length > 0) {
+ numParts++;
+ }
+ if (renderNodeInfo.isolatedOutExtract.length > 0) {
+ numParts++;
+ }
+ if (renderNodeInfo.coreGraph.nodeCount() > 0) {
+ numParts++;
+ }
+ let offset = PARAMS.subscene.meta.extractXOffset;
+ let padding = numParts <= 1 ? 0 : (numParts <= 2 ? offset : 2 * offset);
+
// Add the in-extract and out-extract width to the core box width.
renderNodeInfo.coreBox.width += renderNodeInfo.inExtractBox.width +
- renderNodeInfo.outExtractBox.width;
+ renderNodeInfo.outExtractBox.width + padding;
renderNodeInfo.coreBox.height =
params.labelHeight +
Math.max(
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
index fa0ee99d19..dd43e650d3 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
@@ -964,8 +964,6 @@ export class RenderNodeInfo {
/** Label vertical offset from the center of node shape */
labelOffset: number;
- /** X-space between each extracted node and the core graph. */
- extractXOffset: number;
/** Rectangle radius (for making rounded rectangle) */
radius: number;
@@ -1027,7 +1025,6 @@ export class RenderNodeInfo {
// Params for node box.
this.labelOffset = 0;
- this.extractXOffset = 0;
this.radius = 0;
// Params for expanded node
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts
index b6eb3f7d81..1b6cb3a58c 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts
@@ -321,15 +321,18 @@ function position(sceneGroup, renderNode: render.RenderGroupNodeInfo) {
// in-extract
let hasInExtract = renderNode.isolatedInExtract.length > 0;
+ let hasOutExtract = renderNode.isolatedOutExtract.length > 0;
+
if (hasInExtract) {
+ let offset = layout.PARAMS.subscene.meta.extractXOffset;
let inExtractX = renderNode.coreBox.width -
- renderNode.inExtractBox.width / 2 - renderNode.outExtractBox.width;
+ renderNode.inExtractBox.width / 2 - renderNode.outExtractBox.width -
+ (hasOutExtract ? offset : 0);
translate(selectChild(sceneGroup, "g", Class.Scene.INEXTRACT),
inExtractX, yTranslate);
}
// out-extract
- let hasOutExtract = renderNode.isolatedOutExtract.length > 0;
if (hasOutExtract) {
let outExtractX = renderNode.coreBox.width -
renderNode.outExtractBox.width / 2;
diff --git a/tensorflow/tensorboard/components/tf-graph-dashboard/tf-graph-dashboard.html b/tensorflow/tensorboard/components/tf-graph-dashboard/tf-graph-dashboard.html
index d26bf2e8f4..e0c0184864 100644
--- a/tensorflow/tensorboard/components/tf-graph-dashboard/tf-graph-dashboard.html
+++ b/tensorflow/tensorboard/components/tf-graph-dashboard/tf-graph-dashboard.html
@@ -37,7 +37,7 @@ by default. The user can select a different run from a dropdown menu.
</p>
<p>
To store a graph, create a
- <code>tf.python.training.summary_io.SummaryWriter</code>
+ <code>tf.train.SummaryWriter</code>
and pass the graph either via the constructor, or by calling its
<code>add_graph()</code> method.
</p>
diff --git a/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html b/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html
index d715925d2c..c23e358bf2 100644
--- a/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html
+++ b/tensorflow/tensorboard/components/tf-graph-info/tf-node-info.html
@@ -195,7 +195,7 @@
</paper-icon-button>
Control dependencies
</div>
- <iron-collapse opened="{{_openedControlPred}}">
+ <iron-collapse opened="{{_openedControlPred}}" no-animation>
<template is="dom-if" if="{{_openedControlPred}}" restamp="true">
<iron-list class="sub-list" items="[[_predecessors.control]]">
<template>
@@ -246,7 +246,7 @@
</paper-icon-button>
Control dependencies
</div>
- <iron-collapse opened="{{_openedControlSucc}}">
+ <iron-collapse opened="{{_openedControlSucc}}" no-animation>
<template is="dom-if" if="{{_openedControlSucc}}" restamp="true">
<iron-list class="sub-list" items="[[_successors.control]]">
<template>
diff --git a/tensorflow/tensorboard/components/tf-histogram-dashboard/tf-histogram-dashboard.html b/tensorflow/tensorboard/components/tf-histogram-dashboard/tf-histogram-dashboard.html
index 8f68791c51..d7c631ebdb 100644
--- a/tensorflow/tensorboard/components/tf-histogram-dashboard/tf-histogram-dashboard.html
+++ b/tensorflow/tensorboard/components/tf-histogram-dashboard/tf-histogram-dashboard.html
@@ -109,7 +109,7 @@ The #center div contains tf-charts embedded inside tf-collapsable-panes.
<p>
Maybe data hasn't loaded yet, or maybe you need
to add some <code>tf.histogram_summary</code> ops to your graph, and
- serialize them using the <code>tf.training.summary_io.SummaryWriter</code>.
+ serialize them using the <code>tf.train.SummaryWriter</code>.
</p>
</div>
</template>
diff --git a/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-dashboard.html b/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-dashboard.html
index 57c41abe95..18a7b9f708 100644
--- a/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-dashboard.html
+++ b/tensorflow/tensorboard/components/tf-image-dashboard/tf-image-dashboard.html
@@ -43,7 +43,7 @@ mechanism for loading older images rather than always getting the most recent on
<p>
Maybe data hasn't loaded yet, or maybe you need
to add some <code>tf.image_summary</code> ops to your graph, and
- serialize them using the <code>tf.training.summary_io.SummaryWriter</code>.
+ serialize them using the <code>tf.train.SummaryWriter</code>.
</p>
</div>
</template>
diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html
index 492a5d45ce..31c62345ac 100644
--- a/tensorflow/tensorboard/dist/tf-tensorboard.html
+++ b/tensorflow/tensorboard/dist/tf-tensorboard.html
@@ -2086,7 +2086,7 @@ var TF;
<p>
Maybe data hasn't loaded yet, or maybe you need
to add some <code>tf.scalar_summary</code> ops to your graph, and
- serialize them using the <code>tf.training.summary_io.SummaryWriter</code>.
+ serialize them using the <code>tf.train.SummaryWriter</code>.
</p>
</div>
</template>
@@ -2201,7 +2201,7 @@ var TF;
<p>
Maybe data hasn't loaded yet, or maybe you need
to add some <code>tf.histogram_summary</code> ops to your graph, and
- serialize them using the <code>tf.training.summary_io.SummaryWriter</code>.
+ serialize them using the <code>tf.train.SummaryWriter</code>.
</p>
</div>
</template>
@@ -2476,7 +2476,7 @@ var TF;
<p>
Maybe data hasn't loaded yet, or maybe you need
to add some <code>tf.image_summary</code> ops to your graph, and
- serialize them using the <code>tf.training.summary_io.SummaryWriter</code>.
+ serialize them using the <code>tf.train.SummaryWriter</code>.
</p>
</div>
</template>
@@ -11197,7 +11197,7 @@ function convertToHumanReadable(value, units, unitIndex) {
</p>
<p>
To store a graph, create a
- <code>tf.python.training.summary_io.SummaryWriter</code>
+ <code>tf.train.SummaryWriter</code>
and pass the graph either via the constructor, or by calling its
<code>add_graph()</code> method.
</p>
@@ -11401,4 +11401,4 @@ Polymer({
});
</script>
</dom-module>
-</body></html> \ No newline at end of file
+</body></html>
diff --git a/tensorflow/tensorboard/lib/js/colorScale/colorScale.ts b/tensorflow/tensorboard/lib/js/colorScale/colorScale.ts
new file mode 100644
index 0000000000..f57551dbe0
--- /dev/null
+++ b/tensorflow/tensorboard/lib/js/colorScale/colorScale.ts
@@ -0,0 +1,148 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Each color scale is initialized with a configurable number of base hues.
+// There are also several palettes available.
+// TF.palettes.googleStandard, TF.palettes.googleColorBlind,
+// TF.palettes.googleCool, TF.palettes.googleWarm, TF.palettes.constantValue
+// Each string is hashed to an integer,
+// then mapped to one of the base hues above.
+// If there is a collision, the color that is later in an alphabetical sort
+// gets nudged a little darker or lighter to disambiguate.
+// I would call it mostly stable, in that the same array of strings will
+// always return the same colors, but the same individual string may
+// shift a little depending on its peers.
+//
+// runs = ["train", "test", "test1", "test2"]
+// ccs = new TF.ColorScale(12, "googleStandard");
+// ccs.domain(runs);
+// ccs.getColor("train");
+// ccs.getColor("test1");
+
+module TF {
+ export class ColorScale {
+ public numColors: number;
+ public internalColorScale: d3.scale.Linear<string, string>;
+ private buckets: string[][];
+
+ /**
+ * The palette you provide defines your spectrum. The colorscale will
+ * always use the full spectrum you provide. When you define "numColors"
+ * it resamples at regular intervals along the full extent of the spectrum.
+ * Thus you get the maximum distance between hues for the "numColors"
+ * given. This allows the programmer to tweak the algorithm depending on
+ * how big your expected domain is. If you generally think you're going to
+ * have a small number of elements in the domain, then a small numColors
+ * will be serviceable. With large domains, a small numColors would produce
+ * too many hash collisions, so you'd want to bump it up to the threshold
+ * of human perception (probably around 14 or 18).
+ *
+ * @param {number} [numColors=12] - The number of base colors you want
+ * in the palette. The more colors, the smaller the number
+ * the more hash collisions you will have, but the more
+ * differentiable the base colors will be.
+ *
+ * @param {string[]} [palette=TF.palettes.googleColorBlind] - The color
+ * palette you want as an Array of hex strings. Note, the
+ * length of the array in this palette is independent of the
+ * param numColors above. The scale will interpolate to
+ * create the proper "numColors" given in the first param.
+ *
+ */
+ constructor(numColors = 12, palette: string[] = TF.palettes.googleColorBlind) {
+ this.numColors = numColors;
+ this.domain([]);
+
+ if (palette.length < 2) {
+ throw new Error("Not enough colors in palette. Must be more than one.")
+ }
+
+ var k = (this.numColors - 1) / (palette.length - 1);
+ this.internalColorScale = d3.scale.linear<string>()
+ .domain(d3.range(palette.length).map((i) => i * k))
+ .range(palette);
+ }
+
+ private hash(s: string): number {
+ function h(hash, str) {
+ hash = (hash << 5) - hash + str.charCodeAt(0);
+ return hash & hash;
+ }
+ return Math.abs(Array.prototype.reduce.call(s, h, 0)) % this.numColors;
+ }
+
+
+ /**
+ * Set the domain of strings so we can calculate collisions preemptively.
+ * Can be reset at any point.
+ *
+ * @param {string[]} strings - An array of strings to use as the domain
+ * for your scale.
+ */
+ public domain(strings: string[]) {
+ this.buckets = d3.range(this.numColors).map(() => []);
+ var sortedUniqueKeys = d3.set(strings).values().sort(function(a, b) { return a.localeCompare(b); });
+ sortedUniqueKeys.forEach((s) => this.addToDomain(s));
+ }
+
+ private getBucketForString(s: string) {
+ var bucketIdx = this.hash(s);
+ return this.buckets[bucketIdx];
+ }
+
+ private addToDomain(s: string) {
+ var bucketIdx = this.hash(s);
+ var bucket = this.buckets[bucketIdx];
+ if (bucket.indexOf(s) === -1) {
+ bucket.push(s);
+ }
+ }
+
+ private nudge(color: string, amount: number): any {
+ // If amount is zero, just give back same color
+ if (amount === 0) {
+ return color;
+
+ // For first tick, nudge lighter...
+ } else if (amount === 1) {
+ return d3.hcl(color).brighter(0.6);
+
+ // ..otherwise nudge darker. Darker will approach black, which is visible.
+ } else {
+ return d3.hcl(color).darker((amount - 1) / 2);
+ }
+ }
+
+ /**
+ * Use the color scale to transform an element in the domain into a color.
+ * If there was a hash conflict, the color will be "nudged" darker or lighter so that it is
+ * unique.
+ * @param {string} The input string to map to a color.
+ * @return {string} The color corresponding to that input string.
+ * @throws Will error if input string is not in the scale's domain.
+ */
+
+ public getColor(s: string): string {
+ var bucket = this.getBucketForString(s);
+ var idx = bucket.indexOf(s);
+ if (idx === -1) {
+ throw new Error("String was not in the domain.");
+ }
+ var color = this.internalColorScale(this.hash(s));
+ return this.nudge(color, idx).toString();
+ }
+
+ }
+}
diff --git a/tensorflow/tensorboard/lib/js/colorScale/demo/index.html b/tensorflow/tensorboard/lib/js/colorScale/demo/index.html
new file mode 100644
index 0000000000..c3d94da539
--- /dev/null
+++ b/tensorflow/tensorboard/lib/js/colorScale/demo/index.html
@@ -0,0 +1,176 @@
+<!doctype html>
+<meta charset="utf-8">
+<script src="../../../../components/d3/d3.min.js"></script>
+<script src="../palettes.js"></script>
+<script src="../colorScale.js"></script>
+
+<link rel="stylesheet" href="style.css">
+
+<style>
+
+.color-swatch {
+ display: inline-block;
+ height: 20px;
+}
+.stage {
+ margin-top: 40px;
+ margin-bottom: 200px;
+ position: relative;
+}
+.color {
+ position: absolute;
+ margin: 0 0 4px 0;
+}
+.swatch {
+ border-radius: 2px;
+ float: left;
+ width: 15px;
+ height: 15px;
+ margin-right: 10px;
+ margin-top: 8px;
+}
+
+.label {
+ display: inline;
+}
+</style>
+
+<header>
+ <h1 class="trunk">Stable and Unique Colors for Category Labels</h1>
+ <p class="trunk">A method for defining a stable categorical color scale for real-time, changing data.</p>
+
+</header>
+<h3 class="trunk">Base colors</h3>
+<p class="trunk">Each color scale is initialized with a configurable number of base hues. There are 18 shown below. There are also several palettes available.</p>
+<p class="palettes trunk"></p>
+<h3 class="trunk">A sample list of categories</h3>
+<p class="trunk">
+ Each string is hashed to an integer, then mapped to one of the base hues above. If there is a collision, the color that is later in an alphabetical sort gets nudged a little darker or lighter to disambiguate. I would call it <i>mostly</i> stable, in that the same array of strings will always return the same colors, but the same individual string may shift a little depending on its peers.
+</p>
+<p class="stage trunk"></p>
+
+<script type>
+"use strict";
+
+var runs = [
+ "A Midsummer Night's Dream",
+ "All's Well That Ends Well",
+ "Antony and Cleopatra",
+ "As You Like It",
+ "Coriolanus",
+ "Cymbeline",
+ "Hamlet",
+ "Henry IV",
+ "Henry VIII",
+ "Julius Caesar",
+ "King John",
+ "King Lear",
+ "Love's Labour's Lost",
+ "Macbeth",
+ "Measure for Measure",
+ "Much Ado About Nothing",
+ "Othello",
+ "Pericles, Prince of Tyre",
+ "Richard II",
+ "Richard III",
+ "Romeo and Juliet",
+ "The Comedy of Errors",
+ "The Merchant of Venice",
+ "The Merry Wives of Windsor",
+ "The Taming of the Shrew",
+ "The Tempest",
+ "The Two Noble Kinsmen",
+ "The Winter's Tale",
+ "Timon of Athens ",
+ "Titus Andronicus",
+ "Troilus and Cressida ",
+ "Twelfth Night",
+ "Two Gentlemen of Verona"
+];
+
+var palettes = [
+ "googleColorBlind",
+ "googleStandard",
+ "constantValue",
+ "googleWarm",
+ "googleCool"
+];
+
+var stage = d3.select(".stage");
+
+var palettesStage = d3.select(".palettes");
+
+var palette = palettesStage.selectAll(".palette")
+ .data(palettes)
+ .enter().append("div")
+ .attr("class", "palette");
+
+palette.each(function(d) {
+ d3.select(this).append("div").text(d);
+ var ccs = new TF.ColorScale(17, TF.palettes[d]);
+ var colorSwatches = d3.select(this).selectAll(".color-swatch")
+ .data(d3.range(ccs.numColors))
+ .enter().append("div")
+ .attr("class", "color-swatch")
+ .style("width", 100 / ccs.numColors + "%")
+ .style("background-color", (d) => ccs.internalColorScale(d));
+});
+
+var previousRuns = runs.slice(0, 10).concat(["train", "test", "eval"]);
+function ping() {
+ d3.shuffle(previousRuns);
+ previousRuns = previousRuns.slice(0, -Math.ceil(Math.random() * 3));
+ previousRuns = previousRuns.concat(d3.shuffle(runs).slice(0, Math.floor(Math.random() * 6))).sort();
+ previousRuns = d3.set(previousRuns).values().sort();
+ var ccs = new TF.ColorScale();
+ ccs.domain(previousRuns);
+
+ var color = stage.selectAll(".color")
+ .data(previousRuns, (d) => d);
+
+ color
+ .style("opacity", 1)
+ .style("left", 0)
+ .transition()
+ .delay(200)
+ .duration(300)
+ .style("top", (d, i) => i * 25 + "px");
+
+ var colorEnter = color.enter().append("div")
+ .attr("class", "color")
+ .style("left", "-100px")
+ .style("opacity", 0)
+ .style("top", (d, i) => i * 25 + "px");
+
+ colorEnter
+ .transition()
+ .delay(400)
+ .duration(300)
+ .style("left", "0px")
+ .style("opacity", 1);
+
+ color.exit()
+ .transition()
+ .duration(300)
+ .style("left", "100px")
+ .style("opacity", 0)
+ .remove();
+
+ colorEnter.append("div")
+ .attr("class", "swatch");
+
+ color.select(".swatch").style("background-color", (d) => ccs.getColor(d));
+
+ colorEnter.append("div")
+ .attr("class", "label")
+ .text((d) => d);
+
+ stage.transition().duration(300)
+ .style("height", previousRuns.length * 25 + "px")
+}
+
+ping();
+setInterval(ping, 2000);
+
+
+</script>
diff --git a/tensorflow/tensorboard/lib/js/colorScale/demo/style.css b/tensorflow/tensorboard/lib/js/colorScale/demo/style.css
new file mode 100644
index 0000000000..d2fe1dc294
--- /dev/null
+++ b/tensorflow/tensorboard/lib/js/colorScale/demo/style.css
@@ -0,0 +1,74 @@
+body {
+ font-family: roboto, sans-serif;
+}
+header {
+ /*background-color: hsl(0, 0%, 95%);*/
+ border-bottom: solid 1px rgba(0, 0, 0, 0.1);
+ padding: 60px 0;
+ margin: 0 0 40px 0;
+ z-index: 10;
+ position: relative;
+ color: hsla(0, 0%, 0%, 0.7);
+}
+header h1 {
+ font-size: 36px;
+ font-weight: 700;
+ margin: 0 0 12px;
+ line-height: 1.2em;
+}
+header p {
+ font-size: 22px;
+ line-height: 1.6em;
+ font-weight: 300;
+ margin-bottom: 20px;
+ margin-top: 0;
+}
+.byline {
+ font-weight: 400;
+ font-size: 13px;
+ color: rgba(0, 0, 0, 0.5);
+ display: none;
+}
+.byline .date {
+ margin-left: 12px;
+ padding-left: 12px;
+ border-left: solid 1px #ddd;
+}
+/* Text Styles */
+h3 {
+ color: rgba(0, 0, 0, 0.7);
+ margin-top: 40px;
+}
+a {
+ color: black;
+ text-decoration: none;
+ border-bottom: solid 1px black;
+}
+p {
+ font-weight: 400;
+ font-size: 17px;
+ line-height: 1.8;
+ color: rgba(0, 0, 0, 0.7);
+}
+.trunk {
+ margin-left: auto;
+ margin-right: auto;
+ max-width: 600px;
+}
+.page {
+ margin-left: auto;
+ margin-right: auto;
+ max-width: 900px;
+}
+.screen {
+ margin-left: auto;
+ margin-right: auto;
+}
+
+.data-picker {
+ background: white;
+ padding: 5px 0;
+}
+
+.sticky-fixed .data-picker {
+}
diff --git a/tensorflow/tensorboard/lib/js/colorScale/palettes.ts b/tensorflow/tensorboard/lib/js/colorScale/palettes.ts
new file mode 100644
index 0000000000..76d69779a3
--- /dev/null
+++ b/tensorflow/tensorboard/lib/js/colorScale/palettes.ts
@@ -0,0 +1,54 @@
+module TF {
+ export const palettes = {
+ googleStandard: [
+ "#db4437", //google red 500
+ "#ff7043", //deep orange 400
+ "#f4b400", //google yellow 500
+ "#0f9d58", //google green 500
+ "#00796b", //teal 700
+ "#00acc1", //cyan 600
+ "#4285f4", //google blue 500
+ "#5c6bc0", //indigo 400
+ "#ab47bc" //purple 400
+ ],
+ googleCool: [
+ "#9e9d24", //lime 800
+ "#0f9d58", //google green 500
+ "#00796b", //teal 700
+ "#00acc1", //cyan 600
+ "#4285f4", //google blue 500
+ "#5c6bc0", //indigo 400
+ "#607d8b" //blue gray 500
+ ],
+ googleWarm: [
+ "#795548", //brown 500
+ "#ab47bc", //purple 400
+ "#f06292", //pink 300
+ "#c2185b", //pink 700
+ "#db4437", //google red 500
+ "#ff7043", //deep orange 400
+ "#f4b400" //google yellow 700
+ ],
+ googleColorBlind: [
+ "#c53929", //google red 700
+ "#ff7043", //deep orange 400
+ "#f7cb4d", //google yellow 300
+ "#0b8043", //google green 700
+ "#80deea", //cyan 200
+ "#4285f4", //google blue 500
+ "#5e35b1" //deep purple 600
+ ],
+ //This rainbow palette attempts to keep a constant brightness across hues.
+ constantValue: [
+ "#f44336",
+ "#ffa216",
+ "#c2d22d",
+ "#51b455",
+ "#1ca091",
+ "#505ec4",
+ "#a633ba"
+ ]
+ }
+}
+
+
diff --git a/tensorflow/tensorboard/lib/js/colorScale/test/colorScaleTests.ts b/tensorflow/tensorboard/lib/js/colorScale/test/colorScaleTests.ts
new file mode 100644
index 0000000000..b3709fd520
--- /dev/null
+++ b/tensorflow/tensorboard/lib/js/colorScale/test/colorScaleTests.ts
@@ -0,0 +1,99 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+module TF {
+ let assert = chai.assert;
+
+ describe("ColorScale", function() {
+ let ccs: ColorScale;
+
+ beforeEach(function() {
+ ccs = new ColorScale();
+ });
+
+ it("No collisions with train, eval and test", function() {
+ ccs.domain(["train"]);
+ var trainColor = ccs.getColor("train");
+ ccs.domain(["eval"]);
+ var evalColor = ccs.getColor("eval");
+ ccs.domain(["test"]);
+ var testColor = ccs.getColor("test");
+ assert.notEqual(trainColor, evalColor, testColor);
+ });
+
+ it("Returns consistent colors, given no hash collisions", function() {
+ //These three colors don't have hash collisions
+ ccs.domain(["red", "yellow"]);
+ var firstRedColor = ccs.getColor("red");
+ ccs.domain(["red", "yellow", "blue"]);
+ var secondRedColor = ccs.getColor("red");
+ assert.deepEqual(firstRedColor, secondRedColor);
+ });
+
+ it("A 2-color scale returns the first and last colors of the palette", function() {
+ var twoColorScale = new ColorScale(2, TF.palettes.googleStandard);
+ // No hash collisions with these.
+ twoColorScale.domain(["red", "blue"]);
+ assert.deepEqual(twoColorScale.getColor("blue"), TF.palettes.googleStandard[0]);
+ assert.deepEqual(twoColorScale.getColor("red"), TF.palettes.googleStandard[TF.palettes.googleStandard.length - 1]);
+ })
+
+ //This is testing that when we reset the domain with new colors, the old
+ //domain doesn't influence the new color choices. Basically testing that we
+ //get a fresh slate if we have a new domain. Basically testing that all the
+ //internal bins are reset etc. and we aren't finding collisions with
+ //previous colors.
+ it("Colors don't nudge away from colors from an old domain.", function() {
+ // at 12 breaks, "orange" and "blue" collide.
+ ccs.domain(["red", "blue"]);
+ var firstBlue = ccs.getColor("blue");
+ ccs.domain(["red", "orange"]);
+ var firstOrange = ccs.getColor("orange");
+ assert.deepEqual(firstBlue, firstOrange);
+ });
+
+ it("Nudges all colors, given only one base color", function() {
+ var ccsWithOneColor = new ColorScale(1);
+ ccsWithOneColor.domain(["one", "two", "three"]);
+ assert.notEqual(ccsWithOneColor.getColor("one"), ccsWithOneColor.getColor("two"));
+ assert.notEqual(ccsWithOneColor.getColor("two"), ccsWithOneColor.getColor("three"));
+ assert.notEqual(ccsWithOneColor.getColor("one"), ccsWithOneColor.getColor("three"));
+ });
+
+ it("Nudges a color if it has a hash collision", function() {
+ // at 12 breaks, "orange" and "blue" collide.
+ ccs.domain(["red", "blue"]);
+ var firstBlue = ccs.getColor("blue");
+ ccs.domain(["red", "orange"]);
+ var firstOrange = ccs.getColor("orange");
+ ccs.domain(["red", "blue", "orange"]);
+ var secondBlue = ccs.getColor("blue");
+ var secondOrange = ccs.getColor("orange");
+ assert.deepEqual(firstBlue, secondBlue);
+ assert.deepEqual(firstBlue, firstOrange);
+ assert.notEqual(secondBlue, secondOrange);
+ });
+
+ it("Throws an error if string is not in the domain", function() {
+ ccs.domain(["red", "yellow", "green"]);
+ assert.throws(function() {
+ ccs.getColor("not in domain");
+ }, "String was not in the domain.");
+ });
+
+
+ });
+
+}
diff --git a/tensorflow/tensorboard/lib/js/colorScale/test/index.html b/tensorflow/tensorboard/lib/js/colorScale/test/index.html
new file mode 100644
index 0000000000..a6c3c04aa6
--- /dev/null
+++ b/tensorflow/tensorboard/lib/js/colorScale/test/index.html
@@ -0,0 +1,28 @@
+<!-- Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=============================================================================-->
+<!doctype html>
+<html>
+<head>
+ <meta charset="utf-8">
+ <script src="../../web-component-tester/browser.js"></script>
+ <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script>
+ <link rel="import" href="../../tf-imports/d3.html">
+</head>
+<body>
+ <script src="../colorScale.js"></script>
+ <script src="../palettes.js"></script>
+ <script src="colorScaleTests.js"></script>
+</body>
+</html>
diff --git a/tensorflow/tools/dist_test/Dockerfile b/tensorflow/tools/dist_test/Dockerfile
new file mode 100644
index 0000000000..fba23af55d
--- /dev/null
+++ b/tensorflow/tools/dist_test/Dockerfile
@@ -0,0 +1,28 @@
+FROM ubuntu:14.04
+
+MAINTAINER Shanqing Cai <cais@google.com>
+
+RUN apt-get update
+RUN apt-get install -y \
+ bc \
+ curl \
+ python \
+ python-numpy \
+ python-pip
+
+# Install Google Cloud SDK
+RUN curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/install_google_cloud_sdk.bash
+RUN chmod +x install_google_cloud_sdk.bash
+RUN ./install_google_cloud_sdk.bash --disable-prompts --install-dir=/var/gcloud
+
+# Install kubectl
+RUN /var/gcloud/google-cloud-sdk/bin/gcloud components install kubectl
+
+# Install nightly TensorFlow pip
+# TODO(cais): Should we build it locally instead?
+RUN pip install \
+ http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.7.1-cp27-none-linux_x86_64.whl
+
+# Copy test files
+COPY scripts /var/tf-dist-test/scripts
+COPY python /var/tf-dist-test/python
diff --git a/tensorflow/tools/dist_test/Dockerfile.local b/tensorflow/tools/dist_test/Dockerfile.local
new file mode 100644
index 0000000000..4d82904707
--- /dev/null
+++ b/tensorflow/tools/dist_test/Dockerfile.local
@@ -0,0 +1,20 @@
+FROM jpetazzo/dind
+
+MAINTAINER Shanqing Cai <cais@google.com>
+
+RUN apt-get update
+
+RUN apt-get install -y \
+ bc \
+ build-essential \
+ dbus \
+ git \
+ software-properties-common
+
+# Install the latest golang
+RUN wget https://storage.googleapis.com/golang/go1.4.2.linux-amd64.tar.gz
+RUN tar -C /usr/local -xzf go1.4.2.linux-amd64.tar.gz
+RUN rm -f go1.4.2.linux-amd64.tar.gz
+RUN echo 'PATH=/usr/local/go/bin:${PATH}' >> /root/.bashrc
+
+ADD . /var/tf-k8s
diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md
new file mode 100644
index 0000000000..d986900bd6
--- /dev/null
+++ b/tensorflow/tools/dist_test/README.md
@@ -0,0 +1,76 @@
+# Testing Distributed Runtime in TensorFlow
+This folder containers tools and test suites for the GRPC-based distributed
+runtime in TensorFlow.
+
+There are three general modes of testing:
+
+**1) Launch a local Kubernetes (k8s) cluster and run the test suites on it**
+
+For example:
+
+ ./local_test.sh
+
+This option makes use of the docker-in-docker (dind) containers. It requires
+the docker0 network interface to be set to the promiscuous mode on the host:
+
+ sudo ip link set docker0 promisc on
+
+The environment variable "TF_DIST_SERVER_DOCKER_IMAGE" can be used to override
+the Docker image used to generate the TensorFlow GRPC server pods
+("tensorflow/tf_grpc_test_server"). For example:
+
+ export TF_DIST_SERVER_DOCKER_IMAGE=<docker_image_name>
+ ./local_test.sh
+
+**2) Launch a remote k8s cluster on Google Container Engine (GKE) and run the
+test suite on it**
+
+For example:
+
+ export TF_DIST_GCLOUD_PROJECT="tensorflow-testing"
+ export TF_DIST_GCLOUD_COMPUTE_ZONE="us-central1-f"
+ export CONTAINER_CLUSTER="test-cluster-1"
+ export TF_DIST_GCLOUD_KEY_FILE_DIR="/tmp/gcloud-secrets"
+ ./remote_test.sh
+
+Here you specify the Google Compute Engine (GCE) project, compute zone and
+container cluster with the first three environment variables, in that order.
+The environment variable "TF_DIST_GCLOUD_KEY_FILE_DIR" is a directory in which
+the JSON service account key file named "tensorflow-testing.json" is located.
+You can use the flag "--setup-cluster-only" to perform only the cluster setup
+step and skip the testing step:
+
+ ./remote_test.sh --setup-cluster-only
+
+**3) Run the test suite on an existing k8s TensorFlow cluster**
+
+For example:
+
+ export TF_DIST_GRPC_SERVER_URL="grpc://11.22.33.44:2222"
+ ./remote_test.sh
+
+The IP address above is a dummy example. Such a cluster may have been set up
+using the command described at the end of the previous section.
+
+
+**Building the test server Docker image**
+
+To build the Docker image for a test server of TensorFlow distributed runtime,
+run:
+
+ ./build_server.sh <docker_image_name>
+
+
+**Generating configuration file for TensorFlow k8s clusters**
+
+The script at "scripts/k8s_tensorflow.py" can be used to generate yaml
+configuration files for a TensorFlow k8s cluster consisting of a number of
+workers and parameter servers. For example:
+
+ scripts/k8s_tensorflow.py \
+ --num_workers 2 \
+ --num_parameter_servers 2 \
+ --grpc_port 2222 \
+ --request_load_balancer \
+ --docker_image "tensorflow/tf_grpc_test_server" \
+ > tf-k8s-with-lb.yaml
diff --git a/tensorflow/tools/dist_test/build_server.sh b/tensorflow/tools/dist_test/build_server.sh
new file mode 100755
index 0000000000..8679bde2dc
--- /dev/null
+++ b/tensorflow/tools/dist_test/build_server.sh
@@ -0,0 +1,44 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Builds the test server for distributed (GRPC) TensorFlow
+#
+# Usage: build_server.sh <docker_image_name>
+#
+# Note that the Dockerfile is located in ./server/ but the docker build should
+# use the current directory as the context.
+
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# Check arguments
+if [[ $# != 1 ]]; then
+ die "Usage: $0 <docker_image_name>"
+fi
+
+DOCKER_IMG_NAME=$1
+
+# Current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Call docker build
+docker build --no-cache -t "${DOCKER_IMG_NAME}" \
+ -f "${DIR}/server/Dockerfile" \
+ "${DIR}"
diff --git a/tensorflow/tools/dist_test/local/Dockerfile b/tensorflow/tools/dist_test/local/Dockerfile
new file mode 100644
index 0000000000..dece508c0d
--- /dev/null
+++ b/tensorflow/tools/dist_test/local/Dockerfile
@@ -0,0 +1,20 @@
+FROM jpetazzo/dind
+
+MAINTAINER Shanqing Cai <cais@google.com>
+
+RUN apt-get update
+
+RUN apt-get install -y \
+ build-essential \
+ git \
+ software-properties-common
+
+# Install the latest golang
+RUN wget https://storage.googleapis.com/golang/go1.4.2.linux-amd64.tar.gz
+RUN tar -C /usr/local -xzf go1.4.2.linux-amd64.tar.gz
+RUN rm -f go1.4.2.linux-amd64.tar.gz
+RUN echo 'PATH=/usr/local/go/bin:${PATH}' >> /root/.bashrc
+
+ADD start_local_k8s_cluster.sh /var/k8s/start_local_k8s_cluster.sh
+ADD ../scripts /var/k8s/dist_test/scripts
+ADD ../python /var/k8s/dist_test/python
diff --git a/tensorflow/tools/dist_test/local/start_local_k8s_service.sh b/tensorflow/tools/dist_test/local/start_local_k8s_service.sh
new file mode 100755
index 0000000000..51f4805ee8
--- /dev/null
+++ b/tensorflow/tools/dist_test/local/start_local_k8s_service.sh
@@ -0,0 +1,118 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Start a Kubernetes (k8s) cluster on the local machine.
+#
+# This script assumes that git, docker, and golang are installed and on
+# the path. It will attempt to install the version of etcd recommended by the
+# kubernetes source.
+#
+# Usage: start_local_k8s_service.sh
+#
+# This script obeys the following environment variables:
+# TF_DIST_K8S_SRC_DIR: Overrides the default directory for k8s source code.
+# TF_DIST_K8S_SRC_BRANCH: Overrides the default branch to run the local k8s
+# cluster with.
+
+
+# Configurations
+K8S_SRC_REPO=https://github.com/kubernetes/kubernetes.git
+K8S_SRC_DIR=${TF_DIST_K8S_SRC_DIR:-/local/kubernetes}
+K8S_SRC_BRANCH=${TF_DIST_K8S_SRC_BRANCH:-release-1.2}
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# Start docker service. Try multiple times if necessary.
+COUNTER=0
+while true; do
+ ((COUNTER++))
+ service docker start
+ sleep 1
+
+ service docker status
+ if [[ $? == "0" ]]; then
+ echo "Docker service started successfully."
+ break;
+ else
+ echo "Docker service failed to start"
+
+ # 23 is the exit code to signal failure to start docker service in the dind
+ # container.
+ exit 23
+
+ fi
+done
+
+# Wait for docker0 net interface to appear
+echo "Waiting for docker0 network interface to appear..."
+while true; do
+ if [[ -z $(netstat -i | grep "^docker0") ]]; then
+ sleep 1
+ else
+ break
+ fi
+done
+echo "docker0 interface has appeared."
+
+# Set docker0 to promiscuous mode
+ip link set docker0 promisc on || \
+ die "FAILED to set docker0 to promiscuous"
+echo "Turned promisc on for docker0"
+
+# Check promiscuous mode of docker0
+netstat -i
+
+umask 000
+if [[ ! -d "${K8S_SRC_DIR}/.git" ]]; then
+ mkdir -p ${K8S_SRC_DIR}
+ git clone ${K8S_SRC_REPO} ${K8S_SRC_DIR} || \
+ die "FAILED to clone k8s source from GitHub from: ${K8S_SRC_REPO}"
+fi
+
+pushd ${K8S_SRC_DIR}
+git checkout ${K8S_SRC_BRANCH} || \
+ die "FAILED to checkout k8s source branch: ${K8S_SRC_BRANCH}"
+git pull origin ${K8S_SRC_BRANCH} || \
+ die "FAILED to pull from k8s source branch: ${K8S_SRC_BRANCH}"
+
+# Create kubectl binary
+
+# Install etcd
+hack/install-etcd.sh
+
+export PATH=$(pwd)/third_party/etcd:${PATH}
+
+# Setup golang
+export PATH=/usr/local/go/bin:${PATH}
+
+echo "etcd path: $(which etcd)"
+echo "go path: $(which go)"
+
+# Create shortcut to kubectl
+echo '#!/bin/bash' > /usr/local/bin/kubectl
+echo "$(pwd)/cluster/kubectl.sh \\" >> /usr/local/bin/kubectl
+echo ' $@' >> /usr/local/bin/kubectl
+chmod +x /usr/local/bin/kubectl
+
+# Bring up local cluster
+export KUBE_ENABLE_CLUSTER_DNS=true
+hack/local-up-cluster.sh
+
+popd
diff --git a/tensorflow/tools/dist_test/local/start_tf_cluster_container.sh b/tensorflow/tools/dist_test/local/start_tf_cluster_container.sh
new file mode 100755
index 0000000000..b8448624ef
--- /dev/null
+++ b/tensorflow/tools/dist_test/local/start_tf_cluster_container.sh
@@ -0,0 +1,91 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Starts a docker-in-docker (dind) container that is capable of running docker
+# service and Kubernetes (k8s) cluster inside.
+#
+# Usage: start_tf_cluster_container.sh <local_k8s_dir> <docker_img_name>
+#
+# local_k8s_dir: Kubernetes (k8s) source directory on the host
+# docker_img_name: Name of the docker image to start
+#
+# In addition, this script obeys the following environment variables:
+# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch
+# TensorFlow (GRPC) servers with
+
+# Parse input arguments
+if [[ $# != "2" ]]; then
+ echo "Usage: $0 <host_k8s_dir> <docker_img_name>"
+ exit 1
+fi
+
+HOST_K8S_DIR=$1
+DOCKER_IMG_NAME=$2
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# Maximum number of tries to start the docker container with docker running
+# inside
+MAX_ATTEMPTS=100
+
+# Map environment variables into the docker-in-docker (dind) container
+DOCKER_ENV=""
+if [[ ! -z "${TF_DIST_SERVER_DOCKER_IMAGE}" ]]; then
+ DOCKER_ENV="-e TF_DIST_SERVER_DOCKER_IMAGE=${TF_DIST_SERVER_DOCKER_IMAGE}"
+fi
+
+# Verify that the promisc (promiscuous mode) flag is set on docker0 network
+# interface
+if [[ -z $(netstat -i | grep "^docker0" | awk '{print $NF}' | grep -o P) ]];
+then
+ die "FAILED: Cannot proceed with dind k8s container creation because "\
+"network interface 'docker0' is not set to promisc on the host."
+fi
+
+# Create cache for k8s source
+if [[ ! -d ${HOST_K8S_DIR} ]]; then
+ umask 000
+ mkdir -p ${HOST_K8S_DIR} || die "FAILED to create directory for k8s source"
+fi
+
+# Attempt to start docker service in docker container.
+# Try multiple times if necessary.
+COUNTER=1
+while true; do
+ ((COUNTER++))
+ docker run --net=host --privileged ${DOCKER_ENV} \
+ -v ${HOST_K8S_DIR}:/local/kubernetes \
+ ${DOCKER_IMG_NAME} \
+ /var/tf-k8s/local/start_local_k8s_service.sh
+
+ if [[ $? == "23" ]]; then
+ if [[ $(echo "${COUNTER}>=${MAX_ATTEMPTS}" | bc -l) == "1" ]]; then
+ echo "Reached maximum number of attempts (${MAX_ATTEMPTS}) "\
+"while attempting to start docker-in-docker for local k8s TensorFlow cluster"
+ exit 1
+ fi
+
+ echo "Docker service failed to start."
+ echo "Will make another attempt (#${COUNTER}) to start it..."
+ sleep 1
+ else
+ break
+ fi
+done
diff --git a/tensorflow/tools/dist_test/local/test_local_tf_cluster.sh b/tensorflow/tools/dist_test/local/test_local_tf_cluster.sh
new file mode 100755
index 0000000000..895a2fe24c
--- /dev/null
+++ b/tensorflow/tools/dist_test/local/test_local_tf_cluster.sh
@@ -0,0 +1,88 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Launch a Kubernetes (k8s) TensorFlow cluster on the local machine and run
+# the distributed test suite.
+#
+# This script assumes that a TensorFlow cluster is already running on the
+# local machine and can be controlled by the "kubectl" binary.
+#
+# Usage: test_local_tf_cluster.sh
+#
+
+export GCLOUD_BIN=/usr/local/bin/gcloud
+export TF_DIST_LOCAL_CLUSTER=1
+
+# TODO(cais): Do not hard-code the numbers of workers and ps
+NUM_WORKERS=2
+NUM_PARAMETER_SERVERS=2
+
+# Get current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Get utility functions
+source "${DIR}/../scripts/utils.sh"
+
+# Wait for the kube-system pods to be running
+KUBECTL_BIN=$(which kubectl)
+if [[ -z ${KUBECTL_BIN} ]]; then
+ die "FAILED to find path to kubectl"
+fi
+
+echo "Waiting for kube-system pods to be all running..."
+echo ""
+
+MAX_ATTEMPTS=360
+COUNTER=0
+while true; do
+ sleep 1
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${MAX_ATTEMPTS}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling attempts while waiting for all pods in "\
+"kube-system to be running in local k8s TensorFlow cluster"
+ fi
+
+ if [[ $(are_all_pods_running "${KUBECTL_BIN}" "kube-system") == "1" ]]; then
+ break
+ fi
+done
+
+# Create the local k8s tf cluster
+${DIR}/../scripts/create_tf_cluster.sh \
+ ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} | \
+ tee /tmp/tf_cluster.log || \
+ die "FAILED to create local tf cluster"
+
+DOCKER_CONTAINER_ID=$(cat /tmp/tf_cluster.log | \
+ grep "Docker container ID" |
+ awk '{print $NF}')
+if [[ -z "${DOCKER_CONTAINER_ID}" ]]; then
+ die "FAILED to determine worker0 Docker container ID"
+fi
+
+export TF_DIST_GRPC_SERVER_URL="grpc://tf-worker0:2222"
+GRPC_ENV="TF_DIST_GRPC_SERVER_URL=${TF_DIST_GRPC_SERVER_URL}"
+
+docker exec \
+ ${DOCKER_CONTAINER_ID} \
+ /bin/bash -c \
+ "${GRPC_ENV} /var/tf-k8s/scripts/dist_test.sh"
+
+if [[ $? != "0" ]]; then
+ die "Test of local k8s TensorFlow cluster FAILED"
+else
+ echo "Test of local k8s TensorFlow cluster PASSED"
+fi
diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh
new file mode 100755
index 0000000000..d47324cbc3
--- /dev/null
+++ b/tensorflow/tools/dist_test/local_test.sh
@@ -0,0 +1,152 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Tests distributed TensorFlow on a locally running TF GRPC cluster.
+#
+# This script peforms the following steps:
+# 1) Build the docker-in-docker (dind) image capable of running docker and
+# Kubernetes (k8s) cluster inside.
+# 2) Run a container from the aforementioned image and start docker service
+# in it
+# 3) Call a script to launch a k8s TensorFlow GRPC cluster inside the container
+# and run the distributed test suite.
+#
+# Usage: local_test.sh [--leave-container-running]
+#
+# Arguments:
+# --leave-container-running: Do not stop the docker-in-docker container after
+# the termination of the tests, e.g., for debugging
+#
+# In addition, this script obeys the following environment variables:
+# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch
+# TensorFlow (GRPC) servers with
+# TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images
+
+
+# Configurations
+DOCKER_IMG_NAME="tensorflow/tf-dist-test-local-cluster"
+LOCAL_K8S_CACHE=${HOME}/kubernetes
+
+# Helper function
+get_container_id_by_image_name() {
+ # Get the id of a container by image name
+ # Usage: get_docker_container_id_by_image_name <img_name>
+
+ echo $(docker ps | grep $1 | awk '{print $1}')
+}
+
+# Current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Get utility functions
+source ${DIR}/scripts/utils.sh
+
+
+# First, make sure that no docker-in-docker container of the same image
+# is already running
+if [[ ! -z $(get_container_id_by_image_name ${DOCKER_IMG_NAME}) ]]; then
+ die "It appears that there is already at least one Docker container "\
+"of image name ${DOCKER_IMG_NAME} running. Please stop it before trying again"
+fi
+
+# Build docker-in-docker image for local k8s cluster
+NO_CACHE_FLAG=""
+if [[ ! -z "${TF_DIST_DOCKER_NO_CACHE}" ]] &&
+ [[ "${TF_DIST_DOCKER_NO_CACHE}" != "0" ]]; then
+ NO_CACHE_FLAG="--no-cache"
+fi
+
+docker build ${NO_CACHE_FLAG} -t ${DOCKER_IMG_NAME} \
+ -f ${DIR}/Dockerfile.local ${DIR}
+
+
+# Attempt to start the docker container with docker, which will run the k8s
+# cluster inside.
+
+# Get current script directory
+CONTAINER_START_LOG=$(mktemp --suffix=.log)
+echo "Log file for starting cluster container: ${CONTAINER_START_LOG}"
+echo ""
+
+${DIR}/local/start_tf_cluster_container.sh \
+ ${LOCAL_K8S_CACHE} \
+ ${DOCKER_IMG_NAME} | \
+ tee ${CONTAINER_START_LOG} &
+
+# Poll start log until the k8s service is started properly or when maximum
+# attempt count is reached.
+MAX_SERVER_POLLING_ATTEMPTS=600
+
+echo "Waiting for docker-in-docker container for local k8s TensorFlow "\
+"cluster to start and launch Kubernetes..."
+
+COUNTER=0
+while true; do
+ sleep 1
+
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>=${MAX_SERVER_POLLING_ATTEMPTS}" | bc -l) == "1" ]]; then
+ die "Reached maximum number of attempts (${MAX_SERVER_POLLING_ATTEMPTS}) "\
+"while waiting for docker-in-docker for local k8s TensorFlow cluster to start"
+ fi
+
+ # Check for hitting max attempt while trying to start docker-in-docker
+ if [[ $(grep -i "Reached maximum number of attempts" \
+ "${CONTAINER_START_LOG}" | wc -l) == "1" ]]; then
+ die "Docker-in-docker container for local k8s TensorFlow cluster "\
+"FAILED to start"
+ fi
+
+ if [[ $(grep -i "Local Kubernetes cluster is running" \
+ "${CONTAINER_START_LOG}" | wc -l) == "1" ]]; then
+ break
+ fi
+done
+
+# Determine the id of the docker-in-docker container
+DIND_ID=$(get_container_id_by_image_name ${DOCKER_IMG_NAME})
+
+echo "Docker-in-docker container for local k8s TensorFlow cluster has been "\
+"started successfully."
+echo "Docker-in-docker container ID: ${DIND_ID}"
+echo "Launching k8s tf cluster and tests in container ${DIND_ID} ..."
+echo ""
+
+# Launch k8s tf cluster in the docker-in-docker container and perform tests
+docker exec ${DIND_ID} \
+ /var/tf-k8s/local/test_local_tf_cluster.sh
+TEST_RES=$?
+
+# Tear down: stop docker-in-docker container
+if [[ $1 != "--leave-container-running" ]]; then
+ echo ""
+ echo "Stopping docker-in-docker container ${DIND_ID}"
+
+ docker stop --time=1 ${DIND_ID} || \
+ echo "WARNING: Failed to stop container ${DIND_ID} !!"
+
+ echo ""
+else
+ echo "Will not terminate DIND container ${DIND_ID}"
+fi
+
+if [[ "${TEST_RES}" != "0" ]]; then
+ die "Test of distributed TensorFlow runtime on docker-in-docker local "\
+"k8s cluster FAILED"
+else
+ echo "Test of distributed TensorFlow runtime on docker-in-docker local "\
+"k8s cluster PASSED"
+fi
diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py
new file mode 100755
index 0000000000..e40aae38c2
--- /dev/null
+++ b/tensorflow/tools/dist_test/python/mnist_replica.py
@@ -0,0 +1,144 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Distributed MNIST training and validation, with model replicas.
+
+A simple softmax model with one hidden layer is defined. The parameters
+(weights and biases) are located on two parameter servers (ps), while the
+ops are defined on a worker node. The TF sessions also run on the worker
+node.
+Multiple invocations of this script can be done in parallel, with different
+values for --worker_index. There should be exactly one invocation with
+--worker_index, which will create a master session that carries out variable
+initialization. The other, non-master, sessions will wait for the master
+session to finish the initialization before proceeding to the training stage.
+
+The coordination between the multpile worker invocations occurs due to
+the definition of the parameters on the same ps devices. The parameter updates
+from one worker is visible to all other workers. As such, the workers can
+perform forward computation and gradient calculation in parallel, which
+should lead to increased training speed for the simple model.
+"""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import sys
+import tempfile
+import time
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+import tensorflow as tf
+from tensorflow.examples.tutorials.mnist import input_data
+
+
+flags = tf.app.flags
+flags.DEFINE_string("data_dir", "/tmp/mnist-data",
+ "Directory for storing mnist data")
+flags.DEFINE_boolean("download_only", False,
+ """Only perform downloading of data; Do not proceed to
+ model definition or training""")
+flags.DEFINE_integer("worker_index", 0,
+ """Worker task index, should be >= 0. worker_index=0 is
+ the master worker task the performs the variable
+ initialization""")
+flags.DEFINE_integer("hidden_units", 100,
+ "Number of units in the hidden layer of the NN")
+flags.DEFINE_integer("train_steps", 50, "Number of training steps")
+flags.DEFINE_integer("batch_size", 100, "Training batch size")
+flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
+flags.DEFINE_string("worker_grpc_url", None,
+ "Worker GRPC URL (e.g., grpc://1.2.3.4:2222, or "
+ "grpc://tf-worker0:2222)")
+FLAGS = flags.FLAGS
+
+IMAGE_PIXELS = 28
+
+if __name__ == "__main__":
+ mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
+ if FLAGS.download_only:
+ sys.exit(0)
+
+ print("Worker GRPC URL: %s" % FLAGS.worker_grpc_url)
+ print("Worker index = %d" % FLAGS.worker_index)
+
+ with tf.Graph().as_default():
+ # Variables of the hidden layer
+ with tf.device("/job:ps/task:0"):
+ hid_w = tf.Variable(
+ tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
+ stddev=1.0 / IMAGE_PIXELS), name="hid_w")
+ hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
+
+ # Variables of the softmax layer
+ with tf.device("/job:ps/task:1"):
+ sm_w = tf.Variable(
+ tf.truncated_normal([FLAGS.hidden_units, 10],
+ stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
+ name="sm_w")
+ sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
+
+ # Ops: located on the worker specified with FLAGS.worker_index
+ with tf.device("/job:worker/task:%d" % FLAGS.worker_index):
+ x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
+ y_ = tf.placeholder(tf.float32, [None, 10])
+
+ hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
+ hid = tf.nn.relu(hid_lin)
+
+ y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
+ cross_entropy = -tf.reduce_sum(y_ *
+ tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
+ train_step = tf.train.AdamOptimizer(
+ FLAGS.learning_rate).minimize(cross_entropy)
+
+ train_dir = tempfile.mkdtemp()
+ print(FLAGS.worker_index)
+ sv = tf.train.Supervisor(logdir=train_dir,
+ is_chief=(FLAGS.worker_index == 0))
+
+ # The chief worker (worker_index==0) session will prepare the session,
+ # while the remaining workers will wait for the preparation to complete.
+ sess = sv.prepare_or_wait_for_session(FLAGS.worker_grpc_url)
+
+ # Perform training
+ time_begin = time.time()
+ print("Training begins @ %f" % time_begin)
+
+ # TODO(cais): terminate when a global step counter reaches FLAGS.train_steps
+ for i in xrange(FLAGS.train_steps):
+ # Training feed
+ batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
+ train_feed = {x: batch_xs,
+ y_: batch_ys}
+
+ sess.run(train_step, feed_dict=train_feed)
+
+ time_end = time.time()
+ print("Training ends @ %f" % time_end)
+ training_time = time_end - time_begin
+ print("Training elapsed time: %f s" % training_time)
+
+ # Validation feed
+ val_feed = {x: mnist.validation.images,
+ y_: mnist.validation.labels}
+ val_xent = sess.run(cross_entropy, feed_dict=val_feed)
+ print("After %d training step(s), validation cross entropy = %g" %
+ (FLAGS.train_steps, val_xent))
+
diff --git a/tensorflow/tools/dist_test/remote_test.sh b/tensorflow/tools/dist_test/remote_test.sh
new file mode 100755
index 0000000000..5f331c4cac
--- /dev/null
+++ b/tensorflow/tools/dist_test/remote_test.sh
@@ -0,0 +1,92 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This is the entry-point script to testing TensorFlow's distributed runtime.
+# It builds a docker image with the necessary gcloud and Kubernetes (k8s) tools
+# installed, and then execute k8s cluster preparation and distributed TensorFlow
+# runs from within a container based on the image.
+#
+# Usage:
+# remote_test.sh [--setup-cluster-only]
+# Arguments:
+# --setup-cluster-only:
+# Setup the TensorFlow k8s cluster only, and do not perform testing of
+# the distributed runtime.
+#
+#
+# If any of the following environment variable has non-empty values, it will
+# be mapped into the docker container to override the default values (see
+# dist_test.sh)
+# TF_DIST_GRPC_SERVER_URL: URL to an existing Tensorflow GRPC server.
+# If set to any non-empty and valid value (e.g.,
+# grpc://1.2.3.4:2222), it will cause the test
+# to bypass the k8s cluster setup and
+# teardown process, and just use the this URL
+# as the master session.
+# TF_DIST_GCLOUD_PROJECT: gcloud project in which the GKE cluster
+# will be created (takes effect only if
+# TF_DIST_GRPC_SERVER_URL is empty, same below)
+# TF_DIST_GCLOUD_COMPUTE_ZONE: gcloud compute zone.
+# TF_DIST_CONTAINER_CLUSTER: name of the GKE cluster
+# TF_DIST_GCLOUD_KEY_FILE_DIR: path to the host directory that contains
+# the gloud service key file
+# "tensorflow-testing.json"
+# TF_DIST_GRPC_PORT: port on which to create the TensorFlow GRPC
+# servers
+# TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images
+
+DOCKER_IMG_NAME="tensorflow/tf-dist-test-client"
+
+# Get current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Prepare environment variables for the docker container
+DOCKER_ENV_FLAGS=""
+if [[ ! -z "$TF_DIST_GRPC_SERVER_URL" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_GRPC_SERVER_URL=${TF_DIST_GRPC_SERVER_URL}"
+fi
+if [[ ! -z "$TF_DIST_GCLOUD_PROJECT" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_GCLOUD_PROJECT=${TF_DIST_GCLOUD_PROJECT}"
+fi
+if [[ ! -z "$TF_DIST_GCLOUD_COMPUTE_ZONE" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_GCLOUD_COMPUTE_ZONE=${TF_DIST_GCLOUD_COMPUTE_ZONE}"
+fi
+if [[ ! -z "$TF_DIST_CONTAINER_CLUSTER" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_CONTAINER_CLUSTER=${TF_DIST_CONTAINER_CLUSTER}"
+fi
+if [[ ! -z "$TF_DIST_GRPC_PORT" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_GRPC_PORT=${TF_DIST_GRPC_PORT}"
+fi
+
+NO_CACHE_FLAG=""
+if [[ ! -z "${TF_DIST_DOCKER_NO_CACHE}" ]] &&
+ [[ "${TF_DIST_DOCKER_NO_CACHE}" != "0" ]]; then
+ NO_CACHE_FLAG="--no-cache"
+fi
+
+docker build ${NO_CACHE_FLAG} \
+ -t ${DOCKER_IMG_NAME} -f "${DIR}/Dockerfile" "${DIR}"
+KEY_FILE_DIR=${TF_DIST_GCLOUD_KEY_FILE_DIR:-"${HOME}/gcloud-secrets"}
+
+docker run -v ${KEY_FILE_DIR}:/var/gcloud/secrets \
+ ${DOCKER_ENV_FLAGS} \
+ ${DOCKER_IMG_NAME} \
+ /var/tf-dist-test/scripts/dist_test.sh $@
diff --git a/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
new file mode 100755
index 0000000000..22c0c43037
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
@@ -0,0 +1,231 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Create a Kubernetes (k8s) cluster of TensorFlow workers
+#
+# Usage:
+# create_tf_cluster.sh <num_workers> <num_parameter_servers>
+#
+# In addition, this script obeys values in the folllowing environment variables:
+# TF_DIST_LOCAL_CLUSTER: create TensorFlow cluster on local machine
+# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch
+# TensorFlow (GRPC) servers with
+# TF_DIST_GCLOUD_PROJECT: gcloud project in which the GKE cluster
+# will be created (valid only if aforementioned
+# TF_DIST_GRPC_SERVER_URL is empty).
+# TF_DIST_GCLOUD_COMPUTE_ZONE: gcloud compute zone.
+# TF_DIST_CONTAINER_CLUSTER: name of the GKE cluster
+# TF_DIST_GCLOUD_KEY_FILE: if non-empty, will override GCLOUD_KEY_FILE
+# TF_DIST_GRPC_PORT: overrides the default port (2222)
+# to run the GRPC servers on
+
+# Configurations
+# gcloud operation timeout (steps)
+GCLOUD_OP_MAX_STEPS=360
+
+GRPC_PORT=${TF_DIST_GRPC_PORT:-2222}
+
+DEFAULT_GCLOUD_BIN=/var/gcloud/google-cloud-sdk/bin/gcloud
+GCLOUD_KEY_FILE=${TF_DIST_GCLOUD_KEY_FILE:-\
+"/var/gcloud/secrets/tensorflow-testing.json"}
+GCLOUD_PROJECT=${TF_DIST_GCLOUD_PROJECT:-"tensorflow-testing"}
+
+GCLOUD_COMPUTE_ZONE=${TF_DIST_GCLOUD_COMPUTE_ZONE:-"us-central1-f"}
+CONTAINER_CLUSTER=${TF_DIST_CONTAINER_CLUSTER:-"test-cluster"}
+
+SERVER_DOCKER_IMAGE=${TF_DIST_SERVER_DOCKER_IMAGE:-\
+"tensorflow/tf_grpc_test_server"}
+
+# Get current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Get utility functions
+source "${DIR}/utils.sh"
+
+# Check input arguments
+if [[ $# != 2 ]]; then
+ die "Usage: $0 <num_workers> <num_parameter_servers>"
+fi
+
+NUM_WORKERS=$1
+NUM_PARAMETER_SERVERS=$2
+
+# Verify port string
+if [[ -z $(echo "${GRPC_PORT}" | grep -E "^[0-9]{1,5}") ]]; then
+ die "Invalid GRPC port: \"${GRPC_PORT}\""
+fi
+echo "GRPC port to be used when creating the k8s TensorFlow cluster: "\
+"${GRPC_PORT}"
+
+if [[ -z "${TF_DIST_LOCAL_CLUSTER}" ]] ||
+ [[ "${TF_DIST_LOCAL_CLUSTER}" == "0" ]]; then
+ IS_LOCAL_CLUSTER="0"
+else
+ IS_LOCAL_CLUSTER="1"
+fi
+
+if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
+ # Locate gcloud binary path
+ GCLOUD_BIN=$(which gcloud)
+ if [[ -z "${GCLOUD_BIN}" ]]; then
+ GCLOUD_BIN="${DEFAULT_GCLOUD_BIN}"
+ fi
+
+ if [[ ! -f "${GCLOUD_BIN}" ]]; then
+ die "gcloud binary cannot be found at: ${GCLOUD_BIN}"
+ fi
+ echo "Path to gcloud binary: ${GCLOUD_BIN}"
+
+ # Path to gcloud service key file
+ if [[ ! -f "${GCLOUD_KEY_FILE}" ]]; then
+ die "gcloud service account key file cannot be found at: ${GCLOUD_KEY_FILE}"
+ fi
+ echo "Path to gcloud key file: ${GCLOUD_KEY_FILE}"
+
+ echo "GCLOUD_PROJECT: ${GCLOUD_PROJECT}"
+ echo "GCLOUD_COMPUTER_ZONE: ${GCLOUD_COMPUTE_ZONE}"
+ echo "CONTAINER_CLUSTER: ${CONTAINER_CLUSTER}"
+
+ # Activate gcloud service account
+ "${GCLOUD_BIN}" auth activate-service-account --key-file "${GCLOUD_KEY_FILE}"
+
+ # Set gcloud project
+ "${GCLOUD_BIN}" config set project "${GCLOUD_PROJECT}"
+
+ # Set compute zone
+ "${GCLOUD_BIN}" config set compute/zone "${GCLOUD_COMPUTE_ZONE}"
+
+ # Set container cluster
+ "${GCLOUD_BIN}" config set container/cluster "${CONTAINER_CLUSTER}"
+
+ # Get container cluster credentials
+ "${GCLOUD_BIN}" container clusters get-credentials "${CONTAINER_CLUSTER}"
+ if [[ $? != "0" ]]; then
+ die "FAILED to get credentials for container cluster: ${CONTAINER_CLUSTER}"
+ fi
+
+ # If there is any existing tf k8s cluster, delete it first
+ "${DIR}/delete_tf_cluster.sh" "${GCLOUD_OP_MAX_STEPS}"
+fi
+
+# Path to kubectl binary
+KUBECTL_BIN=$(dirname "${GCLOUD_BIN}")/kubectl
+if [[ ! -f "${KUBECTL_BIN}" ]]; then
+ die "kubectl binary cannot be found at: ${KUBECTL_BIN}"
+fi
+echo "Path to kubectl binary: ${KUBECTL_BIN}"
+
+# Create yaml file for k8s TensorFlow cluster creation
+# Path to the (Python) script for generating k8s yaml file
+K8S_GEN_TF_YAML="${DIR}/k8s_tensorflow.py"
+if [[ ! -f ${K8S_GEN_TF_YAML} ]]; then
+ die "FAILED to find yaml-generating script at: ${K8S_GEN_TF_YAML}"
+fi
+
+K8S_YAML="/tmp/k8s_tf_lb.yaml"
+rm -f "${K8S_YAML}"
+
+echo ""
+echo "Generating k8s cluster yaml config file with the following settings"
+echo " Server docker image: ${SERVER_DOCKER_IMAGE}"
+echo " Number of workers: ${NUM_WORKERS}"
+echo " Number of parameter servers: ${NUM_PARAMETER_SERVERS}"
+echo " GRPC port: ${GRPC_PORT}"
+echo ""
+
+${K8S_GEN_TF_YAML} \
+ --docker_image "${SERVER_DOCKER_IMAGE}" \
+ --num_workers "${NUM_WORKERS}" \
+ --num_parameter_servers "${NUM_PARAMETER_SERVERS}" \
+ --grpc_port "${GRPC_PORT}" \
+ --request_load_balancer=True \
+ > "${K8S_YAML}" || \
+ die "Generation of the yaml configuration file for k8s cluster FAILED"
+
+if [[ ! -f "${K8S_YAML}" ]]; then
+ die "FAILED to generate yaml file for TensorFlow k8s container cluster"
+else
+ echo "Generated yaml configuration file for k8s TensorFlow cluster: "\
+"${K8S_YAML}"
+fi
+
+# Create tf k8s container cluster
+"${KUBECTL_BIN}" create -f "${K8S_YAML}"
+
+# Wait for external IP of worker services to become available
+get_tf_worker_external_ip() {
+ echo $("${KUBECTL_BIN}" get svc | grep "^tf-worker0" | \
+ awk '{print $3}' | grep -E "[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+")
+}
+
+if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
+ echo "Waiting for external IP of tf-worker0 service to emerge..."
+ echo ""
+
+ COUNTER=0
+ while true; do
+ sleep 1
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${GCLOUD_OP_MAX_STEPS}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling steps while waiting for external IP "\
+"of tf-worker0 service to emerge"
+ fi
+
+ SVC_EXTERN_IP=$(get_tf_worker_external_ip)
+
+ if [[ ! -z "${SVC_EXTERN_IP}" ]]; then
+ break
+ fi
+ done
+
+ GRPC_SERVER_URL="grpc://${SVC_EXTERN_IP}:${GRPC_PORT}"
+ echo "GRPC URL of tf-worker0: ${GRPC_SERVER_URL}"
+
+else
+ echo "Waiting for tf pods to be all running..."
+ echo ""
+
+ COUNTER=0
+ while true; do
+ sleep 1
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${GCLOUD_OP_MAX_STEPS}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling steps while waiting for all tf pods to "\
+"be running in local k8s TensorFlow cluster"
+ fi
+
+ PODS_STAT=$(are_all_pods_running "${KUBECTL_BIN}")
+
+ if [[ ${PODS_STAT} == "2" ]]; then
+ # Error has occurred
+ die "Error(s) occurred while tring to launch tf k8s cluster. "\
+"One possible cause is that the Docker image used to launch the cluster is "\
+"invalid: \"${SERVER_DOCKER_IMAGE}\""
+ fi
+
+ if [[ ${PODS_STAT} == "1" ]]; then
+ break
+ fi
+ done
+
+ # Determine the tf-worker0 docker container id
+ WORKER0_ID=$(docker ps | grep "k8s_tf-worker0" | awk '{print $1}')
+ echo "WORKER0 Docker container ID: ${WORKER0_ID}"
+
+fi
+
+
+echo "Cluster setup complete."
diff --git a/tensorflow/tools/dist_test/scripts/delete_tf_cluster.sh b/tensorflow/tools/dist_test/scripts/delete_tf_cluster.sh
new file mode 100755
index 0000000000..0f96b4b57a
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/delete_tf_cluster.sh
@@ -0,0 +1,87 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script checks for any existing TensorFlow worker services, replication
+# controllers and pods in the Kubernetes (k8s) container cluster and delete
+# them if there are any.
+#
+# Usage: delete_tf_cluster [max_steps]
+#
+# max_steps: Maximum number polling steps for kubectl operations
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# Path to kubectl binary
+DEFAULT_KUBECTL_BIN=/var/gcloud/google-cloud-sdk/bin/kubectl
+KUBECTL_BIN=$(which kubectl)
+if [[ -z "${KUBECTL_BIN}" ]]; then
+ KUBECTL_BIN="${DEFAULT_KUBECTL_BIN}"
+fi
+if [[ ! -f "${KUBECTL_BIN}" ]]; then
+ die "kubectl binary cannot be found at: \"${KUBECTL_BIN}\""
+else
+ echo "Path to kubectl binary: ${KUBECTL_BIN}"
+fi
+
+MAX_STEPS=${1:-240}
+
+
+# Helper functions for kubectl workflow
+get_tf_svc_count() {
+ echo $("${KUBECTL_BIN}" get svc | grep "tf-" | wc -l)
+}
+
+get_tf_rc_count() {
+ echo $("${KUBECTL_BIN}" get rc | grep "tf-" | wc -l)
+}
+
+get_tf_pods_count() {
+ echo $("${KUBECTL_BIN}" get pods | grep "tf-" | wc -l)
+}
+
+
+# Delete all running services, replication-controllers and pods, in that order
+ITEMS_TO_DELETE="svc rc pods"
+for ITEM in ${ITEMS_TO_DELETE}; do
+ K8S_ITEM_COUNT=$(get_tf_${ITEM}_count)
+ if [[ ${K8S_ITEM_COUNT} != "0" ]]; then
+ echo "There are currently ${K8S_ITEM_COUNT} tf ${ITEM}(s) running. "
+ echo "Attempting to delete those..."
+
+ "${KUBECTL_BIN}" delete --all ${ITEM}
+
+ # Wait until all are deleted
+ # TODO(cais): Add time out
+ COUNTER=0
+ while true; do
+ sleep 1
+
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${MAX_STEPS}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling steps while trying to delete all tf ${ITEM}"
+ fi
+
+ if [[ $(get_tf_${ITEM}_count) == "0" ]]; then
+ break
+ fi
+ done
+ fi
+
+done
diff --git a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
new file mode 100755
index 0000000000..e0aad2b5c2
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
@@ -0,0 +1,137 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script invokes dist_mnist.py multiple times concurrently to test the
+# TensorFlow's distributed runtime over a Kubernetes (k8s) cluster with the
+# grpc pods and service set up.
+#
+# Usage:
+# dist_mnist_test.sh <worker_grpc_url>
+#
+# worker_grp_url is the IP address or the GRPC URL of the worker of the main
+# worker session, e.g., grpc://1.2.3.4:2222
+
+
+# Configurations
+TIMEOUT=120 # Timeout for MNIST replica sessions
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+if [[ $# != 1 ]]; then
+ die "Usage: $0 <WORKER_GRPC_URL>"
+fi
+WORKER_GRPC_URL=$1
+
+# Verify the validity of the GRPC URL
+if [[ -z $(echo "${WORKER_GRPC_URL}" | \
+ grep -E "^grpc://.+:[0-9]+") ]]; then
+ die "Invalid worker GRPC URL: \"${WORKER_GRPC_URL}\""
+fi
+
+# Current working directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PY_DIR=$(dirname "${DIR}")/python
+
+MNIST_REPLICA="${PY_DIR}/mnist_replica.py"
+
+WKR_LOG_PREFIX="/tmp/worker"
+
+# First, download the data from a single process, to avoid race-condition
+# during data downloading
+timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
+ --download_only=True || \
+ die "Download-only step of MNIST replica FAILED"
+
+# Run a number of workers in parallel
+N_WORKERS=2
+INDICES=""
+IDX=0
+while true; do
+ timeout ${TIMEOUT} \
+ python "${MNIST_REPLICA}" \
+ --worker_grpc_url="${WORKER_GRPC_URL}" \
+ --worker_index=${IDX} 2>&1 > \
+ "${WKR_LOG_PREFIX}${IDX}.log" &
+ # TODO(cais): have each trainer process contact a different worker once
+ # supervisor and sync_replicas etc. are all working in OSS TensorFlow.
+
+ INDICES="${INDICES} ${IDX}"
+
+ ((IDX++))
+ if [[ $(echo "${IDX}==${N_WORKERS}" | bc -l) == "1" ]]; then
+ break
+ fi
+done
+
+# Function for getting final validation cross entropy from worker log files
+get_final_val_xent() {
+ echo $(cat $1 | grep "^After.*validation cross entropy = " | \
+ awk '{print $NF}')
+}
+
+# Poll until all final validation cross entropy values become available or
+# operation times out
+COUNTER=0
+while true; do
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${TIMEOUT}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling steps while polling for final validation "\
+"cross entropies from all workers"
+ fi
+
+ N_AVAIL=0
+ VAL_XENTS=""
+ for N in ${INDICES}; do
+ VAL_XENT=$(get_final_val_xent "${WKR_LOG_PREFIX}${N}.log")
+ if [[ ! -z ${VAL_XENT} ]]; then
+ ((N_AVAIL++))
+ VAL_XENTS="${VAL_XENTS} ${VAL_XENT}"
+ fi
+ done
+
+ if [[ "${N_AVAIL}" == "2" ]]; then
+ # Print out the content of the log files
+ for M in ${INDICES}; do
+ echo "==================================================="
+ echo "=== Log file from worker ${M} ==="
+ cat "${WKR_LOG_PREFIX}${M}.log"
+ echo "==================================================="
+ echo ""
+ done
+
+ break
+ else
+ sleep 1
+ fi
+done
+
+# Sanity check on the validation entropies
+# TODO(cais): In addition to this basic sanity check, we could run the training
+# with 1 and 2 workers, each for a few times and use scipy.stats to do a t-test
+# to verify tha tthe 2-worker training gives significantly lower final cross
+# entropy
+VAL_XENTS=(${VAL_XENTS})
+for N in ${INDICES}; do
+ echo "Final validation cross entropy from worker${N}: ${VAL_XENTS[N]}"
+ if [[ $(echo "${VAL_XENTS[N]}>0" | bc -l) != "1" ]]; then
+ die "Sanity checks on the final validation cross entropy values FAILED"
+ fi
+
+done
diff --git a/tensorflow/tools/dist_test/scripts/dist_test.sh b/tensorflow/tools/dist_test/scripts/dist_test.sh
new file mode 100755
index 0000000000..f8ade7eff8
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/dist_test.sh
@@ -0,0 +1,118 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Performs tests of TensorFlow's distributed runtime over a Kubernetes (k8s)
+# container cluster.
+#
+# This script tears down any existing TensorFlow cluster, consisting of
+# services, replication controllers and pods, before creating a new cluster.
+# The cluster containers a number of parameter server services and a number of
+# worker services. The paramater servers will hold parameters of the ML model,
+# e.g., weights and biases of the NN layers, while the workers will hold the
+# TensorFlow ops.
+#
+# Usage:
+# dist_test.sh [--setup-cluster-only]
+#
+# --setup-cluster-only lets the script only set up the k8s container network
+#
+# This script obeys values in the folllowing environment variables:
+# TF_DIST_GRPC_SERVER_URL: If it is set to a valid grpc server url (e.g.,
+# (grpc://1.2.3.4:2222), the script will bypass
+# the cluster setup and teardown processes and
+# just use this URL.
+
+
+# Configurations
+NUM_WORKERS=2 # Number of worker container
+NUM_PARAMETER_SERVERS=2 # Number of parameter servers
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# gcloud operation timeout (steps)
+GCLOUD_OP_MAX_STEPS=240
+
+GRPC_SERVER_URL=${TF_DIST_GRPC_SERVER_URL}
+
+# Report gcloud / GKE parameters
+echo "GRPC_SERVER_URL: ${GRPC_SERVER_URL}"
+
+# Get current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Locate path to kubectl binary
+TEARDOWN_WHEN_DONE=1
+if [[ ! -z "${GRPC_SERVER_URL}" ]]; then
+ TEARDOWN_WHEN_DONE=0
+ # Verify the validity of the GRPC URL
+ if [[ -z $(echo "${GRPC_SERVER_URL}" | \
+ grep -E "^grpc://.+:[0-9]+") ]]; then
+ die "Invalid GRPC_SERVER_URL: \"${GRPC_SERVER_URL}\""
+ else
+ echo "The preset GRPC_SERVER_URL appears to be valid: ${GRPC_SERVER_URL}"
+ echo "Will bypass the TensorFlow k8s cluster setup and teardown process"
+ echo ""
+ fi
+else
+ TMP=$(mktemp)
+ "${DIR}/create_tf_cluster.sh" ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} 2>&1 | \
+ tee "${TMP}" || \
+ die "Creation of TensorFlow k8s cluster FAILED"
+
+ GRPC_SERVER_URL=$(cat ${TMP} | grep "GRPC URL of tf-worker0: .*" | \
+ awk '{print $NF}')
+ if [[ -z "${GRPC_SERVER_URL}" ]]; then
+ die "FAILED to determine GRPC server URL"
+ fi
+ rm -f ${TMP}
+
+ if [[ $1 == "--setup-cluster-only" ]]; then
+ echo "Skipping testing of distributed runtime due to "\
+"option flag --setup-cluster-only"
+ exit 0
+ fi
+fi
+
+# Invoke script to perform distributed MNIST training
+MNIST_DIST_TEST_BIN="${DIR}/dist_mnist_test.sh"
+if [[ ! -f "${MNIST_DIST_TEST_BIN}" ]]; then
+ die "FAILED to find distributed mnist client test script at "\
+"${MNIST_DIST_TEST_BIN}"
+fi
+
+echo "Performing distributed MNIST training through grpc session @ "\
+"${GRPC_SERVER_URL}..."
+
+"${MNIST_DIST_TEST_BIN}" "${GRPC_SERVER_URL}"
+
+if [[ $? == "0" ]]; then
+ echo "MNIST-replica test PASSED"
+else
+ die "MNIST-replica test FAILED"
+fi
+
+# Tear down current k8s TensorFlow cluster
+if [[ "${TEARDOWN_WHEN_DONE}" == "1" ]]; then
+ echo "Tearing down k8s TensorFlow cluster..."
+ "${DIR}/delete_tf_cluster.sh" "${GCLOUD_OP_MAX_STEPS}" && \
+ echo "Cluster tear-down SUCCEEDED" || \
+ die "Cluster tear-down FAILED"
+fi
+echo "SUCCESS: Test of distributed TensorFlow runtime PASSED"
diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
new file mode 100755
index 0000000000..e3fde2180a
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
@@ -0,0 +1,245 @@
+#!/usr/bin/python
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generates YAML configuration files for distributed Tensorflow workers.
+
+The workers will be run in a Kubernetes (k8s) container cluster.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+# Note: It is intentional that we do not import tensorflow in this script. The
+# machine that launches a TensorFlow k8s cluster does not have to have the
+# Python package of TensorFlow installed on it.
+
+
+DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server'
+DEFAULT_PORT = 2222
+
+# TODO(cais): Consider adding resource requests/limits to the pods.
+WORKER_RC = (
+ """apiVersion: v1
+kind: ReplicationController
+metadata:
+ name: tf-worker{worker_id}
+spec:
+ replicas: 1
+ template:
+ metadata:
+ labels:
+ tf-worker: "{worker_id}"
+ spec:
+ containers:
+ - name: tf-worker{worker_id}
+ image: {docker_image}
+ args:
+ - --cluster_spec={cluster_spec}
+ - --job_name=worker
+ - --task_id={worker_id}
+ ports:
+ - containerPort: {port}
+""")
+WORKER_SVC = (
+ """apiVersion: v1
+kind: Service
+metadata:
+ name: tf-worker{worker_id}
+ labels:
+ tf-worker: "{worker_id}"
+spec:
+ ports:
+ - port: {port}
+ targetPort: {port}
+ selector:
+ tf-worker: "{worker_id}"
+""")
+WORKER_LB_SVC = (
+ """apiVersion: v1
+kind: Service
+metadata:
+ name: tf-worker{worker_id}
+ labels:
+ tf-worker: "{worker_id}"
+spec:
+ type: LoadBalancer
+ ports:
+ - port: {port}
+ selector:
+ tf-worker: "{worker_id}"
+""")
+PARAM_SERVER_RC = (
+ """apiVersion: v1
+kind: ReplicationController
+metadata:
+ name: tf-ps{param_server_id}
+spec:
+ replicas: 1
+ template:
+ metadata:
+ labels:
+ tf-ps: "{param_server_id}"
+ spec:
+ containers:
+ - name: tf-ps{param_server_id}
+ image: {docker_image}
+ args:
+ - --cluster_spec={cluster_spec}
+ - --job_name=ps
+ - --task_id={param_server_id}
+ ports:
+ - containerPort: {port}
+""")
+PARAM_SERVER_SVC = (
+ """apiVersion: v1
+kind: Service
+metadata:
+ name: tf-ps{param_server_id}
+ labels:
+ tf-ps: "{param_server_id}"
+spec:
+ ports:
+ - port: {port}
+ selector:
+ tf-ps: "{param_server_id}"
+""")
+
+
+def main():
+ """Do arg parsing."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_workers',
+ type=int,
+ default=2,
+ help='How many worker pods to run')
+ parser.add_argument('--num_parameter_servers',
+ type=int,
+ default=1,
+ help='How many paramater server pods to run')
+ parser.add_argument('--grpc_port',
+ type=int,
+ default=DEFAULT_PORT,
+ help='GRPC server port (Default: %d)' % DEFAULT_PORT)
+ parser.add_argument('--request_load_balancer',
+ type=bool,
+ default=False,
+ help='To request worker0 to be exposed on a public IP '
+ 'address via an external load balancer, enabling you to '
+ 'run client processes from outside the cluster')
+ parser.add_argument('--docker_image',
+ type=str,
+ default=DEFAULT_DOCKER_IMAGE,
+ help='Override default docker image for the TensorFlow '
+ 'GRPC server')
+ args = parser.parse_args()
+
+ if args.num_workers <= 0:
+ sys.stderr.write('--num_workers must be greater than 0; received %d\n'
+ % args.num_workers)
+ sys.exit(1)
+ if args.num_parameter_servers <= 0:
+ sys.stderr.write(
+ '--num_parameter_servers must be greater than 0; received %d\n'
+ % args.num_parameter_servers)
+ sys.exit(1)
+
+ # Generate contents of yaml config
+ yaml_config = GenerateConfig(args.num_workers,
+ args.num_parameter_servers,
+ args.grpc_port,
+ args.request_load_balancer,
+ args.docker_image)
+ print(yaml_config) # pylint: disable=superfluous-parens
+
+
+def GenerateConfig(num_workers,
+ num_param_servers,
+ port,
+ request_load_balancer,
+ docker_image):
+ """Generate configuration strings."""
+ config = ''
+ for worker in range(num_workers):
+ config += WORKER_RC.format(
+ port=port,
+ worker_id=worker,
+ docker_image=docker_image,
+ cluster_spec=WorkerClusterSpec(num_workers,
+ num_param_servers,
+ port))
+ config += '---\n'
+ if worker == 0 and request_load_balancer:
+ config += WORKER_LB_SVC.format(port=port,
+ worker_id=worker)
+ else:
+ config += WORKER_SVC.format(port=port,
+ worker_id=worker)
+ config += '---\n'
+
+ for param_server in range(num_param_servers):
+ config += PARAM_SERVER_RC.format(
+ port=port,
+ param_server_id=param_server,
+ docker_image=docker_image,
+ cluster_spec=ParamServerClusterSpec(num_workers,
+ num_param_servers,
+ port))
+ config += '---\n'
+ config += PARAM_SERVER_SVC.format(port=port,
+ param_server_id=param_server)
+ config += '---\n'
+
+ return config
+
+
+def WorkerClusterSpec(num_workers,
+ num_param_servers,
+ port):
+ """Generates worker cluster spec."""
+ return ClusterSpec(num_workers, num_param_servers, port)
+
+
+def ParamServerClusterSpec(num_workers,
+ num_param_servers,
+ port):
+ """Generates parameter server spec."""
+ return ClusterSpec(num_workers, num_param_servers, port)
+
+
+def ClusterSpec(num_workers,
+ num_param_servers,
+ port):
+ """Generates general cluster spec."""
+ spec = 'worker|'
+ for worker in range(num_workers):
+ spec += 'tf-worker%d:%d' % (worker, port)
+ if worker != num_workers-1:
+ spec += ';'
+
+ spec += ',ps|'
+ for param_server in range(num_param_servers):
+ spec += 'tf-ps%d:%d' % (param_server, port)
+ if param_server != num_param_servers-1:
+ spec += ';'
+
+ return spec
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow/tools/dist_test/scripts/utils.sh b/tensorflow/tools/dist_test/scripts/utils.sh
new file mode 100644
index 0000000000..bc4485baf0
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/utils.sh
@@ -0,0 +1,56 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Utility functions for dist_test scripts
+
+
+# Print info and exit with code 1
+die() {
+ echo $@
+ exit 1
+}
+
+
+# Determine if all k8s pods in a namespace are all in the "Running" state
+are_all_pods_running() {
+ # Usage: are_all_pods_running <KUBECTL_BIN> [namespace]
+ KUBECTL_BIN=$1
+
+ if [[ -z "$2" ]]; then
+ NS_FLAG=""
+ else
+ NS_FLAG="--namespace=$2"
+ fi
+
+ sleep 1 # Wait for the status to settle
+ NPODS=$("${KUBECTL_BIN}" "${NS_FLAG}" get pods | tail -n +2 | wc -l)
+ NRUNNING=$("${KUBECTL_BIN}" "${NS_FLAG}" get pods | tail -n +2 | \
+ grep "Running" | wc -l)
+ NERR=$("${KUBECTL_BIN}" "${NS_FLAG}" get pods | tail -n +2 | \
+ grep "Err" | wc -l)
+
+ if [[ ${NERR} != "0" ]]; then
+ # "2" signifies that error has occurred
+ echo "2"
+ elif [[ ${NPODS} == ${NRUNNING} ]]; then
+ # "1" signifies that all pods are in Running state
+ echo "1"
+ else
+ # "0" signifies that some pods have not entered Running state, but
+ # no error has occurred
+ echo "0"
+ fi
+}
diff --git a/tensorflow/tools/dist_test/server/Dockerfile b/tensorflow/tools/dist_test/server/Dockerfile
new file mode 100644
index 0000000000..bf384413f1
--- /dev/null
+++ b/tensorflow/tools/dist_test/server/Dockerfile
@@ -0,0 +1,59 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Test server for TensorFlow GRPC server
+#
+# To build the image, use ../build_server.sh
+
+FROM ubuntu:14.04
+
+MAINTAINER Shanqing Cai <cais@google.com>
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y \
+ bc \
+ curl \
+ dnsutils \
+ python-numpy \
+ python-pip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
+ python get-pip.py && \
+ rm get-pip.py
+
+# Install TensorFlow CPU version.
+RUN pip --no-cache-dir install \
+ http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.7.1-cp27-none-linux_x86_64.whl
+
+# Copy files, including the GRPC server binary at
+# server/grpc_tensorflow_server.py
+ADD . /var/tf-k8s
+
+# Download MNIST data for tests
+RUN mkdir -p /tmp/mnist-data
+RUN curl -o /tmp/mnist-data/train-labels-idx1-ubyte.gz \
+ http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
+RUN curl -o /tmp/mnist-data/train-images-idx3-ubyte.gz \
+ http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
+RUN curl -o /tmp/mnist-data/t10k-labels-idx1-ubyte.gz \
+ http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
+RUN curl -o /tmp/mnist-data/t10k-images-idx3-ubyte.gz \
+ http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
+
+# Container entry point
+ENTRYPOINT ["/var/tf-k8s/server/grpc_tensorflow_server.py"]
diff --git a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
new file mode 100755
index 0000000000..b9742112de
--- /dev/null
+++ b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
@@ -0,0 +1,122 @@
+#!/usr/bin/python
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Python-based TensorFlow GRPC server.
+
+Takes input arguments cluster_spec, job_name and task_id, and start a blocking
+TensorFlow GRPC server.
+
+Usage:
+ grpc_tensorflow_server.py --cluster_spec=SPEC --job_name=NAME --task_id=ID
+
+Where:
+ SPEC is <JOB>(,<JOB>)*
+ JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*
+ NAME is a valid job name ([a-z][0-9a-z]*)
+ HOST is a hostname or IP address
+ PORT is a port number
+"""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string("cluster_spec", "",
+ """Cluster spec: SPEC.
+ SPEC is <JOB>(,<JOB>)*,"
+ JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*,"
+ NAME is a valid job name ([a-z][0-9a-z]*),"
+ HOST is a hostname or IP address,"
+ PORT is a port number."
+E.g., local|localhost:2222;localhost:2223, ps|ps0:2222;ps1:2222""")
+tf.app.flags.DEFINE_string("job_name", "", "Job name: e.g., local")
+tf.app.flags.DEFINE_integer("task_id", 0, "Task index, e.g., 0")
+tf.app.flags.DEFINE_boolean("verbose", False, "Verbose mode")
+
+
+def parse_cluster_spec(cluster_spec, cluster):
+ """Parse content of cluster_spec string and inject info into cluster protobuf.
+
+ Args:
+ cluster_spec: cluster specification string, e.g.,
+ "local|localhost:2222;localhost:2223"
+ cluster: cluster protobuf.
+
+ Raises:
+ ValueError: if the cluster_spec string is invalid.
+ """
+
+ job_strings = cluster_spec.split(",")
+
+ for job_string in job_strings:
+ job_def = cluster.job.add()
+
+ if job_string.count("|") != 1:
+ raise ValueError("Not exactly one instance of '|' in cluster_spec")
+
+ job_name = job_string.split("|")[0]
+
+ if not job_name:
+ raise ValueError("Empty job_name in cluster_spec")
+
+ job_def.name = job_name
+
+ if FLAGS.verbose:
+ print("Added job named \"%s\"" % job_name)
+
+ job_tasks = job_string.split("|")[1].split(";")
+ for i in range(len(job_tasks)):
+ if not job_tasks[i]:
+ raise ValueError("Empty job_task string at position %d" % i)
+
+ job_def.tasks[i] = job_tasks[i]
+
+ if FLAGS.verbose:
+ print(" Added task \"%s\" to job \"%s\"" % (job_tasks[i], job_name))
+
+
+def main(unused_args):
+ # Create Protobuf ServerDef
+ server_def = tf.ServerDef(protocol="grpc")
+
+ # Cluster info
+ parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster)
+
+ # Job name
+ if not FLAGS.job_name:
+ raise ValueError("Empty job_name")
+ server_def.job_name = FLAGS.job_name
+
+ # Task index
+ if FLAGS.task_id < 0:
+ raise ValueError("Invalid task_id: %d" % FLAGS.task_id)
+ server_def.task_index = FLAGS.task_id
+
+ # Create GrpcServer instance
+ server = tf.GrpcServer(server_def)
+
+ # join() is blocking, unlike start()
+ server.join()
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensorflow/user_ops/BUILD b/tensorflow/user_ops/BUILD
index 5bd935000d..cb279e0fc4 100644
--- a/tensorflow/user_ops/BUILD
+++ b/tensorflow/user_ops/BUILD
@@ -30,7 +30,10 @@ py_tests(
name = "ackermann_test",
srcs = ["ackermann_test.py"],
data = [":ackermann_op.so"],
- tags = ["notsan"],
+ tags = [
+ "noasan",
+ "notsan",
+ ],
)
filegroup(
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index f0d0702a48..1e3ea6e0a6 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -13,8 +13,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
native.new_http_archive(
name = "eigen_archive",
- url = "https://bitbucket.org/eigen/eigen/get/db7b61411772.tar.gz",
- sha256 = "832e1e082b91d40ad909a079b98630ce52bd904d1ec0c3cb4cdcd2e24bcf95e6",
+ url = "https://bitbucket.org/eigen/eigen/get/0a13bf3e579d.tar.gz",
+ sha256 = "85c9075a51b56e4e20f3814020c726301b84c5df80fc6072d0056d512eb4bf30",
build_file = path_prefix + "eigen.BUILD",
)
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
index e9d61cc882..236fc00cd4 100644
--- a/third_party/eigen3/Eigen/Cholesky
+++ b/third_party/eigen3/Eigen/Cholesky
@@ -1 +1 @@
-#include "eigen-eigen-db7b61411772/Eigen/Cholesky"
+#include "eigen-eigen-0a13bf3e579d/Eigen/Cholesky"
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
index b4320a07f0..b106690ce8 100644
--- a/third_party/eigen3/Eigen/Core
+++ b/third_party/eigen3/Eigen/Core
@@ -1 +1 @@
-#include "eigen-eigen-db7b61411772/Eigen/Core"
+#include "eigen-eigen-0a13bf3e579d/Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
index f32af39fa6..be72e68a6e 100644
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ b/third_party/eigen3/Eigen/Eigenvalues
@@ -1 +1 @@
-#include "eigen-eigen-db7b61411772/Eigen/Eigenvalues"
+#include "eigen-eigen-0a13bf3e579d/Eigen/Eigenvalues"
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
index 960cec6ad1..d925a388fb 100644
--- a/third_party/eigen3/Eigen/LU
+++ b/third_party/eigen3/Eigen/LU
@@ -1 +1 @@
-#include "eigen-eigen-db7b61411772/Eigen/LU"
+#include "eigen-eigen-0a13bf3e579d/Eigen/LU"
diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR
index fd5aa1c519..8198ac216c 100644
--- a/third_party/eigen3/Eigen/QR
+++ b/third_party/eigen3/Eigen/QR
@@ -1 +1 @@
-#include "eigen-eigen-db7b61411772/Eigen/QR"
+#include "eigen-eigen-0a13bf3e579d/Eigen/QR"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
index f374207f41..51f7e7bddd 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
@@ -1 +1 @@
-#include "eigen-eigen-db7b61411772/unsupported/Eigen/CXX11/Tensor"
+#include "eigen-eigen-0a13bf3e579d/unsupported/Eigen/CXX11/Tensor"