aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-05-05 12:02:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 16:20:05 -0700
commita5e1809ddedbff8d9f66b82f9cf0989976694050 (patch)
tree0a85466999be93f08cb49df5ce4480b612ecb5f5
parent62ed0aa37099e07720880a72a285304d34512cba (diff)
[TPU] Add option to only compile a replicated graph.
Useful when wanting to compile a computation but not run it. Returns a serialized CompilationResult string with the error message. PiperOrigin-RevId: 195547847
-rw-r--r--tensorflow/contrib/tpu/BUILD1
-rw-r--r--tensorflow/contrib/tpu/ops/replication_ops.cc4
-rw-r--r--tensorflow/contrib/tpu/proto/BUILD10
-rw-r--r--tensorflow/contrib/tpu/proto/compilation_result.proto13
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py70
5 files changed, 88 insertions, 10 deletions
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 0bdf6f64c9..f84ff1bfe9 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -181,6 +181,7 @@ py_library(
":datasets",
":profiler",
":tpu_py",
+ "//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
"//tensorflow/contrib/tpu/proto:topology_proto_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc
index 3bdf7c2f83..defed00537 100644
--- a/tensorflow/contrib/tpu/ops/replication_ops.cc
+++ b/tensorflow/contrib/tpu/ops/replication_ops.cc
@@ -64,6 +64,10 @@ REGISTER_OP("TPUReplicatedOutput")
"Operator that connects the output of an N-way replicated TPU "
"computation to N separate outputs.");
+REGISTER_OP("TPUCompilationResult")
+ .Output("output: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("TPUReplicate")
.Attr("computation: func")
.Attr("num_replicas: int >= 1")
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD
index fcfbbe1a21..7ecb36852c 100644
--- a/tensorflow/contrib/tpu/proto/BUILD
+++ b/tensorflow/contrib/tpu/proto/BUILD
@@ -21,3 +21,13 @@ tf_proto_library(
cc_api_version = 2,
visibility = ["//visibility:public"],
)
+
+tf_proto_library(
+ name = "compilation_result_proto",
+ srcs = [
+ "compilation_result.proto",
+ ],
+ cc_api_version = 2,
+ protodeps = ["//tensorflow/core:protos_all"],
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/contrib/tpu/proto/compilation_result.proto b/tensorflow/contrib/tpu/proto/compilation_result.proto
new file mode 100644
index 0000000000..cf52897de3
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/compilation_result.proto
@@ -0,0 +1,13 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+package tensorflow.tpu;
+
+import "tensorflow/core/lib/core/error_codes.proto";
+
+// Describes the result of a TPU compilation.
+message CompilationResultProto {
+ // The error message, if any, returned during compilation.
+ error.Code status_code = 1;
+ string status_error_message = 2;
+}
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 7b8786304c..c8f24ed01d 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -58,6 +58,7 @@ _NOT_IMPLEMENTED_OPS = set([
_MAX_WARNING_LINES = 5
_TPU_REPLICATE_ATTR = "_tpu_replicate"
+_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
@@ -385,6 +386,45 @@ def replicate(computation,
ValueError: If the number of inputs per replica does not match
the number of formal parameters to `computation`.
"""
+ return split_compile_and_replicate(computation, inputs, infeed_queue,
+ device_assignment, name)[1]
+
+
+def split_compile_and_replicate(computation,
+ inputs=None,
+ infeed_queue=None,
+ device_assignment=None,
+ name=None):
+ """Builds graph operators that runs compilation and replicated computation.
+
+ This is a lower level interface than replicate that returns a separate compile
+ and execute output tensor. In the generated graph the compile op feeds into
+ the execute op and no additional compilation is incurred when running the
+ compile op before the execute op. The compile op returns additional
+ information about the compilation but does not return the compiled program.
+
+ Args:
+ computation: A Python function that builds the computation to replicate.
+ inputs: A list of lists of input tensors or `None` (equivalent to
+ `[[]]`), indexed by `[replica_num][input_num]`. All replicas must
+ have the same number of inputs.
+ infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
+ of arguments as inputs to computation.
+ device_assignment: If not `None`, a `DeviceAssignment` describing the
+ mapping between logical cores in the computation with physical cores in
+ the TPU topology. Uses a default device assignment if `None`. The
+ `DeviceAssignment` may be omitted if each replica of the computation uses
+ only one core, and there is either only one replica, or the number of
+ replicas is equal to the number of cores in the TPU system.
+ name: (Deprecated) Does nothing.
+ Returns:
+ A list of lists with the first list corresponding to the compile op and the
+ second a list of output tensors, indexed by `[replica_num][output_num]`.
+ Raises:
+ ValueError: If all replicas do not have equal numbers of input tensors.
+ ValueError: If the number of inputs per replica does not match
+ the number of formal parameters to `computation`.
+ """
del name
inputs = [[]] if inputs is None else inputs
@@ -456,8 +496,8 @@ def replicate(computation,
computation_inputs.append(
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
- context = TPUReplicateContext(
- name=graph.unique_name("cluster"), num_replicas=num_replicas)
+ cluster_name = graph.unique_name("cluster")
+ context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas)
try:
context.Enter()
@@ -516,8 +556,7 @@ def replicate(computation,
# Separates the returned Operations and Tensors.
output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
- output_tensors = [o for o in outputs
- if not isinstance(o, ops.Operation)]
+ output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
if outputs != output_tensors + output_operations:
raise ValueError(
@@ -550,22 +589,33 @@ def replicate(computation,
name="output{}".format(i))
for i in xrange(output_arity)]
+ with ops.control_dependencies([metadata]):
+ compile_status = tpu_ops.tpu_compilation_result()
+ op = compile_status.op
+ attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
+ op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access
+
with ops.control_dependencies(output_operations):
if output_arity == 0:
# Returns a list of NoOps dependent on the replication Op, indexed by
# [replica_num].
return [
- control_flow_ops.no_op(name="shard_%d" % i)
- for i in range(num_replicas)
+ compile_status, [
+ control_flow_ops.no_op(name="shard_%d" % i)
+ for i in range(num_replicas)
+ ]
]
else:
# Wraps the outputs in identity operators so the names of any possible
# `fetch` nodes are preserved by the replication rewrite.
return [
- [array_ops.identity(outputs[out][replica],
- name="output_%d_shard_%d" % (out, replica))
- for out in xrange(output_arity)]
- for replica in xrange(num_replicas)
+ compile_status, [[
+ array_ops.identity(
+ outputs[out][replica],
+ name="output_%d_shard_%d" % (out, replica))
+ for out in xrange(output_arity)
+ ]
+ for replica in xrange(num_replicas)]
]