aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/inputs
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-01-08 18:15:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-08 18:19:40 -0800
commit20db88eec824259764b2eafba377f93ea11776b0 (patch)
tree7abecb1f4df194bd44437fd7e150fc9b08804f7e /tensorflow/core/grappler/inputs
parent2cd288baa4a1c18c14b5572ef54fa29bc18dfce1 (diff)
Ignore nodes that are going to be swapped when computing max memory usage
PiperOrigin-RevId: 181248577
Diffstat (limited to 'tensorflow/core/grappler/inputs')
-rw-r--r--tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc14
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
index 6d25556770..ec54bd5c75 100644
--- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
+++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
@@ -31,8 +31,6 @@ namespace {
GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
bool use_multiple_devices, bool insert_queue,
const std::vector<string>& device_names) {
- CHECK_GE(device_names.size(), width);
-
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -49,13 +47,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
std::vector<Output> this_stage;
for (int j = 0; j < width; j++) {
if (last_stage.size() == 1) {
- Output unary_op =
- Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
- last_stage[0]);
+ Output unary_op = Square(
+ s.WithDevice(
+ device_names[use_multiple_devices ? j % device_names.size()
+ : 0]),
+ last_stage[0]);
this_stage.push_back(unary_op);
} else {
Output combine =
- AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
+ AddN(s.WithDevice(
+ device_names[use_multiple_devices ? j % device_names.size()
+ : 0]),
last_stage);
this_stage.push_back(combine);
}