aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/memory_types.cc
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-07-03 12:21:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 12:24:24 -0700
commit962c639b27b40afafdc41d7bffca2ee1d2ccd1cf (patch)
treed8e69f2b7707b5b32ffcea11844c61cc2c499007 /tensorflow/core/framework/memory_types.cc
parent69fe072cb467c0723e8c5266c8b32288fb3104a8 (diff)
Make functions defined with tfe.defun respect devices when executing.
Modifies GraphModeFunction to emit PartitionedCall ops instead of Call ops so that the created functions can execute across devices. This should strictly increase the set of functions that tfe.defun can faithfully execute. Previous to this change, functions executed through tfe.defun would ignore device annotations and only run on a single device. It is not yet possible to execute a function across multiple processes. Specifically, this CL: (1) Adds a stateful version of PartitionedCall, (2) Modifies `defun` to emit PartitionedCall or StatefulPartitionedCall by default, (3) Makes `tf.gradients` aware of the existence of `(Stateful)PartitionedCall`, (4) Fixes bugs in PartitionedCallOp related to the placement of resource-touching ops / which args and retvals are always on host memory, and also removes the requirement for args/retvals to be passed through the host. PiperOrigin-RevId: 203164388
Diffstat (limited to 'tensorflow/core/framework/memory_types.cc')
-rw-r--r--tensorflow/core/framework/memory_types.cc11
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
index 270118bb67..6dff6fe654 100644
--- a/tensorflow/core/framework/memory_types.cc
+++ b/tensorflow/core/framework/memory_types.cc
@@ -60,13 +60,18 @@ void MemoryTypesHelper(const NameRangeMap& name_map,
host_memory_args->resize(keep);
}
+bool IsFunctionCallOp(const string& op_type) {
+ return op_type == "SymbolicGradient" || op_type == "PartitionedCall" ||
+ op_type == "StatefulPartitionedCall";
+}
+
+} // namespace
+
MemoryType MTypeFromDType(const DataType dtype) {
return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY
: DEVICE_MEMORY;
}
-} // namespace
-
Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
const DeviceType& device_type, const NodeDef& ndef,
MemoryTypeVector* inp_mtypes,
@@ -94,7 +99,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
// TODO(zhifengc,phawkins): We should do type inference over function bodies
// to derive the correct input/output memory types. We should also split
// host-memory and non host-memory arguments into separate type lists.
- if (!status.ok() || ndef.op() == "SymbolicGradient") {
+ if (!status.ok() || IsFunctionCallOp(ndef.op())) {
for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
return Status::OK();