aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-11 11:12:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-11 12:23:53 -0700
commita11e669c5a4c855b0b507da97378bc7e03a08f86 (patch)
tree341ba99ba25ed30c2148dbf91d6a41e86d56ad82 /tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc
parent740053be5054027bc66c7995c045d40b07479943 (diff)
Support two data types as inputs in RemoteFusedGraphExecuteOp
Change: 152843285
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc')
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc38
1 files changed, 27 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc
index 580be4b7db..925af1f79e 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc
@@ -37,20 +37,34 @@ namespace tensorflow {
class RemoteFusedGraphExecuteTest : public OpsTestBase {};
-TEST_F(RemoteFusedGraphExecuteTest, ExecuteAddGraph) {
+TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithOneDataType) {
+ DataTypeVector input_types({DT_FLOAT, DT_FLOAT});
+ DataTypeVector output_types({DT_FLOAT});
TF_ASSERT_OK(
NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
.Input(FakeInput(2, DT_FLOAT))
- .Attr("M", 2)
- .Attr("N", 1)
- .Attr("T", DataTypeToEnum<float>::v())
- .Attr("U", DataTypeToEnum<float>::v())
- .Attr("serialized_graph_transfer_info", "")
+ .Attr("Tinputs", input_types)
+ .Attr("Toutputs", output_types)
+ .Attr("serialized_remote_fused_graph_execute_info", "")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// TODO(satok): Add benchmark
}
+TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithWrongDataType) {
+ DataTypeVector input_types({DT_INT32, DT_INT32});
+ DataTypeVector output_types({DT_FLOAT});
+ ASSERT_FALSE(
+ NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
+ .Input(FakeInput(2, DT_FLOAT))
+ .Attr("Tinputs", input_types)
+ .Attr("Toutputs", output_types)
+ .Attr("serialized_remote_fused_graph_execute_info", "")
+ .Finalize(node_def())
+ .ok());
+ // TODO(satok): Add benchmark
+}
+
////////////////////////////
// End-to-end test: Begin //
////////////////////////////
@@ -94,13 +108,15 @@ static Output BuildRemoteFusedGraphExecuteOp(
CHECK(scope.ok());
auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
+
+ DataTypeVector input_types{DT_FLOAT};
+ DataTypeVector output_types{DT_FLOAT};
+
auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
.Input(node_out_list)
- .Attr("M", static_cast<int64>(output_list.size()))
- .Attr("N", static_cast<int64>(output_node_count))
- .Attr("T", DT_FLOAT)
- .Attr("U", DT_FLOAT)
- .Attr("serialized_graph_transfer_info",
+ .Attr("Tinputs", input_types)
+ .Attr("Toutputs", output_types)
+ .Attr("serialized_remote_fused_graph_execute_info",
StringPiece(execute_info.SerializeAsString()));
CHECK(scope.ok());
scope.UpdateBuilder(&builder);