aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-26 12:33:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 12:37:28 -0800
commitda492741630f62bfd4f8475fa532ef216f0d2bfd (patch)
tree8be78847bc98b3a4833b0f3280a6f1d1aca5b593
parent1120deaf0bf5a51db5351c12b548994b35ba71c8 (diff)
Maintain a cache of output dtypes of ops in TFE_Context.
PiperOrigin-RevId: 187062992
-rw-r--r--tensorflow/c/eager/c_api.cc20
-rw-r--r--tensorflow/c/eager/runtime.cc15
-rw-r--r--tensorflow/c/eager/runtime.h6
3 files changed, 38 insertions, 3 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index c27a7129fa..bebb63c746 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
@@ -823,6 +824,25 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
delete kernel;
return;
}
+ // Update output_dtypes inside `kernel`.
+ const tensorflow::OpDef* op_def = nullptr;
+ const tensorflow::FunctionDef* function_def =
+ ctx->func_lib_def.Find(ndef.op());
+ if (function_def != nullptr) {
+ op_def = &(function_def->signature());
+ }
+ if (op_def == nullptr) {
+ status->status = OpDefForOp(ndef.op().c_str(), &op_def);
+ if (!status->status.ok()) {
+ return;
+ }
+ }
+ tensorflow::DataTypeVector input_dtypes;
+ status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes,
+ kernel->output_dtypes());
+ if (!status->status.ok()) {
+ return;
+ }
tensorflow::mutex_lock ml(ctx->cache_mu);
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
}
diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc
index f77a937f1f..4bf24fec2c 100644
--- a/tensorflow/c/eager/runtime.cc
+++ b/tensorflow/c/eager/runtime.cc
@@ -41,17 +41,26 @@ const uint32 kIsList = 1U << 31;
} // namespace
+Status OpDefForOp(const char* op_name, const OpDef** op_def) {
+ const OpRegistrationData* op_reg_data = nullptr;
+ Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
+ if (s.ok()) {
+ *op_def = &op_reg_data->op_def;
+ }
+ return s;
+}
+
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
mutex_lock l(g_op_name_to_attr_type_map_lock);
*out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
if (*out != nullptr) return Status::OK();
- const OpRegistrationData* op_reg_data = nullptr;
- Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
+ const OpDef* op_def = nullptr;
+ Status s = OpDefForOp(op_name, &op_def);
if (!s.ok()) return s;
std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
// TODO(agarwal): Avoid having to create this "registry" at runtime,
// perhaps can be done at op registration time?
- for (const auto& attr : op_reg_data->op_def.attr()) {
+ for (const auto& attr : op_def->attr()) {
string type = attr.type();
const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
if (is_list) {
diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h
index 4d20b5244a..7fede4dae9 100644
--- a/tensorflow/c/eager/runtime.h
+++ b/tensorflow/c/eager/runtime.h
@@ -39,6 +39,9 @@ namespace tensorflow {
// represent the TF_AttrType type of the values in the list.
typedef std::unordered_map<string, uint32> AttrTypeMap;
+// Look up OpDef for `op_name`.
+Status OpDefForOp(const char* op_name, const OpDef** op_def);
+
// Returns the AttrTypeMap for the TensorFlow operation named op_name.
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
@@ -180,12 +183,15 @@ class KernelAndDevice {
const OpKernel* kernel() const { return kernel_.get(); }
+ DataTypeVector* output_dtypes() { return &output_dtypes_; }
+
private:
std::unique_ptr<OpKernel> kernel_;
Device* device_;
FunctionLibraryRuntime* flib_;
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
Rendezvous* rendez_;
+ DataTypeVector output_dtypes_;
};
} // namespace tensorflow