diff options
author | 2017-04-11 11:12:15 -0800 | |
---|---|---|
committer | 2017-04-11 12:23:53 -0700 | |
commit | a11e669c5a4c855b0b507da97378bc7e03a08f86 (patch) | |
tree | 341ba99ba25ed30c2148dbf91d6a41e86d56ad82 /tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc | |
parent | 740053be5054027bc66c7995c045d40b07479943 (diff) |
Support two data types as inputs in RemoteFusedGraphExecuteOp
Change: 152843285
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc')
-rw-r--r-- | tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc index 580be4b7db..925af1f79e 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc @@ -37,20 +37,34 @@ namespace tensorflow { class RemoteFusedGraphExecuteTest : public OpsTestBase {}; -TEST_F(RemoteFusedGraphExecuteTest, ExecuteAddGraph) { +TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithOneDataType) { + DataTypeVector input_types({DT_FLOAT, DT_FLOAT}); + DataTypeVector output_types({DT_FLOAT}); TF_ASSERT_OK( NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute") .Input(FakeInput(2, DT_FLOAT)) - .Attr("M", 2) - .Attr("N", 1) - .Attr("T", DataTypeToEnum<float>::v()) - .Attr("U", DataTypeToEnum<float>::v()) - .Attr("serialized_graph_transfer_info", "") + .Attr("Tinputs", input_types) + .Attr("Toutputs", output_types) + .Attr("serialized_remote_fused_graph_execute_info", "") .Finalize(node_def())); TF_ASSERT_OK(InitOp()); // TODO(satok): Add benchmark } +TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithWrongDataType) { + DataTypeVector input_types({DT_INT32, DT_INT32}); + DataTypeVector output_types({DT_FLOAT}); + ASSERT_FALSE( + NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute") + .Input(FakeInput(2, DT_FLOAT)) + .Attr("Tinputs", input_types) + .Attr("Toutputs", output_types) + .Attr("serialized_remote_fused_graph_execute_info", "") + .Finalize(node_def()) + .ok()); + // TODO(satok): Add benchmark +} + //////////////////////////// // End-to-end test: Begin // //////////////////////////// @@ -94,13 +108,15 @@ static Output BuildRemoteFusedGraphExecuteOp( CHECK(scope.ok()); auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list)); const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute"); + + DataTypeVector input_types{DT_FLOAT}; + DataTypeVector output_types{DT_FLOAT}; + auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute") .Input(node_out_list) - .Attr("M", static_cast<int64>(output_list.size())) - .Attr("N", static_cast<int64>(output_node_count)) - .Attr("T", DT_FLOAT) - .Attr("U", DT_FLOAT) - .Attr("serialized_graph_transfer_info", + .Attr("Tinputs", input_types) + .Attr("Toutputs", output_types) + .Attr("serialized_remote_fused_graph_execute_info", StringPiece(execute_info.SerializeAsString())); CHECK(scope.ok()); scope.UpdateBuilder(&builder); |