aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/jit/build_xla_launch_ops_pass.cc4
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.h2
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc5
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc4
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc3
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.h5
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h4
-rw-r--r--tensorflow/compiler/tests/dense_layer_test.py10
-rw-r--r--tensorflow/compiler/tests/jit_test.py16
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc4
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.cc6
-rw-r--r--tensorflow/docs_src/performance/xla/jit.md4
14 files changed, 39 insertions, 44 deletions
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
index 9a2bb00075..b17ff589e2 100644
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
@@ -40,7 +40,7 @@ static Status BuildLaunchNode(
Graph* graph, Node** node) {
NodeDef def;
def.set_name(graph->NewName(nodename));
- def.set_op("_XlaLaunch");
+ def.set_op("XlaLaunch");
def.set_device(device_name);
AddNodeAttr("Tconstants", constant_dtypes, &def);
AddNodeAttr("Targs", arg_dtypes, &def);
@@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
node->input_types().begin() + num_constant_args,
node->input_types().begin() + num_constant_args + num_nonconst_args);
- // Build a _XlaLaunch operator to execute the function body.
+ // Build a XlaLaunch operator to execute the function body.
Node* launch_node;
TF_RETURN_IF_ERROR(BuildLaunchNode(
graph->NewName(node->name()), node->type_string(), node->def().attr(),
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index 34be4409a3..5fee36f022 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -80,7 +80,7 @@ Status EncapsulateSubgraphsInFunctions(
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
// The attribute that marks function calls produced by the encapsulate
-// subgraphs pass and that should in turn be compiled via _XlaLaunch operators.
+// subgraphs pass and that should in turn be compiled via XlaLaunch operators.
extern const char* const kXlaCompiledKernelAttr;
// Does `node` have the kXlaCompiledKernelAttr attribute?
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 9d856346ec..27287e0f96 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -256,10 +256,9 @@ XlaLocalLaunchOp::~XlaLocalLaunchOp() {
VLOG(1) << "XlaLocalLaunchOp destroyed";
}
-REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU),
- XlaLocalLaunchOp);
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
-REGISTER_KERNEL_BUILDER(Name("_XlaLaunch")
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
.Device(DEVICE_GPU)
.HostMemory("constants")
.HostMemory("resources"),
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index 07320b43da..f2473d98ff 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -17,7 +17,7 @@ limitations under the License.
namespace tensorflow {
-REGISTER_OP("_XlaLaunch")
+REGISTER_OP("XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
.Input("args: Targs")
@@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch")
.Attr("Tresults: list(type) >= 0")
.Attr("function: func")
// XLA random-number generation ops are stateful.
- // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch.
+ // TODO(phawkins): create stateful and non-stateful variants of XlaLaunch.
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 6430975335..7ed609c437 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -122,8 +122,7 @@ Status XlaCompilationCache::BuildSignature(
namespace {
-// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch
-// op.
+// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op.
Status BuildArguments(const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args,
OpKernelContext* ctx,
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h
index 23c6f3903f..7cc3d0e007 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h
@@ -29,11 +29,8 @@ limitations under the License.
namespace tensorflow {
// An OpKernel that compiles an op to an XLA computation and runs it. Unlike
-// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a
+// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a
// vanilla TensorFlow op as long as the bridge supports it.
-//
-// Importantly _XlaLaunch assumes all input and output tensors are on the host,
-// whereas XlacompileOnDemandOp works with tensors in device memory.
class XlaCompileOnDemandOp : public OpKernel {
public:
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 65c0e8577f..9c00a0682c 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -33,7 +33,7 @@ namespace tensorflow {
// Dummy OpKernel, used for kernels assigned to an XLA device that should be
// compiled. Should never be called at runtime since such ops should be
-// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an
+// rewritten to a XlaLaunch op. If it is called, it means the placer placed an
// operator on an XLA device but the compiler did not compile it.
class XlaDeviceDummyOp : public OpKernel {
public:
@@ -42,7 +42,7 @@ class XlaDeviceDummyOp : public OpKernel {
};
#define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
- REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \
+ REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \
.Device(DEVICE) \
.HostMemory("constants") \
.HostMemory("resources"), \
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index b0bf1b79d6..865f60ccab 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -46,8 +46,8 @@ def InLabels(labels, substr):
def XlaLaunchOpCount(labels):
- """Count how many _XlaLaunch labels are present."""
- return sum("_XlaLaunch(" in x for x in labels)
+ """Count how many XlaLaunch labels are present."""
+ return sum("XlaLaunch(" in x for x in labels)
class DenseLayerTest(test.TestCase):
@@ -55,7 +55,7 @@ class DenseLayerTest(test.TestCase):
def testDenseLayerAutoJit(self):
"""Tests dense layer compilation in auto-jit mode.
- Dense layer should be compiled into a single _XlaLaunch op in auto-jit mode.
+ Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
"""
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit")
@@ -83,7 +83,7 @@ class DenseLayerTest(test.TestCase):
"""Tests that the dense layer node is properly compiled in jit scope.
Dense layer with static shape input tensor should be compiled into a single
- _XlaLaunch op by XLA.
+ XlaLaunch op by XLA.
"""
with self.test_session() as sess:
@@ -110,7 +110,7 @@ class DenseLayerTest(test.TestCase):
Dense layer uses shape op to get shape of input tensor if its shape is not
fully defined. XLA does not cluster shape op with other operators. But in
experimental_jit_scope, XLA is forced to compile shape op into its own
- cluster, causing dense layer to be split into TWO _XlaLaunch ops.
+ cluster, causing dense layer to be split into TWO XlaLaunch ops.
"""
with self.test_session() as sess:
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 0310cdde66..4b0043b6b4 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -78,10 +78,10 @@ def InLabels(labels, substr):
def MetadataHasXlaLaunch(run_metadata):
- """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline."""
+ """Returns true if there is a XlaLaunch kernel in run_metadata's timeline."""
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
- return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch")
+ return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch")
class JitLaunchTest(test.TestCase):
@@ -90,8 +90,8 @@ class JitLaunchTest(test.TestCase):
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
# the same number of tensors as arguments that are in 'args', and must return
# a tuple of output tensors.
- # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node
- # actually ran. However, it is sometimes possible for _XlaLaunch ops to be
+ # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node
+ # actually ran. However, it is sometimes possible for XlaLaunch ops to be
# constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
@@ -441,14 +441,14 @@ class XlaCompilationTest(test.TestCase):
self.assertFalse(InLabels(labels, "Log"))
self.assertTrue(InLabels(labels, "Reciprocal"))
self.assertTrue(InLabels(labels, "Mul"))
- self.assertFalse(InLabels(labels, "_XlaLaunch"))
+ self.assertFalse(InLabels(labels, "XlaLaunch"))
- # Compile the backprop. One _XlaLaunch.
+ # Compile the backprop. One XlaLaunch.
labels = _Run(compiled=True)
self.assertFalse(InLabels(labels, "Log"))
self.assertFalse(InLabels(labels, "Reciprocal"))
self.assertFalse(InLabels(labels, "Mul"))
- self.assertTrue(InLabels(labels, "_XlaLaunch"))
+ self.assertTrue(InLabels(labels, "XlaLaunch"))
class ElementWiseFusionTest(test.TestCase):
@@ -482,7 +482,7 @@ class ElementWiseFusionTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = RunMetadataLabels(run_metadata)
- count = sum("_XlaLaunch(" in x for x in labels)
+ count = sum("XlaLaunch(" in x for x in labels)
return output, count
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 621fbc149a..bf496bd8bc 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -38,7 +38,7 @@ class XlaContext;
// It does a symbolic execution of the graph starting from specific input
// shapes, using a JIT device to convert operators into XLA computations.
//
-// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the
+// XlaCompiler is typically invoked from an `XlaLaunch` operator once the
// shapes of all input parameters to the computation are known. This is
// because the symbolic execution requires known shapes for all operations.
//
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index e309cb1e34..4692038b61 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -39,10 +39,10 @@ const char* const DEVICE_XLA_GPU = "XLA_GPU";
static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) {
const OpDef* op_def;
- TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def));
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def));
NodeDef node_def;
node_def.set_name("_XlaLaunch-op");
- node_def.set_op("_XlaLaunch");
+ node_def.set_op("XlaLaunch");
string kernel_class_name;
TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr,
&kernel_class_name));
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 1df499675d..ce989f4b4e 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -186,14 +186,14 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
// primitive op (e.g. matmul).
//
// The wrapper function conforms to the function signature expected by
-// _XlaLaunchOp, with input params ordered by <constants, (variable) args and
+// XlaLaunch, with input params ordered by <constants, (variable) args and
// resources>. For example, if the op has input params <Const1, Arg2, Const3,
// Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
// Resource4> as the input params to the synthesized function.
//
// It populates `const_input_types`, `arg_input_types` and
// `op_input_to_func_input` based on the reordering results, that the caller can
-// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
+// use them to build an XlaLaunch. On error, it returns NULL, and sets
// `status` accordingly.
const FunctionDef* OpToFunction(TFE_Op* op,
std::vector<TF_DataType>* const_input_types,
@@ -311,12 +311,12 @@ const FunctionDef* OpToFunction(TFE_Op* op,
return ret;
}
-// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
+// Builds an XlaLaunch as a wrapper over 'op', so that 'op' can be executed
// via XLA.
std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
- VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->operation.Name();
+ VLOG(1) << "Creating XlaLaunch for TFE_Op " << op->operation.Name();
auto launch_op = std::unique_ptr<TFE_Op>(
- TFE_NewOp(op->operation.ctx, "_XlaLaunch", status));
+ TFE_NewOp(op->operation.ctx, "XlaLaunch", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
if (op->operation.device) {
TFE_OpSetDevice(launch_op.get(), op->operation.device->name().c_str(),
@@ -331,7 +331,7 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
gtl::FlatMap<int, int> op_input_to_func_input;
if (fdef == nullptr) {
// See if this is a primitive op, and if so create a function for it, so
- // that _XlaLaunchOp can access it.
+ // that XlaLaunch can access it.
fdef = OpToFunction(op, &const_input_types, &arg_input_types,
&op_input_to_func_input, status);
if (!status.ok()) return nullptr;
@@ -423,7 +423,7 @@ Status EagerLocalExecute(EagerOperation* op,
if (!status.ok()) return status;
#ifdef TENSORFLOW_EAGER_USE_XLA
std::unique_ptr<TFE_Op> xla_launch_op;
- if (op->UseXla() && op->Name() != "_XlaLaunch") {
+ if (op->UseXla() && op->Name() != "XlaLaunch") {
xla_launch_op = BuildXlaLaunch(op, status);
if (!status.ok()) return status;
op = xla_launch_op.get();
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 7b7fd81155..200454b522 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -126,9 +126,9 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
return false;
}
const std::unordered_set<string> do_not_rewrite_ops{
- "Assert", "CheckNumerics", "_Retval",
- "_Arg", "_ParallelConcatUpdate", "_TPUExecute",
- "_TPUCompile", "ControlTrigger"};
+ "Assert", "CheckNumerics", "_Retval",
+ "_Arg", "_ParallelConcatUpdate", "TPUExecute",
+ "TPUCompile", "ControlTrigger"};
if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
return false;
}
diff --git a/tensorflow/docs_src/performance/xla/jit.md b/tensorflow/docs_src/performance/xla/jit.md
index d9a979ccbd..6724d1eaf8 100644
--- a/tensorflow/docs_src/performance/xla/jit.md
+++ b/tensorflow/docs_src/performance/xla/jit.md
@@ -137,12 +137,12 @@ TF_XLA_FLAGS=--xla_generate_hlo_graph=.* python mnist_softmax_xla.py
```
Open the timeline file created (`timeline.ctf.json`). The rendered timeline
-should look similar to the picture below with one long bar labeled `_XlaLaunch`.
+should look similar to the picture below with one long bar labeled `XlaLaunch`.
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/jit_timeline_gpu_xla.png">
</div>
-To understand what is happening in `_XlaLaunch`, look at the console output for
+To understand what is happening in `XlaLaunch`, look at the console output for
statements similar to the following:
```shell