diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-27 23:35:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-27 23:38:42 -0700 |
commit | 04e17da7ccd40c739d3a24daa2ad4d94bdd77dfe (patch) | |
tree | cf6f8e32e4ce29ae90718d702dea2001006ae392 | |
parent | d047a36a9d6d9cc7c0e15a01c4640a4177374827 (diff) |
Fix kernel creation bug, due to constant folding always use CPU.
PiperOrigin-RevId: 194636076
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer_test.cc | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index fc87f69b8c..dad49cd74f 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -108,10 +108,8 @@ class LayoutOptimizerTest : public GrapplerTest { TensorShape filter_shape( {filter_size, filter_size, input_depth, filter_count}); - Tensor filter_data(DT_FLOAT, filter_shape); - test::FillIota<float>(&filter_data, 1.0f); Output filter = - ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); + ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT); int output_height = input_height; int output_width = input_width; @@ -143,6 +141,10 @@ class LayoutOptimizerTest : public GrapplerTest { return tensor; } + TensorShape GetAttrShape(const NodeDef& node) { + return TensorShape(node.attr().at({"shape"}).shape()); + } + Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) { int batch_size = 16; int input_height = 8; @@ -200,9 +202,12 @@ TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) { test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes); if (gpu_available_) { + TensorShape filter_shape = GetAttrShape(*node_map.GetNode("Filter")); + Tensor filter_data = GenerateRandomTensor<DT_FLOAT>(filter_shape); std::vector<string> fetch = {"Fetch"}; - auto tensors_expected = EvaluateNodes(item.graph, fetch); - auto tensors = EvaluateNodes(output, fetch); + auto tensors_expected = + EvaluateNodes(item.graph, fetch, {{"Filter", filter_data}}); + auto tensors = EvaluateNodes(output, fetch, {{"Filter", filter_data}}); EXPECT_EQ(1, tensors_expected.size()); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); |