diff options
-rw-r--r-- | tensorflow/compiler/jit/xla_device_ops.h | 11 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu.py | 87 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_test.py | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.h | 16 |
5 files changed, 117 insertions, 23 deletions
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index b27c32e9bc..0c49286acd 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -95,7 +95,16 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ SwitchOp); \ REGISTER_KERNEL_BUILDER( \ - Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ + REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("LoopCond") \ + .Device(DEVICE) \ + .HostMemory("input") \ + .HostMemory("output"), \ + LoopCondOp); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 612cd0114b..71a5012691 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -126,7 +126,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): outside the replicated computation. """ - def __init__(self, name, num_replicas): + def __init__(self, name, num_replicas, pivot): + """Builds a new TPUReplicateContext. + + Args: + name: a unique name for the context, used to populate the `_tpu_replicate` + attribute. + num_replicas: an integer that gives the number of replicas for the + computation. + pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any + inputs will have a control dependency on the pivot node. This ensures + that nodes are correctly included in any enclosing control flow + contexts. + """ super(TPUReplicateContext, self).__init__() self._num_replicas = num_replicas self._outer_device_function_stack = None @@ -138,6 +150,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._host_compute_core = [] self._name = name self._unsupported_ops = [] + self._pivot = pivot def report_unsupported_operations(self): if self._unsupported_ops: @@ -262,9 +275,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access super(TPUReplicateContext, self).Enter() - def Exit(self): - super(TPUReplicateContext, self).Exit() - def HostComputeCore(self): return self._host_compute_core @@ -300,10 +310,64 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) + # Remove any control edges from outer control flow contexts. These may cause + # mismatched frame errors. + control_inputs, external_inputs = self._RemoveExternalControlEdges(op) + + if not op.inputs: + # Add a control edge from the control pivot to this op. + if not control_inputs: + # pylint: disable=protected-access + op._add_control_input(self.GetControlPivot()) + # pylint: enable=protected-access + else: + for index in xrange(len(op.inputs)): + x = op.inputs[index] + real_x = self.AddValue(x) + if real_x != x: + op._update_input(index, real_x) # pylint: disable=protected-access + + if external_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(phawkins): fix that. + with ops.control_dependencies(None): + self.Enter() + external_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_inputs + if x.outputs + ] + self.Exit() + # pylint: disable=protected-access + op._add_control_inputs(external_inputs) + # pylint: enable=protected-access + + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + context = self + while context is not None: + # pylint: disable=protected-access + context._values.update(output_names) + context = context._outer_context + # pylint: enable=protected-access + + if self._outer_context: + self._outer_context.AddInnerOp(op) + def AddValue(self, val): + if val.name in self._values: + # Use the real value if it comes from outer context. + result = self._external_values.get(val.name) + return val if result is None else result + result = val + self._values.add(val.name) if self._outer_context: result = self._outer_context.AddValue(val) + self._values.add(result.name) + + self._external_values[val.name] = result + return result def AddInnerOp(self, op): @@ -319,6 +383,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # grad_state should be as if this is the top-level gradient state. return None + @property + def back_prop(self): + """Forwards to the enclosing while context, if any.""" + if self.GetWhileContext(): + return self.GetWhileContext().back_prop + return False + + def GetControlPivot(self): + return self._pivot + def outside_compilation(computation, *args, **kwargs): """Builds part of a computation outside any current TPU replicate scope. @@ -505,7 +579,9 @@ def split_compile_and_replicate(computation, tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") - context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas) + pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") + context = TPUReplicateContext( + name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() @@ -582,6 +658,7 @@ def split_compile_and_replicate(computation, with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors + context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py index c3882b8a27..6bdaa528f9 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.framework import dtypes from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -37,7 +38,8 @@ class TPUContextTest(test.TestCase): def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) - context = tpu.TPUReplicateContext(b"context", 1) + pivot = control_flow_ops.no_op() + context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) context.Enter() z2 = array_ops.identity(1) context.Exit() diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 7d5d54e5be..ebf844d75f 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -587,24 +587,14 @@ REGISTER_SYCL_HOST_KERNEL(string); #undef REGISTER_SYCL_HOST_KERNEL #endif // TENSORFLOW_USE_SYCL -// A LoopCond op has one input and one output. The input is a boolean -// scalar representing the taken branches of the "pivot" Switch that -// determines loop termination. As a contract, any high-level front-end -// should always use port '0' of the "pivot" switches for loop exit. -class LoopCondOp : public OpKernel { - public: - explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - context->set_output(0, context->input(0)); - } +LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {} +LoopCondOp::~LoopCondOp() = default; - bool IsExpensive() override { return false; } - - ~LoopCondOp() override {} +void LoopCondOp::Compute(OpKernelContext* context) { + context->set_output(0, context->input(0)); +} - TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); -}; +bool LoopCondOp::IsExpensive() { return false; } REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp); REGISTER_KERNEL_BUILDER(Name("LoopCond") diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h index 4838f2e2bf..8edbcc9077 100644 --- a/tensorflow/core/kernels/control_flow_ops.h +++ b/tensorflow/core/kernels/control_flow_ops.h @@ -97,6 +97,22 @@ class NextIterationOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp); }; +// A LoopCond op has one input and one output. The input is a boolean +// scalar representing the taken branches of the "pivot" Switch that +// determines loop termination. As a contract, any high-level front-end +// should always use port '0' of the "pivot" switches for loop exit. +class LoopCondOp : public OpKernel { + public: + explicit LoopCondOp(OpKernelConstruction* context); + ~LoopCondOp() override; + + void Compute(OpKernelContext* context) override; + + bool IsExpensive() override; + + TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp); +}; + } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_ |