aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/ops/xla_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/ops/xla_ops.cc')
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index f2473d98ff..1a29c3caab 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
@@ -32,4 +36,19 @@ REGISTER_OP("XlaLaunch")
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");
+REGISTER_OP("XlaClusterOutput")
+ .Input("input: T")
+ // Note: when replication is supported, this op will have N outputs.
+ .Output("outputs: T")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(0));
+ }
+ return Status::OK();
+ })
+ .Doc(
+ "Operator that connects the output of an XLA computation to other "
+ "consumer graph nodes.");
+
} // namespace tensorflow