aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 11:44:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 11:49:45 -0700
commit7cabc6be4e32dfb7f42c7f5e33549984bfdb68a3 (patch)
tree2554fc313a566f93aa221985baff8dace98e0d68 /tensorflow/compiler
parentf0f301f05fb1f1965c966ef57cc390e48d966f12 (diff)
Allow zero number of inputs in XRT execute operation.
PiperOrigin-RevId: 215252408
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_execute_op.cc2
-rw-r--r--tensorflow/compiler/xrt/tests/raw_api_test.cc41
2 files changed, 42 insertions, 1 deletions
diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
index fda4c31298..40ec1b0ba9 100644
--- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
+++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("XRTExecute")
- .Attr("Ninputs: int")
+ .Attr("Ninputs: int >= 0")
.Input("computation_handle: int64")
.Input("execution_config: string")
.Input("input_handles: Ninputs * int64")
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index 2952feb16a..f590fbf0d9 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -108,6 +108,14 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a,
return equal;
}
+xla::XlaComputation OnePlusTwo() {
+ xla::XlaBuilder builder("OnePlusTwo");
+ auto c0 = xla::ConstantR0(&builder, 1.0f);
+ auto c1 = xla::ConstantR0(&builder, 2.0f);
+ xla::Add(c0, c1);
+ return builder.Build().ValueOrDie();
+}
+
xla::XlaComputation AddAndScale() {
xla::XlaBuilder builder("AddAndScale");
auto p0 = xla::Parameter(&builder, 0,
@@ -346,6 +354,39 @@ TEST(RawApiTest, CompileAndExecute) {
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
+TEST(RawApiTest, CompileAndExecuteZeroArg) {
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {});
+
+ xrt::XRTExecutionConfig e;
+ e.set_release_input_handles(true);
+ e.set_release_compilation_handle(true);
+ StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot());
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto e_config =
+ ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ auto result = ops::XRTExecute(root, c_handle, e_config,
+ std::initializer_list<Input>({}));
+ auto read_back = ops::XRTReadLiteralAndRelease(root, result);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
+ EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
+}
+
TEST(RawApiTest, CompileAndExecuteReturnTuple) {
xrt::XLAAllocation p0;
p0.set_device_ordinal(0);