aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h11
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py87
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_test.py4
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc22
-rw-r--r--tensorflow/core/kernels/control_flow_ops.h16
5 files changed, 117 insertions, 23 deletions
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index b27c32e9bc..0c49286acd 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -95,7 +95,16 @@ class XlaAssignVariableOp : public AsyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
SwitchOp); \
REGISTER_KERNEL_BUILDER( \
- Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp);
+ Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \
+ REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \
+ REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \
+ REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \
+ NextIterationOp); \
+ REGISTER_KERNEL_BUILDER(Name("LoopCond") \
+ .Device(DEVICE) \
+ .HostMemory("input") \
+ .HostMemory("output"), \
+ LoopCondOp);
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 612cd0114b..71a5012691 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -126,7 +126,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
outside the replicated computation.
"""
- def __init__(self, name, num_replicas):
+ def __init__(self, name, num_replicas, pivot):
+ """Builds a new TPUReplicateContext.
+
+ Args:
+ name: a unique name for the context, used to populate the `_tpu_replicate`
+ attribute.
+ num_replicas: an integer that gives the number of replicas for the
+ computation.
+ pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any
+ inputs will have a control dependency on the pivot node. This ensures
+ that nodes are correctly included in any enclosing control flow
+ contexts.
+ """
super(TPUReplicateContext, self).__init__()
self._num_replicas = num_replicas
self._outer_device_function_stack = None
@@ -138,6 +150,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._host_compute_core = []
self._name = name
self._unsupported_ops = []
+ self._pivot = pivot
def report_unsupported_operations(self):
if self._unsupported_ops:
@@ -262,9 +275,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access
super(TPUReplicateContext, self).Enter()
- def Exit(self):
- super(TPUReplicateContext, self).Exit()
-
def HostComputeCore(self):
return self._host_compute_core
@@ -300,10 +310,64 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
op.graph.prevent_feeding(op)
op.graph.prevent_fetching(op)
+ # Remove any control edges from outer control flow contexts. These may cause
+ # mismatched frame errors.
+ control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
+
+ if not op.inputs:
+ # Add a control edge from the control pivot to this op.
+ if not control_inputs:
+ # pylint: disable=protected-access
+ op._add_control_input(self.GetControlPivot())
+ # pylint: enable=protected-access
+ else:
+ for index in xrange(len(op.inputs)):
+ x = op.inputs[index]
+ real_x = self.AddValue(x)
+ if real_x != x:
+ op._update_input(index, real_x) # pylint: disable=protected-access
+
+ if external_inputs:
+ # Use an identity to pull control inputs as data inputs. Note that we
+ # ignore ops which don't have outputs. TODO(phawkins): fix that.
+ with ops.control_dependencies(None):
+ self.Enter()
+ external_inputs = [
+ array_ops.identity(x.outputs[0]).op
+ for x in external_inputs
+ if x.outputs
+ ]
+ self.Exit()
+ # pylint: disable=protected-access
+ op._add_control_inputs(external_inputs)
+ # pylint: enable=protected-access
+
+ # Mark op's outputs as seen by this context and any outer contexts.
+ output_names = [x.name for x in op.outputs]
+ context = self
+ while context is not None:
+ # pylint: disable=protected-access
+ context._values.update(output_names)
+ context = context._outer_context
+ # pylint: enable=protected-access
+
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
def AddValue(self, val):
+ if val.name in self._values:
+ # Use the real value if it comes from outer context.
+ result = self._external_values.get(val.name)
+ return val if result is None else result
+
result = val
+ self._values.add(val.name)
if self._outer_context:
result = self._outer_context.AddValue(val)
+ self._values.add(result.name)
+
+ self._external_values[val.name] = result
+
return result
def AddInnerOp(self, op):
@@ -319,6 +383,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# grad_state should be as if this is the top-level gradient state.
return None
+ @property
+ def back_prop(self):
+ """Forwards to the enclosing while context, if any."""
+ if self.GetWhileContext():
+ return self.GetWhileContext().back_prop
+ return False
+
+ def GetControlPivot(self):
+ return self._pivot
+
def outside_compilation(computation, *args, **kwargs):
"""Builds part of a computation outside any current TPU replicate scope.
@@ -505,7 +579,9 @@ def split_compile_and_replicate(computation,
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
cluster_name = graph.unique_name("cluster")
- context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas)
+ pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
+ context = TPUReplicateContext(
+ name=cluster_name, num_replicas=num_replicas, pivot=pivot)
try:
context.Enter()
@@ -582,6 +658,7 @@ def split_compile_and_replicate(computation,
with ops.device(t.device if t.device else core(0)):
new_output_tensors.append(array_ops.identity(t))
output_tensors = new_output_tensors
+ context.ExitResult(output_tensors)
finally:
context.report_unsupported_operations()
context.Exit()
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
index c3882b8a27..6bdaa528f9 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
@@ -37,7 +38,8 @@ class TPUContextTest(test.TestCase):
def testIsInContext(self):
"""Test that control_flow_util can check that we're in a TPU context."""
z1 = array_ops.identity(1)
- context = tpu.TPUReplicateContext(b"context", 1)
+ pivot = control_flow_ops.no_op()
+ context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
context.Enter()
z2 = array_ops.identity(1)
context.Exit()
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index 7d5d54e5be..ebf844d75f 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -587,24 +587,14 @@ REGISTER_SYCL_HOST_KERNEL(string);
#undef REGISTER_SYCL_HOST_KERNEL
#endif // TENSORFLOW_USE_SYCL
-// A LoopCond op has one input and one output. The input is a boolean
-// scalar representing the taken branches of the "pivot" Switch that
-// determines loop termination. As a contract, any high-level front-end
-// should always use port '0' of the "pivot" switches for loop exit.
-class LoopCondOp : public OpKernel {
- public:
- explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- context->set_output(0, context->input(0));
- }
+LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
+LoopCondOp::~LoopCondOp() = default;
- bool IsExpensive() override { return false; }
-
- ~LoopCondOp() override {}
+void LoopCondOp::Compute(OpKernelContext* context) {
+ context->set_output(0, context->input(0));
+}
- TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp);
-};
+bool LoopCondOp::IsExpensive() { return false; }
REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp);
REGISTER_KERNEL_BUILDER(Name("LoopCond")
diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h
index 4838f2e2bf..8edbcc9077 100644
--- a/tensorflow/core/kernels/control_flow_ops.h
+++ b/tensorflow/core/kernels/control_flow_ops.h
@@ -97,6 +97,22 @@ class NextIterationOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp);
};
+// A LoopCond op has one input and one output. The input is a boolean
+// scalar representing the taken branches of the "pivot" Switch that
+// determines loop termination. As a contract, any high-level front-end
+// should always use port '0' of the "pivot" switches for loop exit.
+class LoopCondOp : public OpKernel {
+ public:
+ explicit LoopCondOp(OpKernelConstruction* context);
+ ~LoopCondOp() override;
+
+ void Compute(OpKernelContext* context) override;
+
+ bool IsExpensive() override;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp);
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_