aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-27 23:35:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-27 23:38:42 -0700
commit04e17da7ccd40c739d3a24daa2ad4d94bdd77dfe (patch)
treecf6f8e32e4ce29ae90718d702dea2001006ae392
parentd047a36a9d6d9cc7c0e15a01c4640a4177374827 (diff)
Fix kernel creation bug, due to constant folding always use CPU.
PiperOrigin-RevId: 194636076
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc15
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index fc87f69b8c..dad49cd74f 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -108,10 +108,8 @@ class LayoutOptimizerTest : public GrapplerTest {
TensorShape filter_shape(
{filter_size, filter_size, input_depth, filter_count});
- Tensor filter_data(DT_FLOAT, filter_shape);
- test::FillIota<float>(&filter_data, 1.0f);
Output filter =
- ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
+ ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
int output_height = input_height;
int output_width = input_width;
@@ -143,6 +141,10 @@ class LayoutOptimizerTest : public GrapplerTest {
return tensor;
}
+ TensorShape GetAttrShape(const NodeDef& node) {
+ return TensorShape(node.attr().at({"shape"}).shape());
+ }
+
Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
int batch_size = 16;
int input_height = 8;
@@ -200,9 +202,12 @@ TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
if (gpu_available_) {
+ TensorShape filter_shape = GetAttrShape(*node_map.GetNode("Filter"));
+ Tensor filter_data = GenerateRandomTensor<DT_FLOAT>(filter_shape);
std::vector<string> fetch = {"Fetch"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors_expected =
+ EvaluateNodes(item.graph, fetch, {{"Filter", filter_data}});
+ auto tensors = EvaluateNodes(output, fetch, {{"Filter", filter_data}});
EXPECT_EQ(1, tensors_expected.size());
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);