diff options
author | 2018-10-01 11:44:17 -0700 | |
---|---|---|
committer | 2018-10-01 11:49:45 -0700 | |
commit | 7cabc6be4e32dfb7f42c7f5e33549984bfdb68a3 (patch) | |
tree | 2554fc313a566f93aa221985baff8dace98e0d68 /tensorflow/compiler | |
parent | f0f301f05fb1f1965c966ef57cc390e48d966f12 (diff) |
Allow zero number of inputs in XRT execute operation.
PiperOrigin-RevId: 215252408
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xrt/ops/xrt_execute_op.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xrt/tests/raw_api_test.cc | 41 |
2 files changed, 42 insertions, 1 deletions
diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc index fda4c31298..40ec1b0ba9 100644 --- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { REGISTER_OP("XRTExecute") - .Attr("Ninputs: int") + .Attr("Ninputs: int >= 0") .Input("computation_handle: int64") .Input("execution_config: string") .Input("input_handles: Ninputs * int64") diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 2952feb16a..f590fbf0d9 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -108,6 +108,14 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a, return equal; } +xla::XlaComputation OnePlusTwo() { + xla::XlaBuilder builder("OnePlusTwo"); + auto c0 = xla::ConstantR0(&builder, 1.0f); + auto c1 = xla::ConstantR0(&builder, 2.0f); + xla::Add(c0, c1); + return builder.Build().ValueOrDie(); +} + xla::XlaComputation AddAndScale() { xla::XlaBuilder builder("AddAndScale"); auto p0 = xla::Parameter(&builder, 0, @@ -346,6 +354,39 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, CompileAndExecuteZeroArg) { + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot()); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto result = ops::XRTExecute(root, c_handle, e_config, + std::initializer_list<Input>({})); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector<Tensor> outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()())); + + auto expected = xla::LiteralUtil::CreateR0<float>(3.0f); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, CompileAndExecuteReturnTuple) { xrt::XLAAllocation p0; p0.set_device_ordinal(0); |