diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-26 12:33:17 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-26 12:37:28 -0800 |
commit | da492741630f62bfd4f8475fa532ef216f0d2bfd (patch) | |
tree | 8be78847bc98b3a4833b0f3280a6f1d1aca5b593 | |
parent | 1120deaf0bf5a51db5351c12b548994b35ba71c8 (diff) |
Maintain a cache of output dtypes of ops in TFE_Context.
PiperOrigin-RevId: 187062992
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 20 | ||||
-rw-r--r-- | tensorflow/c/eager/runtime.cc | 15 | ||||
-rw-r--r-- | tensorflow/c/eager/runtime.h | 6 |
3 files changed, 38 insertions, 3 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index c27a7129fa..bebb63c746 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" @@ -823,6 +824,25 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, delete kernel; return; } + // Update output_dtypes inside `kernel`. + const tensorflow::OpDef* op_def = nullptr; + const tensorflow::FunctionDef* function_def = + ctx->func_lib_def.Find(ndef.op()); + if (function_def != nullptr) { + op_def = &(function_def->signature()); + } + if (op_def == nullptr) { + status->status = OpDefForOp(ndef.op().c_str(), &op_def); + if (!status->status.ok()) { + return; + } + } + tensorflow::DataTypeVector input_dtypes; + status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes, + kernel->output_dtypes()); + if (!status->status.ok()) { + return; + } tensorflow::mutex_lock ml(ctx->cache_mu); tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index f77a937f1f..4bf24fec2c 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -41,17 +41,26 @@ const uint32 kIsList = 1U << 31; } // namespace +Status OpDefForOp(const char* op_name, const OpDef** op_def) { + const OpRegistrationData* op_reg_data = nullptr; + Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + if (s.ok()) { + *op_def = &op_reg_data->op_def; + } + return s; +} + Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { mutex_lock l(g_op_name_to_attr_type_map_lock); *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name); if (*out != nullptr) return Status::OK(); - const OpRegistrationData* op_reg_data = nullptr; - Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + const OpDef* op_def = nullptr; + Status s = OpDefForOp(op_name, &op_def); if (!s.ok()) return s; std::unique_ptr<AttrTypeMap> m(new AttrTypeMap); // TODO(agarwal): Avoid having to create this "registry" at runtime, // perhaps can be done at op registration time? - for (const auto& attr : op_reg_data->op_def.attr()) { + for (const auto& attr : op_def->attr()) { string type = attr.type(); const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0); if (is_list) { diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index 4d20b5244a..7fede4dae9 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -39,6 +39,9 @@ namespace tensorflow { // represent the TF_AttrType type of the values in the list. typedef std::unordered_map<string, uint32> AttrTypeMap; +// Look up OpDef for `op_name`. +Status OpDefForOp(const char* op_name, const OpDef** op_def); + // Returns the AttrTypeMap for the TensorFlow operation named op_name. Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); @@ -180,12 +183,15 @@ class KernelAndDevice { const OpKernel* kernel() const { return kernel_.get(); } + DataTypeVector* output_dtypes() { return &output_dtypes_; } + private: std::unique_ptr<OpKernel> kernel_; Device* device_; FunctionLibraryRuntime* flib_; checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; Rendezvous* rendez_; + DataTypeVector output_dtypes_; }; } // namespace tensorflow |