aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/tape.h51
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py30
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/__init__.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/__init__.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py7
-rw-r--r--tensorflow/contrib/boosted_trees/python/utils/__init__.py18
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py12
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py13
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/helper.py36
-rw-r--r--tensorflow/core/framework/bfloat16.cc28
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc9
-rw-r--r--tensorflow/core/framework/fake_input.cc12
-rw-r--r--tensorflow/core/framework/function.cc4
-rw-r--r--tensorflow/core/framework/function.h2
-rw-r--r--tensorflow/core/framework/graph_def_util.cc10
-rw-r--r--tensorflow/core/framework/op_def_util.cc24
-rw-r--r--tensorflow/core/framework/op_def_util_test.cc30
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc1
-rw-r--r--tensorflow/core/framework/op_gen_lib.h1
-rw-r--r--tensorflow/core/framework/op_kernel.cc3
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc28
-rw-r--r--tensorflow/core/framework/reader_base.cc6
-rw-r--r--tensorflow/core/framework/register_types.h10
-rw-r--r--tensorflow/core/framework/register_types_traits.h6
-rw-r--r--tensorflow/core/framework/rendezvous_test.cc8
-rw-r--r--tensorflow/core/framework/shape_inference.h2
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc5
-rw-r--r--tensorflow/core/framework/tensor_shape_test.cc3
-rw-r--r--tensorflow/core/framework/tensor_testutil.cc2
-rw-r--r--tensorflow/core/framework/tensor_types.h44
-rw-r--r--tensorflow/core/framework/types_test.cc4
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc2
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h9
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc58
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc149
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer_test.cc3
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.cc124
-rw-r--r--tensorflow/python/eager/backprop.py8
-rw-r--r--tensorflow/python/ops/math_grad.py26
-rw-r--r--tensorflow/python/training/optimizer.py10
-rw-r--r--tensorflow/python/util/deprecation.py5
44 files changed, 466 insertions, 394 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 2b65e38f54..bdb0815d6b 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -18,12 +18,12 @@ limitations under the License.
// Language-agnostic gradient tape. Does not perform backpropagation, just
// maintains the data structures required to do so.
-#include <unordered_map>
-#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -54,11 +54,11 @@ struct OpTapeEntry {
// Map from tensor_id to internally-defined operation-id of the operation which
// produced this tensor. A value of -1 means that the tensor was directly
// watched and not the result of any operation in the tape.
-using TensorTape = std::unordered_map<int64, int64>;
+using TensorTape = gtl::FlatMap<int64, int64>;
// Map from operation-id to tape entry.
template <typename BackwardFunction>
-using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
+using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
// Operations the tape needs to perform on tensors to do backpropagation. Named
// "vspace" because a subset of these are related to a vector space, such as
@@ -159,7 +159,7 @@ class GradientTape {
// Map from tensor id to number of remaining usages (i.e. how many entries in
// the tape refer to it); to aid in tape garbage collection.
- std::unordered_map<int64, int64> tensor_usage_;
+ gtl::FlatMap<int64, int64> tensor_usage_;
// If false, all activations are deleted in the first call to ComputeGradient.
// Else, only when this is destructed.
@@ -286,11 +286,11 @@ struct BackpropInitialState {
// Map from tensor ID to how many references still exist for this tensor in
// the tape.
- std::unordered_map<int64, int64> tensor_usage_counts;
+ gtl::FlatMap<int64, int64> tensor_usage_counts;
// Maps from op ID to how many output tensors of this op still need to have
// their gradients computed.
- std::unordered_map<int64, int64> op_missing_tensor;
+ gtl::FlatMap<int64, int64> op_missing_tensor;
};
// If `persistent_tape` is true, op_tape is not changed and none of the
@@ -301,8 +301,8 @@ struct BackpropInitialState {
template <typename BackwardFunction>
BackpropInitialState<BackwardFunction> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
- OpTape<BackwardFunction>* op_tape,
- const std::unordered_set<int64>& sources_set, bool persistent_tape) {
+ OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
+ bool persistent_tape) {
std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size());
for (auto t : target) {
@@ -362,7 +362,7 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
template <typename BackwardFunction>
std::vector<int64> InitialStack(
const OpTape<BackwardFunction>& op_tape,
- const std::unordered_map<int64, int64>& op_missing_tensor) {
+ const gtl::FlatMap<int64, int64>& op_missing_tensor) {
std::vector<int64> result;
for (auto& op_entry : op_tape) {
if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
@@ -373,13 +373,13 @@ std::vector<int64> InitialStack(
}
template <typename Gradient, typename BackwardFunction>
-Status InitialGradients(
- const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
- const OpTape<BackwardFunction>& op_tape,
- const std::unordered_map<int64, int64>& tensor_usage_counts,
- std::unordered_map<int64, std::vector<Gradient*>>* result) {
+Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<Gradient*> output_gradients,
+ const TensorTape& tensor_tape,
+ const OpTape<BackwardFunction>& op_tape,
+ const gtl::FlatMap<int64, int64>& tensor_usage_counts,
+ gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
@@ -441,13 +441,13 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
gtl::ArraySlice<int64> source_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
- std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
- source_tensor_ids.end());
+ gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
+ source_tensor_ids.end());
BackpropInitialState<BackwardFunction> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
- std::unordered_map<int64, std::vector<Gradient*>> gradients;
+ gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
tensor_tape_, state.op_tape,
state.tensor_usage_counts, &gradients);
@@ -463,7 +463,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
cleanup();
return s;
}
- std::unordered_map<int64, int64> gradients_size;
+ gtl::FlatMap<int64, int64> gradients_size;
// TODO(apassos) multiple threads could be dequeuing from op_stack at the same
// time, for better CPU backprop performance.
VLOG(1) << "Initial stack:";
@@ -472,11 +472,10 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
VLOG(1) << " " << t;
}
}
- std::unordered_map<string, std::unordered_set<int>>
- functions_accept_none_for_indices({
- {"SoftmaxCrossEntropyWithLogits", {1}},
- {"FusedBatchNorm", {1, 2, 3, 4}},
- });
+ gtl::FlatMap<string, gtl::FlatSet<int>> functions_accept_none_for_indices({
+ {"SoftmaxCrossEntropyWithLogits", {1}},
+ {"FusedBatchNorm", {1, 2, 3, 4}},
+ });
while (!op_stack.empty()) {
const int64 op = op_stack.back();
VLOG(1) << "Popped " << op;
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
index eefa7ef0dc..81f58de28c 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
@@ -183,11 +183,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
self.assertEqual(num_quantiles + 1, len(buckets))
self.assertAllEqual([2030, 2040, 2050, 2060], buckets)
- def _testStreamingQuantileBucketsHelper(self, inputs):
+ def _testStreamingQuantileBucketsHelper(
+ self, inputs, num_quantiles=3, expected_buckets=None):
"""Helper to test quantile buckets on different inputs."""
- # Use 3 quantiles, 4 boundaries for simplicity.
- num_quantiles = 3
# set generate_quantiles to True since the test will generate fewer
# boundaries otherwise.
with self.test_session() as sess:
@@ -213,7 +212,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
buckets, are_ready_flush = (sess.run(
[buckets, are_ready_flush]))
self.assertEqual(True, are_ready_flush)
+ # By default, use 3 quantiles, 4 boundaries for simplicity.
self.assertEqual(num_quantiles + 1, len(buckets))
+ if expected_buckets:
+ self.assertAllEqual(buckets, expected_buckets)
def testStreamingQuantileBucketsRepeatedSingleValue(self):
inputs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
@@ -231,6 +233,28 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
inputs = [5]
self._testStreamingQuantileBucketsHelper(inputs)
+ def testStreamingQuantileBucketsEqualDistributionInSequence(self):
+ # Input pattern is of the form [1, 1, 1, 2, 2, 2, 3, 3, 3, ...]
+ ones = 100 * [1]
+ inputs = []
+ for i in range(1, 101):
+ inputs += [i * k for k in ones]
+ # Expect 100 equally spaced buckets.
+ expected_buckets = range(1, 101)
+ self._testStreamingQuantileBucketsHelper(
+ inputs, num_quantiles=99, expected_buckets=expected_buckets)
+
+ def testStreamingQuantileBucketsEqualDistributionInterleaved(self):
+ # Input pattern is of the form [1, 2, 3, 1, 2, 3, 1, 2, 3, ...]
+ sequence = range(1, 101)
+ inputs = []
+ for _ in range(1, 101):
+ inputs += sequence
+ # Expect 100 equally spaced buckets.
+ expected_buckets = range(1, 101)
+ self._testStreamingQuantileBucketsHelper(
+ inputs, num_quantiles=99, expected_buckets=expected_buckets)
+
def testStreamingQuantileBuckets(self):
"""Sets up the quantile summary op test as follows.
diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
index b281a4c6d1..7a5f329b7a 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py
@@ -81,32 +81,32 @@ def _scheduled_stamp_resource_op_runner(batch, stamp):
if not batch:
return
arg_keys = set(batch[0].args.keys())
- grouped_args = collections.defaultdict(list)
+ grouped_args = collections.OrderedDict()
resource_handles = []
# Check that the set of arguments is the same across all the scheduled ops.
for op in batch:
if set(op.args.keys()) != arg_keys:
raise ValueError("Mismatching arguments: %s, %s.", op.args, arg_keys)
for key in arg_keys:
- grouped_args[key].append(op.args[key])
+ grouped_args.setdefault(key, []).append(op.args[key])
resource_handles.append(op.resource_handle)
# Move all the inputs to the op device in one RPC.
- grouped_args = {
- k: _move_tensors(v, resource_handles[0].device)
- for k, v in grouped_args.items()
- }
+ grouped_args = collections.OrderedDict(
+ (k, _move_tensors(v, resource_handles[0].device))
+ for k, v in sorted(grouped_args.items()))
with ops.device(resource_handles[0].device):
return batch[0].op(resource_handles, stamp, **grouped_args)
def run_handler_scheduled_ops(per_handler_ops, stamp, worker_device):
"""Given a dictionary of ops for each handler, runs them in batch."""
- batched_ops = collections.defaultdict(list)
+ batched_ops = collections.OrderedDict()
# Group the ops by their batching_key. Ops that share the same batching key
# can be executed together.
- for handler in sorted(per_handler_ops.keys()):
+ for handler in per_handler_ops.keys():
for op in per_handler_ops[handler]:
- batched_ops[(op.batching_key(), op.batch_runner_fn())].append(op)
+ key = (op.batching_key(), op.batch_runner_fn())
+ batched_ops.setdefault(key, []).append(op)
op_results = {}
for batch in batched_ops.values():
# Run each of the batched ops using its runner.
diff --git a/tensorflow/contrib/boosted_trees/python/training/__init__.py b/tensorflow/contrib/boosted_trees/python/training/__init__.py
new file mode 100644
index 0000000000..b569ac5fdb
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/python/training/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""training module under boosted_trees."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/__init__.py b/tensorflow/contrib/boosted_trees/python/training/functions/__init__.py
new file mode 100644
index 0000000000..c1750117cd
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""functions module under boosted_trees."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index b95956dae2..f0b66dcbbe 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import copy
from tensorflow.contrib import learn
@@ -163,7 +164,7 @@ def extract_features(features, feature_columns):
scope = "gbdt"
with variable_scope.variable_scope(scope):
feature_columns = list(feature_columns)
- transformed_features = {}
+ transformed_features = collections.OrderedDict()
for fc in feature_columns:
# pylint: disable=protected-access
if isinstance(fc, feature_column_lib._EmbeddingColumn):
@@ -681,13 +682,13 @@ class GradientBoostedDecisionTreeModel(object):
control_flow_ops.no_op))
# Update handler stats.
- handler_reads = {}
+ handler_reads = collections.OrderedDict()
for handler in handlers:
handler_reads[handler] = handler.scheduled_reads()
handler_results = batch_ops_utils.run_handler_scheduled_ops(
handler_reads, ensemble_stamp, worker_device)
- per_handler_updates = {}
+ per_handler_updates = collections.OrderedDict()
# Two values per handler. First one is if the handler is active for the
# current layer. The second one is if the handler is going to be active
# for the next layer.
diff --git a/tensorflow/contrib/boosted_trees/python/utils/__init__.py b/tensorflow/contrib/boosted_trees/python/utils/__init__.py
new file mode 100644
index 0000000000..6ceb150c26
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/python/utils/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""utils module under boosted_trees."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index bf029ca5f9..ea8dbf2b46 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -291,6 +291,9 @@ class Mean(Metric):
Args:
values: Tensor with the per-example value.
weights: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ The arguments, for easy chaining.
"""
if weights is None:
self.denom.assign_add(
@@ -302,6 +305,9 @@ class Mean(Metric):
self.denom.assign_add(math_ops.reduce_sum(weights))
values = math_ops.cast(values, self.dtype) * weights
self.numer.assign_add(math_ops.reduce_sum(values))
+ if weights is None:
+ return values
+ return values, weights
def result(self):
t = self.numer / self.denom
@@ -329,7 +335,13 @@ class Accuracy(Mean):
per element of the Tensor.
predictions: Tensor with the predicted label for each example.
weights: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ The arguments, for easy chaining.
"""
matches = math_ops.equal(labels, predictions)
matches = math_ops.cast(matches, dtypes.float64)
super(Accuracy, self).call(matches, weights=weights)
+ if weights is None:
+ return labels, predictions
+ return labels, predictions, weights
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index 9cf34fd9b2..a9ecaa3f8b 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -180,6 +180,19 @@ class MetricsTest(test.TestCase):
m2 = metrics.Mean()
m2(2)
+ def testMetricsChain(self):
+ with context.graph_mode(), self.test_session():
+ m1 = metrics.Mean()
+ m2 = metrics.Mean(name="m2")
+ update_m2 = m2(3.0)
+ update_m2_2 = m2(m1(1.0))
+ m1.init_variables().run()
+ m2.init_variables().run()
+ update_m2.eval()
+ update_m2_2.eval()
+ self.assertAllEqual(m2.result().eval(), 2.0)
+ self.assertAllEqual(m1.result().eval(), 1.0)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py
index 6d8f786223..ef3722ee41 100644
--- a/tensorflow/contrib/seq2seq/python/ops/helper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/helper.py
@@ -73,14 +73,6 @@ class Helper(object):
raise NotImplementedError("batch_size has not been implemented")
@abc.abstractproperty
- def input_shape(self):
- """Shape of each input element in batch.
-
- Returns a `TensorShape`.
- """
- raise NotImplementedError("input_shape has not been implemented")
-
- @abc.abstractproperty
def sample_ids_shape(self):
"""Shape of tensor returned by `sample`, excluding the batch dimension.
@@ -135,7 +127,6 @@ class CustomHelper(Helper):
self._sample_fn = sample_fn
self._next_inputs_fn = next_inputs_fn
self._batch_size = None
- self._input_shape = None
self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or [])
self._sample_ids_dtype = sample_ids_dtype or dtypes.int32
@@ -158,8 +149,6 @@ class CustomHelper(Helper):
(finished, next_inputs) = self._initialize_fn()
if self._batch_size is None:
self._batch_size = array_ops.size(finished)
- if self._input_shape is None:
- self._input_shape = next_inputs.shape[1:]
return (finished, next_inputs)
def sample(self, time, outputs, state, name=None):
@@ -195,7 +184,6 @@ class TrainingHelper(Helper):
"""
with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
inputs = ops.convert_to_tensor(inputs, name="inputs")
- self._inputs = inputs
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)
@@ -211,17 +199,12 @@ class TrainingHelper(Helper):
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
self._batch_size = array_ops.size(sequence_length)
- self._input_shape = inputs.shape[2:]
@property
def batch_size(self):
return self._batch_size
@property
- def input_shape(self):
- return self._input_shape
-
- @property
def sample_ids_shape(self):
return tensor_shape.TensorShape([])
@@ -229,14 +212,6 @@ class TrainingHelper(Helper):
def sample_ids_dtype(self):
return dtypes.int32
- @property
- def inputs(self):
- return self._inputs
-
- @property
- def sequence_length(self):
- return self._sequence_length
-
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
@@ -541,17 +516,12 @@ class GreedyEmbeddingHelper(Helper):
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._start_inputs = self._embedding_fn(self._start_tokens)
- self._input_shape = self._start_inputs.shape[1:]
@property
def batch_size(self):
return self._batch_size
@property
- def input_shape(self):
- return self._input_shape
-
- @property
def sample_ids_shape(self):
return tensor_shape.TensorShape([])
@@ -662,8 +632,6 @@ class InferenceHelper(Helper):
self._sample_dtype = sample_dtype
self._next_inputs_fn = next_inputs_fn
self._batch_size = array_ops.shape(start_inputs)[0]
- self._input_shape = start_inputs.shape[1:]
-
self._start_inputs = ops.convert_to_tensor(
start_inputs, name="start_inputs")
@@ -672,10 +640,6 @@ class InferenceHelper(Helper):
return self._batch_size
@property
- def input_shape(self):
- return self._input_shape
-
- @property
def sample_ids_shape(self):
return self._sample_shape
diff --git a/tensorflow/core/framework/bfloat16.cc b/tensorflow/core/framework/bfloat16.cc
index 0efe43fde2..6025be5170 100644
--- a/tensorflow/core/framework/bfloat16.cc
+++ b/tensorflow/core/framework/bfloat16.cc
@@ -21,13 +21,13 @@ void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
uint16_t* q = reinterpret_cast<uint16_t*>(dst);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- for (; size != 0; p += 2, q++, size--) {
- *q = p[0];
- }
+ for (; size != 0; p += 2, q++, size--) {
+ *q = p[0];
+ }
#else
- for (; size != 0; p += 2, q++, size--) {
- *q = p[1];
- }
+ for (; size != 0; p += 2, q++, size--) {
+ *q = p[1];
+ }
#endif
}
@@ -35,15 +35,15 @@ void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
uint16_t* q = reinterpret_cast<uint16_t*>(dst);
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
- for (; size != 0; p++, q += 2, size--) {
- q[0] = *p;
- q[1] = 0;
- }
+ for (; size != 0; p++, q += 2, size--) {
+ q[0] = *p;
+ q[1] = 0;
+ }
#else
- for (; size != 0; p++, q += 2, size--) {
- q[0] = 0;
- q[1] = *p;
- }
+ for (; size != 0; p++, q += 2, size--) {
+ q[0] = 0;
+ q[1] = *p;
+ }
#endif
}
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 7ab8e3ec18..8bb87483e1 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1356,10 +1356,11 @@ Status ScatterNdUpdateShape(InferenceContext* c) {
Status s = c->Merge(prefix_indices, prefix_updates, &unused);
if (!s.ok()) {
return errors::InvalidArgument(
- "The outer ", num_outer_dims, " dimensions of indices.shape=",
- c->DebugString(indices_shape), " must match the outer ",
- num_outer_dims, " dimensions of updates.shape=",
- c->DebugString(updates_shape), ": ", s.error_message());
+ "The outer ", num_outer_dims,
+ " dimensions of indices.shape=", c->DebugString(indices_shape),
+ " must match the outer ", num_outer_dims,
+ " dimensions of updates.shape=", c->DebugString(updates_shape),
+ ": ", s.error_message());
}
ShapeHandle input_suffix;
diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc
index ad301a8aa4..70d1e20a17 100644
--- a/tensorflow/core/framework/fake_input.cc
+++ b/tensorflow/core/framework/fake_input.cc
@@ -104,8 +104,8 @@ Status FakeInputImpl::AddInputToBuilder() {
Status status = GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts);
if (!status.ok()) {
return errors::InvalidArgument(
- "Could not infer list of types for input '", arg_->name(), "': ",
- status.error_message());
+ "Could not infer list of types for input '", arg_->name(),
+ "': ", status.error_message());
}
SourceList(dts);
return Status::OK();
@@ -131,8 +131,8 @@ Status FakeInputImpl::GetN(int* n) const {
Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n);
if (!status.ok()) {
return errors::InvalidArgument("Could not infer length of input '",
- arg_->name(), "': ",
- status.error_message());
+ arg_->name(),
+ "': ", status.error_message());
}
}
return Status::OK();
@@ -153,8 +153,8 @@ Status FakeInputImpl::GetDataType(DataType* dt) const {
*dt = attr->default_value().type();
} else {
return errors::InvalidArgument("Could not infer type for input '",
- arg_->name(), "': ",
- status.error_message());
+ arg_->name(),
+ "': ", status.error_message());
}
}
} else {
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 0224f25227..d6b576166c 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1264,8 +1264,8 @@ FunctionDef FunctionDefHelper::Define(const string& name,
}
for (const string& a : src.arg) {
const auto iter = ret_index.find(a);
- CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '"
- << src.ret[0] << "' of " << name;
+ CHECK(iter != ret_index.end())
+ << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
n->add_input(iter->second);
}
for (const string& d : src.dep) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 3bb5638cdf..b933ee0b0e 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -656,7 +656,7 @@ bool RegisterOp(const string& op, Creator func);
// Returns OK the gradient creator for the "op" is found (may be
// nullptr if REGISTER_OP_NO_GRADIENT is used.
Status GetOpGradientCreator(const string& op, Creator* creator);
-};
+}; // namespace gradient
// Declare explicit instantiations of GetAttr
#define GET_ATTR(T) \
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
index bd018b7243..1f670535d5 100644
--- a/tensorflow/core/framework/graph_def_util.cc
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -35,8 +35,8 @@ namespace tensorflow {
string SummarizeGraphDef(const GraphDef& graph_def) {
string ret;
- strings::StrAppend(&ret, "versions = ",
- ProtoShortDebugString(graph_def.versions()), ";\n");
+ strings::StrAppend(
+ &ret, "versions = ", ProtoShortDebugString(graph_def.versions()), ";\n");
for (const NodeDef& node : graph_def.node()) {
strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
}
@@ -90,9 +90,9 @@ static Status RemoveNewDefaultAttrsFromNodeDef(
FindAttr(attr.first, *producer_op_def);
if (producer_attr_def == nullptr) {
return errors::InvalidArgument(
- "Attr '", attr.first, "' missing in producer's OpDef: ",
- SummarizeOpDef(*producer_op_def), " but found in node: ",
- SummarizeNodeDef(*node_def));
+ "Attr '", attr.first,
+ "' missing in producer's OpDef: ", SummarizeOpDef(*producer_op_def),
+ " but found in node: ", SummarizeNodeDef(*node_def));
}
// ...and it has the same value as the default in producer,
if (producer_attr_def->has_default_value() &&
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index a4e8add6c4..2d035ab90d 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -170,20 +170,20 @@ const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
return nullptr;
}
-#define VALIDATE(EXPR, ...) \
- do { \
- if (!(EXPR)) { \
- return errors::InvalidArgument(__VA_ARGS__, "; in OpDef: ", \
- ProtoShortDebugString(op_def)); \
- } \
+#define VALIDATE(EXPR, ...) \
+ do { \
+ if (!(EXPR)) { \
+ return errors::InvalidArgument( \
+ __VA_ARGS__, "; in OpDef: ", ProtoShortDebugString(op_def)); \
+ } \
} while (false)
static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
bool output, std::set<string>* names) {
const string suffix = strings::StrCat(
output ? " for output '" : " for input '", arg.name(), "'");
- VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), "Duplicate name: ",
- arg.name());
+ VALIDATE(gtl::InsertIfNotPresent(names, arg.name()),
+ "Duplicate name: ", arg.name());
VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
if (!arg.number_attr().empty()) {
@@ -250,8 +250,8 @@ Status ValidateOpDef(const OpDef& op_def) {
std::set<string> names; // for detecting duplicate names
for (const auto& attr : op_def.attr()) {
// Validate name
- VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), "Duplicate name: ",
- attr.name());
+ VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()),
+ "Duplicate name: ", attr.name());
DataType dt;
VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
attr.name(), " that matches a data type");
@@ -680,8 +680,8 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
if (!penultimate_attr.has_default_value() ||
!new_attr->has_default_value()) {
return errors::InvalidArgument("Missing default for attr '",
- penultimate_attr.name(), "' in op: ",
- SummarizeOpDef(new_op));
+ penultimate_attr.name(),
+ "' in op: ", SummarizeOpDef(new_op));
}
// Actually test that the attr's default value hasn't changed.
diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc
index 28809c11c5..2b9812d4fc 100644
--- a/tensorflow/core/framework/op_def_util_test.cc
+++ b/tensorflow/core/framework/op_def_util_test.cc
@@ -200,10 +200,11 @@ TEST_F(ValidateOpDefTest, BadAttrDefault) {
"default_value { list { s: ['foo'] } } }"),
"Length for attr 'a' of 1 must be at least minimum 2\n\t in Op "
"'BadAttrDef'");
- ExpectFailure(TestBuilder(OpDefBuilder("GoodAttrDef")
- .Attr("a: list(type) >=2 = [DT_STRING]")),
- "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op "
- "'GoodAttrDef'");
+ ExpectFailure(
+ TestBuilder(
+ OpDefBuilder("GoodAttrDef").Attr("a: list(type) >=2 = [DT_STRING]")),
+ "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op "
+ "'GoodAttrDef'");
}
TEST_F(ValidateOpDefTest, NoRefTypes) {
@@ -213,9 +214,10 @@ TEST_F(ValidateOpDefTest, NoRefTypes) {
ExpectFailure(
TestBuilder(OpDefBuilder("BadAttrDef").Attr("T: type = DT_INT32_REF")),
"AttrValue must not have reference type value of int32_ref");
- ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef")
- .Attr("T: list(type) = [DT_STRING_REF]")),
- "AttrValue must not have reference type value of string_ref");
+ ExpectFailure(
+ TestBuilder(
+ OpDefBuilder("BadAttrDef").Attr("T: list(type) = [DT_STRING_REF]")),
+ "AttrValue must not have reference type value of string_ref");
}
TEST_F(ValidateOpDefTest, BadAttrMin) {
@@ -245,9 +247,10 @@ TEST_F(ValidateOpDefTest, BadAttrAllowed) {
TF_EXPECT_OK(TestBuilder(
OpDefBuilder("GoodAttrtude").Attr("x: numbertype = DT_INT32")));
// Not in list of allowed types.
- ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
- .Attr("x: numbertype = DT_STRING")),
- "attr 'x' of string is not in the list of allowed values");
+ ExpectFailure(
+ TestBuilder(
+ OpDefBuilder("BadAttrtude").Attr("x: numbertype = DT_STRING")),
+ "attr 'x' of string is not in the list of allowed values");
ExpectFailure(
TestBuilder(OpDefBuilder("BadAttrtude")
.Attr("x: list(realnumbertype) = [DT_COMPLEX64]")),
@@ -260,9 +263,10 @@ TEST_F(ValidateOpDefTest, BadAttrAllowed) {
TF_EXPECT_OK(TestBuilder(
OpDefBuilder("GoodAttrtude").Attr("x: {'foo', 'bar'} = 'bar'")));
// Not in list of allowed strings.
- ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
- .Attr("x: {'foo', 'bar'} = 'baz'")),
- "attr 'x' of \"baz\" is not in the list of allowed values");
+ ExpectFailure(
+ TestBuilder(
+ OpDefBuilder("BadAttrtude").Attr("x: {'foo', 'bar'} = 'baz'")),
+ "attr 'x' of \"baz\" is not in the list of allowed values");
ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
.Attr("x: list({'foo', 'bar'}) = ['baz']")),
"attr 'x' of \"baz\" is not in the list of allowed values");
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
index 870bbb141b..5f2eb9d99a 100644
--- a/tensorflow/core/framework/op_gen_lib.cc
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -296,7 +296,6 @@ static void RenameInDocs(const string& from, const string& to,
}
}
-
namespace {
// Initializes given ApiDef with data in OpDef.
diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h
index 94fe194a1a..ff38e4b221 100644
--- a/tensorflow/core/framework/op_gen_lib.h
+++ b/tensorflow/core/framework/op_gen_lib.h
@@ -47,7 +47,6 @@ string PBTxtToMultiline(StringPiece pbtxt,
const std::vector<string>& multi_line_fields);
string PBTxtFromMultiline(StringPiece multiline_pbtxt);
-
// Takes a list of files with ApiDefs text protos, and allows you to
// look up the specific ApiDef for any given op.
class ApiDefMap {
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 16bf5c256f..fd2d06be98 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -101,7 +101,8 @@ OpKernel::OpKernel(OpKernelConstruction* context)
// Kernels executing on GPU/SYCL tie very few resources on the CPU where the
// scheduler runs: we consider them as inexpensive.
- expensive_ = context->device_type() != DeviceType(DEVICE_GPU) && context->device_type() != DeviceType(DEVICE_SYCL);
+ expensive_ = context->device_type() != DeviceType(DEVICE_GPU) &&
+ context->device_type() != DeviceType(DEVICE_SYCL);
}
OpKernel::~OpKernel() {}
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
index 94a9d1335a..b53b877f28 100644
--- a/tensorflow/core/framework/op_kernel_test.cc
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -510,10 +510,9 @@ TEST_F(OpKernelBuilderTest, BuilderBoth) {
}
REGISTER_OP("BuildTypeAttr").Attr("T: type");
-REGISTER_KERNEL_BUILDER(Name("BuildTypeAttr")
- .Device(DEVICE_CPU)
- .TypeConstraint<float>("T"),
- DummyKernel);
+REGISTER_KERNEL_BUILDER(
+ Name("BuildTypeAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ DummyKernel);
TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"});
@@ -525,10 +524,9 @@ TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
}
REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)");
-REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr")
- .Device(DEVICE_CPU)
- .TypeConstraint<bool>("T"),
- DummyKernel);
+REGISTER_KERNEL_BUILDER(
+ Name("BuildTypeListAttr").Device(DEVICE_CPU).TypeConstraint<bool>("T"),
+ DummyKernel);
TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
@@ -574,14 +572,12 @@ TEST_F(OpKernelBuilderTest, DuplicateKernel) {
}
REGISTER_OP("DuplicateKernelForT").Attr("T: type");
-REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT")
- .Device(DEVICE_CPU)
- .TypeConstraint<float>("T"),
- DummyKernel);
-REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT")
- .Device(DEVICE_CPU)
- .TypeConstraint<float>("T"),
- DummyKernel);
+REGISTER_KERNEL_BUILDER(
+ Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ DummyKernel);
+REGISTER_KERNEL_BUILDER(
+ Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ DummyKernel);
TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
const NodeDef ndef =
diff --git a/tensorflow/core/framework/reader_base.cc b/tensorflow/core/framework/reader_base.cc
index b8c771a0a1..f84ef0f953 100644
--- a/tensorflow/core/framework/reader_base.cc
+++ b/tensorflow/core/framework/reader_base.cc
@@ -178,9 +178,9 @@ void ReaderBase::Read(QueueInterface* queue, string* key, string* value,
" must set *at_end=true, *produced=true, or return an error.");
}
if (!status.ok() && produced) {
- status = errors::Internal("ReadLocked() for ", name(),
- " set *produced=true *and* returned an error: ",
- status.ToString());
+ status = errors::Internal(
+ "ReadLocked() for ", name(),
+ " set *produced=true *and* returned an error: ", status.ToString());
}
if (status.ok() && at_end) {
status = OnWorkFinishedLocked();
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index e062adffe8..17d16c9b8d 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -211,14 +211,12 @@ limitations under the License.
#define TF_CALL_SYCL_double(m)
#else // TENSORFLOW_SYCL_NO_DOUBLE
#define TF_CALL_SYCL_double(m) TF_CALL_double(m)
-#endif // TENSORFLOW_SYCL_NO_DOUBLE
+#endif // TENSORFLOW_SYCL_NO_DOUBLE
#ifdef __ANDROID_TYPES_SLIM__
-#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m)
+#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m)
#else // __ANDROID_TYPES_SLIM__
-#define TF_CALL_SYCL_NUMBER_TYPES(m) \
- TF_CALL_float(m) \
- TF_CALL_SYCL_double(m)
-#endif // __ANDROID_TYPES_SLIM__
+#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) TF_CALL_SYCL_double(m)
+#endif // __ANDROID_TYPES_SLIM__
#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
diff --git a/tensorflow/core/framework/register_types_traits.h b/tensorflow/core/framework/register_types_traits.h
index c1fe5517c6..ab35c2f095 100644
--- a/tensorflow/core/framework/register_types_traits.h
+++ b/tensorflow/core/framework/register_types_traits.h
@@ -23,7 +23,7 @@ typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h"
@@ -79,7 +79,7 @@ template <>
struct proxy_type_pod<SYCLDevice, 4> {
typedef float type;
};
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
/// If POD we use proxy_type_pod, otherwise this maps to identiy.
template <typename Device, typename T>
@@ -99,7 +99,7 @@ struct proxy_type {
#ifdef TENSORFLOW_USE_SYCL
#define TF_CALL_SYCL_PROXY_TYPES(m) \
TF_CALL_double(m) TF_CALL_float(m) TF_CALL_int32(m)
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc
index 32b8ad784d..de148f0bd3 100644
--- a/tensorflow/core/framework/rendezvous_test.cc
+++ b/tensorflow/core/framework/rendezvous_test.cc
@@ -69,9 +69,7 @@ class LocalRendezvousTest : public ::testing::Test {
rendez_ = NewLocalRendezvous();
}
- ~LocalRendezvousTest() override {
- rendez_->Unref();
- }
+ ~LocalRendezvousTest() override { rendez_->Unref(); }
void SchedClosure(std::function<void()> fn) {
threads_.Schedule(std::move(fn));
@@ -99,8 +97,8 @@ string V(const Tensor& tensor) {
Rendezvous::ParsedKey MakeKey(const string& name) {
string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
- "/job:mnist/replica:1/task:2/device:GPU:0", name,
- FrameAndIter(0, 0));
+ "/job:mnist/replica:1/task:2/device:GPU:0",
+ name, FrameAndIter(0, 0));
Rendezvous::ParsedKey k;
TF_EXPECT_OK(Rendezvous::ParseKey(s, &k));
return k;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index d552ec1693..e3cc848a16 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -32,7 +32,7 @@ class ShapeRefinerTest;
namespace grappler {
class GraphProperties;
class SymbolicShapeManager;
-}
+} // namespace grappler
namespace shape_inference {
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index a9b63ca60e..f48a7b9c47 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -760,7 +760,10 @@ TEST_F(ShapeInferenceTest, MergePrefix) {
NodeDef def;
InferenceContext c(kVersion, &def, MakeOpDef(4, 2),
{
- Unknown(), S({-1, 2}), S({1, -1, 3}), S({2, 4}),
+ Unknown(),
+ S({-1, 2}),
+ S({1, -1, 3}),
+ S({2, 4}),
},
{}, {}, {});
diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc
index d8a9c0bac5..d7517bb311 100644
--- a/tensorflow/core/framework/tensor_shape_test.cc
+++ b/tensorflow/core/framework/tensor_shape_test.cc
@@ -582,7 +582,8 @@ TEST(TensorShapeTest, Large) {
TEST(TensorShapeTest, Overflow) {
int64 one = 1;
std::vector<std::vector<int64>> overflows = {
- {1 << 30, 1 << 30, 1 << 30}, {1 << 5, (one << 60) + 1},
+ {1 << 30, 1 << 30, 1 << 30},
+ {1 << 5, (one << 60) + 1},
};
for (const auto& overflow : overflows) {
TensorShapeProto proto;
diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc
index a8d1412300..8f480d65f2 100644
--- a/tensorflow/core/framework/tensor_testutil.cc
+++ b/tensorflow/core/framework/tensor_testutil.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <cmath>
#include "tensorflow/core/framework/tensor_testutil.h"
+#include <cmath>
namespace tensorflow {
namespace test {
diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h
index 921f88dc0b..a5c1a56bfc 100644
--- a/tensorflow/core/framework/tensor_types.h
+++ b/tensorflow/core/framework/tensor_types.h
@@ -25,7 +25,8 @@ template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
struct TTypes {
// Rank-<NDIMS> tensor of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
- Eigen::Aligned> Tensor;
+ Eigen::Aligned>
+ Tensor;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstTensor;
@@ -33,35 +34,42 @@ struct TTypes {
// Unaligned Rank-<NDIMS> tensor of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType> >
UnalignedTensor;
- typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor,
- IndexType> > UnalignedConstTensor;
+ typedef Eigen::TensorMap<
+ Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType> >
+ UnalignedConstTensor;
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>,
- Eigen::Aligned> Tensor32Bit;
+ Eigen::Aligned>
+ Tensor32Bit;
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
typedef Eigen::TensorMap<
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
- Eigen::Aligned> Scalar;
+ Eigen::Aligned>
+ Scalar;
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
Eigen::RowMajor, IndexType>,
- Eigen::Aligned> ConstScalar;
+ Eigen::Aligned>
+ ConstScalar;
// Unaligned Scalar tensor of scalar type T.
- typedef Eigen::TensorMap<Eigen::TensorFixedSize<
- T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> > UnalignedScalar;
+ typedef Eigen::TensorMap<
+ Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType> >
+ UnalignedScalar;
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
Eigen::RowMajor, IndexType> >
UnalignedConstScalar;
// Rank-1 tensor (vector) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
- Eigen::Aligned> Flat;
+ Eigen::Aligned>
+ Flat;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstFlat;
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
- Eigen::Aligned> Vec;
+ Eigen::Aligned>
+ Vec;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstVec;
@@ -69,16 +77,19 @@ struct TTypes {
// Unaligned Rank-1 tensor (vector) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
UnalignedFlat;
- typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor,
- IndexType> > UnalignedConstFlat;
+ typedef Eigen::TensorMap<
+ Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> >
+ UnalignedConstFlat;
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType> >
UnalignedVec;
typedef Eigen::TensorMap<
- Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> > UnalignedConstVec;
+ Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType> >
+ UnalignedConstVec;
// Rank-2 tensor (matrix) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
- Eigen::Aligned> Matrix;
+ Eigen::Aligned>
+ Matrix;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstMatrix;
@@ -86,8 +97,9 @@ struct TTypes {
// Unaligned Rank-2 tensor (matrix) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType> >
UnalignedMatrix;
- typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor,
- IndexType> > UnalignedConstMatrix;
+ typedef Eigen::TensorMap<
+ Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType> >
+ UnalignedConstMatrix;
};
typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32;
diff --git a/tensorflow/core/framework/types_test.cc b/tensorflow/core/framework/types_test.cc
index 5ddc986563..60f2b4135a 100644
--- a/tensorflow/core/framework/types_test.cc
+++ b/tensorflow/core/framework/types_test.cc
@@ -70,8 +70,8 @@ TEST(TypesTest, kDataTypeRefOffset) {
<< "Extra reference enum "
<< enum_descriptor->FindValueByNumber(e_ref)->name()
<< " without corresponding base enum with value " << e;
- ASSERT_LT(DataType_MAX, e_ref) << "Gap in reference types, missing value for "
- << e_ref;
+ ASSERT_LT(DataType_MAX, e_ref)
+ << "Gap in reference types, missing value for " << e_ref;
// Make sure there are no enums defined after the last regular type before
// the first reference type.
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc
index 1c2c171383..f241922471 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc
@@ -102,7 +102,7 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) {
Costs summary;
TF_ASSERT_OK(estimator.PredictCosts(item.graph, &cost_graph, &summary));
- EXPECT_EQ(Costs::NanoSeconds(9150), summary.execution_time);
+ EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time);
// Make this estimate accurate:
// TODO(http://b/70031255): Accurate estimator for RandomUniform op needed
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index b7eaf8dc63..9e01ec5ff5 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -78,6 +78,9 @@ struct Costs {
MilliSeconds asMilliSeconds() const {
return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
}
+ static NanoSeconds infinity() {
+ return NanoSeconds(std::chrono::nanoseconds::max());
+ }
};
// We store all our times in nanoseconds. If needs be, we can always switch to
@@ -97,6 +100,8 @@ struct Costs {
// requirements of a graph. For example, it might assume that all activations
// are live for all of a graph's execution.
int64 max_memory; // Maximum main memory requirement in bytes over all ops.
+ int64 persistent_memory;
+ int64 temporary_memory;
// These fields are used for TPU-related estimations. They are per-op
// maximums, so each op is evaluated independently, but we want the maximum of
@@ -129,6 +134,8 @@ Costs::Costs() {
compute_time = Duration::zero();
memory_time = Duration::zero();
max_memory = kMemoryUnknown;
+ persistent_memory = kMemoryUnknown;
+ temporary_memory = kMemoryUnknown;
max_per_op_buffers = kMemoryUnknown;
max_per_op_streaming = kMemoryUnknown;
}
@@ -139,6 +146,8 @@ Costs Costs::ZeroCosts() {
costs.compute_time = Duration::zero();
costs.memory_time = Duration::zero();
costs.max_memory = kZeroMemory;
+ costs.persistent_memory = kZeroMemory;
+ costs.temporary_memory = kZeroMemory;
costs.max_per_op_buffers = kZeroMemory;
costs.max_per_op_streaming = kZeroMemory;
return costs;
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 6bc136a3f8..cf317374cf 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -47,6 +47,8 @@ constexpr char kSize[] = "Size";
constexpr char kStopGradient[] = "StopGradient";
constexpr char kPreventGradient[] = "PreventGradient";
+static const Costs::Duration kMinComputeTime(1);
+
namespace {
string GetDataFormat(const OpInfo& op_features) {
@@ -163,18 +165,20 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
- {kPlaceholder, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kRefIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kStopGradient, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kPreventGradient, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kSend, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kConst, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)},
- {kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)},
+
+ {kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kRefIdentity, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kReshape, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
+
+ {kConst, wrap(&OpLevelCostEstimator::PredictVariable)},
+ {kVariable, wrap(&OpLevelCostEstimator::PredictVariable)},
+ {kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)},
{kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
{kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
@@ -429,6 +433,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
costs.execution_time = compute_cost + memory_cost;
}
costs.inaccurate = found_unknown_shapes;
+ costs.max_memory = total_output_size;
return costs;
}
@@ -885,6 +890,30 @@ Costs OpLevelCostEstimator::PredictNoOp(const OpContext& op_context) const {
return Costs::ZeroCosts();
}
+Costs OpLevelCostEstimator::PredictIdentity(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
+ VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
+ Costs result = Costs::ZeroCosts();
+ result.max_memory = CalculateOutputSize(op_features, &result.inaccurate);
+ // Assign the minimum amount of time we can represent to the identity op since
+ // it tends to be really cheap.
+ result.compute_time = kMinComputeTime;
+ result.execution_time = result.compute_time;
+ return result;
+}
+
+Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
+ const auto& op_features = op_context.op_info;
+ VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)";
+ Costs result = Costs::ZeroCosts();
+ result.persistent_memory =
+ CalculateOutputSize(op_features, &result.inaccurate);
+
+ result.compute_time = kMinComputeTime;
+ result.execution_time = result.execution_time;
+ return result;
+}
+
Costs OpLevelCostEstimator::PredictBatchMatMul(
const OpContext& op_context) const {
const auto& op_features = op_context.op_info;
@@ -898,13 +927,12 @@ Costs OpLevelCostEstimator::PredictBatchMatMul(
Costs OpLevelCostEstimator::PredictMetadata(const OpContext& op_context) const {
const auto& op_features = op_context.op_info;
- Costs costs;
+ Costs costs = Costs::ZeroCosts();
costs.max_memory = CalculateOutputSize(op_features, &costs.inaccurate);
// Metadata operations are so cheap we assume they take the minimum amount of
// time we can represent (1 ns).
- costs.execution_time = 1;
- costs.compute_time = 1;
- costs.memory_time = 0;
+ costs.compute_time = kMinComputeTime;
+ costs.execution_time = costs.compute_time;
return costs;
}
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 5f541ccf04..a292e5e97f 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -132,6 +132,8 @@ class OpLevelCostEstimator {
Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
Costs PredictMatMul(const OpContext& op_context) const;
Costs PredictNoOp(const OpContext& op_context) const;
+ Costs PredictIdentity(const OpContext& op_context) const;
+ Costs PredictVariable(const OpContext& op_context) const;
Costs PredictBatchMatMul(const OpContext& op_context) const;
Costs PredictMetadata(const OpContext& op_context) const;
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 791ad34bbe..68de03e81c 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -285,6 +285,7 @@ cc_library(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/costs:graph_memory",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:topological_sort",
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index f537ecc41b..6f95a00fa3 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/graph_view.h"
@@ -828,8 +829,7 @@ static NodeDef* FindSwapOutTrigger(
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& fanout =
view.GetFanout(generator);
NodeDef* trigger = nullptr;
- Costs::NanoSeconds earliest_fanout(
- static_cast<double>(std::numeric_limits<int64>::max() >> 2));
+ Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
for (const auto& port : fanout) {
if (port.node == node) {
@@ -861,6 +861,15 @@ static bool IsSwappable(GraphView::InputPort input) {
return !IsRefType(dtype);
}
+struct MemInfo {
+ GraphView::OutputPort port;
+ int64 memory_used;
+ std::vector<GraphView::InputPort> uses_left;
+ double fitness;
+
+ bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
+};
+
static bool IdentifySwappingCandidates(
Cluster* cluster, GrapplerItem* item, std::unordered_set<string>* skip_list,
std::unordered_map<NodeDef*, SwapInfo>* nodes_to_swap) {
@@ -890,31 +899,56 @@ static bool IdentifySwappingCandidates(
continue;
}
int64 required_savings = mem_usage.used_memory - prop.memory_size();
- // TODO(bsteiner): sort the tensors by how long they're live.
- std::unordered_map<string, Costs::NanoSeconds> execution_times;
+ std::unordered_map<string, Costs::NanoSeconds> op_completion_times;
{
- std::unordered_map<const NodeDef*, Costs::NanoSeconds>
- tmp_execution_times;
- if (!EstimateEarliestExecutionTimes(*item, cluster, &tmp_execution_times)
- .ok()) {
+ VirtualCluster vcluster(cluster->GetDevices());
+ if (!vcluster.Provision().ok()) {
return false;
}
- for (const auto& exec_time : tmp_execution_times) {
- execution_times.emplace(exec_time.first->name(), exec_time.second);
+ if (!vcluster.Initialize(*item).ok()) {
+ return false;
+ }
+ RunMetadata metadata;
+ Status s = vcluster.Run(item->graph, item->feed, item->fetch, &metadata);
+ if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
+ return false;
+ }
+
+ for (const auto& dev_stats : metadata.step_stats().dev_stats()) {
+ for (const auto& node_stats : dev_stats.node_stats()) {
+ Costs::NanoSeconds exec_time =
+ Costs::NanoSeconds(1) +
+ Costs::MicroSeconds(node_stats.all_start_micros() +
+ node_stats.op_end_rel_micros());
+ op_completion_times.emplace(node_stats.node_name(), exec_time);
+ }
}
}
+ Costs::Duration peak_time = -1;
+ for (const auto& live_tensor : mem_usage.live_tensors) {
+ if (live_tensor.allocation_time > peak_time) {
+ peak_time = live_tensor.allocation_time;
+ }
+ }
+
+ std::vector<MemInfo> mem_state;
+
GraphView graph(&item->graph);
for (const auto& live_tensor : mem_usage.live_tensors) {
+ if (live_tensor.memory_used <= 1024) {
+ // Don't bother with small tensors.
+ continue;
+ }
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
Costs::Duration(1e6)) {
// Not enough time to swap.
VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
continue;
}
- if (live_tensor.memory_used <= 1024) {
- // Don't bother with small tensors.
+
+ if (skip_list->find(live_tensor.node) != skip_list->end()) {
continue;
}
GraphView::OutputPort port =
@@ -922,56 +956,77 @@ static bool IdentifySwappingCandidates(
if (!IsSwappable(graph, port)) {
continue;
}
- Costs::NanoSeconds execution_time(-1);
- GraphView::InputPort fanout_to_swap;
+ MemInfo mem_info;
+ mem_info.port = port;
+ mem_info.memory_used = live_tensor.memory_used;
+ Costs::Duration allocation_time = live_tensor.allocation_time;
+ Costs::Duration earliest_use(Costs::Duration::infinity());
+ bool valid = true;
for (GraphView::InputPort input : graph.GetFanout(port)) {
- if (skip_list->find(input.node->name()) != skip_list->end()) {
+ // Get execution time.
+ auto it = op_completion_times.find(input.node->name());
+ if (it == op_completion_times.end()) {
+ valid = false;
+ break;
+ }
+ if (it->second <= peak_time) {
continue;
}
+
+ if (skip_list->find(input.node->name()) != skip_list->end()) {
+ valid = false;
+ break;
+ }
string input_name =
strings::StrCat(input.node->name(), ":", input.port_id);
if (skip_list->find(input_name) != skip_list->end()) {
- continue;
+ valid = false;
+ break;
}
if (!IsSwappable(input)) {
- continue;
- }
- auto it = execution_times.find(input.node->name());
- if (it != execution_times.end()) {
- if (it->second > execution_time) {
- fanout_to_swap = input;
- execution_time = it->second;
- }
+ valid = false;
+ break;
}
+
+ // Set earliest use time that's after peak.
+ mem_info.uses_left.emplace_back(input);
+ earliest_use = std::min(earliest_use, it->second);
}
- // Annotate the fanout to request the tensor to be swapped if it's not
- // already been done.
- bool found = false;
- if (!fanout_to_swap.node) {
- continue;
- }
- auto it = fanout_to_swap.node->attr().find("_swap_to_host");
- if (it != fanout_to_swap.node->attr().end()) {
- const AttrValue& val = it->second;
- for (int port_id : val.list().i()) {
- if (port_id == fanout_to_swap.port_id) {
- found = true;
- break;
- }
- }
+ if (valid && !mem_info.uses_left.empty()) {
+ // Compute the fitness: we need the tensor to be generated way away of
+ // the time of peak memory usage (to ensure there is enough time to swap
+ // it out). We also need to ensure it's used way after the peak time, to
+ // ensure that swapping the tensor back in won't recreate the memory
+ // bottleneck. Last but not least, we want the tensor to have as few
+ // remaining uses as possible.
+ mem_info.fitness = std::pow((earliest_use - peak_time).count(), 2);
+ mem_info.fitness /= std::pow(mem_info.uses_left.size(), 2);
+ mem_info.fitness += std::pow((allocation_time - peak_time).count(), 2);
+ mem_info.fitness = -mem_info.fitness;
+ mem_state.push_back(mem_info);
}
- if (!found) {
+ }
+
+ // Sort by fitness
+ std::sort(mem_state.begin(), mem_state.end());
+
+ for (const MemInfo& mem_info : mem_state) {
+ for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) {
+ VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
+ << fanout_to_swap.port_id << " of tensor "
+ << mem_info.port.node->name() << ":" << mem_info.port.port_id
+ << " of size " << mem_info.memory_used;
+
(*nodes_to_swap)[fanout_to_swap.node].inputs_to_swap.push_back(
fanout_to_swap.port_id);
- required_savings -= live_tensor.memory_used;
- updated_graph = true;
- if (required_savings < 0) {
- break;
- }
+ }
+ required_savings -= mem_info.memory_used;
+ updated_graph = true;
+ if (required_savings < 0) {
+ break;
}
}
}
-
return updated_graph;
}
@@ -1011,7 +1066,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
}
for (auto& swap : nodes_to_swap) {
const NodeDef* node = swap.first;
- std::vector<OpInfo::TensorProperties> props =
+ const std::vector<OpInfo::TensorProperties>& props =
properties.GetInputProperties(node->name());
SwapInfo& swap_info = swap.second;
int64 bytes_to_swap = 0;
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index dd2d20d8d6..f5d9c87992 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -337,8 +337,9 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) {
for (const auto& node : output.node()) {
if (node.name() == "e") {
// The d node isn't swappable.
- EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ(5, node.input_size());
EXPECT_EQ("d", node.input(2));
+ EXPECT_EQ("^swap_out_d_2", node.input(4));
}
}
}
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.cc b/tensorflow/core/kernels/mkl_tfconv_op.cc
deleted file mode 100644
index c35f857cfe..0000000000
--- a/tensorflow/core/kernels/mkl_tfconv_op.cc
+++ /dev/null
@@ -1,124 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifdef INTEL_MKL
-
-#include <algorithm>
-#include <vector>
-#include "tensorflow/core/framework/numeric_op.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/kernels/ops_util.h"
-#include "tensorflow/core/platform/cpu_info.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/util/tensor_format.h"
-
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
-
-namespace tensorflow {
-typedef Eigen::ThreadPoolDevice CPUDevice;
-
-///////////////////////////////////////////////////////////
-// Op kernel
-///////////////////////////////////////////////////////////
-
-template <typename Device, typename T>
-class MklToTfOp : public OpKernel {
- public:
- explicit MklToTfOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
- OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
- has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
- }
-
- void Compute(OpKernelContext* context) override {
- // Check that input tensor is in MKL format.
- const Tensor& input_tensor = MklGetInput(context, 0);
- MklShape input_shape;
- GetMklShape(context, 0, &input_shape);
-
- // if input is already in Tf format, then just copy input tensor to output.
- if (!input_shape.IsMklTensor()) {
- context->set_output(0, input_tensor);
- VLOG(1) << "MKLToTFConversion: No conversion needed, "
- << "copying input to output";
- return;
- }
-
- // Check that input data type is same as operator data type and that it is
- // same as output data type.
- DataType input_data_type = input_type(0);
- DataType output_data_type = output_type(0);
- CHECK_EQ(op_data_type, input_data_type);
- CHECK_EQ(op_data_type, output_data_type);
-
- TensorShape output_shape;
- size_t ndims = input_shape.GetDimension();
- size_t* in_sizes = new size_t[ndims];
- for (size_t i = 0; i < ndims; i++) {
- // Outermost to innermost dimension
- output_shape.AddDim(input_shape.GetSizes()[input_shape.tf_dim_idx(i)]);
- in_sizes[i] = input_shape.GetSizes()[i];
- }
-
- // Allocate output tensor.
- Tensor* output_tensor = NULL;
- OP_REQUIRES_OK(context,
- context->allocate_output(0, output_shape, &output_tensor));
-
- dnnLayout_t output_layout =
- static_cast<dnnLayout_t>(input_shape.GetTfLayout());
- // Execute DNNConversion.
- void* input_buffer =
- static_cast<void*>(const_cast<T*>(input_tensor.flat<T>().data()));
- delete[] in_sizes;
- void* output_buffer =
- static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data()));
- input_shape.GetConvertedFlatData(output_layout, input_buffer,
- output_buffer);
- VLOG(1) << "MKLToTFConversion complete successfully.";
- }
-
- private:
- /// Data format of the operation
- string data_format_str;
-
- /// Data type of the operation
- DataType op_data_type;
-
- /// CPUIDInfo
- bool has_avx512f_ = false;
-};
-
-///////////////////////////////////////////////////////////
-// Register kernel
-///////////////////////////////////////////////////////////
-
-#define REGISTER_CPU(T) \
- REGISTER_KERNEL_BUILDER(Name("_MklToTf") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklToTfOp<CPUDevice, T>);
-
-TF_CALL_float(REGISTER_CPU);
-#undef REGISTER_CPU
-} // namespace tensorflow
-#endif /* INTEL_MKL */
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index a2a3e230bb..d79d1fc0a6 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -734,7 +734,7 @@ def _num_elements(grad):
raise ValueError("`grad` not a Tensor or IndexedSlices.")
-_last_shape_dtype = [None, None]
+_last_zero_shape_dtype = [None, None]
_last_zero = [None]
@@ -748,13 +748,15 @@ def _zeros(shape, dtype):
# TODO(apassos): need to save enough information about variant tensors to do
# a zeros
return None
- if [shape, dtype] != _last_shape_dtype:
- _last_shape_dtype[:] = [shape, dtype]
+ if [shape, dtype] != _last_zero_shape_dtype:
+ _last_zero_shape_dtype[:] = [shape, dtype]
_last_zero[0] = _fast_fill(0, shape, dtype)
return _last_zero[0]
def _ones(shape, dtype):
+ if shape == (): # pylint: disable=g-explicit-bool-comparison
+ return constant_op.constant(1, dtype=dtype)
return _fast_fill(1, shape, dtype)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index bca4c665d2..3cb71eba8c 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -40,15 +40,16 @@ def _SumGrad(op, grad):
"""Gradient for Sum."""
# Fast path for when reducing to a scalar and ndims is known: adds only
# Reshape and Tile ops (and possibly a Shape).
- if op.inputs[0].get_shape().ndims is not None:
+ input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
+ if input_0_shape is not None:
axes = tensor_util.constant_value(op.inputs[1])
if axes is not None:
- rank = op.inputs[0].get_shape().ndims
+ rank = len(input_0_shape)
if np.array_equal(axes, np.arange(rank)): # Reduce all dims.
grad = array_ops.reshape(grad, [1] * rank)
# If shape is not fully defined (but rank is), we use Shape.
- if op.inputs[0].get_shape().is_fully_defined():
- input_shape = op.inputs[0].get_shape().as_list()
+ if None not in input_0_shape:
+ input_shape = input_0_shape
else:
input_shape = array_ops.shape(op.inputs[0])
return [array_ops.tile(grad, input_shape), None]
@@ -96,9 +97,12 @@ def _MinGrad(op, grad):
def _MeanGrad(op, grad):
"""Gradient for Mean."""
sum_grad = _SumGrad(op, grad)[0]
- input_size = op.inputs[0].get_shape().num_elements()
- output_size = op.outputs[0].get_shape().num_elements()
- if input_size is not None and output_size is not None:
+ input_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access
+ output_shape = op.outputs[0]._shape_tuple() # pylint: disable=protected-access
+ if (input_shape is not None and output_shape is not None and
+ None not in input_shape and None not in output_shape):
+ input_size = np.prod(input_shape)
+ output_size = np.prod(output_shape)
factor = input_size // max(output_size, 1)
factor = constant_op.constant(factor, dtype=sum_grad.dtype)
else:
@@ -106,7 +110,7 @@ def _MeanGrad(op, grad):
output_shape = array_ops.shape(op.outputs[0])
factor = _safe_shape_div(
math_ops.reduce_prod(input_shape), math_ops.reduce_prod(output_shape))
- return sum_grad / math_ops.cast(factor, sum_grad.dtype), None
+ return math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), None
@ops.RegisterGradient("Prod")
@@ -330,7 +334,7 @@ def _SquareGrad(op, grad):
# Added control dependencies to prevent 2*x from being computed too early.
with ops.control_dependencies([grad]):
x = math_ops.conj(x)
- return grad * (2.0 * x)
+ return math_ops.multiply(grad, math_ops.multiply(x, 2.0))
@ops.RegisterGradient("Sqrt")
@@ -756,8 +760,12 @@ def _AddGrad(op, grad):
@ops.RegisterGradient("Sub")
def _SubGrad(op, grad):
+ """Gradient for Sub."""
x = op.inputs[0]
y = op.inputs[1]
+ if (isinstance(grad, ops.Tensor) and
+ _ShapesFullySpecifiedAndEqual(x, y, grad)):
+ return grad, -grad
sx = array_ops.shape(x)
sy = array_ops.shape(y)
# pylint: disable=protected-access
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 719b83e5ca..a06b3eada6 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -533,7 +533,15 @@ class Optimizer(object):
else:
with ops.control_dependencies([self._finish(update_ops, "update")]):
with ops.colocate_with(global_step):
- apply_updates = state_ops.assign_add(global_step, 1, name=name)
+ if isinstance(global_step, resource_variable_ops.ResourceVariable):
+ # TODO(apassos): the implicit read in assign_add is slow; consider
+ # making it less so.
+ apply_updates = resource_variable_ops.assign_add_variable_op(
+ global_step.handle,
+ ops.convert_to_tensor(1, dtype=global_step.dtype),
+ name=name)
+ else:
+ apply_updates = state_ops.assign_add(global_step, 1, name=name)
if context.in_graph_mode():
if isinstance(apply_updates, ops.Tensor):
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 8a66f0435a..2110fc64cf 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -22,6 +22,7 @@ import collections
import functools
import re
+from tensorflow.python.eager import context
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import decorator_utils
from tensorflow.python.util import tf_contextlib
@@ -284,7 +285,9 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples,
@functools.wraps(func)
def new_func(*args, **kwargs):
"""Deprecation wrapper."""
- if _PRINT_DEPRECATION_WARNINGS:
+ # TODO(apassos) figure out a way to have reasonable performance with
+ # deprecation warnings and eager mode.
+ if context.in_graph_mode() and _PRINT_DEPRECATION_WARNINGS:
invalid_args = []
named_args = tf_inspect.getcallargs(func, *args, **kwargs)
for arg_name, spec in iter(deprecated_positions.items()):