diff options
author | 2018-07-03 12:21:08 -0700 | |
---|---|---|
committer | 2018-07-03 12:24:24 -0700 | |
commit | 962c639b27b40afafdc41d7bffca2ee1d2ccd1cf (patch) | |
tree | d8e69f2b7707b5b32ffcea11844c61cc2c499007 /tensorflow/core/framework/memory_types.cc | |
parent | 69fe072cb467c0723e8c5266c8b32288fb3104a8 (diff) |
Make functions defined with tfe.defun respect devices when executing.
Modifies GraphModeFunction to emit PartitionedCall ops instead of Call ops
so that the created functions can execute across devices. This should strictly
increase the set of functions that tfe.defun can faithfully execute.
Previous to this change, functions executed through tfe.defun would ignore
device annotations and only run on a single device. It is not yet possible to execute
a function across multiple processes.
Specifically, this CL:
(1) Adds a stateful version of PartitionedCall,
(2) Modifies `defun` to emit PartitionedCall or StatefulPartitionedCall by default,
(3) Makes `tf.gradients` aware of the existence of `(Stateful)PartitionedCall`,
(4) Fixes bugs in PartitionedCallOp related to the placement of
resource-touching ops / which args and retvals are always on host memory, and
also removes the requirement for args/retvals to be passed through the host.
PiperOrigin-RevId: 203164388
Diffstat (limited to 'tensorflow/core/framework/memory_types.cc')
-rw-r--r-- | tensorflow/core/framework/memory_types.cc | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 270118bb67..6dff6fe654 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -60,13 +60,18 @@ void MemoryTypesHelper(const NameRangeMap& name_map, host_memory_args->resize(keep); } +bool IsFunctionCallOp(const string& op_type) { + return op_type == "SymbolicGradient" || op_type == "PartitionedCall" || + op_type == "StatefulPartitionedCall"; +} + +} // namespace + MemoryType MTypeFromDType(const DataType dtype) { return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY : DEVICE_MEMORY; } -} // namespace - Status MemoryTypesForNode(const OpRegistryInterface* op_registry, const DeviceType& device_type, const NodeDef& ndef, MemoryTypeVector* inp_mtypes, @@ -94,7 +99,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, // TODO(zhifengc,phawkins): We should do type inference over function bodies // to derive the correct input/output memory types. We should also split // host-memory and non host-memory arguments into separate type lists. - if (!status.ok() || ndef.op() == "SymbolicGradient") { + if (!status.ok() || IsFunctionCallOp(ndef.op())) { for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t)); for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t)); return Status::OK(); |