diff options
author | Jacques Pienaar <jpienaar@google.com> | 2018-05-05 12:02:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-07 16:20:05 -0700 |
commit | a5e1809ddedbff8d9f66b82f9cf0989976694050 (patch) | |
tree | 0a85466999be93f08cb49df5ce4480b612ecb5f5 | |
parent | 62ed0aa37099e07720880a72a285304d34512cba (diff) |
[TPU] Add option to only compile a replicated graph.
Useful when wanting to compile a computation but not run it. Returns a serialized CompilationResult string with the error message.
PiperOrigin-RevId: 195547847
-rw-r--r-- | tensorflow/contrib/tpu/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/ops/replication_ops.cc | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/proto/BUILD | 10 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/proto/compilation_result.proto | 13 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu.py | 70 |
5 files changed, 88 insertions, 10 deletions
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 0bdf6f64c9..f84ff1bfe9 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -181,6 +181,7 @@ py_library( ":datasets", ":profiler", ":tpu_py", + "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index 3bdf7c2f83..defed00537 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -64,6 +64,10 @@ REGISTER_OP("TPUReplicatedOutput") "Operator that connects the output of an N-way replicated TPU " "computation to N separate outputs."); +REGISTER_OP("TPUCompilationResult") + .Output("output: string") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("TPUReplicate") .Attr("computation: func") .Attr("num_replicas: int >= 1") diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD index fcfbbe1a21..7ecb36852c 100644 --- a/tensorflow/contrib/tpu/proto/BUILD +++ b/tensorflow/contrib/tpu/proto/BUILD @@ -21,3 +21,13 @@ tf_proto_library( cc_api_version = 2, visibility = ["//visibility:public"], ) + +tf_proto_library( + name = "compilation_result_proto", + srcs = [ + "compilation_result.proto", + ], + cc_api_version = 2, + protodeps = ["//tensorflow/core:protos_all"], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/tpu/proto/compilation_result.proto b/tensorflow/contrib/tpu/proto/compilation_result.proto new file mode 100644 index 0000000000..cf52897de3 --- /dev/null +++ b/tensorflow/contrib/tpu/proto/compilation_result.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; +package tensorflow.tpu; + +import "tensorflow/core/lib/core/error_codes.proto"; + +// Describes the result of a TPU compilation. +message CompilationResultProto { + // The error message, if any, returned during compilation. + error.Code status_code = 1; + string status_error_message = 2; +} diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 7b8786304c..c8f24ed01d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -58,6 +58,7 @@ _NOT_IMPLEMENTED_OPS = set([ _MAX_WARNING_LINES = 5 _TPU_REPLICATE_ATTR = "_tpu_replicate" +_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status" _OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" @@ -385,6 +386,45 @@ def replicate(computation, ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ + return split_compile_and_replicate(computation, inputs, infeed_queue, + device_assignment, name)[1] + + +def split_compile_and_replicate(computation, + inputs=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Builds graph operators that runs compilation and replicated computation. + + This is a lower level interface than replicate that returns a separate compile + and execute output tensor. In the generated graph the compile op feeds into + the execute op and no additional compilation is incurred when running the + compile op before the execute op. The compile op returns additional + information about the compilation but does not return the compiled program. + + Args: + computation: A Python function that builds the computation to replicate. + inputs: A list of lists of input tensors or `None` (equivalent to + `[[]]`), indexed by `[replica_num][input_num]`. All replicas must + have the same number of inputs. + infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple + of arguments as inputs to computation. + device_assignment: If not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. Uses a default device assignment if `None`. The + `DeviceAssignment` may be omitted if each replica of the computation uses + only one core, and there is either only one replica, or the number of + replicas is equal to the number of cores in the TPU system. + name: (Deprecated) Does nothing. + Returns: + A list of lists with the first list corresponding to the compile op and the + second a list of output tensors, indexed by `[replica_num][output_num]`. + Raises: + ValueError: If all replicas do not have equal numbers of input tensors. + ValueError: If the number of inputs per replica does not match + the number of formal parameters to `computation`. + """ del name inputs = [[]] if inputs is None else inputs @@ -456,8 +496,8 @@ def replicate(computation, computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) - context = TPUReplicateContext( - name=graph.unique_name("cluster"), num_replicas=num_replicas) + cluster_name = graph.unique_name("cluster") + context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas) try: context.Enter() @@ -516,8 +556,7 @@ def replicate(computation, # Separates the returned Operations and Tensors. output_operations = [o for o in outputs if isinstance(o, ops.Operation)] - output_tensors = [o for o in outputs - if not isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] if outputs != output_tensors + output_operations: raise ValueError( @@ -550,22 +589,33 @@ def replicate(computation, name="output{}".format(i)) for i in xrange(output_arity)] + with ops.control_dependencies([metadata]): + compile_status = tpu_ops.tpu_compilation_result() + op = compile_status.op + attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) + op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access + with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ - control_flow_ops.no_op(name="shard_%d" % i) - for i in range(num_replicas) + compile_status, [ + control_flow_ops.no_op(name="shard_%d" % i) + for i in range(num_replicas) + ] ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [ - [array_ops.identity(outputs[out][replica], - name="output_%d_shard_%d" % (out, replica)) - for out in xrange(output_arity)] - for replica in xrange(num_replicas) + compile_status, [[ + array_ops.identity( + outputs[out][replica], + name="output_%d_shard_%d" % (out, replica)) + for out in xrange(output_arity) + ] + for replica in xrange(num_replicas)] ] |