aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/allocation_description.proto15
-rw-r--r--tensorflow/core/framework/allocator.cc25
-rw-r--r--tensorflow/core/framework/allocator.h132
-rw-r--r--tensorflow/core/framework/allocator_test.cc61
-rw-r--r--tensorflow/core/framework/attr_value.proto57
-rw-r--r--tensorflow/core/framework/attr_value_util.cc382
-rw-r--r--tensorflow/core/framework/attr_value_util.h83
-rw-r--r--tensorflow/core/framework/attr_value_util_test.cc91
-rw-r--r--tensorflow/core/framework/bfloat16.cc22
-rw-r--r--tensorflow/core/framework/bfloat16.h58
-rw-r--r--tensorflow/core/framework/bfloat16_test.cc69
-rw-r--r--tensorflow/core/framework/cancellation.cc79
-rw-r--r--tensorflow/core/framework/cancellation.h121
-rw-r--r--tensorflow/core/framework/cancellation_test.cc102
-rw-r--r--tensorflow/core/framework/config.proto61
-rw-r--r--tensorflow/core/framework/control_flow.h43
-rw-r--r--tensorflow/core/framework/device_attributes.proto35
-rw-r--r--tensorflow/core/framework/device_base.cc7
-rw-r--r--tensorflow/core/framework/device_base.h172
-rw-r--r--tensorflow/core/framework/fake_input.cc214
-rw-r--r--tensorflow/core/framework/fake_input.h25
-rw-r--r--tensorflow/core/framework/function.cc878
-rw-r--r--tensorflow/core/framework/function.h376
-rw-r--r--tensorflow/core/framework/function.proto68
-rw-r--r--tensorflow/core/framework/function_test.cc634
-rw-r--r--tensorflow/core/framework/function_testlib.cc146
-rw-r--r--tensorflow/core/framework/function_testlib.h53
-rw-r--r--tensorflow/core/framework/graph.proto103
-rw-r--r--tensorflow/core/framework/graph_def_util.cc25
-rw-r--r--tensorflow/core/framework/graph_def_util.h29
-rw-r--r--tensorflow/core/framework/kernel_def.proto33
-rw-r--r--tensorflow/core/framework/kernel_def_builder.cc47
-rw-r--r--tensorflow/core/framework/kernel_def_builder.h77
-rw-r--r--tensorflow/core/framework/kernel_def_builder_test.cc76
-rw-r--r--tensorflow/core/framework/lookup_interface.cc45
-rw-r--r--tensorflow/core/framework/lookup_interface.h65
-rw-r--r--tensorflow/core/framework/node_def_builder.cc194
-rw-r--r--tensorflow/core/framework/node_def_builder.h176
-rw-r--r--tensorflow/core/framework/node_def_builder_test.cc1036
-rw-r--r--tensorflow/core/framework/node_def_util.cc414
-rw-r--r--tensorflow/core/framework/node_def_util.h157
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc442
-rw-r--r--tensorflow/core/framework/numeric_op.h96
-rw-r--r--tensorflow/core/framework/numeric_types.h15
-rw-r--r--tensorflow/core/framework/op.cc135
-rw-r--r--tensorflow/core/framework/op.h122
-rw-r--r--tensorflow/core/framework/op_def.proto142
-rw-r--r--tensorflow/core/framework/op_def_builder.cc447
-rw-r--r--tensorflow/core/framework/op_def_builder.h109
-rw-r--r--tensorflow/core/framework/op_def_builder_test.cc519
-rw-r--r--tensorflow/core/framework/op_def_util.cc344
-rw-r--r--tensorflow/core/framework/op_def_util.h32
-rw-r--r--tensorflow/core/framework/op_def_util_test.cc330
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc55
-rw-r--r--tensorflow/core/framework/op_gen_lib.h24
-rw-r--r--tensorflow/core/framework/op_kernel.cc749
-rw-r--r--tensorflow/core/framework/op_kernel.h1250
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc803
-rw-r--r--tensorflow/core/framework/op_segment.cc86
-rw-r--r--tensorflow/core/framework/op_segment.h67
-rw-r--r--tensorflow/core/framework/op_segment_test.cc142
-rw-r--r--tensorflow/core/framework/queue_interface.h77
-rw-r--r--tensorflow/core/framework/reader_interface.h66
-rw-r--r--tensorflow/core/framework/reader_op_kernel.cc39
-rw-r--r--tensorflow/core/framework/reader_op_kernel.h42
-rw-r--r--tensorflow/core/framework/register_types.h90
-rw-r--r--tensorflow/core/framework/rendezvous.cc263
-rw-r--r--tensorflow/core/framework/rendezvous.h102
-rw-r--r--tensorflow/core/framework/rendezvous_test.cc314
-rw-r--r--tensorflow/core/framework/resource_mgr.cc146
-rw-r--r--tensorflow/core/framework/resource_mgr.h280
-rw-r--r--tensorflow/core/framework/resource_mgr_test.cc173
-rw-r--r--tensorflow/core/framework/step_stats.proto58
-rw-r--r--tensorflow/core/framework/summary.proto67
-rw-r--r--tensorflow/core/framework/tensor.cc570
-rw-r--r--tensorflow/core/framework/tensor.proto57
-rw-r--r--tensorflow/core/framework/tensor_description.proto19
-rw-r--r--tensorflow/core/framework/tensor_shape.cc138
-rw-r--r--tensorflow/core/framework/tensor_shape.proto29
-rw-r--r--tensorflow/core/framework/tensor_shape_test.cc75
-rw-r--r--tensorflow/core/framework/tensor_slice.cc226
-rw-r--r--tensorflow/core/framework/tensor_slice.h189
-rw-r--r--tensorflow/core/framework/tensor_slice.proto34
-rw-r--r--tensorflow/core/framework/tensor_slice_test.cc246
-rw-r--r--tensorflow/core/framework/tensor_test.cc551
-rw-r--r--tensorflow/core/framework/tensor_testutil.cc43
-rw-r--r--tensorflow/core/framework/tensor_testutil.h189
-rw-r--r--tensorflow/core/framework/tensor_types.h92
-rw-r--r--tensorflow/core/framework/tensor_util.cc28
-rw-r--r--tensorflow/core/framework/tensor_util.h21
-rw-r--r--tensorflow/core/framework/tensor_util_test.cc124
-rw-r--r--tensorflow/core/framework/tracking_allocator.cc100
-rw-r--r--tensorflow/core/framework/tracking_allocator.h80
-rw-r--r--tensorflow/core/framework/tracking_allocator_test.cc115
-rw-r--r--tensorflow/core/framework/type_traits.h69
-rw-r--r--tensorflow/core/framework/types.cc210
-rw-r--r--tensorflow/core/framework/types.h168
-rw-r--r--tensorflow/core/framework/types.proto48
-rw-r--r--tensorflow/core/framework/types_test.cc117
99 files changed, 17650 insertions, 0 deletions
diff --git a/tensorflow/core/framework/allocation_description.proto b/tensorflow/core/framework/allocation_description.proto
new file mode 100644
index 0000000000..f6f4bc0126
--- /dev/null
+++ b/tensorflow/core/framework/allocation_description.proto
@@ -0,0 +1,15 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+message AllocationDescription {
+ // Total number of bytes requested
+ int64 requested_bytes = 1;
+
+ // Total number of bytes allocated if known
+ int64 allocated_bytes = 2;
+
+ // Name of the allocator used
+ string allocator_name = 3;
+};
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
new file mode 100644
index 0000000000..93f68dcccb
--- /dev/null
+++ b/tensorflow/core/framework/allocator.cc
@@ -0,0 +1,25 @@
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+Allocator::~Allocator() {}
+
+class CPUAllocator : public Allocator {
+ public:
+ ~CPUAllocator() override {}
+
+ string Name() override { return "cpu"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ return port::aligned_malloc(num_bytes, alignment);
+ }
+
+ void DeallocateRaw(void* ptr) override { port::aligned_free(ptr); }
+};
+
+Allocator* cpu_allocator() {
+ static CPUAllocator* cpu_alloc = new CPUAllocator;
+ return cpu_alloc;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
new file mode 100644
index 0000000000..6f162a608c
--- /dev/null
+++ b/tensorflow/core/framework/allocator.h
@@ -0,0 +1,132 @@
+#ifndef TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
+#define TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
+
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <limits>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+// Allocator is an abstract interface for allocating and deallocating
+// device memory.
+class Allocator {
+ public:
+ virtual ~Allocator();
+
+ // Return a string identifying this allocator
+ virtual string Name() = 0;
+
+ // Return an uninitialized block of memory that is "num_bytes" bytes
+ // in size. The returned pointer is guaranteed to be aligned to a
+ // multiple of "alignment" bytes.
+ // REQUIRES: "alignment" is a power of 2.
+ virtual void* AllocateRaw(size_t alignment, size_t num_bytes) = 0;
+
+ // Deallocate a block of memory pointer to by "ptr"
+ // REQUIRES: "ptr" was previously returned by a call to AllocateRaw
+ virtual void DeallocateRaw(void* ptr) = 0;
+
+ // Convenience functions to do typed allocation. Note that these functions
+ // do not invoke C++ constructors or destructors. May return NULL if the
+ // tensor has too many elements to represent in a single allocation.
+ template <typename T>
+ T* Allocate(size_t num_elements) {
+ // TODO(jeff): Do we need to allow clients to pass in alignment
+ // requirements?
+
+ if (num_elements > (std::numeric_limits<size_t>::max() / sizeof(T))) {
+ return NULL;
+ }
+
+ void* p = AllocateRaw(32 /* align to 32 byte boundary */,
+ sizeof(T) * num_elements);
+ return reinterpret_cast<T*>(p);
+ }
+
+ template <typename T>
+ void Deallocate(T* ptr) {
+ DeallocateRaw(ptr);
+ }
+
+ // Returns true if this allocator tracks the sizes of allocations.
+ // RequestedSize and AllocatedSize must be overridden if
+ // TracksAlloctionSizes is overridden to return true.
+ virtual bool TracksAllocationSizes() { return false; }
+
+ // Returns the user-requested size of the data allocated at
+ // 'ptr'. Note that the actual buffer allocated might be larger
+ // than requested, but this function returns the size requested by
+ // the user.
+ //
+ // REQUIRES: TracksAllocationSizes() is true.
+ //
+ // REQUIRES: 'ptr!=nullptr' and points to a buffer previously
+ // allocated by this allocator.
+ virtual size_t RequestedSize(void* ptr) {
+ CHECK(false) << "allocator doesn't track sizes";
+ }
+
+ // Returns the allocated size of the buffer at 'ptr' if known,
+ // otherwise returns RequestedSize(ptr). AllocatedSize(ptr) is
+ // guaranteed to be >= RequestedSize(ptr).
+ //
+ // REQUIRES: TracksAllocationSizes() is true.
+ //
+ // REQUIRES: 'ptr!=nullptr' and points to a buffer previously
+ // allocated by this allocator.
+ virtual size_t AllocatedSize(void* ptr) { return RequestedSize(ptr); }
+
+ // TODO(jeff): Maybe provide some interface to give info about
+ // current allocation state (total number of bytes available for
+ // allocation, number of bytes free on device, etc.)
+};
+
+// A tensorflow Op may need access to different kinds of memory that
+// are not simply a function of the device to which the Op has been
+// assigned. For example, an Op executing on a GPU may still need
+// to allocate CPU RAM for some purpose. Internal to the tensorflow
+// runtime we may choose to allocate CPU ram from special regions
+// that have been prepared for higher performance in some use
+// contexts, e.g. doing DMA with particular devices. For these
+// reasons, the Device interface does not expose just one memory
+// Allocator, but instead provides an accessor that takes a
+// specification of the desired memory attributes in order to select
+// an Allocator.
+//
+// NOTE: The upper 8 bits of the value are reserved for
+// device-specific uses. Implementors of a device can interpret these
+// upper 8 bits in device-specific ways, and ops implemented for those
+// devices are responsible for setting those 8 bits appropriately.
+//
+// Example use:
+// // Allocator for ordinary device memory:
+// Allocator* a = allocator(AllocatorAttributes());
+// ...
+// // Allocator for CPU RAM, regardless of where Op is executing:
+// AllocatorAttributes attr;
+// attr.set_on_host(true);
+// Allocator* a = allocator(attr);
+struct AllocatorAttributes {
+ void set_on_host(bool v) { value |= (static_cast<int>(v)); }
+ bool on_host() const { return value & 0x1; }
+ void set_nic_compatible(bool v) { value |= (static_cast<int>(v) << 1); }
+ bool nic_compatible() const { return value & (0x1 << 1); }
+ void set_gpu_compatible(bool v) { value |= (static_cast<int>(v) << 2); }
+ bool gpu_compatible() const { return value & (0x1 << 2); }
+
+ void Merge(AllocatorAttributes other) { value |= other.value; }
+
+ uint32 value = 0;
+};
+
+// Returns a trivial implementation of Allocator which uses the system
+// default malloc.
+Allocator* cpu_allocator();
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc
new file mode 100644
index 0000000000..6b1e52cfc4
--- /dev/null
+++ b/tensorflow/core/framework/allocator_test.cc
@@ -0,0 +1,61 @@
+#include "tensorflow/core/framework/allocator.h"
+#include <algorithm>
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+namespace tensorflow {
+
+TEST(CPUAllocatorTest, Simple) {
+ Allocator* a = cpu_allocator();
+ std::vector<void*> ptrs;
+ for (int s = 1; s < 1024; s++) {
+ void* raw = a->AllocateRaw(1, s);
+ ptrs.push_back(raw);
+ }
+ std::sort(ptrs.begin(), ptrs.end());
+ for (size_t i = 0; i < ptrs.size(); i++) {
+ if (i > 0) {
+ CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups
+ }
+ a->DeallocateRaw(ptrs[i]);
+ }
+ float* t1 = a->Allocate<float>(1024);
+ double* t2 = a->Allocate<double>(1048576);
+ a->Deallocate(t1);
+ a->Deallocate(t2);
+}
+
+// Define a struct that we will use to observe behavior in the unit tests
+struct TestStruct {
+ int x; // not used just want to make sure sizeof(TestStruct) > 1
+};
+
+TEST(CPUAllocatorTest, CheckStructSize) { CHECK_GT(sizeof(TestStruct), 1); }
+
+TEST(CPUAllocatorTest, AllocateOverflowMaxSizeT) {
+ Allocator* a = cpu_allocator();
+
+ // The maximum size_t value will definitely overflow.
+ size_t count_to_allocate = std::numeric_limits<size_t>::max();
+ TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate);
+
+ CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL));
+}
+
+TEST(CPUAllocatorTest, AllocateOverflowSmallest) {
+ Allocator* a = cpu_allocator();
+
+ // count_to_allocate is the smallest count that will cause overflow.
+ const size_t count_to_allocate =
+ (std::numeric_limits<size_t>::max() / sizeof(TestStruct)) + 1;
+ TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate);
+
+ CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL));
+}
+
+TEST(CPUAllocatorTest, Sizes) {
+ Allocator* a = cpu_allocator();
+
+ EXPECT_EQ(false, a->TracksAllocationSizes());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/attr_value.proto b/tensorflow/core/framework/attr_value.proto
new file mode 100644
index 0000000000..c6a9940815
--- /dev/null
+++ b/tensorflow/core/framework/attr_value.proto
@@ -0,0 +1,57 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+
+// Protocol buffer representing the value for an attr used to configure an Op.
+// Comment indicates the corresponding attr type. Only the field matching the
+// attr type may be filled.
+message AttrValue {
+ message ListValue {
+ repeated bytes s = 2; // "list(string)"
+ repeated int64 i = 3 [packed = true]; // "list(int)"
+ repeated float f = 4 [packed = true]; // "list(float)"
+ repeated bool b = 5 [packed = true]; // "list(bool)"
+ repeated DataType type = 6 [packed = true]; // "list(type)"
+ repeated TensorShapeProto shape = 7; // "list(shape)"
+ repeated TensorProto tensor = 8; // "list(tensor)"
+ // TODO(zhifengc/josh11b): implements list(func) if needed.
+ }
+
+ oneof value {
+ bytes s = 2; // "string"
+ int64 i = 3; // "int"
+ float f = 4; // "float"
+ bool b = 5; // "bool"
+ DataType type = 6; // "type"
+ TensorShapeProto shape = 7; // "shape"
+ TensorProto tensor = 8; // "tensor"
+ ListValue list = 1; // any "list(...)"
+
+ // "func" represents a function. func.name is a function's name or
+ // a primitive op's name. func.attr.first is the name of an attr
+ // defined for that function. func.attr.second is the value for
+ // that attr in the instantiation.
+ NameAttrList func = 10;
+
+ // This is a placeholder only used in nodes defined inside a
+ // function. It indicates the attr value will be supplied when
+ // the function is instantiated. For example, let us suppose a
+ // node "N" in function "FN". "N" has an attr "A" with value
+ // placeholder = "foo". When FN is instantiated with attr "foo"
+ // set to "bar", the instantiated node N's attr A will have been
+ // given the value "bar".
+ string placeholder = 9;
+ }
+}
+
+// A list of attr names and their values. The whole list is attached
+// with a string name. E.g., MatMul[T=float].
+message NameAttrList {
+ string name = 1;
+ map<string, AttrValue> attr = 2;
+}
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
new file mode 100644
index 0000000000..400ef118b8
--- /dev/null
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -0,0 +1,382 @@
+#include "tensorflow/core/framework/attr_value_util.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+namespace {
+
+string SummarizeString(const string& str) {
+ return strings::StrCat("\"", str_util::CEscape(str), "\"");
+}
+
+string SummarizeShape(const TensorShapeProto& proto) {
+ TensorShape shape(proto);
+ return shape.ShortDebugString();
+}
+
+string SummarizeTensor(const TensorProto& tensor_proto) {
+ Tensor t;
+ if (!t.FromProto(tensor_proto)) {
+ return strings::StrCat("<Invalid TensorProto: ",
+ tensor_proto.ShortDebugString(), ">");
+ }
+ return t.DebugString();
+}
+
+} // namespace
+
+string SummarizeAttrValue(const AttrValue& attr_value) {
+ switch (attr_value.value_case()) {
+ case AttrValue::kS:
+ return SummarizeString(attr_value.s());
+ case AttrValue::kI:
+ return strings::StrCat(attr_value.i());
+ case AttrValue::kF:
+ return strings::StrCat(attr_value.f());
+ case AttrValue::kB:
+ return attr_value.b() ? "true" : "false";
+ case AttrValue::kType:
+ return DataType_Name(attr_value.type());
+ case AttrValue::kShape:
+ return SummarizeShape(attr_value.shape());
+ case AttrValue::kTensor:
+ return SummarizeTensor(attr_value.tensor());
+ case AttrValue::kList: {
+ string ret = "[";
+ if (attr_value.list().s_size() > 0) {
+ for (int i = 0; i < attr_value.list().s_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, SummarizeString(attr_value.list().s(i)));
+ }
+ } else if (attr_value.list().i_size() > 0) {
+ for (int i = 0; i < attr_value.list().i_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, attr_value.list().i(i));
+ }
+ } else if (attr_value.list().f_size() > 0) {
+ for (int i = 0; i < attr_value.list().f_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, attr_value.list().f(i));
+ }
+ } else if (attr_value.list().b_size() > 0) {
+ for (int i = 0; i < attr_value.list().b_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false");
+ }
+ } else if (attr_value.list().type_size() > 0) {
+ for (int i = 0; i < attr_value.list().type_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, DataType_Name(attr_value.list().type(i)));
+ }
+ } else if (attr_value.list().shape_size() > 0) {
+ for (int i = 0; i < attr_value.list().shape_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, SummarizeShape(attr_value.list().shape(i)));
+ }
+ } else if (attr_value.list().tensor_size() > 0) {
+ for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret,
+ SummarizeTensor(attr_value.list().tensor(i)));
+ }
+ }
+ strings::StrAppend(&ret, "]");
+ return ret;
+ }
+ case AttrValue::kFunc: {
+ std::vector<string> entries;
+ for (auto p : attr_value.func().attr()) {
+ entries.push_back(
+ strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
+ }
+ sort(entries.begin(), entries.end());
+ return strings::StrCat(attr_value.func().name(), "[",
+ str_util::Join(entries, ", "), "]");
+ }
+ case AttrValue::kPlaceholder:
+ return strings::StrCat("$", attr_value.placeholder());
+ case AttrValue::VALUE_NOT_SET:
+ return "<Unknown AttrValue type>";
+ }
+ return "<Unknown AttrValue type>"; // Prevent missing return warning
+}
+
+Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
+ int num_set = 0;
+
+#define VALIDATE_FIELD(name, type_string, oneof_case) \
+ do { \
+ if (attr_value.has_list()) { \
+ if (attr_value.list().name##_size() > 0) { \
+ if (type != "list(" type_string ")") { \
+ return errors::InvalidArgument( \
+ "AttrValue had value with type list(" type_string ") when ", \
+ type, " expected"); \
+ } \
+ ++num_set; \
+ } \
+ } else if (attr_value.value_case() == AttrValue::oneof_case) { \
+ if (type != type_string) { \
+ return errors::InvalidArgument( \
+ "AttrValue had value with type " type_string " when ", type, \
+ " expected"); \
+ } \
+ ++num_set; \
+ } \
+ } while (false)
+
+ VALIDATE_FIELD(s, "string", kS);
+ VALIDATE_FIELD(i, "int", kI);
+ VALIDATE_FIELD(f, "float", kF);
+ VALIDATE_FIELD(b, "bool", kB);
+ VALIDATE_FIELD(type, "type", kType);
+ VALIDATE_FIELD(shape, "shape", kShape);
+ VALIDATE_FIELD(tensor, "tensor", kTensor);
+
+#undef VALIDATE_FIELD
+
+ if (attr_value.value_case() == AttrValue::kFunc) {
+ if (type != "func") {
+ return errors::InvalidArgument(
+ "AttrValue had value with type 'func' when ", type, " expected");
+ }
+ ++num_set;
+ }
+
+ if (attr_value.value_case() == AttrValue::kPlaceholder) {
+ return errors::InvalidArgument(
+ "AttrValue had value with unexpected type 'placeholder");
+ }
+
+ // If the attr type is 'list', we expect attr_value.has_list() to be true.
+ // However, proto3's attr_value.has_list() can be false when set to an empty
+ // list. So we simply check if has_list is false and some other field in
+ // attr_value is set to flag the error.
+ if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) {
+ if (num_set) {
+ return errors::InvalidArgument(
+ "AttrValue missing value with expected type ", type);
+ } else {
+ // Indicate that we have a list, but an empty one.
+ ++num_set;
+ }
+ }
+
+ // Okay to have an empty list, but not to be missing a non-list value.
+ if (num_set == 0 && !StringPiece(type).starts_with("list(")) {
+ return errors::InvalidArgument(
+ "AttrValue missing value with expected type ", type);
+ }
+
+ // Ref types and DT_INVALID are illegal.
+ if (type == "type") {
+ if (IsRefType(attr_value.type())) {
+ return errors::InvalidArgument(
+ "AttrValue must not have reference type value of ",
+ DataTypeString(attr_value.type()));
+ }
+ if (attr_value.type() == DT_INVALID) {
+ return errors::InvalidArgument("AttrValue has invalid DataType");
+ }
+ } else if (type == "list(type)") {
+ for (auto as_int : attr_value.list().type()) {
+ const DataType dtype = static_cast<DataType>(as_int);
+ if (IsRefType(dtype)) {
+ return errors::InvalidArgument(
+ "AttrValue must not have reference type value of ",
+ DataTypeString(dtype));
+ }
+ if (dtype == DT_INVALID) {
+ return errors::InvalidArgument("AttrValue contains invalid DataType");
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
+ // Parse type.
+ string field_name;
+ bool is_list = type.Consume("list(");
+ if (type.Consume("string")) {
+ field_name = "s";
+ } else if (type.Consume("int")) {
+ field_name = "i";
+ } else if (type.Consume("float")) {
+ field_name = "f";
+ } else if (type.Consume("bool")) {
+ field_name = "b";
+ } else if (type.Consume("type")) {
+ field_name = "type";
+ } else if (type.Consume("shape")) {
+ field_name = "shape";
+ } else if (type.Consume("tensor")) {
+ field_name = "tensor";
+ } else if (type.Consume("func")) {
+ field_name = "func";
+ } else if (type.Consume("placeholder")) {
+ field_name = "placeholder";
+ } else {
+ return false;
+ }
+ if (is_list && !type.Consume(")")) {
+ return false;
+ }
+
+ // Construct a valid text proto message to parse.
+ string to_parse;
+ if (is_list) {
+ // TextFormat parser considers "i: 7" to be the same as "i: [7]",
+ // but we only want to allow list values with [].
+ if (!RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[.*\\]\\s*")) {
+ return false;
+ }
+ if (RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[\\s*\\]\\s*")) {
+ // User wrote "[]", so return empty list without invoking the TextFormat
+ // parse which returns an error for "i: []".
+ out->Clear();
+ out->mutable_list();
+ return true;
+ }
+ to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
+ } else {
+ to_parse = strings::StrCat(field_name, ": ", text);
+ }
+
+ // Parse if we can.
+ return protobuf::TextFormat::ParseFromString(to_parse, out);
+}
+
+#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
+ void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
+
+#define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
+ void SetAttrValue(ARG_TYPE value, AttrValue* out) { \
+ out->mutable_list(); /* create list() even if value empty */ \
+ for (const auto& v : value) { \
+ out->mutable_list()->add_##FIELD(v); \
+ } \
+ }
+
+#define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
+ DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
+ DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
+
+DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
+DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
+DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
+DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
+DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
+DEFINE_SET_ATTR_VALUE_BOTH(float, f)
+DEFINE_SET_ATTR_VALUE_BOTH(double, f)
+DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
+DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
+DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
+DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
+
+void SetAttrValue(StringPiece value, AttrValue* out) {
+ out->set_s(value.data(), value.size());
+}
+
+void SetAttrValue(const TensorShape& value, AttrValue* out) {
+ value.AsProto(out->mutable_shape());
+}
+
+void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
+ out->mutable_list(); // Create list() even if value empty.
+ for (const auto& v : value) {
+ v.AsProto(out->mutable_list()->add_shape());
+ }
+}
+
+void SetAttrValue(const Tensor& value, AttrValue* out) {
+ if (value.NumElements() > 1) {
+ value.AsProtoTensorContent(out->mutable_tensor());
+ } else {
+ value.AsProtoField(out->mutable_tensor());
+ }
+}
+
+void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
+ out->mutable_list(); // Create list() even if value empty.
+ for (const auto& v : value) {
+ if (v.NumElements() > 1) {
+ v.AsProtoTensorContent(out->mutable_list()->add_tensor());
+ } else {
+ v.AsProtoField(out->mutable_list()->add_tensor());
+ }
+ }
+}
+
+void SetAttrValue(const TensorProto& value, AttrValue* out) {
+ *out->mutable_tensor() = value;
+}
+
+void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
+ out->mutable_list(); // Create list() even if value empty.
+ for (const auto& v : value) {
+ *out->mutable_list()->add_tensor() = v;
+ }
+}
+
+void SetAttrValue(const NameAttrList& value, AttrValue* out) {
+ *out->mutable_func() = value;
+}
+
+bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
+ string a_str, b_str;
+ a.SerializeToString(&a_str);
+ b.SerializeToString(&b_str);
+ // Note: it should be safe to compare proto serializations of the attr
+ // values since at most one field should be set in each (indeed, it
+ // must be the same field if they are to compare equal).
+ // Exception: there are multiple equivalent representations of
+ // TensorProtos. So a return value of true implies a == b, but not the
+ // converse.
+ return a_str == b_str;
+}
+
+bool HasPlaceHolder(const AttrValue& val) {
+ switch (val.value_case()) {
+ case AttrValue::kFunc:
+ for (const auto& p : val.func().attr()) {
+ if (HasPlaceHolder(p.second)) {
+ return true;
+ }
+ }
+ break;
+ case AttrValue::kPlaceholder:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value) {
+ switch (value->value_case()) {
+ case AttrValue::kFunc:
+ for (auto& p : *(value->mutable_func()->mutable_attr())) {
+ if (!SubstitutePlaceholders(substitute, &p.second)) {
+ return false;
+ }
+ }
+ break;
+ case AttrValue::kPlaceholder:
+ return substitute(value->placeholder(), value);
+ case AttrValue::VALUE_NOT_SET:
+ return false;
+ default:
+ break;
+ }
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h
new file mode 100644
index 0000000000..1faf74a327
--- /dev/null
+++ b/tensorflow/core/framework/attr_value_util.h
@@ -0,0 +1,83 @@
+#ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
+
+#include <string>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+// A human-readable rendering of attr_value, that is more concise than a
+// text-format proto.
+string SummarizeAttrValue(const AttrValue& attr_value);
+
+// Generates an error if attr_value doesn't have the indicated attr type.
+Status AttrValueHasType(const AttrValue& attr_value, StringPiece type);
+
+// Converts a text proto value from "text" into the the field of *out
+// indicated by "type" (e.g. from the type field of an AttrDef).
+// Examples:
+// * If type:"int" and text:"-14", then *out is set to "i: -14"
+// * If type:"list(string)" and text:"['foo', 'bar']",
+// then *out is set to "list { s: ['foo', 'bar'] }"
+// Returns true on success.
+bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
+
+// Sets *out based on the type of value.
+void SetAttrValue(const string& value, AttrValue* out);
+void SetAttrValue(const char* value, AttrValue* out);
+void SetAttrValue(StringPiece value, AttrValue* out);
+void SetAttrValue(int64 value, AttrValue* out);
+void SetAttrValue(int32 value, AttrValue* out);
+void SetAttrValue(float value, AttrValue* out);
+void SetAttrValue(double value, AttrValue* out);
+void SetAttrValue(bool value, AttrValue* out);
+void SetAttrValue(DataType value, AttrValue* out);
+void SetAttrValue(const TensorShape& value, AttrValue* out);
+void SetAttrValue(const Tensor& value, AttrValue* out);
+void SetAttrValue(const TensorProto& value, AttrValue* out);
+void SetAttrValue(const NameAttrList& value, AttrValue* out);
+
+void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<int64> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<int32> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<float> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<double> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<bool> value, AttrValue* out);
+void SetAttrValue(const std::vector<bool>& value, AttrValue* out);
+void SetAttrValue(std::initializer_list<bool> value, AttrValue* out);
+void SetAttrValue(DataTypeSlice value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<TensorShape> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out);
+
+inline void SetAttrValue(const AttrValue& value, AttrValue* out) {
+ *out = value;
+}
+
+// Returns true if a and b have the same value.
+// NOTE: May return false negatives for tensor values.
+bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b);
+
+// Returns true if "val" has a placeholder.
+bool HasPlaceHolder(const AttrValue& val);
+
+// SubstitutePlaceholders recursively replaces placeholders in 'value'
+// with an attr value by calling SubstituteFunc. Returns true iff all
+// placeholders in "value" are replaced with a value.
+//
+// SubstituteFunc is given a placeholder string. If the placeholder is
+// unknown, SubstituteFunc returns false. Otherwise, overwrites the
+// attr value and returns true.
+typedef std::function<bool(const string&, AttrValue*)> SubstituteFunc;
+bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
new file mode 100644
index 0000000000..bdfbf1707a
--- /dev/null
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -0,0 +1,91 @@
+#include "tensorflow/core/framework/attr_value_util.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+// A few helpers to construct AttrValue protos.
+template <typename T>
+AttrValue V(T value) {
+ AttrValue ret;
+ SetAttrValue(value, &ret);
+ return ret;
+}
+
+AttrValue P(const string& p) {
+ AttrValue ret;
+ ret.set_placeholder(p);
+ return ret;
+}
+
+AttrValue F(const string& name,
+ std::vector<std::pair<string, AttrValue> > pairs) {
+ AttrValue ret;
+ ret.mutable_func()->set_name(name);
+ ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end());
+ return ret;
+}
+
+TEST(AttrValueUtil, HasType) {
+ // OK
+ EXPECT_TRUE(AttrValueHasType(V(123), "int").ok());
+ EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok());
+ EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok());
+ EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok());
+
+ // not OK.
+ EXPECT_FALSE(AttrValueHasType(V(123), "func").ok());
+ EXPECT_FALSE(AttrValueHasType(V(1.2), "int").ok());
+ EXPECT_FALSE(AttrValueHasType(V(DT_FLOAT), "shape").ok());
+ EXPECT_FALSE(AttrValueHasType(F("f", {}), "string").ok());
+ EXPECT_FALSE(AttrValueHasType(P("T"), "float").ok());
+}
+
+SubstituteFunc ReplaceTWith(const AttrValue& val) {
+ return [val](const string& placeholder, AttrValue* target) {
+ if (placeholder == "T") {
+ *target = val;
+ return true;
+ } else {
+ return false;
+ }
+ };
+}
+
+TEST(AttrValueUtil, Basic) {
+ auto v = F("MatMul", {{"dtype", P("T")},
+ {"transpose_a", V(false)},
+ {"transpose_b", V(true)},
+ {"use_cublas", V(true)}});
+ TF_CHECK_OK(AttrValueHasType(v, "func"));
+ EXPECT_TRUE(HasPlaceHolder(v));
+
+ EXPECT_EQ(
+ SummarizeAttrValue(v),
+ "MatMul[dtype=$T, transpose_a=false, transpose_b=true, use_cublas=true]");
+
+ SubstitutePlaceholders(ReplaceTWith(V(DT_FLOAT)), &v);
+ EXPECT_TRUE(!HasPlaceHolder(v));
+ EXPECT_EQ(SummarizeAttrValue(v),
+ "MatMul[dtype=DT_FLOAT, transpose_a=false, transpose_b=true, "
+ "use_cublas=true]");
+}
+
+TEST(AttrValueUtil, DeepAttr) {
+ auto v = F("f", {{"T", P("T")}});
+ TF_CHECK_OK(AttrValueHasType(v, "func"));
+ EXPECT_TRUE(HasPlaceHolder(v));
+
+ for (int i = 0; i < 3; ++i) {
+ v = F("f", {{"T", P("T")}, {"F", v}});
+ EXPECT_TRUE(HasPlaceHolder(v));
+ }
+ EXPECT_EQ(SummarizeAttrValue(v), "f[F=f[F=f[F=f[T=$T], T=$T], T=$T], T=$T]");
+
+ SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v);
+ EXPECT_TRUE(!HasPlaceHolder(v));
+ EXPECT_EQ(SummarizeAttrValue(v),
+ "f[F=f[F=f[F=f[T=x[]], T=x[]], T=x[]], T=x[]]");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/bfloat16.cc b/tensorflow/core/framework/bfloat16.cc
new file mode 100644
index 0000000000..0068283367
--- /dev/null
+++ b/tensorflow/core/framework/bfloat16.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/framework/bfloat16.h"
+
+namespace tensorflow {
+
+void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
+ uint16_t* q = reinterpret_cast<uint16_t*>(dst);
+ for (; size; p += 2, q++, size--) {
+ *q = p[1];
+ }
+}
+
+void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
+ uint16_t* q = reinterpret_cast<uint16_t*>(dst);
+ for (; size; p++, q += 2, size--) {
+ q[0] = 0;
+ q[1] = *p;
+ }
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/bfloat16.h b/tensorflow/core/framework/bfloat16.h
new file mode 100644
index 0000000000..9cd260ee13
--- /dev/null
+++ b/tensorflow/core/framework/bfloat16.h
@@ -0,0 +1,58 @@
+#ifndef TENSORFLOW_FRAMEWORK_BFLOAT16_H_
+#define TENSORFLOW_FRAMEWORK_BFLOAT16_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+// Compact 16-bit encoding of floating point numbers. This representation uses
+// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. It
+// is assumed that floats are in IEEE 754 format so the representation is just
+// bits 16-31 of a single precision float.
+//
+// NOTE: The IEEE floating point standard defines a float16 format that
+// is different than this format (it has fewer bits of exponent and more
+// bits of mantissa). We don't use that format here because conversion
+// to/from 32-bit floats is more complex for that format, and the
+// conversion for this format is very simple.
+//
+// Because of the existing IEEE float16 type, we do not name our representation
+// "float16" but just use "uint16".
+//
+// <-----our 16bits float------->
+// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f
+// <------------------------------float-------------------------->
+// 3 3 2 2 1 1 0
+// 1 0 3 2 5 4 0
+//
+//
+// This type only supports conversion back and forth with float.
+//
+// This file must be compilable by nvcc.
+
+namespace tensorflow {
+struct bfloat16 {
+ EIGEN_DEVICE_FUNC bfloat16() {}
+ EIGEN_DEVICE_FUNC explicit bfloat16(const uint16_t v) : value(v) {}
+
+ uint16_t value;
+};
+
+// Conversion routines between an array of float and bfloat16 of
+// "size".
+void FloatToBFloat16(const float* src, bfloat16* dst, int64 size);
+void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size);
+
+} // namespace tensorflow
+
+namespace Eigen {
+template <>
+struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
+
+EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a,
+ const tensorflow::bfloat16 b) {
+ return a.value == b.value;
+}
+
+} // namespace Eigen
+
+#endif // TENSORFLOW_FRAMEWORK_BFLOAT16_H_
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
new file mode 100644
index 0000000000..4fe791fdeb
--- /dev/null
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -0,0 +1,69 @@
+#include "tensorflow/core/framework/bfloat16.h"
+
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(Bfloat16Test, Simple) {
+ bfloat16 a(12);
+ EXPECT_EQ(12, a.value);
+}
+
+TEST(Bfloat16Test, Conversion) {
+ float a[100];
+ for (int i = 0; i < 100; ++i) {
+ a[i] = i + 1.25;
+ }
+ bfloat16 b[100];
+ float c[100];
+ FloatToBFloat16(a, b, 100);
+ BFloat16ToFloat(b, c, 100);
+ for (int i = 0; i < 100; ++i) {
+ // The relative error should be less than 1/(2^7) since bfloat16
+ // has 7 bits mantissa.
+ EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128);
+ }
+}
+
+static void BM_FloatToBFloat16(int iters) {
+ testing::StopTiming();
+ static const int N = 32 << 20;
+ const int64 tot = static_cast<int64>(iters) * N;
+ testing::ItemsProcessed(tot);
+ testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
+
+ float* inp = new float[N];
+ bfloat16* out = new bfloat16[N];
+
+ testing::StartTiming();
+ while (iters--) {
+ FloatToBFloat16(inp, out, N);
+ }
+ delete[] inp;
+ delete[] out;
+}
+BENCHMARK(BM_FloatToBFloat16);
+
+static void BM_BFloat16ToFloat(int iters) {
+ testing::StopTiming();
+ static const int N = 32 << 20;
+ const int64 tot = static_cast<int64>(iters) * N;
+ testing::ItemsProcessed(tot);
+ testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
+
+ bfloat16* inp = new bfloat16[N];
+ float* out = new float[N];
+
+ testing::StartTiming();
+ while (iters--) {
+ BFloat16ToFloat(inp, out, N);
+ }
+ delete[] inp;
+ delete[] out;
+}
+BENCHMARK(BM_BFloat16ToFloat);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
new file mode 100644
index 0000000000..51423792a8
--- /dev/null
+++ b/tensorflow/core/framework/cancellation.cc
@@ -0,0 +1,79 @@
+#include "tensorflow/core/framework/cancellation.h"
+
+#include <vector>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+const CancellationToken CancellationManager::kInvalidToken = -1;
+
+CancellationManager::CancellationManager()
+ : is_cancelling_(false), is_cancelled_(0), next_cancellation_token_(0) {}
+
+void CancellationManager::StartCancel() {
+ std::unordered_map<CancellationToken, CancelCallback> callbacks_to_run;
+ {
+ mutex_lock l(mu_);
+ if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
+ return;
+ }
+ is_cancelling_ = true;
+ std::swap(callbacks_, callbacks_to_run);
+ }
+ // We call these callbacks without holding mu_, so that concurrent
+ // calls to DeregisterCallback, which can happen asynchronously, do
+ // not block. The callbacks remain valid because any concurrent call
+ // to DeregisterCallback will block until the
+ // cancelled_notification_ is notified.
+ for (auto key_and_value : callbacks_to_run) {
+ key_and_value.second();
+ }
+ {
+ mutex_lock l(mu_);
+ is_cancelling_ = false;
+ is_cancelled_.store(true, std::memory_order_release);
+ }
+ cancelled_notification_.Notify();
+}
+
+CancellationToken CancellationManager::get_cancellation_token() {
+ mutex_lock l(mu_);
+ return next_cancellation_token_++;
+}
+
+bool CancellationManager::RegisterCallback(CancellationToken token,
+ CancelCallback callback) {
+ mutex_lock l(mu_);
+ CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
+ bool should_register = !is_cancelled_ && !is_cancelling_;
+ if (should_register) {
+ std::swap(callbacks_[token], callback);
+ }
+ return should_register;
+}
+
+bool CancellationManager::DeregisterCallback(CancellationToken token) {
+ mu_.lock();
+ if (is_cancelled_) {
+ mu_.unlock();
+ return false;
+ } else if (is_cancelling_) {
+ mu_.unlock();
+ // Wait for all of the cancellation callbacks to be called. This
+ // wait ensures that the caller of DeregisterCallback does not
+ // return immediately and free objects that may be used in the
+ // execution of any currently pending callbacks in StartCancel.
+ cancelled_notification_.WaitForNotification();
+ return false;
+ } else {
+ callbacks_.erase(token);
+ mu_.unlock();
+ return true;
+ }
+}
+
+CancellationManager::~CancellationManager() { StartCancel(); }
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
new file mode 100644
index 0000000000..feda548e97
--- /dev/null
+++ b/tensorflow/core/framework/cancellation.h
@@ -0,0 +1,121 @@
+#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+
+#include <atomic>
+#include <functional>
+#include <unordered_map>
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// A token that can be used to register and deregister a
+// CancelCallback with a CancellationManager.
+//
+// CancellationToken values must be created by a call to
+// CancellationManager::get_cancellation_token.
+typedef int64 CancellationToken;
+
+// A callback that is invoked when a step is cancelled.
+//
+// NOTE(mrry): See caveats about CancelCallback implementations in the
+// comment for CancellationManager::RegisterCallback.
+typedef std::function<void()> CancelCallback;
+
+class CancellationManager {
+ public:
+ // A value that won't be returned by get_cancellation_token().
+ static const CancellationToken kInvalidToken;
+
+ CancellationManager();
+ ~CancellationManager();
+
+ // Run all callbacks associated with this manager.
+ void StartCancel();
+
+ // Returns true iff StartCancel() has been called.
+ bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
+
+ // Returns a token that must be used in calls to RegisterCallback
+ // and DeregisterCallback.
+ CancellationToken get_cancellation_token();
+
+ // Attempts to register the given callback to be invoked when this
+ // manager is cancelled. Returns true if the callback was
+ // registered; returns false if this manager was already cancelled,
+ // and the callback was not registered.
+ //
+ // If this method returns false, it is the caller's responsibility
+ // to perform any cancellation cleanup.
+ //
+ // This method is tricky to use correctly. The following usage pattern
+ // is recommended:
+ //
+ // class ObjectWithCancellableOperation {
+ // mutex mu_;
+ // void CancellableOperation(CancellationManager* cm,
+ // std::function<void(Status)> callback) {
+ // bool already_cancelled;
+ // CancellationToken token = cm->get_cancellation_token();
+ // {
+ // mutex_lock(mu_);
+ // already_cancelled = cm->RegisterCallback(
+ // [this, token]() { Cancel(token); });
+ // if (!already_cancelled) {
+ // // Issue asynchronous operation. Associate the pending operation
+ // // with `token` in some object state, or provide another way for
+ // // the Cancel method to look up the operation for cancellation.
+ // // Ensure that `cm->DeregisterCallback(token)` is called without
+ // // holding `mu_`, before `callback` is invoked.
+ // // ...
+ // }
+ // }
+ // if (already_cancelled) {
+ // callback(errors::Cancelled("Operation was cancelled"));
+ // }
+ // }
+ //
+ // void Cancel(CancellationToken token) {
+ // mutex_lock(mu_);
+ // // Take action to cancel the operation with the given cancellation
+ // // token.
+ // }
+ //
+ // NOTE(mrry): The caller should take care that (i) the calling code
+ // is robust to `callback` being invoked asynchronously (e.g. from
+ // another thread), (ii) `callback` is deregistered by a call to
+ // this->DeregisterCallback(token) when the operation completes
+ // successfully, and (iii) `callback` does not invoke any method
+ // on this cancellation manager. Furthermore, it is important that
+ // the eventual caller of the complementary DeregisterCallback does not
+ // hold any mutexes that are required by `callback`.
+ bool RegisterCallback(CancellationToken token, CancelCallback callback);
+
+ // Deregister the callback that, when registered, was associated
+ // with the given cancellation token. Returns true iff the callback
+ // was deregistered and will not be invoked; otherwise returns false
+ // after the callback has been invoked, blocking if necessary.
+ //
+ // NOTE(mrry): This method may block if cancellation is in progress.
+ // The caller of this method must not hold any mutexes that are required
+ // to invoke any cancellation callback that has been registered with this
+ // cancellation manager.
+ bool DeregisterCallback(CancellationToken token);
+
+ private:
+ bool is_cancelling_;
+ std::atomic_bool is_cancelled_;
+
+ mutex mu_;
+ Notification cancelled_notification_;
+ CancellationToken next_cancellation_token_ GUARDED_BY(mu_);
+ std::unordered_map<CancellationToken, CancelCallback> callbacks_
+ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
new file mode 100644
index 0000000000..1925dd20cc
--- /dev/null
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -0,0 +1,102 @@
+#include "tensorflow/core/framework/cancellation.h"
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(Cancellation, SimpleNoCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ bool deregistered = manager->DeregisterCallback(token);
+ EXPECT_TRUE(deregistered);
+ delete manager;
+ EXPECT_FALSE(is_cancelled);
+}
+
+TEST(Cancellation, SimpleCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ delete manager;
+}
+
+TEST(Cancellation, CancelBeforeRegister) {
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ manager->StartCancel();
+ bool registered = manager->RegisterCallback(token, nullptr);
+ EXPECT_FALSE(registered);
+ delete manager;
+}
+
+TEST(Cancellation, DeregisterAfterCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ bool deregistered = manager->DeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+ delete manager;
+}
+
+TEST(Cancellation, CancelMultiple) {
+ bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token_1 = manager->get_cancellation_token();
+ bool registered_1 = manager->RegisterCallback(
+ token_1, [&is_cancelled_1]() { is_cancelled_1 = true; });
+ EXPECT_TRUE(registered_1);
+ auto token_2 = manager->get_cancellation_token();
+ bool registered_2 = manager->RegisterCallback(
+ token_2, [&is_cancelled_2]() { is_cancelled_2 = true; });
+ EXPECT_TRUE(registered_2);
+ EXPECT_FALSE(is_cancelled_1);
+ EXPECT_FALSE(is_cancelled_2);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled_1);
+ EXPECT_TRUE(is_cancelled_2);
+ EXPECT_FALSE(is_cancelled_3);
+ auto token_3 = manager->get_cancellation_token();
+ bool registered_3 = manager->RegisterCallback(
+ token_3, [&is_cancelled_3]() { is_cancelled_3 = true; });
+ EXPECT_FALSE(registered_3);
+ EXPECT_FALSE(is_cancelled_3);
+ delete manager;
+}
+
+TEST(Cancellation, IsCancelled) {
+ CancellationManager* cm = new CancellationManager();
+ thread::ThreadPool w(Env::Default(), "test", 4);
+ std::vector<Notification> done(8);
+ for (size_t i = 0; i < done.size(); ++i) {
+ Notification* n = &done[i];
+ w.Schedule([n, cm]() {
+ while (!cm->IsCancelled()) {
+ }
+ n->Notify();
+ });
+ }
+ Env::Default()->SleepForMicroseconds(1000000 /* 1 second */);
+ cm->StartCancel();
+ for (size_t i = 0; i < done.size(); ++i) {
+ done[i].WaitForNotification();
+ }
+ delete cm;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto
new file mode 100644
index 0000000000..f0def3d6d7
--- /dev/null
+++ b/tensorflow/core/framework/config.proto
@@ -0,0 +1,61 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+message GPUOptions {
+ // A value between 0 and 1 that indicates what fraction of the
+ // available GPU memory to pre-allocate for each process. 1 means
+ // to pre-allocate all of the GPU memory, 0.5 means the process
+ // allocates ~50% of the available GPU memory.
+ double per_process_gpu_memory_fraction = 1;
+};
+
+// Session configuration parameters.
+// The system picks an appropriate values for fields that are not set.
+message ConfigProto {
+ // Map from device type name (e.g., "CPU" or "GPU" ) to maximum
+ // number of devices of that type to use. If a particular device
+ // type is not found in the map, the system picks an appropriate
+ // number.
+ map<string, int32> device_count = 1;
+
+ // The execution of an individual op (for some op types) can be
+ // parallelized on a pool of intra_op_parallelism_threads.
+ // 0 means the system picks an appropriate number.
+ int32 intra_op_parallelism_threads = 2;
+
+ // Nodes that perform blocking operations are enqueued on a pool of
+ // inter_op_parallelism_threads available in each process.
+ //
+ // 0 means the system picks an appropriate number.
+ //
+ // Note that the first Session created in the process sets the
+ // number of threads for all future sessions.
+ int32 inter_op_parallelism_threads = 5;
+
+ // Assignment of Nodes to Devices is recomputed every placement_period
+ // steps until the system warms up (at which point the recomputation
+ // typically slows down automatically).
+ int32 placement_period = 3;
+
+ // When any filters are present sessions will ignore all devices which do not
+ // match the filters. Each filter can be partially specified, e.g. "/job:ps"
+ // "/job:worker/replica:3", etc.
+ repeated string device_filters = 4;
+
+ // Options that apply to all GPUs.
+ GPUOptions gpu_options = 6;
+
+ // Whether soft placement is allowed. If allow_soft_placement is true,
+ // an op will be placed on CPU if
+ // 1. there's no GPU implementation for the OP
+ // or
+ // 2. no GPU devices are known or registered
+ // or
+ // 3. need to co-locate with reftype input(s) which are from CPU.
+ bool allow_soft_placement = 7;
+
+ // Whether device placements should be logged.
+ bool log_device_placement = 8;
+};
diff --git a/tensorflow/core/framework/control_flow.h b/tensorflow/core/framework/control_flow.h
new file mode 100644
index 0000000000..f59e0f5310
--- /dev/null
+++ b/tensorflow/core/framework/control_flow.h
@@ -0,0 +1,43 @@
+#ifndef TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
+#define TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+
+const uint64 kIllegalFrameId = ~0uLL;
+const int64 kIllegalIterId = -1;
+
+// For the purpose of control flow, every tensor produced by TensorFlow is
+// conceptually tagged by a 'FrameAndIter'. FrameAndIter consists of a
+// 'frame_id' and an 'iter_id'. The tensor value it represents is produced
+// in the frame with frame_id at the iteration of iter_id.
+struct FrameAndIter {
+ uint64 frame_id = kIllegalFrameId;
+ int64 iter_id = kIllegalIterId;
+
+ FrameAndIter() {}
+
+ FrameAndIter(uint64 frame, int64 iter) {
+ frame_id = frame;
+ iter_id = iter;
+ }
+
+ bool operator==(const FrameAndIter& other) const {
+ return (frame_id == other.frame_id && iter_id == other.iter_id);
+ }
+};
+
+struct FrameAndIterHash {
+ size_t operator()(const FrameAndIter& key) const {
+ // Make sure there are no padding bytes that we don't want
+ CHECK_EQ(sizeof(uint64) + sizeof(int64), sizeof(FrameAndIter));
+ return Hash64(reinterpret_cast<const char*>(&key), sizeof(FrameAndIter));
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
diff --git a/tensorflow/core/framework/device_attributes.proto b/tensorflow/core/framework/device_attributes.proto
new file mode 100644
index 0000000000..7592215d1e
--- /dev/null
+++ b/tensorflow/core/framework/device_attributes.proto
@@ -0,0 +1,35 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+// BusAdjacency identifies the ability of a device to participate in
+// maximally efficient DMA operations within the local context of a
+// process.
+//
+// This is currently ignored.
+enum BusAdjacency {
+ BUS_0 = 0;
+ BUS_1 = 1;
+ BUS_ANY = 2;
+ BUS_NUM_ADJACENCIES = 3;
+};
+
+message DeviceAttributes {
+ string name = 1;
+
+ // String representation of device_type.
+ string device_type = 2;
+
+ // Memory capacity of device in bytes.
+ int64 memory_limit = 4;
+
+ BusAdjacency bus_adjacency = 5;
+
+ // A device is assigned a global unique number each time it is
+ // initialized. "incarnation" should never be 0.
+ fixed64 incarnation = 6;
+
+ // String representation of the physical device that this device maps to.
+ string physical_device_desc = 7;
+}
diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc
new file mode 100644
index 0000000000..83ad199062
--- /dev/null
+++ b/tensorflow/core/framework/device_base.cc
@@ -0,0 +1,7 @@
+#include "tensorflow/core/framework/device_base.h"
+
+namespace tensorflow {
+
+DeviceBase::~DeviceBase() {}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
new file mode 100644
index 0000000000..ed4ffc5d94
--- /dev/null
+++ b/tensorflow/core/framework/device_base.h
@@ -0,0 +1,172 @@
+#ifndef TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
+#define TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
+
+#include <memory>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/public/status.h"
+
+namespace Eigen {
+class ThreadPoolDevice;
+} // end namespace Eigen
+
+namespace perftools {
+namespace gputools {
+class Stream;
+} // namespace gputools
+} // namespace perftools
+
+namespace tensorflow {
+
+class Device;
+class Env;
+class EventMgr;
+
+namespace thread {
+class ThreadPool;
+}
+
+// A wrapper for an Eigen Gpu Device that includes per-op state
+class PerOpGpuDevice {
+ public:
+ virtual ~PerOpGpuDevice() {}
+ virtual const Eigen::GpuDevice& device() const = 0;
+};
+
+// A class that devices can subclass to pass around
+// Device-specific context to OpKernels.
+class DeviceContext : public core::RefCounted {
+ public:
+ ~DeviceContext() override {}
+ virtual perftools::gputools::Stream* stream() const { return nullptr; }
+ virtual void MaintainLifetimeOnStream(
+ const Tensor* t, perftools::gputools::Stream* stream) const {}
+
+ // "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into
+ // "device_tensor" which is on a GPU device "device". "device_tensor"
+ // must be allocated to be of the same size as "cpu_tensor".
+ virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const {
+ done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
+ }
+
+ // "device_tensor" is a tensor on a non-CPU device. Copies
+ // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated
+ // to be of the same size as "device_tensor".
+ virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ const string& tensor_name, Device* device,
+ Tensor* cpu_tensor, StatusCallback done) {
+ done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
+ }
+};
+
+typedef std::unordered_map<int, DeviceContext*> DeviceContextMap;
+
+class DeviceBase {
+ public:
+ explicit DeviceBase(Env* env) : env_(env) {}
+ virtual ~DeviceBase();
+
+ Env* env() const { return env_; }
+
+ // Override this to return true for devices that require an Op's
+ // compute method to save references to the temporary tensors it
+ // allocates until the Op execution completes
+ virtual bool SaveTemporaryTensors() const { return false; }
+
+ struct CpuWorkerThreads {
+ int num_threads = 0;
+ thread::ThreadPool* workers = nullptr;
+ };
+
+ // Does not take ownership.
+ void set_tensorflow_cpu_worker_threads(CpuWorkerThreads* t) {
+ cpu_worker_threads_ = t;
+ }
+
+ const CpuWorkerThreads* tensorflow_cpu_worker_threads() const {
+ CHECK(cpu_worker_threads_ != nullptr);
+ return cpu_worker_threads_;
+ }
+
+ // "stream" is used in special circumstances (such as the
+ // constructors of Ops) where there is no available OpKernelContext.
+ // "default_context" is used by OpKernelContext whenever a device does not
+ // supply a DeviceContext for an op in FillContextMap (e.g. when only
+ // using a single stream.)
+ // "event_mgr" is used to delay deallocation of temporary GPU buffers.
+ // TODO(pbar) Work out how to move this out of DeviceBase.
+ struct GpuDeviceInfo {
+ perftools::gputools::Stream* stream;
+ DeviceContext* default_context;
+ EventMgr* event_mgr;
+ };
+
+ // Does not take ownership.
+ void set_tensorflow_gpu_device_info(GpuDeviceInfo* g) {
+ gpu_device_info_ = g;
+ }
+
+ const GpuDeviceInfo* tensorflow_gpu_device_info() const {
+ return gpu_device_info_;
+ }
+
+ // Does not take ownership.
+ void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
+ eigen_cpu_device_ = d;
+ }
+
+ // Return the Allocator implementation to use based on the allocator
+ // attributes requested. See allocator.h for more details.
+ virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
+ LOG(FATAL) << "GetAllocator() is not implemented.";
+ }
+
+ const Eigen::ThreadPoolDevice* eigen_cpu_device() {
+ CHECK(eigen_cpu_device_ != nullptr);
+ return eigen_cpu_device_;
+ }
+
+ // The caller owns the returned device and must free it by calling
+ // DisposeGpuDevice below
+ virtual const PerOpGpuDevice* MakeGpuDevice(DeviceContext* /*dc*/,
+ Allocator* /*allocator*/) {
+ // The OpKernelContext calls this even for devices that do not
+ // implement an eigen_gpu_device
+ return nullptr;
+ }
+
+ virtual const DeviceAttributes& attributes() const {
+ LOG(FATAL) << "Device does not implement attributes()";
+ }
+
+ // Materializes the given TensorProto into 'tensor' stored in Device
+ // memory. Most devices will want to override this.
+ //
+ // TODO(vrv): We should be able to put this function into
+ // OpKernelContext and handle the copies from device memory via send
+ // and receive nodes, instead of requiring that each device handle
+ // the copies here as well as in copy ops.
+ virtual Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) {
+ return errors::Internal("Device does not implement MakeTensorFromProto()");
+ }
+
+ private:
+ Env* const env_;
+ CpuWorkerThreads* cpu_worker_threads_ = nullptr;
+ GpuDeviceInfo* gpu_device_info_ = nullptr;
+ Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc
new file mode 100644
index 0000000000..493c35e05f
--- /dev/null
+++ b/tensorflow/core/framework/fake_input.cc
@@ -0,0 +1,214 @@
+#include "tensorflow/core/framework/fake_input.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace {
+
+class FakeInputImpl {
+ public:
+ FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def,
+ NodeDefBuilder* builder);
+ void SetN(int n);
+ void SetDataType(DataType dt);
+ void SetTypeList(DataTypeSlice dts);
+ Status AddInputToBuilder();
+
+ private:
+ static string FakeNodeName(int in_index);
+ Status GetN(int* n) const;
+ Status GetDataType(DataType* dt) const;
+ void NSources(int n, DataType dt) const;
+ void SourceList(DataTypeSlice dts) const;
+
+ const OpDef* const op_def_;
+ const OpDef::ArgDef* const arg_;
+ const string in_node_;
+ const NodeDef* const node_def_;
+ NodeDefBuilder* const builder_;
+
+ bool n_specified_;
+ int n_;
+ bool dt_specified_;
+ DataType dt_;
+ bool dts_specified_;
+ DataTypeSlice dts_;
+};
+
+FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index,
+ const NodeDef* node_def, NodeDefBuilder* builder)
+ : op_def_(op_def),
+ arg_(&op_def->input_arg(in_index)),
+ in_node_(FakeNodeName(in_index)),
+ node_def_(node_def),
+ builder_(builder),
+ n_specified_(false),
+ dt_specified_(false),
+ dts_specified_(false) {}
+
+void FakeInputImpl::SetN(int n) {
+ n_specified_ = true;
+ n_ = n;
+}
+
+void FakeInputImpl::SetDataType(DataType dt) {
+ dt_specified_ = true;
+ dt_ = dt;
+}
+
+void FakeInputImpl::SetTypeList(DataTypeSlice dts) {
+ dts_specified_ = true;
+ dts_ = dts;
+}
+
+Status FakeInputImpl::AddInputToBuilder() {
+ if (dts_specified_) {
+ SourceList(dts_);
+
+ } else if (n_specified_ || !arg_->number_attr().empty()) {
+ int n;
+ TF_RETURN_IF_ERROR(GetN(&n));
+
+ DataType dt;
+ if (n > 0) {
+ TF_RETURN_IF_ERROR(GetDataType(&dt));
+ } else {
+ dt = DT_FLOAT;
+ }
+
+ NSources(n, dt);
+ } else {
+ if (!dt_specified_ && !arg_->type_list_attr().empty()) {
+ DataTypeVector dts;
+ Status status =
+ GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts);
+ if (!status.ok()) {
+ return errors::InvalidArgument(
+ "Could not infer list of types for input '", arg_->name(), "': ",
+ status.error_message());
+ }
+ SourceList(dts);
+ return Status::OK();
+ }
+
+ DataType dt;
+ TF_RETURN_IF_ERROR(GetDataType(&dt));
+ builder_->Input(in_node_, 0, dt);
+ }
+ return Status::OK();
+}
+
+// static
+string FakeInputImpl::FakeNodeName(int in_index) {
+ char c = 'a' + (in_index % 26);
+ return string(&c, 1);
+}
+
+Status FakeInputImpl::GetN(int* n) const {
+ if (n_specified_) {
+ *n = n_;
+ } else {
+ Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n);
+ if (!status.ok()) {
+ return errors::InvalidArgument("Could not infer length of input '",
+ arg_->name(), "': ",
+ status.error_message());
+ }
+ }
+ return Status::OK();
+}
+
+Status FakeInputImpl::GetDataType(DataType* dt) const {
+ if (dt_specified_) {
+ *dt = dt_;
+ } else if (arg_->type() != DT_INVALID) {
+ *dt = arg_->type();
+ } else if (!arg_->type_attr().empty()) {
+ Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt);
+ if (!status.ok()) {
+ return errors::InvalidArgument("Could not infer type for input '",
+ arg_->name(), "': ",
+ status.error_message());
+ }
+ } else {
+ return errors::InvalidArgument("No type or type_attr field in arg '",
+ arg_->name(), "'");
+ }
+ return Status::OK();
+}
+
+void FakeInputImpl::NSources(int n, DataType dt) const {
+ std::vector<NodeDefBuilder::NodeOut> srcs;
+ srcs.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ srcs.emplace_back(in_node_, i, dt);
+ }
+ builder_->Input(srcs);
+}
+
+void FakeInputImpl::SourceList(DataTypeSlice dts) const {
+ std::vector<NodeDefBuilder::NodeOut> srcs;
+ srcs.reserve(dts.size());
+ for (size_t i = 0; i < dts.size(); ++i) {
+ srcs.emplace_back(in_node_, i, dts[i]);
+ }
+ builder_->Input(srcs);
+}
+
+} // namespace
+
+// Public interface ------------------------------------------------------------
+
+FakeInputFunctor FakeInput() {
+ return [](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(DataType dt) {
+ return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetDataType(dt);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(int n) {
+ return [n](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetN(n);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(int n, DataType dt) {
+ return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetN(n);
+ impl.SetDataType(dt);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(DataTypeSlice dts) {
+ // Make a copy to ensure the data will still be around when the lambda is
+ // called.
+ DataTypeVector dtv(dts.begin(), dts.end());
+ return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetTypeList(dtv);
+ return impl.AddInputToBuilder();
+ };
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/fake_input.h b/tensorflow/core/framework/fake_input.h
new file mode 100644
index 0000000000..39b38e9a59
--- /dev/null
+++ b/tensorflow/core/framework/fake_input.h
@@ -0,0 +1,25 @@
+#ifndef TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
+#define TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+
+// These functions return values that may be passed to
+// NodeDefBuilder::Input() to add an input for a test. Use them when
+// you don't care about the node names/output indices providing the
+// input. They also allow you to omit the input types and/or
+// list length when they may be inferred.
+FakeInputFunctor FakeInput(); // Infer everything
+FakeInputFunctor FakeInput(DataType dt);
+FakeInputFunctor FakeInput(int n); // List of length n
+FakeInputFunctor FakeInput(int n, DataType dt);
+FakeInputFunctor FakeInput(DataTypeSlice dts);
+inline FakeInputFunctor FakeInput(std::initializer_list<DataType> dts) {
+ return FakeInput(DataTypeSlice(dts));
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
new file mode 100644
index 0000000000..b73e1ab8a9
--- /dev/null
+++ b/tensorflow/core/framework/function.cc
@@ -0,0 +1,878 @@
+#include "tensorflow/core/framework/function.h"
+
+#include <unordered_set>
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+REGISTER_OP("_Arg")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("index: int >= 0")
+ .Doc(R"doc(
+A graph node which represents an argument to a function.
+
+output: The argument.
+index: This argument is the index-th argument of the function.
+)doc");
+
+REGISTER_OP("_Retval")
+ .Input("input: T")
+ .Attr("T: type")
+ .Attr("index: int >= 0")
+ .Doc(R"doc(
+A graph node which represents a return value of a function.
+
+input: The return value.
+index: This return value is the index-th return value of the function.
+)doc");
+
+REGISTER_OP("_ListToArray")
+ .Input("input: Tin")
+ .Output("output: N * T")
+ .Attr("Tin: list(type)")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Doc(R"doc(
+Converts a list of tensors to an array of tensors.
+)doc");
+
+REGISTER_OP("_ArrayToList")
+ .Input("input: N * T")
+ .Output("output: out_types")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Attr("out_types: list(type)")
+ .Doc(R"doc(
+Converts an array of tensors to a list of tensors.
+)doc");
+
+namespace {
+
+// Extracts the actual type from "attr_values" based on its definition
+// "arg_def".
+Status ArgNumType(const InstantiateAttrValueMap& attrs,
+ const OpDef::ArgDef& arg_def, int* num, DataType* dtype) {
+ if (!arg_def.type_list_attr().empty()) {
+ return errors::Unimplemented("type_list is not supported.");
+ }
+
+ if (arg_def.number_attr().empty()) {
+ *num = 1;
+ } else {
+ const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr());
+ if (v == nullptr) {
+ return errors::NotFound("type attr not found: ", arg_def.type_attr());
+ }
+ *num = v->i();
+ }
+
+ if (arg_def.type() != DT_INVALID) {
+ *dtype = arg_def.type();
+ } else if (arg_def.type_attr().empty()) {
+ *dtype = DT_INVALID;
+ } else {
+ const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr());
+ if (v == nullptr) {
+ return errors::NotFound("type attr not found: ", arg_def.type_attr());
+ }
+ *dtype = v->type();
+ }
+ return Status::OK();
+}
+
+string Name(int node_index) { return strings::StrCat("n", node_index); }
+
+string Name(int node_index, int output_index) {
+ if (output_index == 0) {
+ return Name(node_index);
+ } else {
+ return strings::StrCat("n", node_index, ":", output_index);
+ }
+}
+
+string Dep(int node_index) { return strings::StrCat("^", Name(node_index)); }
+
+template <typename T>
+void AddAttr(const string& name, const T& val, NodeDef* ndef) {
+ SetAttrValue(val, &((*ndef->mutable_attr())[name]));
+}
+
+Status ValidateSignatureWithAttrs(const OpDef& sig,
+ const InstantiateAttrValueMap& attr_values) {
+ // attr_values should specify all attrs defined in fdef.
+ for (const auto& a : sig.attr()) {
+ if (attr_values.find(a.name()) == attr_values.end()) {
+ return errors::NotFound("Attr ", a.name(), " is not found.");
+ }
+ }
+
+ for (const auto& p : attr_values) {
+ if (HasPlaceHolder(p.second)) {
+ return errors::InvalidArgument(p.first,
+ " in attr_values is still a placeholder.");
+ }
+ }
+
+ return Status::OK();
+}
+
+// We build a small index for all names that can be used as a node's
+// input arguments.
+//
+// If is_func_arg is true, the name is a function's argument. In
+// this case, the produced graph def has gdef.node[nid ... nid +
+// num).
+//
+// Otherwise, the name is a function body's node return value. In
+// this case, the produced graph def has one node gdef.node[nid] and
+// the node's output index [idx ... idx + num) corresponds to the
+// named outputs.
+//
+// In all cases, "dtype" specifies the data type.
+struct NameInfoItem {
+ bool is_func_arg;
+ int nid;
+ int idx;
+ int num;
+ DataType dtype;
+};
+typedef std::unordered_map<string, NameInfoItem> NameInfoIndex;
+
+Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
+ const InstantiateAttrValueMap& attr_values,
+ NameInfoIndex* name_info,
+ InstantiationResult* result) {
+ int num;
+ DataType dtype;
+ TF_RETURN_IF_ERROR(ArgNumType(attr_values, arg_def, &num, &dtype));
+ CHECK_GE(num, 1);
+ GraphDef* gdef = &result->gdef;
+ int arg_index = gdef->node_size();
+ if (!name_info->insert({arg_def.name(), {true, arg_index, 0, num, dtype}})
+ .second) {
+ return errors::InvalidArgument("Duplicated arg name.");
+ }
+ // Creates "num" nodes in the gdef.
+ for (int i = 0; i < num; ++i) {
+ DCHECK_EQ(arg_index, gdef->node_size());
+ NodeDef* gnode = gdef->add_node();
+ gnode->set_name(Name(arg_index));
+ gnode->set_op("_Arg");
+ AddAttr("T", dtype, gnode);
+ AddAttr("index", arg_index, gnode);
+ result->arg_types.push_back(dtype);
+ ++arg_index;
+ }
+ return Status::OK();
+}
+
+Status BuildNodeOutputIndex(const FunctionDef::Node& node,
+ const InstantiateAttrValueMap& attrs,
+ GetFunctionSignature get_function,
+ const int arg_index, NameInfoIndex* name_info) {
+ const OpDef* node_sig = nullptr;
+ TF_RETURN_IF_ERROR(get_function(node.op(), &node_sig));
+ if (node_sig->output_arg_size() == 0) {
+ // This node produces no output.
+ if (node.ret_size() != 1) {
+ return errors::InvalidArgument("Expect one ret name.");
+ }
+ if (!name_info->insert({node.ret(0), {false, arg_index, 0, 0, DT_INVALID}})
+ .second) {
+ return errors::InvalidArgument("Duplicated ret name.");
+ }
+ return Status::OK();
+ }
+
+ // When the signature says the last return value is of list(type),
+ // i.e., it's variadic, we need to consult
+ // attrs[last_retval.type_list_attr] to determine for the last arg
+ // * the actual number of outputs;
+ // * the actual data type of outputs.
+ const int num_retval = node_sig->output_arg_size();
+ const OpDef::ArgDef& last_retval = node_sig->output_arg(num_retval - 1);
+ const bool last_retval_is_typelist = !last_retval.type_list_attr().empty();
+ if (!last_retval_is_typelist && (node.ret_size() != num_retval)) {
+ return errors::InvalidArgument("Malformed function node (#ret).");
+ }
+ int start = 0;
+ const int num_fixed_size_retval =
+ last_retval_is_typelist ? num_retval - 1 : num_retval;
+ for (int i = 0; i < num_fixed_size_retval; ++i) {
+ int num;
+ DataType dtype;
+ TF_RETURN_IF_ERROR(
+ ArgNumType(attrs, node_sig->output_arg(i), &num, &dtype));
+ if (!name_info->insert({node.ret(i), {false, arg_index, start, num, dtype}})
+ .second) {
+ return errors::InvalidArgument("Duplicated ret name.");
+ }
+ start += num;
+ }
+ if (last_retval_is_typelist) {
+ const AttrValue* typelist =
+ gtl::FindOrNull(attrs, last_retval.type_list_attr());
+ if (typelist == nullptr) {
+ return errors::InvalidArgument("Missing attr ",
+ last_retval.type_list_attr(), ".");
+ }
+ if (num_fixed_size_retval + typelist->list().type_size() !=
+ node.ret_size()) {
+ return errors::InvalidArgument("Wrong #ret: ", num_fixed_size_retval, " ",
+ typelist->list().type_size(), " ",
+ node.ret_size(), ".");
+ }
+ for (int i = 0; i < typelist->list().type_size(); ++i) {
+ if (!name_info->insert({node.ret(i),
+ {false, arg_index, start, 1,
+ typelist->list().type(i)}})
+ .second) {
+ return errors::InvalidArgument("Duplicated ret name.");
+ }
+ ++start;
+ }
+ }
+ return Status::OK();
+}
+
+Status InstantiateNode(const FunctionDef::Node& fnode,
+ const InstantiateAttrValueMap& attrs,
+ GetFunctionSignature get_function,
+ const NameInfoIndex& name_info, GraphDef* gdef) {
+ const OpDef* fnode_sig = nullptr;
+ TF_CHECK_OK(get_function(fnode.op(), &fnode_sig));
+ NodeDef* gnode = gdef->add_node();
+ gnode->set_name(Name(gdef->node_size() - 1));
+ gnode->set_op(fnode.op());
+
+ // Input
+ //
+ // When the signature says the last argument is of list(type),
+ // i.e., it's variadic, we need to consult
+ // attrs[last_arg.type_list_attr] to determine for the last arg
+ // * the number of arguments;
+ // * the data types of arguments.
+ const int num_arg = fnode_sig->input_arg_size();
+ bool last_arg_is_typelist = false;
+ if (num_arg > 0 &&
+ !fnode_sig->input_arg(num_arg - 1).type_list_attr().empty()) {
+ last_arg_is_typelist = true;
+ }
+ if (!last_arg_is_typelist && (fnode.arg_size() != num_arg)) {
+ return errors::InvalidArgument("arg.size != sig.arg.size.");
+ }
+ const int num_fixed_size_args = last_arg_is_typelist ? num_arg - 1 : num_arg;
+ for (int i = 0; i < num_fixed_size_args; ++i) {
+ int num;
+ DataType dtype;
+ TF_RETURN_IF_ERROR(
+ ArgNumType(attrs, fnode_sig->input_arg(i), &num, &dtype));
+ const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i));
+ if (item == nullptr) {
+ return errors::InvalidArgument("arg[", i, "] is not found: ",
+ fnode.ShortDebugString());
+ }
+ if (num != item->num || dtype != item->dtype) {
+ return errors::InvalidArgument("Invalid arg(", i, ") for function arg: ",
+ " ", num, "/", dtype, " vs. ", item->num,
+ "/", item->dtype, ".");
+ }
+ for (int j = 0; j < num; ++j) {
+ if (item->is_func_arg) {
+ gnode->add_input(Name(item->nid + j));
+ } else {
+ gnode->add_input(Name(item->nid, item->idx + j));
+ }
+ }
+ }
+ if (last_arg_is_typelist) {
+ AttrValue typelist;
+ for (int i = num_fixed_size_args; i < fnode.arg_size(); ++i) {
+ const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i));
+ if (item == nullptr) {
+ return errors::InvalidArgument("arg[", i, "] is not found.");
+ }
+ for (int j = 0; j < item->num; ++j) {
+ if (item->is_func_arg) {
+ gnode->add_input(Name(item->nid + j));
+ } else {
+ gnode->add_input(Name(item->nid, item->idx + j));
+ }
+ typelist.mutable_list()->add_type(item->dtype);
+ }
+ }
+
+ // 'typelist' is inferred from the inputs' data types.
+ const auto& last_arg = fnode_sig->input_arg(num_arg - 1);
+ gnode->mutable_attr()->insert({last_arg.type_list_attr(), typelist});
+ }
+
+ // Control deps.
+ for (int i = 0; i < fnode.dep_size(); ++i) {
+ const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.dep(i));
+ if (item == nullptr) {
+ return errors::InvalidArgument("dep[", i, "] is not found.");
+ }
+ gnode->add_input(Dep(item->nid));
+ }
+
+ // Attrs.
+ for (const auto& p : attrs) {
+ (*gnode->mutable_attr())[p.first] = p.second;
+ }
+
+ return Status::OK();
+}
+
+Status AddReturnNode(const OpDef::ArgDef& ret_def,
+ const InstantiateAttrValueMap& attrs,
+ const NameInfoIndex& name_info, int* ret_index,
+ InstantiationResult* result) {
+ int num;
+ DataType dtype;
+ TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &num, &dtype));
+ CHECK_GE(num, 1);
+ const NameInfoItem* item = gtl::FindOrNull(name_info, ret_def.name());
+ if (item == nullptr) {
+ return errors::InvalidArgument("ret is not found.");
+ }
+ if (num != item->num || dtype != item->dtype) {
+ return errors::InvalidArgument("Invalid ret name.");
+ }
+ GraphDef* gdef = &result->gdef;
+ for (int i = 0; i < num; ++i) {
+ NodeDef* gnode = gdef->add_node();
+ gnode->set_name(Name(gdef->node_size() - 1));
+ gnode->set_op("_Retval");
+ gnode->add_input(Name(item->nid, item->idx + i));
+ AddAttr("T", dtype, gnode);
+ AddAttr("index", (*ret_index)++, gnode);
+ result->ret_types.push_back(dtype);
+ }
+ return Status::OK();
+}
+
+// Various helpers Print(proto) to print relevant protos to ascii.
+string Print(const OpDef::ArgDef& arg) {
+ string out;
+ strings::StrAppend(&out, arg.name(), ":");
+ if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
+ if (!arg.number_attr().empty()) {
+ strings::StrAppend(&out, arg.number_attr(), "*");
+ }
+ if (arg.type() != DT_INVALID) {
+ strings::StrAppend(&out, DataTypeString(arg.type()));
+ } else {
+ strings::StrAppend(&out, arg.type_attr());
+ }
+ if (arg.is_ref()) strings::StrAppend(&out, ")");
+ return out;
+}
+
+string Print(const AttrValue& attr_value) {
+ if (attr_value.value_case() == AttrValue::kType) {
+ return DataTypeString(attr_value.type());
+ } else if ((attr_value.value_case() == AttrValue::kList) &&
+ (attr_value.list().type_size() > 0)) {
+ string ret = "{";
+ for (int i = 0; i < attr_value.list().type_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
+ }
+ strings::StrAppend(&ret, "}");
+ return ret;
+ } else if (attr_value.value_case() == AttrValue::kFunc) {
+ if (attr_value.func().attr_size() == 0) {
+ return attr_value.func().name();
+ }
+ std::vector<string> entries;
+ for (auto p : attr_value.func().attr()) {
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ }
+ sort(entries.begin(), entries.end());
+ return strings::StrCat(attr_value.func().name(), "[",
+ str_util::Join(entries, ", "), "]");
+ }
+ return SummarizeAttrValue(attr_value);
+}
+
+string Print(const FunctionDef::Node& node) {
+ string out;
+ for (int i = 0; i < node.ret_size(); ++i) {
+ const auto& name = node.ret(i);
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, name);
+ }
+ strings::StrAppend(&out, " = ", node.op());
+ if (node.attr_size() > 0) {
+ std::vector<string> entries;
+ for (auto p : node.attr()) {
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ }
+ sort(entries.begin(), entries.end());
+ strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
+ }
+ strings::StrAppend(&out, "(");
+ for (int i = 0; i < node.arg_size(); ++i) {
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, node.arg(i));
+ }
+ strings::StrAppend(&out, ")");
+ if (node.dep_size() > 0) {
+ strings::StrAppend(&out, " @ ");
+ for (int i = 0; i < node.dep_size(); ++i) {
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, node.dep(i));
+ }
+ }
+ return out;
+}
+
+string Print(const FunctionDef& fdef) {
+ string out;
+ const OpDef& sig = fdef.signature();
+ strings::StrAppend(&out, "\n", sig.name());
+ if (sig.attr_size() > 0) {
+ strings::StrAppend(&out, "[");
+ for (int i = 0; i < sig.attr_size(); ++i) {
+ const auto& a = sig.attr(i);
+ if (i > 0) strings::StrAppend(&out, ", ");
+ if (a.type() == "type") {
+ strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
+ } else {
+ strings::StrAppend(&out, a.name(), ":", a.type());
+ }
+ }
+ strings::StrAppend(&out, "]");
+ }
+ strings::StrAppend(&out, "(");
+ for (int i = 0; i < sig.input_arg_size(); ++i) {
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, Print(sig.input_arg(i)));
+ }
+ strings::StrAppend(&out, ") -> (");
+ for (int i = 0; i < sig.output_arg_size(); ++i) {
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, Print(sig.output_arg(i)));
+ }
+ strings::StrAppend(&out, ") {\n");
+ for (const auto& n : fdef.node()) {
+ strings::StrAppend(&out, " ", Print(n), "\n");
+ }
+ strings::StrAppend(&out, "}\n");
+ return out;
+}
+
+string Print(const NodeDef& n) {
+ string out;
+ strings::StrAppend(&out, n.name(), " = ", n.op());
+ if (n.attr_size() > 0) {
+ std::vector<string> entries;
+ for (auto& a : n.attr()) {
+ entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
+ }
+ sort(entries.begin(), entries.end());
+ strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
+ }
+ strings::StrAppend(&out, "(");
+ std::vector<StringPiece> dat;
+ std::vector<string> dep;
+ for (StringPiece s : n.input()) {
+ if (s.Consume("^")) {
+ dep.push_back(s.ToString());
+ } else {
+ dat.push_back(s);
+ }
+ }
+ strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
+ if (!dep.empty()) {
+ strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
+ }
+ return out;
+}
+
+string Print(const GraphDef& gdef) {
+ std::vector<const NodeDef*> arg;
+ std::vector<const NodeDef*> ret;
+ std::vector<const NodeDef*> body;
+ for (const NodeDef& n : gdef.node()) {
+ if (n.op() == "_Arg") {
+ arg.push_back(&n);
+ } else if (n.op() == "_Retval") {
+ ret.push_back(&n);
+ } else {
+ body.push_back(&n);
+ }
+ }
+ auto comp = [](const NodeDef* x, const NodeDef* y) {
+ int xi;
+ TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
+ int yi;
+ TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
+ return xi < yi;
+ };
+ sort(arg.begin(), arg.end(), comp);
+ sort(ret.begin(), ret.end(), comp);
+ string out;
+ strings::StrAppend(&out, "\n(");
+ auto get_type = [](const NodeDef& n) {
+ for (auto a : n.attr()) {
+ if (a.first == "T") {
+ return DataTypeString(a.second.type());
+ }
+ }
+ return DataTypeString(DT_INVALID);
+ };
+ for (size_t i = 0; i < arg.size(); ++i) {
+ const NodeDef* n = arg[i];
+ if (i > 0) strings::StrAppend(&out, ", ");
+ CHECK_EQ(2, n->attr_size());
+ strings::StrAppend(&out, n->name(), ":", get_type(*n));
+ }
+ strings::StrAppend(&out, ") -> (");
+ for (size_t i = 0; i < ret.size(); ++i) {
+ const NodeDef* n = ret[i];
+ if (i > 0) strings::StrAppend(&out, ", ");
+ CHECK_EQ(2, n->attr_size());
+ CHECK_EQ(1, n->input_size());
+ strings::StrAppend(&out, n->input(0), ":", get_type(*n));
+ }
+ strings::StrAppend(&out, ") {\n");
+ for (size_t i = 0; i < body.size(); ++i) {
+ strings::StrAppend(&out, " ", Print(*body[i]), "\n");
+ }
+ strings::StrAppend(&out, "}\n");
+ return out;
+}
+
+} // end namespace
+
+Status InstantiateFunction(const FunctionDef& fdef,
+ const InstantiateAttrValueMap& attr_values,
+ GetFunctionSignature get_function,
+ InstantiationResult* result) {
+ const OpDef& sig = fdef.signature();
+ GraphDef* gdef = &result->gdef;
+ gdef->Clear();
+
+ TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
+
+ auto substitute = [&attr_values](const string& name, AttrValue* val) {
+ auto iter = attr_values.find(name);
+ if (iter == attr_values.end()) {
+ return false;
+ } else {
+ *val = iter->second;
+ return true;
+ }
+ };
+
+ // Makes a copy of all attrs in fdef and substitutes placeholders.
+ // After this step, every attr is bound to a concrete value.
+ std::vector<InstantiateAttrValueMap> node_attrs;
+ node_attrs.resize(fdef.node_size());
+ for (int i = 0; i < fdef.node_size(); ++i) {
+ for (auto attr : fdef.node(i).attr()) {
+ if (!SubstitutePlaceholders(substitute, &attr.second)) {
+ return errors::InvalidArgument("Failed to bind all placeholders in ",
+ SummarizeAttrValue(attr.second));
+ }
+ CHECK(node_attrs[i].insert(attr).second);
+ }
+ }
+
+ NameInfoIndex name_info;
+ Status s;
+ for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
+ s = BuildInputArgIndex(arg_def, attr_values, &name_info, result);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " In ", Print(arg_def));
+ return s;
+ }
+ }
+ for (int i = 0; i < fdef.node_size(); ++i) {
+ s = BuildNodeOutputIndex(fdef.node(i), node_attrs[i], get_function,
+ gdef->node_size() + i, &name_info);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " In ", Print(fdef.node(i)));
+ return s;
+ }
+ }
+
+ // Emits one gdef.node for each fdef.node.
+ for (int i = 0; i < fdef.node_size(); ++i) {
+ s = InstantiateNode(fdef.node(i), node_attrs[i], get_function, name_info,
+ gdef);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " In ", Print(fdef.node(i)));
+ return s;
+ }
+ }
+
+ // Emits nodes for the function's return values.
+ int ret_index = 0;
+ for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
+ s = AddReturnNode(ret_def, attr_values, name_info, &ret_index, result);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " In ", Print(ret_def));
+ return s;
+ }
+ }
+
+ return Status::OK();
+}
+
+string DebugString(const FunctionDef& func_def) { return Print(func_def); }
+
+string DebugString(const GraphDef& instantiated_func_def) {
+ return Print(instantiated_func_def);
+}
+
+string DebugStringWhole(const GraphDef& gdef) {
+ string ret;
+ for (auto fdef : gdef.library().function()) {
+ strings::StrAppend(&ret, Print(fdef));
+ }
+ strings::StrAppend(&ret, "\n");
+ for (auto ndef : gdef.node()) {
+ strings::StrAppend(&ret, Print(ndef), "\n");
+ }
+ return ret;
+}
+
+string Canonicalize(const string& funcname,
+ const InstantiateAttrValueMap& attrs) {
+ std::vector<string> entries;
+ entries.reserve(attrs.size());
+ for (auto p : attrs) {
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ }
+ sort(entries.begin(), entries.end());
+ return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
+}
+
+FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
+ DataTypeSlice ret_types)
+ : arg_types_(arg_types.begin(), arg_types.end()),
+ ret_types_(ret_types.begin(), ret_types.end()) {
+ args_.resize(arg_types_.size());
+ rets_.resize(ret_types_.size());
+}
+
+FunctionCallFrame::~FunctionCallFrame() {}
+
+Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
+ // Input type checks.
+ if (args.size() != arg_types_.size()) {
+ return errors::InvalidArgument("Expects ", arg_types_.size(),
+ " arguments, but ", args.size(),
+ " is provided");
+ }
+ for (size_t i = 0; i < args.size(); ++i) {
+ if (arg_types_[i] != args[i].dtype()) {
+ return errors::InvalidArgument(
+ "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
+ DataTypeString(args[i].dtype()), " is provided");
+ }
+ args_[i] = args[i];
+ }
+ return Status::OK();
+}
+
+Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
+ rets->clear();
+ rets->reserve(rets_.size());
+ for (size_t i = 0; i < rets_.size(); ++i) {
+ auto item = rets_[i];
+ if (item.has_val) {
+ rets->push_back(item.val);
+ } else {
+ return errors::Internal("Retval[", i, "] does not have value");
+ }
+ }
+ return Status::OK();
+}
+
+Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
+ if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
+ return errors::OutOfRange("GetArg ", index, " is not within [0, ",
+ args_.size(), ")");
+ }
+ *val = args_[index];
+ return Status::OK();
+}
+
+Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
+ if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
+ return errors::OutOfRange("SetRetval ", index, " is not within [0, ",
+ rets_.size(), ")");
+ }
+ if (val.dtype() != ret_types_[index]) {
+ return errors::InvalidArgument(
+ "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
+ ", but ", DataTypeString(val.dtype()), " is provided.");
+ }
+ Retval* item = &rets_[index];
+ if (!item->has_val) {
+ item->has_val = true;
+ item->val = val;
+ } else {
+ return errors::Internal("Retval[", index, "] has already been set.");
+ }
+ return Status::OK();
+}
+
+FunctionLibraryDefinition::FunctionLibraryDefinition(
+ const FunctionDefLibrary& def_lib)
+ : function_defs_(def_lib.function_size()) {
+ for (auto fdef : def_lib.function()) {
+ // The latter function definition wins.
+ function_defs_[fdef.signature().name()] = fdef;
+ }
+}
+
+FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
+
+const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
+ auto iter = function_defs_.find(name);
+ if (iter == function_defs_.end()) {
+ return nullptr;
+ } else {
+ return &iter->second;
+ }
+}
+
+const OpDef* FunctionLibraryDefinition::LookUp(const string& op,
+ Status* status) const {
+ auto fdef = Find(op);
+ if (fdef != nullptr) {
+ return &(fdef->signature());
+ }
+ return OpRegistry::Global()->LookUp(op, status);
+}
+
+Status InstantiateFunction(const FunctionDef& fdef,
+ InstantiateAttrValueSlice attr_values,
+ GetFunctionSignature get_function,
+ InstantiationResult* result) {
+ InstantiateAttrValueMap m;
+ for (const auto& aval : attr_values) {
+ m.insert({aval.first, aval.second.proto});
+ }
+ return InstantiateFunction(fdef, m, get_function, result);
+}
+
+string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) {
+ InstantiateAttrValueMap m;
+ for (const auto& aval : attrs) {
+ m.insert({aval.first, aval.second.proto});
+ }
+ return Canonicalize(funcname, m);
+}
+
+Status FunctionLibraryRuntime::Instantiate(const string& function_name,
+ InstantiateAttrValueSlice attrs,
+ Handle* handle) {
+ InstantiateAttrValueMap m;
+ for (const auto& aval : attrs) {
+ m.insert({aval.first, aval.second.proto});
+ }
+ return Instantiate(function_name, m, handle);
+}
+
+void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
+ if (val.size() >= 2 && val[0] == '$') {
+ proto.set_placeholder(val.data() + 1, val.size() - 1);
+ } else {
+ SetAttrValue(val, &proto);
+ }
+}
+
+FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
+ const string& name,
+ gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
+ AttrValueWrapper ret;
+ ret.proto.mutable_func()->set_name(name);
+ for (const auto& a : attrs) {
+ ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
+ }
+ return ret;
+}
+
+FunctionDef::Node FunctionDefHelper::Node::ToProto() const {
+ FunctionDef::Node n;
+ for (const string& r : this->ret) {
+ n.add_ret(r);
+ }
+ n.set_op(this->op);
+ for (const string& a : arg) {
+ n.add_arg(a);
+ }
+ for (const auto& a : this->attr) {
+ n.mutable_attr()->insert({a.first, a.second.proto});
+ }
+ for (const string& d : dep) {
+ n.add_dep(d);
+ }
+ return n;
+}
+
+/* static */
+FunctionDef FunctionDefHelper::Define(const string& name,
+ gtl::ArraySlice<string> arg_def,
+ gtl::ArraySlice<string> ret_def,
+ gtl::ArraySlice<string> attr_def,
+ gtl::ArraySlice<Node> node_def) {
+ FunctionDef fdef;
+ OpDefBuilder b(name);
+ for (const auto& a : arg_def) b.Input(a);
+ for (const auto& r : ret_def) b.Output(r);
+ for (const auto& a : attr_def) b.Attr(a);
+ TF_CHECK_OK(b.Finalize(fdef.mutable_signature()));
+ for (const auto& n : node_def) {
+ *(fdef.add_node()) = n.ToProto();
+ }
+ return fdef;
+}
+
+FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
+ gtl::ArraySlice<string> ret_def,
+ gtl::ArraySlice<string> attr_def,
+ gtl::ArraySlice<Node> node_def) {
+ return Define("_", arg_def, ret_def, attr_def, node_def);
+}
+
+namespace gradient {
+
+typedef std::unordered_map<string, Creator> OpGradFactory;
+
+OpGradFactory* GetOpGradFactory() {
+ static OpGradFactory* factory = new OpGradFactory;
+ return factory;
+}
+
+bool RegisterOp(const string& op, Creator func) {
+ CHECK(GetOpGradFactory()->insert({op, func}).second)
+ << "Duplicated gradient for " << op;
+ return true;
+}
+
+Status GetOpGradientCreator(const string& op, Creator* creator) {
+ auto fac = GetOpGradFactory();
+ auto iter = fac->find(op);
+ if (iter == fac->end()) {
+ return errors::NotFound("No gradient defined for op: ", op);
+ }
+ *creator = iter->second;
+ return Status::OK();
+}
+
+} // end namespace gradient
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
new file mode 100644
index 0000000000..1ef93a0533
--- /dev/null
+++ b/tensorflow/core/framework/function.h
@@ -0,0 +1,376 @@
+#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
+#define TENSORFLOW_FRAMEWORK_FUNCTION_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+class CancellationManager;
+class Node;
+class OpKernel;
+
+// FunctionDefHelper::Define is a convenient helper to construct a
+// FunctionDef proto.
+//
+// E.g.,
+// FunctionDef my_func = FunctionDefHelper::Define(
+// "my_func_name",
+// {"x:T", "y:T" /* one string per argument */},
+// {"z:T" /* one string per return value */},
+// {"T: {float, double}" /* one string per attribute */},
+// {
+// {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
+// /* one entry per function node */
+// })
+//
+// NOTE: When we have a TFLang parser, we can add another helper:
+// FunctionDef FunctionDefHelper::Define(const string& tf_func);
+class FunctionDefHelper {
+ public:
+ // AttrValueWrapper has copy constructors for the type T so that
+ // it's easy to construct a simple AttrValue proto.
+ //
+ // If T is a string type (const char*, string, or StringPiece), and
+ // it starts with "$", we construct a AttrValue of "placeholder".
+ //
+ // E.g.,
+ // std::<string, AttrValueWrapper> x = {"T", "$T"}
+ // is a named attr value placeholder.
+ struct AttrValueWrapper {
+ AttrValue proto;
+
+ AttrValueWrapper() {}
+
+ template <typename T>
+ AttrValueWrapper(T val) { // NOLINT(runtime/explicit)
+ SetAttrValue(val, &proto);
+ }
+
+ private:
+ void InitFromString(StringPiece val);
+ };
+
+ // Constructs an AttrValue.func given the "name" and "attrs".
+ static AttrValueWrapper FunctionRef(
+ const string& name,
+ gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
+ static AttrValueWrapper FunctionRef(const string& name) {
+ return FunctionRef(name, {});
+ }
+
+ // Node is used to consturct FunctionDef.Node using initialization
+ // lists. E.g.,
+ // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y
+ struct Node {
+ std::vector<string> ret;
+ string op;
+ std::vector<string> arg;
+ std::vector<std::pair<string, AttrValueWrapper>> attr;
+ std::vector<string> dep;
+
+ FunctionDef::Node ToProto() const;
+ };
+
+ static FunctionDef Define(const string& function_name,
+ gtl::ArraySlice<string> arg_def,
+ gtl::ArraySlice<string> ret_def,
+ gtl::ArraySlice<string> attr_def,
+ gtl::ArraySlice<Node> node_def);
+
+ // Defines an anonymous function. I.e., its name is not relevant.
+ static FunctionDef Define(gtl::ArraySlice<string> arg_def,
+ gtl::ArraySlice<string> ret_def,
+ gtl::ArraySlice<string> attr_def,
+ gtl::ArraySlice<Node> node_def);
+
+ // Helpers to construct a constant scalar.
+ template <typename T>
+ static Node Const(const string& name, const T& val) {
+ Node n = {{name}, "Const"};
+ const DataType dtype = DataTypeToEnum<T>::value;
+ n.attr.push_back({"dtype", dtype});
+ Tensor t(dtype, TensorShape({}));
+ t.scalar<T>()() = val;
+ n.attr.push_back({"value", t});
+ return n;
+ }
+
+ template <typename T>
+ static Node Const(const string& name, gtl::ArraySlice<T> vals) {
+ Node n = {{name}, "Const"};
+ const DataType dtype = DataTypeToEnum<T>::value;
+ n.attr.push_back({"dtype", dtype});
+ int64 num = vals.size();
+ Tensor t(dtype, TensorShape({num}));
+ for (int i = 0; i < vals.size(); ++i) {
+ t.flat<T>()(i) = vals[i];
+ }
+ n.attr.push_back({"value", t});
+ return n;
+ }
+};
+
+template <>
+inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
+ InitFromString(val);
+}
+
+template <>
+inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
+ const string& val) {
+ InitFromString(val);
+}
+
+template <>
+inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
+ InitFromString(val);
+}
+
+// Instantiate a function.
+//
+// "fdef" encodes a TF function with some attrs in fdef.signature.attr
+// containing placeholders. InstantiateFunction binds these
+// placeholders and produces an instantiated function encoded in
+// "result.gdef". The value to substitute a placeholder is given by
+// "attr_values", which is a map from a placeholder name to an attr
+// value.
+//
+// InstatiateFunction calls "get_function" to find signatures of other
+// functions and primitive ops.
+
+// Placeholders in "fdef" is substitued based on "attr_values" here.
+typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap;
+typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
+ InstantiateAttrValueSlice;
+
+// GetFunctionSignature(func name, opdef) returns OK if the func name is found
+// and opdef is filled with a pointer to the corresponding signature
+// (a OpDef proto). Otherwise, returns an error.
+typedef std::function<Status(const string&, const OpDef**)>
+ GetFunctionSignature;
+
+struct InstantiationResult {
+ DataTypeVector arg_types;
+ DataTypeVector ret_types;
+ GraphDef gdef;
+};
+Status InstantiateFunction(const FunctionDef& fdef,
+ const InstantiateAttrValueMap& attr_values,
+ GetFunctionSignature get_function,
+ InstantiationResult* result);
+Status InstantiateFunction(const FunctionDef& fdef,
+ InstantiateAttrValueSlice attr_values,
+ GetFunctionSignature get_function,
+ InstantiationResult* result);
+
+// Returns a debug string for a function definition.
+//
+// The returned text is multiple-line. It is intended to be
+// human-readable rather than being friendly to parsers. It is _NOT_
+// intended to be the canonical string representation of "func_def".
+// Particularly, it may not include all information presented in
+// "func_def" (e.g., comments, description of the function arguments,
+// etc.)
+string DebugString(const FunctionDef& func_def);
+string DebugString(const GraphDef& instantiated_func_def);
+
+// Returns a debug string for a top level graph (the main program and
+// its supporting functions defined in its library).
+string DebugStringWhole(const GraphDef& gdef);
+
+// Returns a canonicalized string for the instantiation of the
+// function of the given "name" and attributes "attrs".
+//
+// The returned string is guaranteed to be stable within one address
+// space. But it may be change as the implementation
+// evolves. Therefore, it should not be persisted or compared across
+// address spaces.
+string Canonicalize(const string& funcname,
+ const InstantiateAttrValueMap& attrs);
+string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs);
+
+// Represents a function call frame. I.e., the data structure used to
+// pass arguments to a function and retrieve its results.
+//
+// Runtime must arrange accesses to one FunctionCallFrame s.t.
+// 1. SetArgs() happens before any GetArg();
+// 2. GetRetvals happens after all SetRetval();
+class FunctionCallFrame {
+ public:
+ FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
+ ~FunctionCallFrame();
+
+ // Caller methods.
+ Status SetArgs(gtl::ArraySlice<Tensor> args);
+ Status GetRetvals(std::vector<Tensor>* rets) const;
+
+ // Callee methods.
+ Status GetArg(int index, Tensor* val) const;
+ Status SetRetval(int index, const Tensor& val);
+
+ private:
+ DataTypeVector arg_types_;
+ DataTypeVector ret_types_;
+ gtl::InlinedVector<Tensor, 4> args_;
+ struct Retval {
+ bool has_val = false;
+ Tensor val;
+ };
+ gtl::InlinedVector<Retval, 4> rets_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
+};
+
+// Helper to maintain a map between function names in a given
+// FunctionDefLibrary and function definitions.
+class FunctionLibraryDefinition : public OpRegistryInterface {
+ public:
+ explicit FunctionLibraryDefinition(const FunctionDefLibrary& lib_def);
+ ~FunctionLibraryDefinition() override;
+
+ // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
+ // returns its definition proto.
+ const FunctionDef* Find(const string& func) const;
+
+ // OpRegistryInterface method. Useful for constructing a Graph.
+ //
+ // If "op" is defined in the library, returns its signature.
+ // Otherwise, assume "op" is a primitive op and returns its op
+ // signature.
+ const OpDef* LookUp(const string& op, Status* status) const override;
+
+ private:
+ std::unordered_map<string, FunctionDef> function_defs_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryDefinition);
+};
+
+// Forward declare. Defined in common_runtime/function.h
+struct FunctionBody;
+
+class FunctionLibraryRuntime {
+ public:
+ virtual ~FunctionLibraryRuntime() {}
+
+ // Instantiate a function with the given "attrs".
+ //
+ // Returns OK and fills in "handle" if the instantiation succeeds.
+ // Otherwise returns an error and "handle" is undefined.
+ typedef uint64 Handle;
+ virtual Status Instantiate(const string& function_name,
+ const InstantiateAttrValueMap& attrs,
+ Handle* handle) = 0;
+ Status Instantiate(const string& function_name,
+ InstantiateAttrValueSlice attrs, Handle* handle);
+
+ // Returns the function body for the instantiated function given its
+ // handle 'h'. Returns nullptr if "h" is not found.
+ //
+ // *this keeps the ownership of the returned object, which remains alive
+ // as long as *this.
+ virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
+
+ // Asynchronously invokes the instantiated function identified by
+ // "handle".
+ //
+ // If function execution succeeds, "done" is called with OK and
+ // "*rets" is filled with the function's return values. Otheriwse,
+ // "done" is called with an error status.
+ //
+ // Does not take ownership of "rets".
+ struct Options {
+ CancellationManager* cancellation_manager = nullptr;
+ };
+ typedef std::function<void(const Status&)> DoneCallback;
+ virtual void Run(const Options& opts, Handle handle,
+ gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
+ DoneCallback done) = 0;
+
+ // Creates a "kernel" for the given node def "ndef".
+ //
+ // If succeeds, returns OK and the caller takes the ownership of the
+ // returned "*kernel". Otherwise, returns an error.
+ virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0;
+
+ // Return true iff 'function_name' is the name of a defined function.
+ virtual bool IsDefined(const string& function_name) = 0;
+};
+
+// To register a gradient function for a builtin op, one should use
+// REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
+//
+// Typically, the c++ grad factory is a plan function that can be
+// converted into ::tensorflow::gradient::Creator, which is
+// std::function<Status(const AttrSlice&, FunctionDef*)>.
+//
+// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
+// definition of a brain function which computate the gradient for the
+// <op_name> when the <op_name> is instantiated with the given attrs.
+//
+// E.g.,
+//
+// Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
+// bool transpose_a;
+// TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
+// bool transpose_b;
+// TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
+// DataType dtype;
+// TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
+// if (!transpose_a && !transpose_b) {
+// *g = FunctionDefHelper::Define(
+// "MatMulGrad",
+// {"x:T ", "y:T", "dz:T"}, // Inputs to this function
+// {"dx:T", "dy:T"}, // Outputs from this function
+// {"T: {float, double}"}, // Attributes needed by this function
+// {
+// {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
+// {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
+// {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
+// {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
+// });
+// } else {
+// ... ...
+// }
+// return Status::OK();
+// }
+//
+// NOTE: $T is substituted with the type variable "T" when the
+// gradient function MatMul is instantiated.
+//
+// TODO(zhifengc): Better documentation somewhere.
+
+// Macros to define a gradient function factory for a primitive
+// operation.
+#define REGISTER_OP_GRADIENT(name, fn) \
+ REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
+
+#define REGISTER_OP_NO_GRADIENT(name) \
+ REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
+
+#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
+ REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
+
+#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
+ static bool unused_grad_##ctr = ::tensorflow::gradient::RegisterOp(name, fn)
+
+namespace gradient {
+// Register a gradient creator for the "op".
+typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
+bool RegisterOp(const string& op, Creator func);
+
+// Returns OK the gradient creator for the "op" is found (may be
+// nullptr if REGISTER_OP_NO_GRADIENT is used.
+Status GetOpGradientCreator(const string& op, Creator* creator);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_
diff --git a/tensorflow/core/framework/function.proto b/tensorflow/core/framework/function.proto
new file mode 100644
index 0000000000..4b8a26947c
--- /dev/null
+++ b/tensorflow/core/framework/function.proto
@@ -0,0 +1,68 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/attr_value.proto";
+import "tensorflow/core/framework/op_def.proto";
+
+// A library is a set of named functions.
+message FunctionDefLibrary {
+ repeated FunctionDef function = 1;
+}
+
+// A function can be instantiated when the runtime can bind every attr
+// with a value. When a GraphDef has a call to a function, it must
+// have binding for every attr defined in the signature.
+//
+// TODO(zhifengc):
+// * device spec, etc.
+message FunctionDef {
+ // The definition of the function's name, arguments, return values,
+ // attrs etc.
+ OpDef signature = 1;
+
+ // The body of the function.
+ repeated Node node = 2; // function.node.ret[*] are unique.
+
+ // A node is a multi-value assignment:
+ // (ret[0], ret[1], ...) = func(arg[0], arg[1], ...)
+ //
+ // By convention, "func" is resolved by consulting with a user-defined
+ // library first. If not resolved, "func" is assumed to be a builtin op.
+ message Node {
+ // This node produces multiple outputs. They are named ret[0],
+ // ret[1], ..., etc.
+ //
+ // REQUIRES: function.node.ret[*] are unique across all nodes.
+ // REQUIRES: ret.size == func/op def's number of output args.
+ repeated string ret = 1;
+
+ // The op/function name.
+ string op = 2;
+
+ // Arguments passed to this func/op.
+ //
+ // arg[i] must be either one of
+ // function.signature.input_args[*].name or one of
+ // function.node[*].ret[*].
+ //
+ // REQUIRES: arg.size == func/op def's number of input args.
+ repeated string arg = 3;
+
+ // Control dependencies.
+ //
+ // dep[i] must be one of function.node[*].ret[*] or one of
+ // function.signature.input_args[*].name.
+ repeated string dep = 4;
+
+ // Attrs.
+ //
+ // 'attr' maps names defined by 'func's attr defs to attr values.
+ // attr values may have placeholders which are substituted
+ // recursively by concrete values when this node is instantiated.
+ // These placeholdes must name an attr listed in the FunctionDef's
+ // signature.
+ map<string, AttrValue> attr = 5;
+ }
+}
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
new file mode 100644
index 0000000000..c9483fad18
--- /dev/null
+++ b/tensorflow/core/framework/function_test.cc
@@ -0,0 +1,634 @@
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+typedef FunctionDefHelper FDH;
+
+Status GetOpSig(const string& op, const OpDef** sig) {
+ Status s;
+ *sig = OpRegistry::Global()->LookUp(op, &s);
+ return s;
+}
+
+REGISTER_OP("One")
+ .Output("y: T")
+ .Attr("T: {float, double, int32, int64}")
+ .Doc(R"doc(
+Returns a tensor with a single element (1) of type T.
+
+y: A scalar in type T.
+
+)doc");
+
+static InstantiateAttrValueMap kNoAttrs;
+
+TEST(TFunc, SquarePlusOne) {
+ RequireDefaultOps();
+ auto fdef = FDH::Define(
+ // Name
+ "SquarePlusOne",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attrs
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {// a = Square<T>(x)
+ {{"a"}, "Square", {"x"}, {{"T", "$T"}}},
+ // o = One<T>()
+ // NOTE: We can also have a Cast<Tin, Tout>(x) instead.
+ {{"o"}, "One", {}, {{"T", "$T"}}},
+ // y = Add<T>(a, o)
+ {{"y"}, "Add", {"a", "o"}, {{"T", "$T"}}}});
+
+ const char* e = R"P(
+SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
+ a = Square[T=$T](x)
+ o = One[T=$T]()
+ y = Add[T=$T](a, o)
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ // Instantiate one with T=float
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
+ const char* e2 = R"P(
+(n0:float) -> (n3:float) {
+ n1 = Square[T=float](n0)
+ n2 = One[T=float]()
+ n3 = Add[T=float](n1, n2)
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+// NOTE: This is the simplest Map op. It takes a f:T->U.
+REGISTER_OP("Map")
+ .Input("x: N * T")
+ .Output("y: N * U")
+ .Attr("T: type")
+ .Attr("U: type")
+ .Attr("N: int >= 1")
+ // .Attr("func: func_name_with_attr")
+ .Doc(R"doc(
+Applies the 'func' on every input. I.e.,
+
+y[i] = func<...>(x[i])
+
+x: N tensors, each of type T;
+y: N tensors, each of type U;
+
+)doc");
+
+TEST(TFunc, AddSquared) {
+ auto fdef = FDH::Define(
+ // Name
+ "AddSquared",
+ // Args
+ {"x: N*T"},
+ // Return values
+ {"y: T"},
+ // Attrs
+ {"N:int", "T:{float, double, int32, int64}"},
+ // Nodes
+ {// a = Map<func=Square<$T>,T=$T,U=$T,N=$N>(x)
+ {{"a"},
+ "Map",
+ {"x"},
+ {{"func", FDH::FunctionRef("Square", {{"T", "$T"}})},
+ {"T", "$T"},
+ {"U", "$T"},
+ {"N", "$N"}}},
+ // y = AddN<N=$N,T=$T>(a)
+ {{"y"}, "AddN", {"a"}, {{"N", "$N"}, {"T", "$T"}}}});
+
+ const char* e = R"P(
+AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
+ a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
+ y = AddN[N=$N, T=$T](a)
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ // Instantiate one with T=float
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, {{"N", 3}, {"T", DT_FLOAT}}, GetOpSig,
+ &result));
+ const char* e2 = R"P(
+(n0:float, n1:float, n2:float) -> (n4:float) {
+ n3 = Map[N=3, T=float, U=float, func=Square[T=float]](n0, n1, n2)
+ n4 = AddN[N=3, T=float](n3, n3:1, n3:2)
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+TEST(TFunc, ControlDeps) {
+ auto fdef = FDH::Define(
+ // Name
+ "ControlDeps",
+ // Args
+ {"x: float"},
+ // Return values
+ {},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}},
+ {{"u"}, "NoOp", {}, {}, {"a"}},
+ {{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}},
+ {{"v"}, "NoOp", {}, {}, {"b"}},
+ {{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}},
+ });
+ const char* e = R"P(
+ControlDeps(x:float) -> () {
+ a = One[T=float]() @ x
+ u = NoOp() @ a
+ b = One[T=float]() @ u
+ v = NoOp() @ b
+ c = One[T=float]() @ a, v
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ const char* e2 = R"P(
+(n0:float) -> () {
+ n1 = One[T=float]() @ n0
+ n2 = NoOp() @ n1
+ n3 = One[T=float]() @ n2
+ n4 = NoOp() @ n3
+ n5 = One[T=float]() @ n1, n4
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+TEST(TFunc, XTimesTwo) {
+ auto expect = R"P(
+XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
+ two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
+ scale = Cast[DstT=$T, SrcT=int64](two)
+ y = Mul[T=$T](x, scale)
+}
+)P";
+ EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
+}
+
+TEST(TFunc, WXPlusB) {
+ auto expect = R"P(
+WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
+ mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
+ y = Add[T=$T](mm, b)
+}
+)P";
+ EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
+}
+
+TEST(TFunc, Body_TypeList) {
+ const Tensor kZero = test::AsScalar<int32>(0);
+ auto fdef = FDH::Define(
+ // Name
+ "Test",
+ // Args
+ {"i:float"},
+ // Return values
+ {"o:float"},
+ // Attrs
+ {},
+ // Nodes
+ {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
+ {{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}},
+ {{"a", "b", "c", "d"},
+ "_ArrayToList",
+ {"s"},
+ {{"N", 4},
+ {"T", DT_FLOAT},
+ {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
+ {{"l"}, "Mul", {"a", "b"}, {{"T", DT_FLOAT}}},
+ {{"r"}, "Mul", {"c", "d"}, {{"T", DT_FLOAT}}},
+ {{"x"}, "_ListToArray", {"l", "r"}, {{"N", 2}, {"T", DT_FLOAT}}},
+ {{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}});
+
+ const char* e = R"P(
+Test(i:float) -> (o:float) {
+ zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
+ s = Split[T=float, num_split=4](zero, i)
+ a, b, c, d = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s)
+ l = Mul[T=float](a, b)
+ r = Mul[T=float](c, d)
+ x = _ListToArray[N=2, T=float](l, r)
+ o = AddN[N=2, T=float](x)
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ const char* e2 = R"P(
+(n0:float) -> (n7:float) {
+ n1 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
+ n2 = Split[T=float, num_split=4](n1, n0)
+ n3 = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](n2, n2:1, n2:2, n2:3)
+ n4 = Mul[T=float](n3, n3:1)
+ n5 = Mul[T=float](n3:2, n3:3)
+ n6 = _ListToArray[N=2, T=float, Tin={float, float}](n4, n5)
+ n7 = AddN[N=2, T=float](n6, n6:1)
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+REGISTER_OP("Cond")
+ .Input("input: Tin")
+ .Output("output: out_types")
+ .Attr("Tin: list(type)")
+ .Attr("out_types: list(type)")
+ .Attr("cond: func")
+ .Attr("then_branch: func")
+ .Attr("else_branch: func")
+ .Doc(R"doc(
+output = Cond(input) ? then_branch(input) : else_branch(input)
+
+cond: A function takes 'input' and returns a scalar.
+then_branch: A funcion takes 'input' and returns 'output'.
+else_branch: A funcion takes 'input' and returns 'output'.
+)doc");
+
+TEST(TFunc, Body_Array_List_Converter) {
+ auto fdef = FDH::Define(
+ // Name
+ "MySelect",
+ // Args
+ {"x:float"},
+ // Return values
+ {"z:float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"y"},
+ "Cond",
+ {"x"},
+ {{"Tin", DataTypeSlice{DT_FLOAT}},
+ {"out_types", DataTypeSlice{DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond")},
+ {"then_branch", FDH::FunctionRef("MyThen")},
+ {"else_branch", FDH::FunctionRef("MyElse")}}},
+ {{"z"},
+ "Cond",
+ {"y", "y"},
+ {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
+ {"out_types", DataTypeSlice{DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond2")},
+ {"then_branch", FDH::FunctionRef("MyThen2")},
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
+ });
+
+ const char* e = R"P(
+MySelect(x:float) -> (z:float) {
+ y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
+ z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y)
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ const char* e2 = R"P(
+(n0:float) -> (n2:float) {
+ n1 = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](n0)
+ n2 = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](n1, n1)
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+static void HasError(const Status& s, const string& substr) {
+ EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+ << s << ", expected substring " << substr;
+}
+
+TEST(InstantiateErrors, Not_Sufficient_Attrs) {
+ auto fdef =
+ FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, {{"U", DT_FLOAT}}, GetOpSig, &result),
+ "T is not found");
+}
+
+TEST(InstantiateErrors, AttrValue_Value_Placeholder) {
+ auto fdef =
+ FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, {{"T", "$bad"}}, GetOpSig, &result),
+ "T in attr_values is still a placeholder");
+}
+
+TEST(InstantiateErrors, Unbounded_Attr) {
+ auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"},
+ {
+ {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result),
+ "Failed to bind all placeholders");
+}
+
+TEST(InstantiateErrors, DupArgs) {
+ auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {});
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Duplicated arg name");
+}
+
+TEST(InstantiateErrors, Dup_Arg_Node_Name) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Duplicated ret name");
+}
+
+TEST(InstantiateErrors, Dup_Node_Names) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y"}, "One", {}, {{"T", DT_FLOAT}}},
+ {{"y"}, "One", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Duplicated ret name");
+}
+
+TEST(InstantiateErrors, Node_Signature_Mismatch_NoOp) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y", "z"}, "NoOp", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Expect one ret name");
+}
+
+TEST(InstantiateErrors, Node_Signature_Mismatch) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y", "z"}, "One", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Malformed function node (#ret)");
+}
+
+TEST(InstantiateErrors, Node_Arg_Notfound) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "arg[1] is not found");
+}
+
+TEST(InstantiateErrors, Node_Arg_Mismatch) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Invalid arg(0) for function arg");
+}
+
+TEST(InstantiateErrors, Node_Arg_ControlMissing) {
+ auto fdef =
+ FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "dep[0] is not found");
+}
+
+TEST(InstantiateErrors, FuncRet_Missing) {
+ auto fdef = FDH::Define("test", {}, {"y: float"}, {},
+ {
+ {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "ret is not found");
+}
+
+TEST(InstantiateErrors, FuncRet_Mismatch) {
+ auto fdef = FDH::Define("test", {}, {"y: float"}, {},
+ {
+ {{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Invalid ret name.\n\t In y");
+}
+
+TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
+ auto fdef = FDH::Define(
+ // Name
+ "MySelect",
+ // Args
+ {"x: float"},
+ // Return values
+ {"y: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"y"},
+ "Cond",
+ {"x", "x"},
+ {{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond2")},
+ {"then_branch", FDH::FunctionRef("MyThen2")},
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Missing attr out_types");
+}
+
+TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
+ auto fdef = FDH::Define(
+ // Name
+ "MySelect",
+ // Args
+ {"x: float"},
+ // Return values
+ {"y: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"y"},
+ "Cond",
+ {"x", "x"},
+ {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
+ {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond2")},
+ {"then_branch", FDH::FunctionRef("MyThen2")},
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Wrong #ret: 0 2 1");
+}
+
+TEST(InstantiateErrors, TypeList_Missing_Arg) {
+ auto fdef = FDH::Define(
+ // Name
+ "MySelect",
+ // Args
+ {"x: float"},
+ // Return values
+ {"y: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"y"},
+ "Cond",
+ {"x", "unknown"},
+ {{"out_types", DataTypeSlice{DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond2")},
+ {"then_branch", FDH::FunctionRef("MyThen2")},
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "arg[1] is not found");
+}
+
+TEST(FunctionCallFrame, Void_Void) {
+ FunctionCallFrame frame({}, {});
+ EXPECT_OK(frame.SetArgs({}));
+ auto a = test::AsTensor<float>({100});
+ HasError(frame.SetArgs({a}), "Invalid argument");
+ Tensor v;
+ HasError(frame.GetArg(0, &v), "Out of range");
+ HasError(frame.SetRetval(0, v), "Out of range");
+ std::vector<Tensor> rets;
+ EXPECT_OK(frame.GetRetvals(&rets));
+ EXPECT_EQ(rets.size(), 0);
+}
+
+TEST(FunctionCallFrame, Float_Float_Float) {
+ FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
+ HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments");
+ auto a = test::AsTensor<float>({100});
+ auto b = test::AsTensor<float>({200});
+ auto c = test::AsTensor<int64>({300});
+ HasError(frame.SetArgs({a, c}),
+ "Invalid argument: Expects arg[1] to be float");
+ EXPECT_OK(frame.SetArgs({a, b}));
+
+ Tensor v;
+ HasError(frame.GetArg(-1, &v), "Out of range");
+ HasError(frame.GetArg(2, &v), "Out of range");
+ EXPECT_OK(frame.GetArg(0, &v));
+ test::ExpectTensorEqual<float>(a, v);
+ EXPECT_OK(frame.GetArg(1, &v));
+ test::ExpectTensorEqual<float>(b, v);
+
+ v = test::AsTensor<float>({-100});
+ HasError(frame.SetRetval(-1, v), "Out of range");
+ HasError(frame.SetRetval(1, v), "Out of range");
+ HasError(frame.SetRetval(0, test::AsTensor<int64>({-100})),
+ "Invalid argument: Expects ret[0] to be float");
+
+ std::vector<Tensor> rets;
+ HasError(frame.GetRetvals(&rets), "does not have value");
+ EXPECT_OK(frame.SetRetval(0, v));
+ HasError(frame.SetRetval(0, v), "has already been set");
+
+ EXPECT_OK(frame.GetRetvals(&rets));
+ EXPECT_EQ(rets.size(), 1);
+ test::ExpectTensorEqual<float>(rets[0], v);
+}
+
+TEST(Canonicalize, Basic) {
+ EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
+ {"transpose_a", false},
+ {"transpose_b", false}}),
+ "MatMul[T=float,transpose_a=false,transpose_b=false]");
+ EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
+ {"transpose_b", false},
+ {"transpose_a", false}}),
+ "MatMul[T=float,transpose_a=false,transpose_b=false]");
+ EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_DOUBLE},
+ {"transpose_b", true},
+ {"transpose_a", false}}),
+ "MatMul[T=double,transpose_a=false,transpose_b=true]");
+}
+
+TEST(FunctionLibraryDefinitionTest, Find) {
+ FunctionDefLibrary proto;
+ *proto.add_function() = test::function::XTimesTwo();
+ FunctionLibraryDefinition lib_def(proto);
+
+ EXPECT_EQ(lib_def.Find("XTimes16"), nullptr);
+
+ auto expect = R"P(
+XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
+ two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
+ scale = Cast[DstT=$T, SrcT=int64](two)
+ y = Mul[T=$T](x, scale)
+}
+)P";
+ auto found = lib_def.Find("XTimesTwo");
+ ASSERT_NE(found, nullptr);
+ EXPECT_EQ(expect, DebugString(*found));
+}
+
+TEST(FunctionLibraryDefinitionTest, LookUp) {
+ FunctionDefLibrary proto;
+ *proto.add_function() = test::function::XTimesTwo();
+ FunctionLibraryDefinition lib_def(proto);
+
+ Status s;
+ EXPECT_EQ(lib_def.LookUp("XTimes16", &s), nullptr);
+
+ auto found = lib_def.LookUp("XTimesTwo", &s);
+ ASSERT_NE(found, nullptr);
+ EXPECT_EQ(found->DebugString(),
+ test::function::XTimesTwo().signature().DebugString());
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
new file mode 100644
index 0000000000..5ead947076
--- /dev/null
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -0,0 +1,146 @@
+#include "tensorflow/core/framework/function_testlib.h"
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+
+namespace tensorflow {
+namespace test {
+namespace function {
+
+typedef FunctionDefHelper FDH;
+
+GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
+ gtl::ArraySlice<FunctionDef> funcs) {
+ GraphDef g;
+ for (auto n : nodes) {
+ *(g.add_node()) = n;
+ }
+ auto lib = g.mutable_library();
+ for (auto f : funcs) {
+ *(lib->add_function()) = f;
+ }
+ return g;
+}
+
+// Helper to construct a NodeDef.
+NodeDef NDef(const string& name, const string& op,
+ gtl::ArraySlice<string> inputs,
+ gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
+ const string& device) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op(op);
+ for (auto in : inputs) n.add_input(in);
+ n.set_device(device);
+ for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
+ return n;
+}
+
+FunctionDef NonZero() {
+ return FDH::Define(
+ // Name
+ "NonZero",
+ // Args
+ {"x:T"},
+ // Return values
+ {"y:T"},
+ // Attr def
+ {"T:{float, double, int32, int64, string}"},
+ // Nodes
+ {
+ {{"y"}, "Identity", {"x"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XTimesTwo() {
+ const Tensor kTwo = test::AsScalar<int64>(2);
+ return FDH::Define(
+ // Name
+ "XTimesTwo",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
+ {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XTimesFour() {
+ return FDH::Define(
+ // Name
+ "XTimesFour",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
+ {{"y"}, "XTimesTwo", {"x2"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XTimes16() {
+ return FDH::Define(
+ // Name
+ "XTimes16",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
+ {{"y"}, "XTimesFour", {"x4"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef WXPlusB() {
+ return FDH::Define(
+ // Name
+ "WXPlusB",
+ // Args
+ {"w: T", "x: T", "b: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double}"},
+ // Nodes
+ {{{"mm"},
+ "MatMul",
+ {"w", "x"},
+ {{"T", "$T"},
+ {"transpose_a", false},
+ {"transpose_b", false},
+ {"_kernel", "eigen"}}},
+ {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
+}
+
+FunctionDef Swap() {
+ return FDH::Define(
+ // Name
+ "Swap",
+ // Args
+ {"i0: T", "i1: T"},
+ // Return values
+ {"o0: T", "o1: T"},
+ // Attr def
+ {"T: {float, double}"},
+ // Nodes
+ {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}},
+ {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
+}
+
+} // end namespace function
+} // end namespace test
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
new file mode 100644
index 0000000000..ed0446ea85
--- /dev/null
+++ b/tensorflow/core/framework/function_testlib.h
@@ -0,0 +1,53 @@
+#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
+#define TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace test {
+namespace function {
+
+// Helper to construct a NodeDef.
+NodeDef NDef(
+ const string& name, const string& op, gtl::ArraySlice<string> inputs,
+ gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
+ attrs = {},
+ const string& device = "");
+
+// Helper to construct a GraphDef proto.
+GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
+ gtl::ArraySlice<FunctionDef> funcs = {});
+
+// For testing convenience, we provide a few simple functions that can
+// be easily executed and tested.
+
+// x:T -> x * 2.
+FunctionDef XTimesTwo();
+
+// x:T -> (x * 2) * 2.
+FunctionDef XTimesFour();
+
+// x:T -> ((x * 2) * 2) * 2.
+FunctionDef XTimes16();
+
+// w:T, x:T, b:T -> MatMul(w, x) + b
+FunctionDef WXPlusB();
+
+// x:T -> x:T, T is a type which we automatically converts to a bool.
+FunctionDef NonZero();
+
+// x:T, y:T -> y:T, x:T
+FunctionDef Swap();
+
+} // end namespace function
+} // end namespace test
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
diff --git a/tensorflow/core/framework/graph.proto b/tensorflow/core/framework/graph.proto
new file mode 100644
index 0000000000..a9bc07e88c
--- /dev/null
+++ b/tensorflow/core/framework/graph.proto
@@ -0,0 +1,103 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/attr_value.proto";
+import "tensorflow/core/framework/function.proto";
+
+// Represents the graph of operations
+// TODO(sanjay): Also want to put the following somewhere:
+// * random_seed
+// * replicas: Do we stamp them out in python itself?
+// * where to load parameters
+// * optimizer info? does it go with the parameter layers/ops?
+message GraphDef {
+ repeated NodeDef node = 1;
+
+ // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
+ //
+ // "library" provides user-defined functions.
+ //
+ // Naming:
+ // * library.function.name are in a flat namespace.
+ // NOTE: We may need to change it to be hierarchical to support
+ // different orgs. E.g.,
+ // { "/google/nn", { ... }},
+ // { "/google/vision", { ... }}
+ // { "/org_foo/module_bar", {...}}
+ // map<string, FunctionDefLib> named_lib;
+ // * If node[i].op is the name of one function in "library",
+ // node[i] is deemed as a function call. Otherwise, node[i].op
+ // must be a primitive operation supported by the runtime.
+ //
+ //
+ // Function call semantics:
+ //
+ // * The callee may start execution as soon as some of its inputs
+ // are ready. The caller may want to use Tuple() mechanism to
+ // ensure all inputs are ready in the same time.
+ //
+ // * The consumer of return values may start executing as soon as
+ // the return values the consumer depends on are ready. The
+ // consumer may want to use Tuple() mechanism to ensure the
+ // consumer does not start until all return values of the callee
+ // function are ready.
+ FunctionDefLibrary library = 2;
+};
+
+message NodeDef {
+ // The name given to this operator. Used for naming inputs,
+ // logging, visualization, etc. Unique within a single GraphDef.
+ // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
+ string name = 1;
+
+ // The operation name. There may be custom parameters in attrs.
+ // Op names starting with an underscore are reserved for internal use.
+ string op = 2;
+
+ // Each input is "node:src_output" with "node" being a string name and
+ // "src_output" indicating which output tensor to use from "node". If
+ // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
+ // may optionally be followed by control inputs that have the format
+ // "^node".
+ repeated string input = 3;
+
+ // A (possibly partial) specification for the device on which this
+ // node should be placed.
+ // The expected syntax for this string is as follows:
+ //
+ // DEVICE_SPEC ::= COLOCATED_NODE | PARTIAL_SPEC
+ //
+ // COLOCATED_NODE ::= "@" NODE_NAME // See NodeDef.name above.
+ // PARTIAL_SPEC ::= ("/" CONSTRAINT) *
+ // CONSTRAINT ::= ("job:" JOB_NAME)
+ // | ("replica:" [1-9][0-9]*)
+ // | ("task:" [1-9][0-9]*)
+ // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") )
+ //
+ // Valid values for this string include:
+ // * "@other/node" (colocate with "other/node")
+ // * "/job:worker/replica:0/task:1/gpu:3" (full specification)
+ // * "/job:worker/gpu:3" (partial specification)
+ // * "" (no specification)
+ //
+ // If the constraints do not resolve to a single device (or if this
+ // field is empty or not present), the runtime will attempt to
+ // choose a device automatically.
+ string device = 4;
+
+ // Operation-specific graph-construction-time configuration.
+ // Note that this should include all attrs defined in the
+ // corresponding OpDef, including those with a value matching
+ // the default -- this allows the default to change and makes
+ // NodeDefs easier to interpret on their own. However, if
+ // an attr with a default is not specified in this list, the
+ // default will be used.
+ // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
+ // one of the names from the corresponding OpDef's attr field).
+ // The values must have a type matching the corresponding OpDef
+ // attr's type field.
+ // TODO(josh11b): Add some examples here showing best practices.
+ map<string, AttrValue> attr = 5;
+};
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
new file mode 100644
index 0000000000..1e0d280126
--- /dev/null
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -0,0 +1,25 @@
+#include "tensorflow/core/framework/graph_def_util.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+string SummarizeGraphDef(const GraphDef& graph_def) {
+ string ret;
+ for (const NodeDef& node : graph_def.node()) {
+ strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
+ }
+ return ret;
+}
+
+Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
+ for (const NodeDef& node : graph_def.node()) {
+ TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h
new file mode 100644
index 0000000000..7a2ec9c7a7
--- /dev/null
+++ b/tensorflow/core/framework/graph_def_util.h
@@ -0,0 +1,29 @@
+#ifndef TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Produce a human-readable version of a GraphDef that is more concise
+// than a text-format proto.
+string SummarizeGraphDef(const GraphDef& graph_def);
+
+// Validates the syntax of a GraphDef provided externally.
+//
+// The following is an EBNF-style syntax for GraphDef objects. Note that
+// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
+// which contain many other fields that are not (currently) validated.
+//
+// Graph = Node *
+// Node = NodeName, Inputs
+// Inputs = ( DataInput * ), ( ControlInput * )
+// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
+// ControlInput = "^", NodeName
+// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
+Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/kernel_def.proto b/tensorflow/core/framework/kernel_def.proto
new file mode 100644
index 0000000000..db7856a156
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def.proto
@@ -0,0 +1,33 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/attr_value.proto";
+
+message KernelDef {
+ // Must match the name of an Op.
+ string op = 1;
+
+ // Type of device this kernel runs on.
+ string device_type = 2;
+
+ message AttrConstraint {
+ // Name of an attr from the Op.
+ string name = 1;
+
+ // A list of values that this kernel supports for this attr.
+ // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops.
+ AttrValue allowed_values = 2;
+ }
+ repeated AttrConstraint constraint = 3;
+
+ // Names of the Op's input_/output_args that reside in host memory
+ // instead of device memory.
+ repeated string host_memory_arg = 4;
+
+ // This allows experimental kernels to be registered for an op that
+ // won't be used unless the user specifies a "_kernel" attr with
+ // value matching this.
+ string label = 5;
+}
diff --git a/tensorflow/core/framework/kernel_def_builder.cc b/tensorflow/core/framework/kernel_def_builder.cc
new file mode 100644
index 0000000000..8fba883a16
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_builder.cc
@@ -0,0 +1,47 @@
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+
+KernelDefBuilder::KernelDefBuilder(const char* op_name) {
+ kernel_def_ = new KernelDef;
+ kernel_def_->set_op(op_name);
+}
+
+KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
+ kernel_def_->set_device_type(device_type);
+ return *this;
+}
+
+KernelDefBuilder& KernelDefBuilder::TypeConstraint(
+ const char* attr_name, gtl::ArraySlice<DataType> allowed) {
+ auto* constraint = kernel_def_->add_constraint();
+ constraint->set_name(attr_name);
+ auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
+ for (DataType dt : allowed) {
+ allowed_values->add_type(dt);
+ }
+ return *this;
+}
+
+KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name,
+ DataType allowed) {
+ auto* constraint = kernel_def_->add_constraint();
+ constraint->set_name(attr_name);
+ constraint->mutable_allowed_values()->mutable_list()->add_type(allowed);
+ return *this;
+}
+
+KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) {
+ kernel_def_->add_host_memory_arg(arg_name);
+ return *this;
+}
+
+KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
+ CHECK_EQ(kernel_def_->label(), "")
+ << "Trying to set a kernel's label a second time: '" << label
+ << "' in: " << kernel_def_->ShortDebugString();
+ kernel_def_->set_label(label);
+ return *this;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h
new file mode 100644
index 0000000000..0c14d1e006
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_builder.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
+class KernelDefBuilder {
+ public:
+ // Starts with just the name field set.
+ // Caller MUST call Build() and take ownership of the result.
+ explicit KernelDefBuilder(const char* op_name);
+
+ ~KernelDefBuilder() {
+ DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
+ }
+
+ // Required: specify the type of device this kernel supports.
+ // Returns *this.
+ KernelDefBuilder& Device(const char* device_type);
+ // KernelDefBuilder& Device(DeviceType device_type);
+
+ // Specify that this kernel supports a limited set of values for a
+ // particular type or list(type) attr (a further restriction than
+ // what the Op allows).
+ // Returns *this.
+ KernelDefBuilder& TypeConstraint(const char* attr_name,
+ gtl::ArraySlice<DataType> allowed);
+
+ // Like TypeConstraint but supports just a single type.
+ KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
+
+ // Like TypeConstraint, but (a) gets the type from a template parameter
+ // and (b) only supports a constraint to a single type.
+ template <class T>
+ KernelDefBuilder& TypeConstraint(const char* attr_name);
+ // TODO(josh11b): Support other types of attr constraints as needed.
+
+ // Specify that this kernel requires/provides an input/output arg
+ // in host memory (instead of the default, device memory).
+ // Returns *this.
+ KernelDefBuilder& HostMemory(const char* arg_name);
+
+ // Specify that this kernel requires a particular value for the
+ // "_kernel" attr. May only be specified once. Returns *this.
+ KernelDefBuilder& Label(const char* label);
+
+ // Returns a pointer to a KernelDef with fields set based on the
+ // above calls to this instance.
+ // Caller takes ownership of the result.
+ const KernelDef* Build() {
+ KernelDef* r = kernel_def_;
+ kernel_def_ = nullptr;
+ return r;
+ }
+
+ private:
+ KernelDef* kernel_def_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
+};
+
+// IMPLEMENTATION
+
+template <class T>
+inline KernelDefBuilder& KernelDefBuilder::TypeConstraint(
+ const char* attr_name) {
+ return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v());
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/kernel_def_builder_test.cc b/tensorflow/core/framework/kernel_def_builder_test.cc
new file mode 100644
index 0000000000..eba7144b59
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_builder_test.cc
@@ -0,0 +1,76 @@
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(KernelDefBuilderTest, Basic) {
+ const KernelDef* def = KernelDefBuilder("A").Device(DEVICE_CPU).Build();
+ KernelDef expected;
+ protobuf::TextFormat::ParseFromString("op: 'A' device_type: 'CPU'",
+ &expected);
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+}
+
+TEST(KernelDefBuilderTest, TypeConstraint) {
+ const KernelDef* def = KernelDefBuilder("B")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .Build();
+ KernelDef expected;
+ protobuf::TextFormat::ParseFromString(R"proto(
+ op: 'B' device_type: 'GPU'
+ constraint { name: 'T' allowed_values { list { type: DT_FLOAT } } } )proto",
+ &expected);
+
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+
+ def = KernelDefBuilder("C")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("U")
+ .TypeConstraint<bool>("V")
+ .Build();
+
+ protobuf::TextFormat::ParseFromString(R"proto(
+ op: 'C' device_type: 'GPU'
+ constraint { name: 'U' allowed_values { list { type: DT_INT32 } } }
+ constraint { name: 'V' allowed_values { list { type: DT_BOOL } } } )proto",
+ &expected);
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+
+ def = KernelDefBuilder("D")
+ .Device(DEVICE_CPU)
+ .TypeConstraint("W", {DT_DOUBLE, DT_STRING})
+ .Build();
+ protobuf::TextFormat::ParseFromString(R"proto(
+ op: 'D' device_type: 'CPU'
+ constraint { name: 'W'
+ allowed_values { list { type: [DT_DOUBLE, DT_STRING] } } } )proto",
+ &expected);
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+}
+
+TEST(KernelDefBuilderTest, HostMemory) {
+ const KernelDef* def = KernelDefBuilder("E")
+ .Device(DEVICE_GPU)
+ .HostMemory("in")
+ .HostMemory("out")
+ .Build();
+ KernelDef expected;
+ protobuf::TextFormat::ParseFromString(
+ "op: 'E' device_type: 'GPU' "
+ "host_memory_arg: ['in', 'out']",
+ &expected);
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc
new file mode 100644
index 0000000000..c660b84aa0
--- /dev/null
+++ b/tensorflow/core/framework/lookup_interface.cc
@@ -0,0 +1,45 @@
+#include "tensorflow/core/framework/lookup_interface.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace lookup {
+
+Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key,
+ const Tensor& value) {
+ if (key.dtype() != key_dtype()) {
+ return errors::InvalidArgument("Key must be type ", key_dtype(),
+ " but got ", key.dtype());
+ }
+ if (value.dtype() != value_dtype()) {
+ return errors::InvalidArgument("Value must be type ", value_dtype(),
+ " but got ", value.dtype());
+ }
+ if (key.NumElements() != value.NumElements()) {
+ return errors::InvalidArgument("Number of elements of key(",
+ key.NumElements(), ") and value(",
+ value.NumElements(), ") are different.");
+ }
+ if (!key.shape().IsSameSize(value.shape())) {
+ return errors::InvalidArgument("key and value have different shapes.");
+ }
+ return Status::OK();
+}
+
+Status LookupInterface::CheckFindArguments(const Tensor& key,
+ const Tensor& value,
+ const Tensor& default_value) {
+ TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value));
+
+ if (default_value.dtype() != value_dtype()) {
+ return errors::InvalidArgument("Default value must be type ", value_dtype(),
+ " but got ", default_value.dtype());
+ }
+ if (!TensorShapeUtils::IsScalar(default_value.shape())) {
+ return errors::InvalidArgument("Default values must be scalar.");
+ }
+ return Status::OK();
+}
+
+} // namespace lookup
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h
new file mode 100644
index 0000000000..d4036d2019
--- /dev/null
+++ b/tensorflow/core/framework/lookup_interface.h
@@ -0,0 +1,65 @@
+#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
+#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace lookup {
+
+// Lookup interface for batch lookups used by table lookup ops.
+class LookupInterface : public ResourceBase {
+ public:
+ // Performs batch lookups, for every element in the key tensor, Find returns
+ // the corresponding value into the values tensor.
+ // If an element is not present in the table, the given default value is used.
+
+ // For tables that require initialization, Find is available once the table
+ // is marked as initialized.
+
+ // Returns the following statuses:
+ // - OK: when the find finishes successfully.
+ // - FailedPrecondition: if the table is not initialized.
+ // - InvalidArgument: if any of the preconditions on the lookup key or value
+ // fails.
+ // - In addition, other implementations may provide another non-OK status
+ // specific to their failure modes.
+ virtual Status Find(const Tensor& keys, Tensor* values,
+ const Tensor& default_value) = 0;
+
+ // Returns the number of elements in the table.
+ virtual size_t size() const = 0;
+
+ // Returns the data type of the key.
+ virtual DataType key_dtype() const = 0;
+
+ // Returns the data type of the value.
+ virtual DataType value_dtype() const = 0;
+
+ string DebugString() override { return "A lookup table"; }
+
+ protected:
+ virtual ~LookupInterface() = default;
+
+ // Check format of the key and value tensors.
+ // Returns OK if all the following requirements are satisfied, otherwise it
+ // returns InvalidArgument:
+ // - DataType of the tensor key equals to the table key_dtype
+ // - DataType of the test value equals to the table value_dtype
+ // - key and value have the same size and shape
+ Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values);
+
+ // Check the arguments of a find operation. Returns OK if all the following
+ // requirements are satisfied, otherwise it returns InvalidArgument:
+ // - All requirements of CheckKeyAndValueTensors
+ // - default_value type equals to the table value_dtype
+ // - default_value is scalar
+ Status CheckFindArguments(const Tensor& keys, const Tensor& values,
+ const Tensor& default_value);
+};
+
+} // namespace lookup
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
new file mode 100644
index 0000000000..12757f153a
--- /dev/null
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -0,0 +1,194 @@
+#include "tensorflow/core/framework/node_def_builder.h"
+
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name,
+ const OpRegistryInterface* op_registry) {
+ node_def_.set_name(name);
+ Status status;
+ op_def_ = op_registry->LookUp(op_name, &status);
+ if (op_def_ == nullptr) {
+ errors_.push_back(status.error_message());
+ inputs_specified_ = 0;
+ } else {
+ Initialize();
+ }
+}
+
+NodeDefBuilder::NodeDefBuilder(const string& name, const OpDef* op_def)
+ : op_def_(op_def) {
+ node_def_.set_name(name);
+ Initialize();
+}
+
+void NodeDefBuilder::Initialize() {
+ inputs_specified_ = 0;
+ node_def_.set_op(op_def_->name());
+}
+
+const OpDef::ArgDef* NodeDefBuilder::NextArgDef() {
+ if (!NextArgAvailable()) return nullptr;
+ return &op_def_->input_arg(inputs_specified_++);
+}
+
+bool NodeDefBuilder::NextArgAvailable() {
+ if (op_def_ == nullptr) {
+ return false;
+ } else if (inputs_specified_ >= op_def_->input_arg_size()) {
+ errors_.push_back(strings::StrCat("More Input() calls than the ",
+ op_def_->input_arg_size(),
+ " input_args"));
+ return false;
+ }
+ return true;
+}
+
+NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
+ if (NextArgAvailable()) {
+ Status status =
+ fake_input(*op_def_, inputs_specified_, node_def_, this);
+ if (!status.ok()) errors_.push_back(status.error_message());
+ }
+ return *this;
+}
+
+void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
+ const string& src_node, int src_index,
+ DataType dt) {
+ AddInput(src_node, src_index);
+
+ if (!input_arg->number_attr().empty() ||
+ !input_arg->type_list_attr().empty()) {
+ errors_.push_back(strings::StrCat("Single tensor passed to '",
+ input_arg->name(), "', expected list"));
+ return;
+ }
+
+ if (input_arg->type() != DT_INVALID) {
+ const DataType expected = MaybeAddRef(input_arg, input_arg->type());
+ VerifyInputType(input_arg, expected, dt);
+ } else {
+ VerifyInputRef(input_arg, dt);
+ Attr(input_arg->type_attr(), BaseType(dt));
+ }
+}
+
+void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
+ gtl::ArraySlice<NodeOut> src_list) {
+ for (const auto& node_out : src_list) {
+ AddInput(node_out.node, node_out.index);
+ }
+
+ if (!input_arg->number_attr().empty()) {
+ Attr(input_arg->number_attr(), static_cast<int64>(src_list.size()));
+ if (input_arg->type() != DT_INVALID) {
+ const DataType expected = MaybeAddRef(input_arg, input_arg->type());
+ for (const auto& node_out : src_list) {
+ VerifyInputType(input_arg, expected, node_out.data_type);
+ }
+ } else if (!src_list.empty()) {
+ const DataType base = BaseType(src_list[0].data_type);
+ Attr(input_arg->type_attr(), base);
+ const DataType expected = MaybeAddRef(input_arg, base);
+ for (const auto& node_out : src_list) {
+ VerifyInputType(input_arg, expected, node_out.data_type);
+ }
+ }
+ } else if (!input_arg->type_list_attr().empty()) {
+ DataTypeVector type_vec;
+ type_vec.reserve(src_list.size());
+ for (const auto& node_out : src_list) {
+ const DataType dt = node_out.data_type;
+ VerifyInputRef(input_arg, dt);
+ type_vec.push_back(BaseType(dt));
+ }
+ Attr(input_arg->type_list_attr(), type_vec);
+ } else {
+ errors_.push_back(strings::StrCat("List provided to input '",
+ input_arg->name(),
+ "' when single Tensor expected"));
+ }
+}
+
+void NodeDefBuilder::AddInput(const string& src_node, int src_index) {
+ if (src_node.empty()) {
+ errors_.push_back("Empty input node name");
+ } else if (src_node[0] == '^') {
+ errors_.push_back(
+ strings::StrCat("Non-control input starting with ^: ", src_node));
+ } else if (src_index > 0) {
+ node_def_.add_input(strings::StrCat(src_node, ":", src_index));
+ } else {
+ node_def_.add_input(src_node);
+ }
+}
+
+void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg,
+ DataType expected, DataType dt) {
+ if (!TypesCompatible(expected, dt)) {
+ errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
+ DataTypeString(dt), " expected ",
+ DataTypeString(expected)));
+ }
+}
+
+void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
+ DataType dt) {
+ if (input_arg->is_ref() && !IsRefType(dt)) {
+ errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
+ DataTypeString(dt),
+ " expected ref type"));
+ }
+}
+
+Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
+ const std::vector<string>* errors_ptr = &errors_;
+ std::vector<string> errors_storage;
+ if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) {
+ // Since this is a const method, to add an error, we have to make
+ // a copy of the existing errors.
+ errors_storage = errors_;
+ errors_storage.push_back(
+ strings::StrCat(inputs_specified_, " inputs specified of ",
+ op_def_->input_arg_size(), " inputs in Op"));
+ errors_ptr = &errors_storage;
+ }
+
+ if (!errors_ptr->empty()) {
+ if (errors_ptr->size() == 1) {
+ if (op_def_ == nullptr) {
+ return errors::InvalidArgument((*errors_ptr)[0],
+ " while building NodeDef '",
+ node_def_.name(), "'");
+ }
+ return errors::InvalidArgument(
+ (*errors_ptr)[0], " while building NodeDef '", node_def_.name(),
+ "' using ", SummarizeOpDef(*op_def_));
+ } else {
+ return errors::InvalidArgument(
+ errors_ptr->size(), " errors while building NodeDef '",
+ node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n",
+ str_util::Join(*errors_ptr, "\n"));
+ }
+ } else {
+ NodeDef node_def_backup;
+ if (node_def == nullptr) node_def = &node_def_backup;
+ *node_def = node_def_;
+
+ // Add control inputs after the regular inputs.
+ for (const auto& control_input : control_inputs_) {
+ node_def->add_input(strings::StrCat("^", control_input));
+ }
+
+ // Add default values for unspecified attrs.
+ AddDefaultsToNodeDef(*op_def_, node_def);
+
+ return Status::OK();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h
new file mode 100644
index 0000000000..706f072608
--- /dev/null
+++ b/tensorflow/core/framework/node_def_builder.h
@@ -0,0 +1,176 @@
+#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
+#define TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
+
+#include <functional>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+class NodeDefBuilder;
+typedef std::function<Status(const OpDef&, int, const NodeDef&,
+ NodeDefBuilder*)> FakeInputFunctor;
+
+// This is a helper for creating a NodeDef. Automatically sets attrs
+// that can be inferred from the inputs, and uses default values
+// (where they exist) for unspecified attrs. Example usage:
+//
+// NodeDef node_def;
+// Status status = NodeDefBuilder(node_name, op_name)
+// .Input(...)
+// .Attr(...)
+// .Finalize(&node_def);
+// if (!status.ok()) return status;
+// // Use node_def here.
+class NodeDefBuilder {
+ public:
+ // To specify an output to be consumed by one of the Input() methods below.
+ struct NodeOut {
+ NodeOut(const string& n, int i, DataType dt)
+ : node(n), index(i), data_type(dt) {}
+ NodeOut() {} // uninitialized, call Reset() before use.
+ void Reset(const string& n, int i, DataType dt) {
+ node = n;
+ index = i;
+ data_type = dt;
+ }
+ string node;
+ int index;
+ DataType data_type;
+ };
+
+ // Specify the name and the Op (either via an OpDef or the name of
+ // the Op plus a registry) for the NodeDef. Other fields are
+ // specified by calling the methods below.
+ // REQUIRES: The OpDef must satisfy ValidateOpDef().
+ NodeDefBuilder(const string& name, const string& op_name,
+ const OpRegistryInterface* op_registry = OpRegistry::Global());
+ // REQUIRES: in addition, *op_def must outlive *this.
+ NodeDefBuilder(const string& name, const OpDef* op_def);
+
+ // You must call one Input() function per input_arg in the Op,
+ // *and in the same order as the input_args appear in the OpDef.*
+
+ // For inputs that take a single tensor.
+ NodeDefBuilder& Input(const string& src_node, int src_index, DataType dt) {
+ const OpDef::ArgDef* arg = NextArgDef();
+ if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
+ return *this;
+ }
+ NodeDefBuilder& Input(const NodeOut& src) {
+ Input(src.node, src.index, src.data_type);
+ return *this;
+ }
+
+ // For inputs that take a list of tensors.
+ NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list) {
+ const OpDef::ArgDef* arg = NextArgDef();
+ if (arg != nullptr) ListInput(arg, src_list);
+ return *this;
+ }
+
+ // To create inputs in tests, see fake_input.h.
+ NodeDefBuilder& Input(FakeInputFunctor fake_input);
+
+ // Specify that this node must only run after src_node.
+ NodeDefBuilder& ControlInput(const string& src_node) {
+ control_inputs_.push_back(src_node);
+ return *this;
+ }
+
+ // Constrains what devices this node may be scheduled on.
+ NodeDefBuilder& Device(const string& device_spec) {
+ node_def_.set_device(device_spec);
+ return *this;
+ }
+
+ // Sets the attr, if not already set. If already set with a different
+ // value, an error will be returned from Finalize().
+ template <class T>
+ NodeDefBuilder& Attr(const string& attr_name, T&& value);
+ // Note: overload needed to allow {...} expressions for value.
+ template <class T>
+ NodeDefBuilder& Attr(const string& attr_name,
+ std::initializer_list<T> value) {
+ Attr<std::initializer_list<T>>(attr_name, std::move(value));
+ return *this;
+ }
+
+ // Finish building the NodeDef, returning any errors or setting
+ // *node_def if none.
+ // WARNING: Not all problems are detected! The resulting NodeDef may
+ // not be valid! Call ValidateNodeDef() from node_def_utils to be sure.
+ Status Finalize(NodeDef* node_def) const;
+
+ // Accessor for the OpDef set in the constructor.
+ const OpDef& op_def() const { return *op_def_; }
+
+ private:
+ // Called in the constructors.
+ void Initialize();
+
+ // Get the current ArgDef and advance to the next one. Returns nullptr
+ // if no more inputs are available.
+ const OpDef::ArgDef* NextArgDef();
+
+ // Returns true if there is still an input_arg available in *op_def_,
+ // otherwise adds to error_ and returns false.
+ bool NextArgAvailable();
+
+ // These do the main work of the Input() methods.
+ void SingleInput(const OpDef::ArgDef* input_arg, const string& src_node,
+ int src_index, DataType dt);
+ void ListInput(const OpDef::ArgDef* input_arg,
+ gtl::ArraySlice<NodeOut> src_list);
+
+ // Add "src_node:src_index" to the list of inputs in the node_def_.
+ void AddInput(const string& src_node, int src_index);
+
+ // Generate an error if you can't pass dt when expected is expected.
+ void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected,
+ DataType dt);
+
+ // If input_arg->is_ref() is true, generate an error if dt is not a ref.
+ void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt);
+
+ // Makes dt a ref type if that is what the input_arg specifies.
+ DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) {
+ return input_arg->is_ref() ? MakeRefType(dt) : dt;
+ }
+
+ const OpDef* op_def_;
+ NodeDef node_def_;
+ int inputs_specified_;
+ std::vector<string> control_inputs_;
+ std::vector<string> errors_;
+};
+
+// IMPLEMENTATION -------------------------------------------------------------
+
+template <class T>
+NodeDefBuilder& NodeDefBuilder::Attr(const string& attr_name, T&& value) {
+ const AttrValue* found = AttrSlice(node_def_).Find(attr_name);
+ if (found == nullptr) {
+ AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
+ } else {
+ AttrValue attr_value;
+ SetAttrValue(std::forward<T>(value), &attr_value);
+ if (!AreAttrValuesEqual(*found, attr_value)) {
+ errors_.push_back(strings::StrCat(
+ "Inconsistent values for attr '", attr_name, "' ",
+ SummarizeAttrValue(*found), " vs. ", SummarizeAttrValue(attr_value)));
+ }
+ }
+ return *this;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc
new file mode 100644
index 0000000000..6fd4a8d1ed
--- /dev/null
+++ b/tensorflow/core/framework/node_def_builder_test.cc
@@ -0,0 +1,1036 @@
+#include "tensorflow/core/framework/node_def_builder.h"
+
+#include <memory>
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class NodeDefBuilderTest : public ::testing::Test {
+ protected:
+ // Specify an OpDef via an OpDefBuilder.
+ void Op(const OpDefBuilder& op_def_builder) {
+ EXPECT_OK(op_def_builder.Finalize(&op_def_));
+ }
+
+ // Resets builder_ with a new NodeDefBuilder using the Op from the last call
+ // to Op() above.
+ NodeDefBuilder& Builder() {
+ EXPECT_FALSE(op_def_.name().empty()) << "Must call Op() before Builder()";
+ builder_.reset(new NodeDefBuilder("n", &op_def_));
+ return *builder_;
+ }
+
+ // Calls Finalize() and verifies it returns success and the result matches
+ // expectations.
+ void ExpectSuccess(const NodeDefBuilder& builder,
+ DataTypeSlice expected_in_types,
+ DataTypeSlice expected_out_types, StringPiece proto) {
+ NodeDef node_def;
+ Status status = builder.Finalize(&node_def);
+ EXPECT_OK(status);
+ if (!status.ok()) return;
+ NodeDef expected;
+ protobuf::TextFormat::ParseFromString(strings::StrCat("name: 'n' ", proto),
+ &expected);
+ EXPECT_EQ(node_def.DebugString(), expected.DebugString());
+
+ DataTypeVector in_types, out_types;
+ status =
+ InOutTypesForNode(node_def, builder.op_def(), &in_types, &out_types);
+ EXPECT_OK(status);
+ if (!status.ok()) return;
+ EXPECT_EQ(DataTypeSliceString(expected_in_types),
+ DataTypeVectorString(in_types));
+ EXPECT_EQ(DataTypeSliceString(expected_out_types),
+ DataTypeVectorString(out_types));
+
+ status = ValidateNodeDef(node_def, op_def_);
+ EXPECT_OK(status);
+ }
+
+ // Calls Finalize() and verifies it returns an error.
+ // Each message must appear as a substring of the error.
+ void ExpectFailures(const NodeDefBuilder& builder,
+ const std::vector<string>& messages) {
+ NodeDef node_def;
+ Status status = builder.Finalize(&node_def);
+ EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
+ if (status.ok()) return;
+ for (const string& message : messages) {
+ EXPECT_TRUE(StringPiece(status.error_message()).contains(message))
+ << status << ", " << message;
+ }
+ }
+
+ // Calls Finalize() and verifies it returns an error.
+ // Message must appear as a substring of the error.
+ void ExpectFailure(const NodeDefBuilder& builder, const string& message) {
+ ExpectFailures(builder, {message});
+ }
+
+ // Like ExpectFailure(), except that the error can come from
+ // ValidateNodeDef().
+ void ExpectInvalid(const NodeDefBuilder& builder, const string& message) {
+ NodeDef node_def;
+ Status status = builder.Finalize(&node_def);
+ if (status.ok()) {
+ status = ValidateNodeDef(node_def, op_def_);
+ }
+ EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
+ if (status.ok()) return;
+ EXPECT_TRUE(StringPiece(status.error_message()).contains(message))
+ << "Actual error: " << status.error_message()
+ << "\nDoes not contain: " << message;
+ }
+
+ OpDef op_def_;
+ std::unique_ptr<NodeDefBuilder> builder_;
+};
+
+TEST_F(NodeDefBuilderTest, Simple) {
+ Op(OpDefBuilder("Simple").Input("a: int32").Output("out: float"));
+
+ ExpectSuccess(Builder().Input("x", 0, DT_INT32), {DT_INT32}, {DT_FLOAT},
+ R"proto( op: "Simple" input: "x" )proto");
+
+ // Port != 0
+ ExpectSuccess(Builder().Input("y", 2, DT_INT32), {DT_INT32}, {DT_FLOAT},
+ R"proto( op: "Simple" input: "y:2" )proto");
+
+ // FakeInput
+ ExpectSuccess(Builder().Input(FakeInput()), {DT_INT32}, {DT_FLOAT}, R"proto(
+ op: "Simple" input: "a" )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_FLOAT},
+ R"proto( op: "Simple" input: "a" )proto");
+
+ // Ref input
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32},
+ {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto");
+
+ // ControlInput
+ ExpectSuccess(
+ Builder().ControlInput("x").Input(FakeInput()).ControlInput("y"),
+ {DT_INT32}, {DT_FLOAT}, R"proto(
+ op: "Simple" input: ["a", "^x", "^y"] )proto");
+
+ // Device
+ ExpectSuccess(Builder().Input(FakeInput()).Device("ddd"), {DT_INT32},
+ {DT_FLOAT}, R"proto(
+ op: "Simple" input: "a" device: "ddd" )proto");
+
+ // Extra input
+ ExpectFailure(Builder().Input("x", 0, DT_INT32).Input("y", 0, DT_INT32),
+ "More Input() calls than the 1 input_args while building "
+ "NodeDef 'n' using Op<name=Simple; signature=a:int32 -> "
+ "out:float>");
+
+ // Missing input
+ ExpectFailure(Builder(), "0 inputs specified of 1 inputs in Op while");
+
+ { // Finalize() twice.
+ NodeDefBuilder& builder = Builder();
+ builder.Input(FakeInput()).Finalize(nullptr); // First call to Finalize()
+ // ExpectSuccess() also calls Finalize().
+ ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto(
+ op: "Simple" input: "a" )proto");
+ }
+
+ { // Input() after Finalize()
+ NodeDefBuilder& builder = Builder();
+ // Calling Finalize() before enough inputs -> error.
+ ExpectFailure(builder, "0 inputs specified of 1 inputs in Op while");
+ builder.Input(FakeInput());
+ // Calling Finalize() with enough inputs -> success
+ ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto(
+ op: "Simple" input: "a" )proto");
+ // Calling Finalize() with too many inputs -> error.
+ builder.Input(FakeInput(DT_INT32));
+ ExpectFailure(builder, "More Input() calls than the 1 input_args while");
+ }
+
+ // Wrong input type
+ ExpectFailure(Builder().Input("x", 0, DT_FLOAT),
+ "Input 'a' passed float expected int32 ");
+
+ ExpectFailure(Builder().Input("x", 0, DT_FLOAT_REF),
+ "Input 'a' passed float_ref expected int32 ");
+
+ // List input
+ ExpectFailure(Builder().Input(FakeInput(3, DT_FLOAT)),
+ "List provided to input 'a' when single Tensor expected while");
+
+ ExpectFailure(Builder().Input(FakeInput(3)),
+ "List provided to input 'a' when single Tensor expected while");
+
+ // Bad ControlInput
+ ExpectInvalid(Builder().Input(FakeInput()).ControlInput("z:2"),
+ "Control input '^z:2' must not have ':' in NodeDef:");
+
+ // Bad input name
+ ExpectFailure(Builder().Input("", 0, DT_INT32),
+ "Empty input node name while");
+
+ ExpectFailure(Builder().Input("^x", 0, DT_INT32),
+ "Non-control input starting with ^: ^x while");
+}
+
+TEST_F(NodeDefBuilderTest, OpDoesNotExist) {
+ NodeDefBuilder builder("n", "Op Does Not Exist");
+ builder.Input(FakeInput())
+ .Input(FakeInput(12))
+ .ControlInput("y")
+ .Attr("foo", 12)
+ .Device("device");
+ ExpectFailure(
+ builder,
+ "Op type not registered 'Op Does Not Exist' while building NodeDef 'n'");
+}
+
+TEST_F(NodeDefBuilderTest, Polymorphic) {
+ Op(OpDefBuilder("Polymorphic")
+ .Input("v: T")
+ .Output("out: T")
+ .Attr("T: type"));
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_INT32},
+ R"proto(
+ op: "Polymorphic" input: "a"
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_FLOAT)), {DT_FLOAT}, {DT_FLOAT},
+ R"proto(
+ op: "Polymorphic" input: "a"
+ attr { key: "T" value { type: DT_FLOAT } } )proto");
+
+ // Redundant Attr()
+ ExpectSuccess(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_BOOL),
+ {DT_BOOL}, {DT_BOOL}, R"proto(
+ op: "Polymorphic" input: "a"
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ // Conficting Attr()
+ ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_STRING),
+ "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while");
+
+ ExpectFailure(Builder().Attr("T", DT_STRING).Input(FakeInput(DT_BOOL)),
+ "Inconsistent values for attr 'T' DT_STRING vs. DT_BOOL while");
+
+ ExpectFailure(Builder().Attr("T", 12).Input(FakeInput(DT_BOOL)),
+ "Inconsistent values for attr 'T' 12 vs. DT_BOOL while");
+}
+
+TEST_F(NodeDefBuilderTest, PolymorphicOut) {
+ Op(OpDefBuilder("PolymorphicOut").Output("out: T").Attr("T: type"));
+
+ ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32}, R"proto(
+ op: "PolymorphicOut"
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto(
+ op: "PolymorphicOut"
+ attr { key: "T" value { type: DT_FLOAT } } )proto");
+
+ // Redundant attr
+ ExpectSuccess(Builder().Attr("T", DT_FLOAT).Attr("T", DT_FLOAT), {},
+ {DT_FLOAT}, R"proto(
+ op: "PolymorphicOut"
+ attr { key: "T" value { type: DT_FLOAT } } )proto");
+
+ // Conflicting attr
+ ExpectFailure(Builder().Attr("T", DT_BOOL).Attr("T", DT_FLOAT),
+ "Inconsistent values for attr 'T' DT_BOOL vs. DT_FLOAT while");
+
+ // Missing attr
+ ExpectInvalid(Builder(), "NodeDef missing attr 'T' from");
+
+ // Attr has the wrong type
+ ExpectInvalid(Builder().Attr("T", {DT_INT32, DT_BOOL}),
+ "AttrValue had value with type list(type) when type expected");
+
+ ExpectInvalid(Builder().Attr("T", 12),
+ "AttrValue had value with type int when type expected");
+}
+
+TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) {
+ Op(OpDefBuilder("PolymorphicDefaultOut")
+ .Output("out: T")
+ .Attr("T: type = DT_STRING"));
+
+ ExpectSuccess(Builder(), {}, {DT_STRING}, R"proto(
+ op: "PolymorphicDefaultOut"
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", DT_BOOL), {}, {DT_BOOL}, R"proto(
+ op: "PolymorphicDefaultOut"
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, Binary) {
+ Op(OpDefBuilder("Binary").Input("a: T").Input("b: T").Output("out: T").Attr(
+ "T: type"));
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32)).Input(FakeInput(DT_INT32)),
+ {DT_INT32, DT_INT32}, {DT_INT32}, R"proto(
+ op: "Binary" input: "a" input: "b"
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_STRING)).Input(FakeInput()),
+ {DT_STRING, DT_STRING}, {DT_STRING}, R"proto(
+ op: "Binary" input: "a" input: "b"
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ // Type mismatch
+ ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Input(FakeInput(DT_STRING)),
+ "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while");
+}
+
+TEST_F(NodeDefBuilderTest, Restrict) {
+ Op(OpDefBuilder("Restrict")
+ .Input("a: T")
+ .Output("out: T")
+ .Attr("T: {string, bool}"));
+ ExpectSuccess(Builder().Input(FakeInput(DT_STRING)), {DT_STRING}, {DT_STRING},
+ R"proto(
+ op: "Restrict" input: "a"
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectInvalid(Builder().Input(FakeInput(DT_INT32)),
+ "Value for attr 'T' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, TypeList) {
+ Op(OpDefBuilder("TypeList").Input("a: T").Attr("T: list(type)"));
+
+ ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_INT32})),
+ {DT_STRING, DT_INT32}, {}, R"proto(
+ op: "TypeList" input: ["a", "a:1"]
+ attr { key: "T" value { list { type: [DT_STRING, DT_INT32] } } }
+ )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(3, DT_BOOL)),
+ {DT_BOOL, DT_BOOL, DT_BOOL}, {}, R"proto(
+ op: "TypeList" input: ["a", "a:1", "a:2"]
+ attr { key: "T" value { list { type: [DT_BOOL, DT_BOOL, DT_BOOL] } } }
+ )proto");
+
+ ExpectInvalid(Builder().Input(FakeInput(0)),
+ "Length for attr 'T' of 0 must be at least minimum 1");
+
+ ExpectInvalid(Builder().Input(FakeInput({})),
+ "Length for attr 'T' of 0 must be at least minimum 1");
+
+ ExpectInvalid(Builder().Input(FakeInput(DT_BOOL)),
+ "Single tensor passed to 'a', expected list while");
+
+ ExpectFailures(Builder().Input(FakeInput()),
+ {"2 errors while building NodeDef",
+ "Could not infer list of types for input 'a': "
+ "No attr named 'T' in NodeDef:",
+ "0 inputs specified of 1 inputs in Op"});
+}
+
+TEST_F(NodeDefBuilderTest, TypeListNoMin) {
+ Op(OpDefBuilder("TypeListNoMin").Input("a: T").Attr("T: list(type) >= 0"));
+
+ ExpectSuccess(Builder().Input(FakeInput(0)), {}, {}, R"proto(
+ op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(DataTypeVector())), {}, {}, R"proto(
+ op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput({})), {}, {}, R"proto(
+ op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput({DT_BOOL})), {DT_BOOL}, {}, R"proto(
+ op: "TypeListNoMin" input: "a"
+ attr { key: "T" value { list { type: DT_BOOL } } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, TypeListTwice) {
+ Op(OpDefBuilder("TypeListTwice")
+ .Input("a: T")
+ .Input("b: T")
+ .Attr("T: list(type) >= 0"));
+
+ ExpectSuccess(Builder()
+ .Input(FakeInput({DT_INT32, DT_BOOL}))
+ .Input(FakeInput({DT_INT32, DT_BOOL})),
+ {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto(
+ op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"]
+ attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto");
+
+ ExpectSuccess(
+ Builder().Input(FakeInput({DT_INT32, DT_BOOL})).Input(FakeInput()),
+ {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto(
+ op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"]
+ attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(0)), {}, {},
+ R"proto(
+ op: "TypeListTwice" attr { key: "T" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {},
+ R"proto(
+ op: "TypeListTwice" attr { key: "T" value { list { } } } )proto");
+
+ ExpectFailure(Builder()
+ .Input(FakeInput({DT_INT32, DT_BOOL}))
+ .Input(FakeInput({DT_INT32, DT_STRING})),
+ "Inconsistent values for attr 'T' [DT_INT32, DT_BOOL] vs. "
+ "[DT_INT32, DT_STRING] while");
+}
+
+TEST_F(NodeDefBuilderTest, OutTypeList) {
+ Op(OpDefBuilder("OutTypeList").Output("out: T").Attr("T: list(type) >= 0"));
+
+ ExpectSuccess(Builder().Attr("T", {DT_FLOAT}), {}, {DT_FLOAT}, R"proto(
+ op: "OutTypeList"
+ attr { key: "T" value { list { type: DT_FLOAT } } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", {DT_STRING, DT_BOOL}), {},
+ {DT_STRING, DT_BOOL}, R"proto(
+ op: "OutTypeList"
+ attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", DataTypeVector()), {}, {}, R"proto(
+ op: "OutTypeList"
+ attr { key: "T" value { list { } } } )proto");
+
+ ExpectInvalid(Builder().Attr("T", DT_FLOAT),
+ "AttrValue had value with type type when list(type) expected");
+}
+
+TEST_F(NodeDefBuilderTest, TypeListRestrict) {
+ Op(OpDefBuilder("TypeListRestrict")
+ .Input("a: T")
+ .Attr("T: list({string, bool}) >= 0"));
+
+ ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_BOOL})),
+ {DT_STRING, DT_BOOL}, {}, R"proto(
+ op: "TypeListRestrict" input: ["a", "a:1"]
+ attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto");
+
+ ExpectInvalid(Builder().Input(FakeInput({DT_STRING, DT_INT32})),
+ "Value for attr 'T' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, OutTypeListRestrict) {
+ Op(OpDefBuilder("OutTypeListRestrict")
+ .Output("out: t")
+ .Attr("t: list({string, bool}) >= 0"));
+
+ ExpectSuccess(Builder().Attr("t", {DT_BOOL, DT_STRING}), {},
+ {DT_BOOL, DT_STRING}, R"proto(
+ op: "OutTypeListRestrict"
+ attr { key: "t" value { list { type: [DT_BOOL, DT_STRING] } } } )proto");
+
+ ExpectInvalid(Builder().Attr("t", {DT_STRING, DT_INT32}),
+ "Value for attr 't' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, Attr) {
+ Op(OpDefBuilder("Attr").Attr("a: int"));
+
+ ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto(
+ op: "Attr" attr { key: "a" value { i: 12 } } )proto");
+
+ // Attr has wrong type
+ ExpectInvalid(Builder().Attr("a", "bad"),
+ "AttrValue had value with type string when int expected");
+
+ ExpectInvalid(Builder().Attr("a", {12}),
+ "AttrValue had value with type list(int) when int expected");
+
+ // Missing attr
+ ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<");
+
+ // Wrong attr
+ ExpectInvalid(Builder().Attr("b", 12),
+ "NodeDef mentions attr 'b' not in Op<");
+
+ // Extra attr
+ ExpectInvalid(Builder().Attr("a", 12).Attr("extra", 12),
+ "NodeDef mentions attr 'extra' not in Op<");
+}
+
+TEST_F(NodeDefBuilderTest, AttrFloat) {
+ Op(OpDefBuilder("AttrFloat").Attr("a: float"));
+
+ ExpectSuccess(Builder().Attr("a", 1.2f /* float */), {}, {}, R"proto(
+ op: "AttrFloat" attr { key: "a" value { f: 1.2 } }
+ )proto");
+
+ ExpectSuccess(Builder().Attr("a", 1.2 /* double */), {}, {}, R"proto(
+ op: "AttrFloat" attr { key: "a" value { f: 1.2 } }
+ )proto");
+
+ // Won't automatically cast int to float
+ ExpectInvalid(Builder().Attr("a", 12),
+ "AttrValue had value with type int when float expected");
+}
+
+TEST_F(NodeDefBuilderTest, AttrBoolList) {
+ Op(OpDefBuilder("AttrBoolList").Attr("a: list(bool)"));
+
+ ExpectSuccess(Builder().Attr("a", {true, false, true}), {}, {}, R"proto(
+ op: "AttrBoolList"
+ attr { key: "a" value { list { b: [true, false, true] } } }
+ )proto");
+
+ ExpectSuccess(Builder().Attr("a", std::vector<bool>()), {}, {}, R"proto(
+ op: "AttrBoolList" attr { key: "a" value { list { } } }
+ )proto");
+
+ // Won't cast int -> bool.
+ ExpectInvalid(Builder().Attr("a", {0}),
+ "AttrValue had value with type list(int) when list(bool) "
+ "expected");
+}
+
+TEST_F(NodeDefBuilderTest, AttrMin) {
+ Op(OpDefBuilder("AttrMin").Attr("a: int >= 5"));
+
+ ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto(
+ op: "AttrMin" attr { key: "a" value { i: 12 } } )proto");
+
+ ExpectInvalid(Builder().Attr("a", 2),
+ "Value for attr 'a' of 2 must be at least minimum 5");
+}
+
+TEST_F(NodeDefBuilderTest, AttrListMin) {
+ Op(OpDefBuilder("AttrListMin").Attr("a: list(int) >= 2"));
+
+ ExpectSuccess(Builder().Attr("a", {1, 2}), {}, {}, R"proto(
+ op: "AttrListMin"
+ attr { key: "a" value { list { i: [1, 2] } } } )proto");
+
+ ExpectInvalid(Builder().Attr("a", {17}),
+ "Length for attr 'a' of 1 must be at least minimum 2");
+}
+
+TEST_F(NodeDefBuilderTest, AttrEnum) {
+ Op(OpDefBuilder("AttrEnum").Attr("a: {'apples', 'oranges'}"));
+
+ ExpectSuccess(Builder().Attr("a", "oranges"), {}, {}, R"proto(
+ op: "AttrEnum"
+ attr { key: "a" value { s: "oranges" } } )proto");
+
+ ExpectInvalid(
+ Builder().Attr("a", "invalid"),
+ "Value for attr 'a' of \"invalid\" is not in the list of allowed values: "
+ "\"apples\", \"oranges\"");
+}
+
+TEST_F(NodeDefBuilderTest, AttrEnumList) {
+ Op(OpDefBuilder("AttrEnumList").Attr("a: list({'apples', 'oranges'})"));
+
+ ExpectSuccess(Builder().Attr("a", {"oranges", "apples"}), {}, {}, R"proto(
+ op: "AttrEnumList"
+ attr { key: "a" value { list { s: ["oranges", "apples"] } } } )proto");
+
+ ExpectInvalid(
+ Builder().Attr("a", {"apples", "invalid", "oranges"}),
+ "Value for attr 'a' of \"invalid\" is not in the list of allowed values: "
+ "\"apples\", \"oranges\"");
+}
+
+TEST_F(NodeDefBuilderTest, AttrShape) {
+ Op(OpDefBuilder("AttrShape").Attr("a: shape"));
+
+ ExpectSuccess(Builder().Attr("a", TensorShape({5})), {}, {}, R"proto(
+ op: "AttrShape"
+ attr { key: "a" value { shape { dim { size: 5 } } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", TensorShape({4, 3, 2})), {}, {}, R"proto(
+ op: "AttrShape"
+ attr { key: "a" value { shape {
+ dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", TensorShape({3, 2})), {}, {},
+ R"proto(
+ op: "AttrShape"
+ attr { key: "a" value { shape {
+ dim { size: 3 } dim { size: 2 } } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", TensorShape()), {}, {}, R"proto(
+ op: "AttrShape"
+ attr { key: "a" value { shape { } } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, AttrDefault) {
+ Op(OpDefBuilder("AttrDefault").Attr("a: string = 'banana'"));
+
+ ExpectSuccess(Builder(), {}, {}, R"proto(
+ op: "AttrDefault"
+ attr { key: "a" value { s: "banana" } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", "kiwi"), {}, {}, R"proto(
+ op: "AttrDefault"
+ attr { key: "a" value { s: "kiwi" } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, AttrManyDefault) {
+ Op(OpDefBuilder("AttrManyDefault")
+ .Attr("a: string = 'banana'")
+ .Attr("b: string = 'kiwi'"));
+
+ ExpectSuccess(Builder(), {}, {}, R"proto(
+ op: "AttrManyDefault"
+ attr { key: "a" value { s: "banana" } }
+ attr { key: "b" value { s: "kiwi" } } )proto");
+
+ Op(OpDefBuilder("AttrManyDefaultWithMandatory")
+ .Attr("a: string = 'banana'")
+ .Attr("b: string = 'kiwi'")
+ .Attr("c: string"));
+
+ ExpectSuccess(Builder().Attr("c", "strawberry"), {}, {}, R"proto(
+ op: "AttrManyDefaultWithMandatory"
+ attr { key: "c" value { s: "strawberry" } }
+ attr { key: "a" value { s: "banana" } }
+ attr { key: "b" value { s: "kiwi" } } )proto");
+
+ Op(OpDefBuilder("AttrManyDefaultAndInferred")
+ .Input("input: T")
+ .Attr("T: {float, double}")
+ .Attr("a: string")
+ .Attr("b: list(string) >= 1")
+ .Attr("c: bool = true")
+ .Attr("d: float = 0.3")
+ .Attr("e: string")
+ .Attr("f: float = 0.25"));
+
+ ExpectSuccess(Builder()
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("a", "foo")
+ .Attr("e", "foo")
+ .Attr("b", std::vector<string>({"bar", "baz"}))
+ .Attr("f", 1.0f),
+ {DT_FLOAT}, {}, R"proto(
+ op: "AttrManyDefaultAndInferred"
+ input: "a"
+ attr { key: "T" value { type: DT_FLOAT } }
+ attr { key: "a" value { s: "foo" } }
+ attr { key: "e" value { s: "foo" } }
+ attr { key: "b" value { list { s: "bar" s: "baz" } } }
+ attr { key: "f" value { f: 1.0 } }
+ attr { key: "c" value { b: true } }
+ attr { key: "d" value { f: 0.3 } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, AttrListDefault) {
+ Op(OpDefBuilder("AttrListDefault").Attr("a: list(int) = [5, 15]"));
+
+ ExpectSuccess(Builder(), {}, {}, R"proto(
+ op: "AttrListDefault"
+ attr { key: "a" value { list { i: [5, 15] } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto(
+ op: "AttrListDefault"
+ attr { key: "a" value { list { i: 3 } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto(
+ op: "AttrListDefault"
+ attr { key: "a" value { list { } } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, AttrEmptyListDefault) {
+ Op(OpDefBuilder("AttrEmptyListDefault").Attr("a: list(int) = []"));
+
+ ExpectSuccess(Builder(), {}, {}, R"proto(
+ op: "AttrEmptyListDefault"
+ attr { key: "a" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto(
+ op: "AttrEmptyListDefault"
+ attr { key: "a" value { list { i: 3 } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto(
+ op: "AttrEmptyListDefault"
+ attr { key: "a" value { list { } } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, NIntsIn) {
+ Op(OpDefBuilder("NIntsIn").Input("a: N*int32").Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2)), {DT_INT32, DT_INT32}, {},
+ R"proto(
+ op: "NIntsIn" input: ["a", "a:1"]
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(5, DT_INT32)),
+ {DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto(
+ op: "NIntsIn"
+ input: ["a", "a:1", "a:2", "a:3", "a:4"]
+ attr { key: "N" value { i: 5 } } )proto");
+
+ ExpectFailures(Builder().Input(FakeInput(2, DT_STRING)),
+ {"2 errors while building NodeDef",
+ "Input 'a' passed string expected int32"});
+
+ ExpectInvalid(Builder().Input(FakeInput(1)),
+ "Value for attr 'N' of 1 must be at least minimum 2");
+
+ ExpectFailures(
+ Builder().Input(FakeInput(DT_INT32)),
+ {"2 errors while building NodeDef",
+ "Could not infer length of input 'a': No attr named 'N' in NodeDef:",
+ "0 inputs specified of 1 inputs in Op"});
+
+ ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}),
+ "Input 'a' passed string expected int32 while");
+
+ ExpectFailures(
+ Builder().Input(FakeInput()),
+ {"2 errors while building NodeDef",
+ "Could not infer length of input 'a': No attr named 'N' in NodeDef:",
+ "0 inputs specified of 1 inputs in Op"});
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicIn) {
+ Op(OpDefBuilder("NPolymorphicIn")
+ .Input("a: N*T")
+ .Attr("T: type")
+ .Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)), {DT_INT32, DT_INT32},
+ {}, R"proto(
+ op: "NPolymorphicIn" input: ["a", "a:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)),
+ {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto(
+ op: "NPolymorphicIn"
+ input: ["a", "a:1", "a:2"]
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectFailures(
+ Builder().Input(FakeInput(2)),
+ {"2 errors while building NodeDef",
+ "Could not infer type for input 'a': No attr named 'T' in NodeDef:",
+ "0 inputs specified of 1 inputs in Op"});
+
+ ExpectFailure(Builder().Input(FakeInput({DT_INT32, DT_STRING})),
+ "Input 'a' passed string expected int32 while");
+
+ ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}),
+ "Input 'a' passed string expected int32 while");
+
+ ExpectInvalid(Builder().Input(FakeInput(1, DT_INT32)),
+ "Value for attr 'N' of 1 must be at least minimum 2");
+
+ ExpectFailure(Builder().Input("in", 0, DT_INT32),
+ "Single tensor passed to 'a', expected list while");
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicRestrictIn) {
+ Op(OpDefBuilder("NPolymorphicRestrictIn")
+ .Input("a: N*T")
+ .Attr("T: {string, bool}")
+ .Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2, DT_BOOL)), {DT_BOOL, DT_BOOL}, {},
+ R"proto(
+ op: "NPolymorphicRestrictIn" input: ["a", "a:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)),
+ {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto(
+ op: "NPolymorphicRestrictIn"
+ input: ["a", "a:1", "a:2"]
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectInvalid(Builder().Input(FakeInput(2, DT_INT32)),
+ "Value for attr 'T' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, NInTwice) {
+ Op(OpDefBuilder("NInTwice")
+ .Input("a: N*int32")
+ .Input("b: N*string")
+ .Attr("N: int >= 0"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2)).Input(FakeInput(2)),
+ {DT_INT32, DT_INT32, DT_STRING, DT_STRING}, {}, R"proto(
+ op: "NInTwice"
+ input: ["a", "a:1", "b", "b:1"]
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {},
+ R"proto(
+ op: "NInTwice" attr { key: "N" value { i: 0 } } )proto");
+
+ ExpectFailure(Builder().Input(FakeInput(3)).Input(FakeInput(1)),
+ "Inconsistent values for attr 'N' 3 vs. 1 while");
+}
+
+TEST_F(NodeDefBuilderTest, NInPolymorphicTwice) {
+ Op(OpDefBuilder("NInPolymorphicTwice")
+ .Input("a: N*T")
+ .Input("b: N*T")
+ .Attr("T: type")
+ .Attr("N: int >= 0"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput()),
+ {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto(
+ op: "NInPolymorphicTwice"
+ input: ["a", "a:1", "b", "b:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_INT32)),
+ "Inconsistent values for attr 'N' 3 vs. 1 while");
+
+ ExpectFailure(Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1)),
+ "Inconsistent values for attr 'N' 3 vs. 1 while");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)),
+ "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_STRING)),
+ "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while");
+}
+
+TEST_F(NodeDefBuilderTest, NInTwoTypeVariables) {
+ Op(OpDefBuilder("NInTwoTypeVariables")
+ .Input("a: N*S")
+ .Input("b: N*T")
+ .Attr("S: type")
+ .Attr("T: type")
+ .Attr("N: int >= 0"));
+
+ ExpectSuccess(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_BOOL)),
+ {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto(
+ op: "NInTwoTypeVariables"
+ input: ["a", "a:1", "b", "b:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "S" value { type: DT_INT32 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectSuccess(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_BOOL)),
+ {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto(
+ op: "NInTwoTypeVariables"
+ input: ["a", "a:1", "b", "b:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "S" value { type: DT_INT32 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_STRING)),
+ "Inconsistent values for attr 'N' 3 vs. 1 while");
+}
+
+TEST_F(NodeDefBuilderTest, InPolymorphicTwice) {
+ Op(OpDefBuilder("InPolymorphicTwice")
+ .Input("a: N*T")
+ .Input("b: M*T")
+ .Attr("T: type")
+ .Attr("N: int >= 0")
+ .Attr("M: int >= 0"));
+
+ ExpectSuccess(
+ Builder().Input(FakeInput(1, DT_INT32)).Input(FakeInput(3, DT_INT32)),
+ {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto(
+ op: "InPolymorphicTwice"
+ input: ["a", "b", "b:1", "b:2"]
+ attr { key: "N" value { i: 1 } }
+ attr { key: "T" value { type: DT_INT32 } }
+ attr { key: "M" value { i: 3 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(1, DT_BOOL)).Input(FakeInput(0)),
+ {DT_BOOL}, {}, R"proto(
+ op: "InPolymorphicTwice" input: "a"
+ attr { key: "N" value { i: 1 } }
+ attr { key: "T" value { type: DT_BOOL } }
+ attr { key: "M" value { i: 0 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(1, DT_BOOL)),
+ {DT_BOOL}, {}, R"proto(
+ op: "InPolymorphicTwice" input: "b"
+ attr { key: "N" value { i: 0 } }
+ attr { key: "M" value { i: 1 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)),
+ "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while");
+}
+
+TEST_F(NodeDefBuilderTest, NIntsOut) {
+ Op(OpDefBuilder("NIntsOut").Output("a: N*int32").Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto(
+ op: "NIntsOut"
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 3), {}, {DT_INT32, DT_INT32, DT_INT32},
+ R"proto(
+ op: "NIntsOut"
+ attr { key: "N" value { i: 3 } } )proto");
+
+ ExpectInvalid(Builder().Attr("N", 1),
+ "Value for attr 'N' of 1 must be at least minimum 2");
+
+ ExpectInvalid(Builder().Attr("N", {3}),
+ "AttrValue had value with type list(int) when int expected");
+
+ ExpectInvalid(Builder(), "NodeDef missing attr 'N' from");
+}
+
+TEST_F(NodeDefBuilderTest, NIntsOutDefault) {
+ Op(OpDefBuilder("NIntsOutDefault")
+ .Output("a: N*int32")
+ .Attr("N: int >= 2 = 3"));
+
+ ExpectSuccess(Builder(), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto(
+ op: "NIntsOutDefault"
+ attr { key: "N" value { i: 3 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto(
+ op: "NIntsOutDefault"
+ attr { key: "N" value { i: 2 } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicOut) {
+ Op(OpDefBuilder("NPolymorphicOut")
+ .Output("a: N*T")
+ .Attr("T: type")
+ .Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Attr("T", DT_INT32).Attr("N", 2), {},
+ {DT_INT32, DT_INT32}, R"proto(
+ op: "NPolymorphicOut"
+ attr { key: "T" value { type: DT_INT32 } }
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_STRING), {},
+ {DT_STRING, DT_STRING, DT_STRING}, R"proto(
+ op: "NPolymorphicOut"
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING),
+ "Value for attr 'N' of 1 must be at least minimum 2");
+
+ ExpectInvalid(Builder().Attr("N", 3).Attr("T", {DT_STRING}),
+ "AttrValue had value with type list(type) when type expected");
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) {
+ Op(OpDefBuilder("NPolymorphicOutDefault")
+ .Output("a: N*T")
+ .Attr("T: type = DT_BOOL")
+ .Attr("N: int >= 2 = 2"));
+
+ ExpectSuccess(Builder(), {}, {DT_BOOL, DT_BOOL}, R"proto(
+ op: "NPolymorphicOutDefault"
+ attr { key: "T" value { type: DT_BOOL } }
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 3), {}, {DT_BOOL, DT_BOOL, DT_BOOL},
+ R"proto(
+ op: "NPolymorphicOutDefault"
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32},
+ R"proto(
+ op: "NPolymorphicOutDefault"
+ attr { key: "T" value { type: DT_INT32 } }
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_INT32), {},
+ {DT_INT32, DT_INT32, DT_INT32}, R"proto(
+ op: "NPolymorphicOutDefault"
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicRestrictOut) {
+ Op(OpDefBuilder("NPolymorphicRestrictOut")
+ .Output("a: N*T")
+ .Attr("T: {string, bool}")
+ .Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_BOOL), {},
+ {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto(
+ op: "NPolymorphicRestrictOut"
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectInvalid(Builder().Attr("N", 3).Attr("T", DT_INT32),
+ "Value for attr 'T' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, RefIn) {
+ Op(OpDefBuilder("RefIn").Input("a: Ref(int32)"));
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32_REF}, {},
+ R"proto(
+ op: "RefIn" input: "a" )proto");
+
+ ExpectFailure(Builder().Input(FakeInput(DT_BOOL_REF)),
+ "Input 'a' passed bool_ref expected int32_ref while");
+
+ ExpectFailure(Builder().Input(FakeInput(DT_INT32)),
+ "Input 'a' passed int32 expected int32_ref while");
+}
+
+TEST_F(NodeDefBuilderTest, PolymorphicRefIn) {
+ Op(OpDefBuilder("PolymorphicRefIn").Input("a: Ref(T)").Attr("T: type"));
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_BOOL_REF)), {DT_BOOL_REF}, {},
+ R"proto(
+ op: "PolymorphicRefIn" input: "a"
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectFailure(Builder().Input(FakeInput(DT_BOOL)),
+ "Input 'a' passed bool expected ref type while");
+}
+
+TEST_F(NodeDefBuilderTest, RefOut) {
+ Op(OpDefBuilder("RefOut").Output("a: Ref(string)"));
+
+ ExpectSuccess(Builder(), {}, {DT_STRING_REF}, R"proto(
+ op: "RefOut" )proto");
+}
+
+TEST_F(NodeDefBuilderTest, PolymorphicRefOut) {
+ Op(OpDefBuilder("PolymorphicRefOut").Output("a: Ref(t)").Attr("t: type"));
+
+ ExpectSuccess(Builder().Attr("t", DT_BOOL), {}, {DT_BOOL_REF}, R"proto(
+ op: "PolymorphicRefOut"
+ attr { key: "t" value { type: DT_BOOL } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, SpecifyDevice) {
+ Op(OpDefBuilder("SpecifyDevice"));
+
+ ExpectSuccess(Builder().Device("ADevice"), {}, {}, R"proto(
+ op: "SpecifyDevice" device: "ADevice" )proto");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
new file mode 100644
index 0000000000..aefd416187
--- /dev/null
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -0,0 +1,414 @@
+#include "tensorflow/core/framework/node_def_util.h"
+
+#include <algorithm>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+string SummarizeNodeDef(const NodeDef& node_def) {
+ string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
+
+ // We sort the attrs so the output is deterministic.
+ std::vector<string> attr_names;
+ attr_names.reserve(node_def.attr().size());
+ for (const auto& attr : node_def.attr()) {
+ attr_names.push_back(attr.first);
+ }
+ std::sort(attr_names.begin(), attr_names.end());
+ bool first = true;
+ for (const string& attr_name : attr_names) {
+ if (!first) strings::StrAppend(&ret, ", ");
+ first = false;
+ auto iter = node_def.attr().find(attr_name);
+ strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second));
+ }
+
+ // Consider the device to be a final attr with name "_device".
+ if (!node_def.device().empty()) {
+ if (!first) strings::StrAppend(&ret, ", ");
+ first = false;
+ strings::StrAppend(&ret, "_device=\"", node_def.device(), "\"");
+ }
+ strings::StrAppend(&ret, "](");
+
+ // Output inputs, including control inputs, verbatim.
+ first = true;
+ for (const string& input : node_def.input()) {
+ if (!first) strings::StrAppend(&ret, ", ");
+ first = false;
+ strings::StrAppend(&ret, input);
+ }
+ strings::StrAppend(&ret, ")");
+ return ret;
+}
+
+const AttrValue* AttrSlice::Find(const string& attr_name) const {
+ auto iter = attrs_->find(attr_name);
+ if (iter == attrs_->end()) return nullptr;
+ return &iter->second;
+}
+
+Status AttrSlice::Find(const string& attr_name,
+ const AttrValue** attr_value) const {
+ *attr_value = Find(attr_name);
+ if (*attr_value != nullptr) {
+ return Status::OK();
+ }
+ Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
+ if (ndef_) {
+ s = AttachDef(s, *ndef_);
+ }
+ return s;
+}
+
+// The ... is to allow the caller to inject some value validation code. Use
+// just ; if no additional validation code is needed.
+#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
+ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ TYPE* value) { \
+ const AttrValue* attr_value; \
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
+ const auto& v = attr_value->FIELD(); \
+ __VA_ARGS__; \
+ *value = CAST; \
+ return Status::OK(); \
+ } \
+ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ std::vector<TYPE>* value) { \
+ const AttrValue* attr_value; \
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
+ for (const auto& v : attr_value->list().FIELD()) { \
+ __VA_ARGS__; \
+ value->APPEND_OP(CAST); \
+ } \
+ return Status::OK(); \
+ }
+
+DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
+DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
+DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast<int32>(v),
+ if (static_cast<int64>(static_cast<int32>(v)) != v) {
+ return errors::InvalidArgument("Attr ", attr_name,
+ " has value ", v,
+ " out of range for an int32");
+ })
+DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
+// std::vector<bool> specialization does not have emplace_back until
+// c++14, so we have to use push_back (see
+// http://en.cppreference.com/w/cpp/container/vector/emplace_back)
+DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
+DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
+ ;)
+DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
+DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v), ;)
+DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
+ if (!t.FromProto(v)) {
+ return errors::InvalidArgument(
+ "Attr ", attr_name, " has value ", v.ShortDebugString(),
+ " that can't be converted to a Tensor");
+ })
+
+#undef DEFINE_GET_ATTR
+
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ DataTypeVector* value) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
+ for (const auto& v : attr_value->list().type()) {
+ value->push_back(static_cast<DataType>(v));
+ }
+ return Status::OK();
+}
+
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const TensorProto** value) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
+ *value = &attr_value->tensor();
+ return Status::OK();
+}
+
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const NameAttrList** value) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
+ *value = &attr_value->func();
+ return Status::OK();
+}
+
+namespace { // Helper for InOutTypesForNode().
+
+Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
+ DataTypeVector* sig) {
+ const int original_size = sig->size();
+ if (!arg_def.number_attr().empty()) {
+ // Same type repeated "repeats" times.
+ int32 repeats = -1;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats));
+ if (repeats < 0) {
+ return errors::InvalidArgument("Value for number_attr() ", repeats,
+ " < 0");
+ }
+
+ if (!arg_def.type_attr().empty()) {
+ DataType dtype;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype));
+ for (int i = 0; i < repeats; ++i) {
+ sig->push_back(dtype);
+ }
+ } else if (arg_def.type() != DT_INVALID) {
+ for (int i = 0; i < repeats; ++i) {
+ sig->push_back(arg_def.type());
+ }
+ } else {
+ return errors::InvalidArgument("Missing type or type_attr field in ",
+ arg_def.ShortDebugString());
+ }
+ } else if (!arg_def.type_attr().empty()) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(
+ AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value));
+ sig->push_back(attr_value->type());
+ } else if (!arg_def.type_list_attr().empty()) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(
+ AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
+ for (int dtype : attr_value->list().type()) {
+ sig->push_back(static_cast<DataType>(dtype));
+ }
+ } else if (arg_def.type() != DT_INVALID) {
+ sig->push_back(arg_def.type());
+ } else {
+ return errors::InvalidArgument("No type fields in ",
+ arg_def.ShortDebugString());
+ }
+ if (arg_def.is_ref()) {
+ // For all types that were added by this function call, make them refs.
+ for (size_t i = original_size; i < sig->size(); ++i) {
+ (*sig)[i] = MakeRefType((*sig)[i]);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs, DataTypeVector* outputs) {
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+ }
+ for (const auto& arg : op_def.output_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
+ }
+ return Status::OK();
+}
+
+Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
+ if (node_def.op() != op_def.name()) {
+ return errors::InvalidArgument("NodeDef op '", node_def.op(),
+ "' does not match ", SummarizeOpDef(op_def),
+ "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+
+ bool seen_control = false;
+ size_t num_inputs = 0;
+ // TODO(josh11b): Unify the input field validation.
+ for (const string& input : node_def.input()) {
+ if (StringPiece(input).starts_with("^")) {
+ seen_control = true;
+ if (input.find(':') != string::npos) {
+ return errors::InvalidArgument("Control input '", input,
+ "' must not have ':' in NodeDef: ",
+ SummarizeNodeDef(node_def));
+ }
+ } else if (seen_control) {
+ return errors::InvalidArgument("Non-control input '", input,
+ "' after control input in NodeDef: ",
+ SummarizeNodeDef(node_def));
+ } else {
+ ++num_inputs;
+ }
+ }
+
+ std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
+ for (const auto& attr : op_def.attr()) {
+ if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
+ return errors::InvalidArgument("OpDef has duplicate attr name '",
+ attr.name(), "': ",
+ SummarizeOpDef(op_def));
+ }
+ }
+ for (const auto& attr : node_def.attr()) {
+ // Allow internal optional attributes with names starting with "_".
+ if (StringPiece(attr.first).starts_with("_")) {
+ continue;
+ }
+ auto iter = op_attrs.find(attr.first);
+ if (iter == op_attrs.end()) {
+ return errors::InvalidArgument("NodeDef mentions attr '", attr.first,
+ "' not in ", SummarizeOpDef(op_def),
+ "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ",
+ SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def));
+ // Keep track of which attr names have (not) been found in the NodeDef.
+ op_attrs.erase(iter);
+ }
+
+ // Were all attrs in the OpDef found in the NodeDef?
+ if (!op_attrs.empty()) {
+ string attrs;
+ for (const auto& attr_pair : op_attrs) {
+ if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
+ strings::StrAppend(&attrs, attr_pair.first);
+ }
+ return errors::InvalidArgument("NodeDef missing attr",
+ op_attrs.size() == 1 ? " '" : "s '", attrs,
+ "' from ", SummarizeOpDef(op_def),
+ "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+
+ // Validate the number of inputs.
+ DataTypeVector inputs, outputs;
+ TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
+
+ if (num_inputs != inputs.size()) {
+ return errors::InvalidArgument(
+ "NodeDef expected inputs '", DataTypeVectorString(inputs),
+ "' do not match ", num_inputs, " inputs specified; ",
+ SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+
+ return Status::OK();
+}
+
+namespace { // Helpers for NameRangesForNode()
+
+Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
+ const OpDef& op_def, int* num) {
+ if (!arg_def.number_attr().empty()) {
+ // Same type repeated "num" times.
+ return GetNodeAttr(node_def, arg_def.number_attr(), num);
+ } else if (!arg_def.type_list_attr().empty()) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(
+ AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
+ *num = attr_value->list().type_size();
+ } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
+ *num = 1;
+ } else {
+ return errors::InvalidArgument("Argument '", arg_def.name(),
+ "' incorrectly specified in op definition: ",
+ SummarizeOpDef(op_def));
+ }
+ return Status::OK();
+}
+
+Status NameRangesHelper(const NodeDef& node_def,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ const OpDef& op_def, NameRangeMap* result) {
+ int start = 0;
+ int num;
+ for (const auto& arg : args) {
+ TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num));
+ (*result)[arg.name()] = std::make_pair(start, start + num);
+ start += num;
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
+ NameRangeMap* inputs, NameRangeMap* outputs) {
+ TF_RETURN_IF_ERROR(
+ NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs));
+ return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs);
+}
+
+void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
+ for (const auto& attr_def : op_def.attr()) {
+ AttrSlice attrs(*node_def);
+ if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
+ AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
+ }
+ }
+}
+
+namespace {
+
+static RE2* valid_op_name_pattern = new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+static RE2* valid_data_input_pattern =
+ new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*(\\:(0|([1-9][0-9]*)))?");
+static RE2* valid_control_input_pattern =
+ new RE2("\\^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+
+} // namespace
+
+Status ValidateOpInput(const string& input_name, bool* is_control_input) {
+ *is_control_input = false;
+ if (RE2::FullMatch(input_name, *valid_data_input_pattern)) {
+ return Status::OK();
+ } else if (RE2::FullMatch(input_name, *valid_control_input_pattern)) {
+ *is_control_input = true;
+ return Status::OK();
+ } else {
+ return errors::InvalidArgument("Illegal op input name '", input_name, "'");
+ }
+}
+
+Status ValidateOpName(const string& op_name) {
+ if (RE2::FullMatch(op_name, *valid_op_name_pattern)) {
+ return Status::OK();
+ } else {
+ return errors::InvalidArgument("Illegal op name '", op_name, "'");
+ }
+}
+
+Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
+ Status s = ValidateOpName(node_def.name());
+ if (!s.ok()) {
+ return AttachDef(s, node_def);
+ }
+ bool in_control_inputs = false;
+ for (const string& input_name : node_def.input()) {
+ bool is_control_input;
+ s = ValidateOpInput(input_name, &is_control_input);
+ if (!s.ok()) {
+ return AttachDef(s, node_def);
+ }
+
+ if (in_control_inputs && !is_control_input) {
+ return AttachDef(errors::InvalidArgument(
+ "All control inputs must follow all data inputs"),
+ node_def);
+ }
+ in_control_inputs = is_control_input;
+ }
+ return Status::OK();
+}
+
+Status AttachDef(const Status& status, const NodeDef& node_def) {
+ Status ret = status;
+ errors::AppendToMessage(
+ &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]"));
+ return ret;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
new file mode 100644
index 0000000000..fce6fd2433
--- /dev/null
+++ b/tensorflow/core/framework/node_def_util.h
@@ -0,0 +1,157 @@
+#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// Produce a human-readable version of a NodeDef that is more concise
+// than a text-format proto.
+string SummarizeNodeDef(const NodeDef& node_def);
+
+typedef protobuf::Map<string, AttrValue> AttrValueMap;
+
+// Adds an attr with name <name> and value <value> to *node_def.
+// The type of the attr is based on the type of value.
+template <class T>
+void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) {
+ AttrValue attr_value;
+ SetAttrValue(std::forward<T>(value), &attr_value);
+ node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+}
+
+// Version to workaround C++'s "perfect" forwarding not being able to
+// forward {...} initialization.
+template <class T>
+void AddNodeAttr(const string& name, std::initializer_list<T> value,
+ NodeDef* node_def) {
+ AttrValue attr_value;
+ SetAttrValue(value, &attr_value);
+ node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+}
+
+class AttrSlice {
+ public:
+ AttrSlice(const NodeDef& node_def) // NOLINT(runtime/explicit)
+ : ndef_(&node_def),
+ attrs_(&ndef_->attr()) {}
+
+ explicit AttrSlice(const AttrValueMap* a) : attrs_(a) {}
+
+ // Returns the attr with attr_name if found. Otherwise, returns
+ // nullptr.
+ const AttrValue* Find(const string& attr_name) const;
+
+ // Returns the attr_value for attr_name if found. Otherwise, returns a
+ // NotFound status.
+ Status Find(const string& attr_name, const AttrValue** attr_value) const;
+
+ private:
+ const NodeDef* ndef_ = nullptr;
+ const AttrValueMap* attrs_;
+};
+
+// Look up the attr with name attr_name and set *value to its value. If no
+// attr with attr_name is found in node_def, or the attr does not have
+// a matching type, a non-ok status will be returned.
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ string* value); // type: "string"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ int64* value); // type: "int"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ int32* value); // type: "int"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ float* value); // type: "float"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ bool* value); // type: "bool"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ DataType* value); // type: "type"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ TensorShapeProto* value); // type: "shape"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ TensorShape* value); // type: "shape"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ Tensor* value); // type: "tensor"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<string>* value); // type "list(string)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<int64>* value); // type "list(int)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<int32>* value); // type "list(int)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<float>* value); // type "list(float)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<bool>* value); // type "list(bool)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<DataType>* value); // type "list(type)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ DataTypeVector* value); // type "list(type)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<TensorShapeProto>* value); // type "list(shape)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<TensorShape>* value); // type "list(shape)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<Tensor>* value); // type: "list(tensor)"
+
+// This version avoids copying the TensorProto.
+// REQUIRES: Must not use *value beyond the lifetime of node_def.
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const TensorProto** value); // type: "tensor"
+
+// This version avoids copying the NameAttrList.
+// REQUIRES: Must not use *value beyond the lifetime of node_def.
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const NameAttrList** value); // type: "func"
+
+// Computes the input and output types for a specific node, for
+// attr-style ops.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs, DataTypeVector* outputs);
+
+// Validates that the NodeDef:
+// * Defines all expected attrs from the OpDef.
+// * All attrs satisfies constraints from the OpDef.
+// * Has a signature matching SignatureForNode().
+// etc.
+Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
+
+// Computes the mapping from input/output argument name to the
+// corresponding input/output index range. For example,
+// input "foo" coresponds to input indices
+// [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
+typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap;
+Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
+ NameRangeMap* inputs, NameRangeMap* outputs);
+
+// Adds default values to *node_def for unspecified attrs from op_def.
+void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
+
+// Validates the syntax of a NodeDef provided externally.
+//
+// The following is an EBNF-style syntax for NodeDef objects. Note that
+// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
+// which contain many other fields that are not (currently) validated.
+//
+// Node = NodeName, Inputs
+// Inputs = ( DataInput * ), ( ControlInput * )
+// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
+// ControlInput = "^", NodeName
+// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
+Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
+
+// Returns "status" with kernel's NodeDef attached as additional text
+// in the error message.
+Status AttachDef(const Status& status, const NodeDef& node_def);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
new file mode 100644
index 0000000000..71f1760a09
--- /dev/null
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -0,0 +1,442 @@
+#include "tensorflow/core/framework/node_def_util.h"
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+OpDef ToOpDef(const OpDefBuilder& builder) {
+ OpDef op_def;
+ EXPECT_OK(builder.Finalize(&op_def));
+ return op_def;
+}
+
+NodeDef ToNodeDef(const string& text) {
+ NodeDef node_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
+ return node_def;
+}
+
+NodeDef ToNodeDef(const NodeDefBuilder& builder) {
+ NodeDef node_def;
+ EXPECT_OK(builder.Finalize(&node_def));
+ return node_def;
+}
+
+void ExpectSuccess(const NodeDef& good, const OpDef& op_def) {
+ EXPECT_EQ(Status::OK(), ValidateNodeDef(good, op_def))
+ << "NodeDef: " << SummarizeNodeDef(good)
+ << "; OpDef: " << SummarizeOpDef(op_def);
+}
+
+void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
+ const string& message) {
+ Status status = ValidateNodeDef(bad, op_def);
+
+ EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad)
+ << "; OpDef: " << SummarizeOpDef(op_def);
+ if (status.ok()) return;
+
+ EXPECT_TRUE(errors::IsInvalidArgument(status))
+ << status << "; NodeDef: " << SummarizeNodeDef(bad)
+ << "; OpDef: " << SummarizeOpDef(op_def);
+
+ LOG(INFO) << "Message: " << status.error_message();
+ EXPECT_TRUE(StringPiece(status.ToString()).contains(message))
+ << "NodeDef: " << SummarizeNodeDef(bad)
+ << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
+ << "\nDoes not contain: " << message;
+}
+
+TEST(NodeDefUtilTest, In) {
+ const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type"));
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'In' input:'a' attr { key:'T' value { type:DT_FLOAT } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
+
+ // Mismatching Op names.
+ NodeDef bad = node_def;
+ bad.set_op("Wrong");
+ ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op<name=In;");
+
+ // Missing attr
+ bad = node_def;
+ bad.clear_attr();
+ ExpectFailure(bad, op, "NodeDef missing attr 'T' from Op<name=In;");
+
+ // Extra attr
+ bad = node_def;
+ AddNodeAttr("EXTRA", 17, &bad);
+ ExpectFailure(bad, op, "NodeDef mentions attr 'EXTRA' not in Op<name=In;");
+
+ // Attr has wrong type
+ bad = node_def;
+ bad.clear_attr();
+ AddNodeAttr("T", 17, &bad);
+ ExpectFailure(
+ bad, op,
+ "AttrValue had value with type int when type expected\n\t for attr "
+ "'T'\n\t; NodeDef: ");
+
+ // Wrong number of inputs
+ bad = node_def;
+ bad.add_input("b");
+ ExpectFailure(
+ bad, op,
+ "NodeDef expected inputs 'float' do not match 2 inputs specified;");
+
+ bad = node_def;
+ bad.clear_input();
+ ExpectFailure(
+ bad, op,
+ "NodeDef expected inputs 'float' do not match 0 inputs specified;");
+
+ // Control inputs must appear after data inputs
+ NodeDef good = node_def;
+ good.add_input("^b");
+ ExpectSuccess(node_def, op);
+
+ bad = node_def;
+ bad.clear_input();
+ bad.add_input("^b");
+ bad.add_input("a");
+ ExpectFailure(bad, op,
+ "Invalid argument: Non-control input 'a' after control input "
+ "in NodeDef:");
+
+ bad = node_def;
+ bad.add_input("^b:0");
+ ExpectFailure(bad, op, "Control input '^b:0' must not have ':' in NodeDef:");
+}
+
+TEST(NodeDefUtilTest, Out) {
+ const OpDef op =
+ ToOpDef(OpDefBuilder("Out").Output("o: T").Attr("T: numbertype"));
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'Out' attr { key:'T' value { type:DT_INT32 } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
+
+ // Non-number type.
+ NodeDef bad = node_def;
+ bad.clear_attr();
+ AddNodeAttr("T", DT_STRING, &bad);
+ ExpectFailure(bad, op,
+ "Value for attr 'T' of string is not in the list of allowed "
+ "values: float, double, int64, int32, uint8, int16, int8, "
+ "complex64, qint8, quint8, qint32");
+}
+
+TEST(NodeDefUtilTest, Enum) {
+ const OpDef op = ToOpDef(OpDefBuilder("Enum").Attr("e: {'apple','orange'}"));
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'Enum' attr { key:'e' value { s:'apple' } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
+
+ NodeDef good = node_def;
+ good.clear_attr();
+ AddNodeAttr("e", "orange", &good);
+ ExpectSuccess(good, op);
+
+ // Non-allowed value.
+ NodeDef bad = node_def;
+ bad.clear_attr();
+ AddNodeAttr("e", "foo", &bad);
+ ExpectFailure(bad, op,
+ "Value for attr 'e' of \"foo\" is not in the list of allowed "
+ "values: \"apple\", \"orange\"");
+}
+
+TEST(NodeDefUtilTest, SameIn) {
+ const OpDef op = ToOpDef(OpDefBuilder("SameIn")
+ .Input("i: N * T")
+ .Attr("N: int >= 2")
+ .Attr("T: {float,double}"));
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'SameIn' input:'a' input:'b'
+ attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_DOUBLE } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = SameIn[N=2, T=DT_DOUBLE](a, b)", SummarizeNodeDef(node_def));
+
+ // Illegal type
+ NodeDef bad = ToNodeDef(R"proto(
+ name:'n' op:'SameIn' input:'a' input:'b'
+ attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_STRING } }
+ )proto");
+ ExpectFailure(bad, op,
+ "Value for attr 'T' of string is not in the list of allowed "
+ "values: float, double");
+
+ // Too few inputs
+ bad = ToNodeDef(R"proto(
+ name:'n' op:'SameIn' input:'a' input:'b'
+ attr { key:'N' value { i:1 } } attr { key:'T' value { type:DT_FLOAT } }
+ )proto");
+ ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2");
+}
+
+TEST(NodeDefUtilTest, AnyIn) {
+ const OpDef op =
+ ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1"));
+
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
+ SummarizeNodeDef(node_def));
+
+ const NodeDef bad = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' attr { key:'T' value { list { } } }
+ )proto");
+ ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1");
+
+ // With proto3 semantics, an empty value {} is indistinguishable from a value
+ // with an empty list in it. So we simply expect to get a message complaining
+ // about empty list for value {}.
+ const NodeDef bad2 = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' attr { key:'T' value { } }
+ )proto");
+ ExpectFailure(bad2, op,
+ "Length for attr 'T' of 0 must be at least minimum 1");
+}
+
+TEST(NodeDefUtilTest, Device) {
+ const OpDef op_def1 = ToOpDef(OpDefBuilder("None"));
+ const NodeDef node_def1 =
+ ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17"));
+ ExpectSuccess(node_def1, op_def1);
+ EXPECT_EQ("d = None[_device=\"/cpu:17\"]()", SummarizeNodeDef(node_def1));
+
+ const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
+ const NodeDef node_def2 =
+ ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5"));
+ ExpectSuccess(node_def2, op_def2);
+ EXPECT_EQ("d = WithAttr[v=7, _device=\"/cpu:5\"]()",
+ SummarizeNodeDef(node_def2));
+}
+
+void ExpectValidSyntax(const NodeDef& good) {
+ EXPECT_EQ(Status::OK(), ValidateExternalNodeDefSyntax(good))
+ << "NodeDef: " << SummarizeNodeDef(good);
+}
+
+void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
+ Status status = ValidateExternalNodeDefSyntax(bad);
+
+ ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad);
+
+ EXPECT_TRUE(errors::IsInvalidArgument(status))
+ << status << "; NodeDef: " << SummarizeNodeDef(bad);
+
+ EXPECT_TRUE(StringPiece(status.ToString()).contains(message))
+ << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
+ << message;
+}
+
+TEST(NodeDefUtilTest, ValidSyntax) {
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectValidSyntax(node_def);
+
+ const NodeDef node_def_explicit_inputs = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a:0' input:'b:123'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectValidSyntax(node_def_explicit_inputs);
+
+ EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
+ SummarizeNodeDef(node_def_explicit_inputs));
+
+ const NodeDef node_def_control_input = ToNodeDef(R"proto(
+ name:'n-' op:'AnyIn' input:'a' input:'^b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectValidSyntax(node_def_control_input);
+
+ const NodeDef node_def_invalid_name = ToNodeDef(R"proto(
+ name:'n:0' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'");
+
+ const NodeDef node_def_internal_name = ToNodeDef(R"proto(
+ name:'_n' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
+
+ const NodeDef node_def_internal_input_name = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'_a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_internal_input_name,
+ "Illegal op input name '_a'");
+
+ const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' input:'^b:0'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_invalid_control_input_name,
+ "Illegal op input name '^b:0'");
+
+ const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'^a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_data_input_after_control,
+ "All control inputs must follow all data inputs");
+}
+
+TEST(NameRangesForNodeTest, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ NameRangeMap inputs, outputs;
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
+
+ EXPECT_EQ("simple = Simple[](a, b)", SummarizeNodeDef(node_def));
+
+ OpDef bad_op_def = op_def;
+ bad_op_def.mutable_input_arg(0)->clear_type();
+ EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok());
+}
+
+TEST(NameRangesForNodeTest, Polymorphic) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic")
+ .Input("a: T")
+ .Input("b: T")
+ .Output("c: T")
+ .Attr("T: type"));
+ NameRangeMap inputs, outputs;
+ const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("poly", &op_def)
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32)));
+ EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
+ EXPECT_EQ("poly = Polymorphic[T=DT_INT32](a, b)",
+ SummarizeNodeDef(node_def1));
+
+ const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def)
+ .Input(FakeInput(DT_BOOL))
+ .Input(FakeInput(DT_BOOL)));
+ EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
+ EXPECT_EQ("poly = Polymorphic[T=DT_BOOL](a, b)", SummarizeNodeDef(node_def2));
+}
+
+TEST(NameRangesForNodeTest, NRepeats) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats")
+ .Input("a: N * int32")
+ .Input("b: N * T")
+ .Output("c: T")
+ .Output("d: N * string")
+ .Output("e: M * bool")
+ .Attr("N: int")
+ .Attr("M: int")
+ .Attr("T: type"));
+ NameRangeMap inputs, outputs;
+ const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("nr", &op_def)
+ .Input(FakeInput(4, DT_INT32))
+ .Input(FakeInput(4, DT_FLOAT))
+ .Attr("M", 3));
+ EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
+ outputs);
+ EXPECT_EQ(
+ "nr = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, b:2, b:3)",
+ SummarizeNodeDef(node_def1));
+
+ const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def)
+ .Input(FakeInput(2, DT_INT32))
+ .Input(FakeInput(2, DT_DOUBLE))
+ .Attr("M", 7));
+ EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
+ outputs);
+ EXPECT_EQ("nr = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
+ SummarizeNodeDef(node_def2));
+
+ NodeDef bad_node_def = node_def2;
+ bad_node_def.clear_attr();
+ EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
+}
+
+TEST(NameRangesForNodeTest, TypeList) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("TypeList")
+ .Input("a: T1")
+ .Input("b: T2")
+ .Output("c: T2")
+ .Output("d: T3")
+ .Output("e: T1")
+ .Attr("T1: list(type)")
+ .Attr("T2: list(type)")
+ .Attr("T3: list(type)"));
+ NameRangeMap inputs, outputs;
+ const NodeDef node_def1 =
+ ToNodeDef(NodeDefBuilder("tl", &op_def)
+ .Input(FakeInput({DT_BOOL, DT_FLOAT}))
+ .Input(FakeInput(4, DT_FLOAT))
+ .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING}));
+ EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
+ outputs);
+ EXPECT_EQ(
+ "tl = TypeList[T1=[DT_BOOL, DT_FLOAT],"
+ " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT],"
+ " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
+ SummarizeNodeDef(node_def1));
+
+ const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("tl", &op_def)
+ .Input(FakeInput(7, DT_INT32))
+ .Input(FakeInput({DT_DOUBLE}))
+ .Attr("T3", {DT_DOUBLE, DT_STRING}));
+ EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
+ outputs);
+ EXPECT_EQ(
+ "tl = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32,"
+ " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]"
+ "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)",
+ SummarizeNodeDef(node_def2));
+
+ NodeDef bad_node_def = node_def2;
+ bad_node_def.clear_attr();
+ EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h
new file mode 100644
index 0000000000..8413d18f33
--- /dev/null
+++ b/tensorflow/core/framework/numeric_op.h
@@ -0,0 +1,96 @@
+#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+#define TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// One input and one output, both the same type.
+template <class T>
+class UnaryOp : public OpKernel {
+ public:
+ explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt}));
+ }
+};
+
+// Two inputs and one output, all the same type.
+template <class T>
+class BinaryOp : public OpKernel {
+ public:
+ explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt}));
+ }
+};
+
+// For operations where the input and output are the same shape.
+//
+// For usage, see ../framework/elementwise_ops.cc.
+template <class T, class CHILD>
+class UnaryElementWiseOp : public UnaryOp<T> {
+ public:
+ using UnaryOp<T>::UnaryOp;
+
+ void Compute(OpKernelContext* context) override {
+ // Output shape is the same as input shape.
+ const Tensor& input = context->input(0);
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ static_cast<CHILD*>(this)->Operate(context, input, output);
+ }
+};
+
+// For binary elementwise operations.
+template <class T, class CHILD>
+class BinaryElementWiseOp : public BinaryOp<T> {
+ public:
+ using BinaryOp<T>::BinaryOp;
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& a = context->input(0);
+ const Tensor& b = context->input(1);
+
+ if (!context->ValidateInputsAreSameShape(this)) {
+ return;
+ }
+
+ Tensor* output;
+ OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(), &output));
+
+ // Dispatch to the descendant's Operate() function.
+ switch (a.dims()) {
+#define NDIM_CASE(NDIMS) \
+ case NDIMS: { \
+ static_cast<CHILD*>(this)->template Operate<NDIMS>(context, a, b, output); \
+ break; \
+ }
+
+ NDIM_CASE(1);
+ NDIM_CASE(2);
+ NDIM_CASE(3);
+ NDIM_CASE(4);
+ NDIM_CASE(5);
+ NDIM_CASE(6);
+ NDIM_CASE(7);
+ NDIM_CASE(8);
+#undef NDIM_CASE
+
+ default:
+ context->SetStatus(errors::OutOfRange(
+ "We only handle up to Tensor::dims() up to 8, not ", a.dims()));
+ break;
+ }
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
new file mode 100644
index 0000000000..366f00ae03
--- /dev/null
+++ b/tensorflow/core/framework/numeric_types.h
@@ -0,0 +1,15 @@
+#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
+
+#include <complex>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Single precision complex.
+typedef std::complex<float> complex64;
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
new file mode 100644
index 0000000000..15b7eab4da
--- /dev/null
+++ b/tensorflow/core/framework/op.cc
@@ -0,0 +1,135 @@
+#include "tensorflow/core/framework/op.h"
+
+#include <algorithm>
+#include <memory>
+#include <vector>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// OpRegistry -----------------------------------------------------------------
+
+OpRegistryInterface::~OpRegistryInterface() {}
+
+OpRegistry::OpRegistry() : initialized_(false) {}
+
+void OpRegistry::Register(std::function<OpDef(void)> func) {
+ mutex_lock lock(mu_);
+ if (initialized_) {
+ OpDef def = func();
+ TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: "
+ << SummarizeOpDef(def);
+ } else {
+ deferred_.push_back(func);
+ }
+}
+
+const OpDef* OpRegistry::LookUp(const string& op_type_name,
+ Status* status) const {
+ const OpDef* op_def = nullptr;
+ bool first_call = false;
+ { // Scope for lock.
+ mutex_lock lock(mu_);
+ first_call = CallDeferred();
+ op_def = gtl::FindWithDefault(registry_, op_type_name, nullptr);
+ // Note: Can't hold mu_ while calling Export() below.
+ }
+ if (first_call) {
+ TF_QCHECK_OK(ValidateKernelRegistrations(this));
+ }
+ if (op_def == nullptr) {
+ status->Update(
+ errors::NotFound("Op type not registered '", op_type_name, "'"));
+ static bool first = true;
+ if (first) {
+ OpList op_list;
+ Export(true, &op_list);
+ LOG(INFO) << "All registered Ops:";
+ for (const auto& op : op_list.op()) {
+ LOG(INFO) << SummarizeOpDef(op);
+ }
+ first = false;
+ }
+ }
+ return op_def;
+}
+
+void OpRegistry::Export(bool include_internal, OpList* ops) const {
+ mutex_lock lock(mu_);
+ CallDeferred();
+
+ std::vector<std::pair<string, const OpDef*>> sorted(registry_.begin(),
+ registry_.end());
+ std::sort(sorted.begin(), sorted.end());
+
+ auto out = ops->mutable_op();
+ out->Clear();
+ out->Reserve(sorted.size());
+
+ for (const auto& item : sorted) {
+ if (include_internal || !StringPiece(item.first).starts_with("_")) {
+ *out->Add() = *item.second;
+ }
+ }
+}
+
+string OpRegistry::DebugString(bool include_internal) const {
+ OpList op_list;
+ Export(include_internal, &op_list);
+ string ret;
+ for (const auto& op : op_list.op()) {
+ strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
+ }
+ return ret;
+}
+
+bool OpRegistry::CallDeferred() const {
+ if (initialized_) return false;
+ initialized_ = true;
+ for (const auto& fn : deferred_) {
+ OpDef def = fn();
+ TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: "
+ << SummarizeOpDef(def);
+ }
+ deferred_.clear();
+ return true;
+}
+
+Status OpRegistry::RegisterAlreadyLocked(const OpDef& def) const {
+ TF_RETURN_IF_ERROR(ValidateOpDef(def));
+
+ std::unique_ptr<OpDef> copy(new OpDef(def));
+ if (gtl::InsertIfNotPresent(&registry_, def.name(), copy.get())) {
+ copy.release(); // Ownership transferred to op_registry
+ return Status::OK();
+ } else {
+ return errors::AlreadyExists("Op with name ", def.name());
+ }
+}
+
+// static
+OpRegistry* OpRegistry::Global() {
+ static OpRegistry* global_op_registry = new OpRegistry;
+ return global_op_registry;
+}
+
+namespace register_op {
+OpDefBuilder& RegisterOp(StringPiece name) {
+ VLOG(1) << "RegisterOp: " << name;
+ OpDefBuilder* b = new OpDefBuilder(name);
+ OpRegistry::Global()->Register([b]() -> ::tensorflow::OpDef {
+ OpDef op_def;
+ TF_QCHECK_OK(b->Finalize(&op_def));
+ delete b;
+ return op_def;
+ });
+ return *b;
+}
+} // namespace register_op
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
new file mode 100644
index 0000000000..95ad32df35
--- /dev/null
+++ b/tensorflow/core/framework/op.h
@@ -0,0 +1,122 @@
+#ifndef TENSORFLOW_FRAMEWORK_OP_H_
+#define TENSORFLOW_FRAMEWORK_OP_H_
+
+#include <functional>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Users that want to look up an OpDef by type name should take an
+// OpRegistryInterface. Functions accepting a
+// (const) OpRegistryInterface* may call LookUp() from multiple threads.
+class OpRegistryInterface {
+ public:
+ virtual ~OpRegistryInterface();
+
+ // Returns nullptr and sets *status if no OpDef is registered under that
+ // name, otherwise returns the registered OpDef.
+ // Caller must not delete the returned pointer.
+ virtual const OpDef* LookUp(const string& op_type_name,
+ Status* status) const = 0;
+};
+
+// The standard implementation of OpRegistryInterface, along with a
+// global singleton used for registering OpDefs via the REGISTER
+// macros below. Thread-safe.
+//
+// Example registration:
+// OpRegistry::Global()->Register([]()->OpDef{
+// OpDef def;
+// // Populate def here.
+// return def;
+// });
+class OpRegistry : public OpRegistryInterface {
+ public:
+ OpRegistry();
+ ~OpRegistry() override {}
+
+ // Calls func() and registers the returned OpDef. Since Register()
+ // is normally called during program initialization (before main()),
+ // we defer calling func() until the first call to LookUp() or
+ // Export() (if one of those has already been called, func() is
+ // called immediately).
+ void Register(std::function<OpDef(void)> func);
+
+ const OpDef* LookUp(const string& op_type_name,
+ Status* status) const override;
+
+ // Fills *ops with all registered OpDefss (except those with names
+ // starting with '_' if include_internal == false).
+ void Export(bool include_internal, OpList* ops) const;
+
+ // Returns ASCII-format OpList for all registered OpDefs (except
+ // those with names starting with '_' if include_internal == false).
+ string DebugString(bool include_internal) const;
+
+ // A singleton available at startup.
+ static OpRegistry* Global();
+
+ private:
+ // Ensures that all the functions in deferred_ get called, their OpDef's
+ // registered, and returns with deferred_ empty. Returns true the first
+ // time it is called.
+ bool CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Add 'def' to the registry. On failure, or if there is already an
+ // OpDef with that name registered, returns a non-okay status.
+ Status RegisterAlreadyLocked(const OpDef& def) const
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ mutable mutex mu_;
+ // Functions in deferred_ may only be called with mu_ held.
+ mutable std::vector<std::function<OpDef(void)>> deferred_ GUARDED_BY(mu_);
+ mutable std::unordered_map<string, OpDef*> registry_ GUARDED_BY(mu_);
+ mutable bool initialized_ GUARDED_BY(mu_);
+};
+
+// Support for defining the OpDef (specifying the semantics of the Op and how
+// it should be created) and registering it in the OpRegistry::Global()
+// registry. Usage:
+//
+// REGISTER_OP("my_op_name")
+// .Attr("<name>:<type>")
+// .Attr("<name>:<type>=<default>")
+// .Input("<name>:<type-expr>")
+// .Input("<name>:Ref(<type-expr>)")
+// .Output("<name>:<type-expr>")
+// .Doc(R"(
+// <1-line summary>
+// <rest of the description (potentially many lines)>
+// <name-of-attr-input-or-output>: <description of name>
+// <name-of-attr-input-or-output>: <description of name;
+// if long, indent the description on subsequent lines>
+// )");
+//
+// Note: .Doc() should be last.
+// For details, see the OpDefBuilder class in op_def_builder.h.
+
+namespace register_op {
+// To call OpRegistry::Global()->Register(...), used by the
+// REGISTER_OP macro below.
+OpDefBuilder& RegisterOp(StringPiece name);
+} // namespace register_op
+
+#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
+#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
+#define REGISTER_OP_UNIQ(ctr, name) \
+ static ::tensorflow::OpDefBuilder& register_op##ctr TF_ATTRIBUTE_UNUSED = \
+ ::tensorflow::register_op::RegisterOp(name)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_H_
diff --git a/tensorflow/core/framework/op_def.proto b/tensorflow/core/framework/op_def.proto
new file mode 100644
index 0000000000..4a2e90b1b9
--- /dev/null
+++ b/tensorflow/core/framework/op_def.proto
@@ -0,0 +1,142 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/attr_value.proto";
+import "tensorflow/core/framework/types.proto";
+
+// Defines an operation. A NodeDef in a GraphDef specifies an Op by
+// using the "op" field which should match the name of a OpDef.
+message OpDef {
+ // Op names starting with an underscore are reserved for internal use.
+ // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
+ string name = 1;
+
+ // For describing inputs and outputs.
+ message ArgDef {
+ // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
+ string name = 1;
+
+ // Human readable description.
+ string description = 2;
+
+ // Describes the type of one or more tensors that are accepted/produced
+ // by this input/output arg. The only legal combinations are:
+ // * For a single tensor: either the "type" field is set or the
+ // "type_attr" field is set to the name of an attr with type "type".
+ // * For a sequence of tensors with the same type: the "number_attr"
+ // field will be set to the name of an attr with type "int", and
+ // either the "type" or "type_attr" field will be set as for
+ // single tensors.
+ // * For a sequence of tensors, the "type_list_attr" field will be set
+ // to the name of an attr with type "list(type)".
+ DataType type = 3;
+ string type_attr = 4; // if specified, attr must have type "type"
+ string number_attr = 5; // if specified, attr must have type "int"
+ // If specified, attr must have type "list(type)", and none of
+ // type, type_attr, and number_attr may be specified.
+ string type_list_attr = 6;
+
+ // For inputs: if true, the inputs are required to be refs.
+ // By default, inputs can be either refs or non-refs.
+ // For outputs: if true, outputs are refs, otherwise they are not.
+ bool is_ref = 16;
+ };
+
+ // Description of the input(s).
+ repeated ArgDef input_arg = 2;
+
+ // Description of the output(s).
+ repeated ArgDef output_arg = 3;
+
+ // Description of the graph-construction-time configuration of this
+ // Op. That is to say, this describes the attr fields that will
+ // be specified in the NodeDef.
+ message AttrDef {
+ // A descriptive name for the argument. May be used, e.g. by the
+ // Python client, as a keyword argument name, and so should match
+ // the regexp "[a-z][a-z0-9_]+".
+ string name = 1;
+
+ // One of the type names from attr_value.proto ("string", "list(string)",
+ // "int", etc.).
+ string type = 2;
+
+ // A reasonable default for this attribute if the user does not supply
+ // a value. If not specified, the user must supply a value.
+ AttrValue default_value = 3;
+
+ // Human-readable description.
+ string description = 4;
+
+ // TODO(josh11b): bool is_optional?
+
+ // --- Constraints ---
+ // These constraints are only in effect if specified. Default is no
+ // constraints.
+
+ // For type == "int", this is a minimum value. For "list(___)"
+ // types, this is the minimum length.
+ bool has_minimum = 5;
+ int64 minimum = 6;
+
+ // The set of allowed values. Has type that is the "list" version
+ // of the "type" field above (uses the ..._list, fields of AttrValue).
+ // If type == "type" or "list(type)" above, then the type_list field
+ // of allowed_values has the set of allowed DataTypes.
+ // If type == "string" or "list(string)", then the s_list field has
+ // the set of allowed strings.
+ AttrValue allowed_values = 7;
+ }
+ repeated AttrDef attr = 4;
+
+ // One-line human-readable description of what the Op does.
+ string summary = 5;
+
+ // Additional, longer human-readable description of what the Op does.
+ string description = 6;
+
+ // -------------------------------------------------------------------------
+ // Which optimizations this operation can participate in.
+
+ // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
+ bool is_commutative = 18;
+
+ // If is_aggregate is true, then this operation accepts N >= 2
+ // inputs and produces 1 output all of the same type. Should be
+ // associative and commutative, and produce output with the same
+ // shape as the input. The optimizer may replace an aggregate op
+ // taking input from multiple devices with a tree of aggregate ops
+ // that aggregate locally within each device (and possibly within
+ // groups of nearby devices) before communicating.
+ // TODO(josh11b): Implement that optimization.
+ bool is_aggregate = 16; // for things like add
+
+ // Other optimizations go here, like
+ // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.
+
+ // -------------------------------------------------------------------------
+ // Optimization constraints.
+
+ // By default Ops may be moved between devices. Stateful ops should
+ // either not be moved, or should only be moved if that state can also
+ // be moved (e.g. via some sort of save / restore).
+ // Stateful ops are guaranteed to never be optimized away by Common
+ // Subexpression Elimination (CSE).
+ bool is_stateful = 17; // for things like variables, queue
+
+ // -------------------------------------------------------------------------
+ // Non-standard options.
+
+ // By default, all inputs to an Op must be initialized Tensors. Ops
+ // that may initialize tensors for the first time should set this
+ // field to true, to allow the Op to take an uninitialized Tensor as
+ // input.
+ bool allows_uninitialized_input = 19; // for Assign, etc.
+};
+
+// A collection of OpDefs
+message OpList {
+ repeated OpDef op = 1;
+};
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
new file mode 100644
index 0000000000..7d7c07de4c
--- /dev/null
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -0,0 +1,447 @@
+#include "tensorflow/core/framework/op_def_builder.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+namespace {
+
+bool RE2Consume(StringPiece* sp, const char* pattern) {
+ RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
+ bool r = RE2::Consume(&base_sp, pattern);
+ *sp = FromRegexpStringPiece(base_sp);
+ return r;
+}
+
+bool RE2Consume(StringPiece* sp, const char* pattern, StringPiece* out) {
+ RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
+ RegexpStringPiece base_out;
+ bool r = RE2::Consume(&base_sp, pattern, &base_out);
+ *sp = FromRegexpStringPiece(base_sp);
+ *out = FromRegexpStringPiece(base_out);
+ return r;
+}
+
+bool RE2Consume(StringPiece* sp, const char* pattern, int64* out) {
+ RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
+ bool r = RE2::Consume(&base_sp, pattern, out);
+ *sp = FromRegexpStringPiece(base_sp);
+ return r;
+}
+
+string AttrError(StringPiece orig, const string& op_name) {
+ return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
+}
+
+#define VERIFY(expr, ...) \
+ do { \
+ if (!(expr)) { \
+ errors->push_back( \
+ strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \
+ return; \
+ } \
+ } while (false)
+
+void FinalizeAttr(StringPiece spec, OpDef* op_def,
+ std::vector<string>* errors) {
+ OpDef::AttrDef* attr = op_def->add_attr();
+ StringPiece orig(spec);
+
+ // Parse "<name>:" at the beginning.
+ StringPiece tmp_name;
+ VERIFY(RE2Consume(&spec, "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*", &tmp_name),
+ "Trouble parsing '<name>:'");
+ attr->set_name(tmp_name.data(), tmp_name.size());
+
+ // Read "<type>" or "list(<type>)".
+ bool is_list = RE2Consume(&spec, "list\\s*\\(\\s*");
+ string type;
+ if (spec.Consume("string")) {
+ type = "string";
+ } else if (spec.Consume("int")) {
+ type = "int";
+ } else if (spec.Consume("float")) {
+ type = "float";
+ } else if (spec.Consume("bool")) {
+ type = "bool";
+ } else if (spec.Consume("type")) {
+ type = "type";
+ } else if (spec.Consume("shape")) {
+ type = "shape";
+ } else if (spec.Consume("tensor")) {
+ type = "tensor";
+ } else if (spec.Consume("func")) {
+ type = "func";
+ } else if (spec.Consume("numbertype") || spec.Consume("numerictype")) {
+ type = "type";
+ AttrValue* allowed = attr->mutable_allowed_values();
+ for (DataType dt : NumberTypes()) {
+ allowed->mutable_list()->add_type(dt);
+ }
+ } else if (spec.Consume("quantizedtype")) {
+ type = "type";
+ AttrValue* allowed = attr->mutable_allowed_values();
+ for (DataType dt : QuantizedTypes()) {
+ allowed->mutable_list()->add_type(dt);
+ }
+ } else if (spec.Consume("realnumbertype") ||
+ spec.Consume("realnumerictype")) {
+ type = "type";
+ AttrValue* allowed = attr->mutable_allowed_values();
+ for (DataType dt : RealNumberTypes()) {
+ allowed->mutable_list()->add_type(dt);
+ }
+ } else if (spec.Consume("{")) {
+ // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
+ RE2Consume(&spec, "\\s*");
+ AttrValue* allowed = attr->mutable_allowed_values();
+ if (spec.starts_with("\"") || spec.starts_with("'")) {
+ type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
+ while (true) {
+ StringPiece escaped_string;
+ VERIFY((RE2Consume(&spec, R"xx("((?:[^"\\]|\\.)*)"\s*)xx",
+ &escaped_string) ||
+ RE2Consume(&spec, R"xx('((?:[^'\\]|\\.)*)'\s*)xx",
+ &escaped_string)),
+ "Trouble parsing allowed string at '", spec, "'");
+ string unescaped;
+ string error;
+ VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
+ "Trouble unescaping \"", escaped_string, "\", got error: ",
+ error);
+ allowed->mutable_list()->add_s(unescaped);
+ if (spec.Consume(",")) {
+ RE2Consume(&spec, "\\s*");
+ if (spec.Consume("}")) break; // Allow ending with ", }".
+ } else {
+ VERIFY(spec.Consume("}"),
+ "Expected , or } after strings in list, not: '", spec, "'");
+ break;
+ }
+ }
+ } else { // "{ int32, float, bool }"
+ type = "type";
+ while (true) {
+ StringPiece type_string;
+ VERIFY(RE2Consume(&spec, "([a-z0-9]+)\\s*", &type_string),
+ "Trouble parsing type string at '", spec, "'");
+ DataType dt;
+ VERIFY(DataTypeFromString(type_string, &dt),
+ "Unrecognized type string '", type_string, "'");
+ allowed->mutable_list()->add_type(dt);
+ if (spec.Consume(",")) {
+ RE2Consume(&spec, "\\s*");
+ if (spec.Consume("}")) break; // Allow ending with ", }".
+ } else {
+ VERIFY(spec.Consume("}"),
+ "Expected , or } after types in list, not: '", spec, "'");
+ break;
+ }
+ }
+ }
+ } else {
+ VERIFY(false, "Trouble parsing type string at '", spec, "'");
+ }
+ RE2Consume(&spec, "\\s*");
+
+ // Write the type into *attr.
+ if (is_list) {
+ VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'");
+ RE2Consume(&spec, "\\s*");
+ attr->set_type(strings::StrCat("list(", type, ")"));
+ } else {
+ attr->set_type(type);
+ }
+
+ // Read optional minimum constraint at the end.
+ if ((is_list || type == "int") && spec.Consume(">=")) {
+ int64 min_limit = -999;
+ VERIFY(RE2Consume(&spec, "\\s*(-?\\d+)\\s*", &min_limit),
+ "Could not parse integer lower limit after '>=', found '", spec,
+ "' instead");
+ attr->set_has_minimum(true);
+ attr->set_minimum(min_limit);
+ }
+
+ // Parse default value, if present.
+ if (spec.Consume("=")) {
+ RE2Consume(&spec, "\\s*");
+ VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
+ "Could not parse default value '", spec, "'");
+ } else {
+ VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
+ }
+}
+
+#undef VERIFY
+
+string InOutError(bool is_output, StringPiece orig, const string& op_name) {
+ return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig,
+ "\") for Op ", op_name);
+}
+
+#define VERIFY(expr, ...) \
+ do { \
+ if (!(expr)) { \
+ errors->push_back(strings::StrCat( \
+ __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \
+ return; \
+ } \
+ } while (false)
+
+void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
+ std::vector<string>* errors) {
+ OpDef::ArgDef* arg =
+ is_output ? op_def->add_output_arg() : op_def->add_input_arg();
+
+ StringPiece orig(spec);
+
+ // Parse "<name>:" at the beginning.
+ StringPiece tmp_name;
+ VERIFY(RE2Consume(&spec, "([a-z][a-z0-9_]*)\\s*:\\s*", &tmp_name),
+ "Trouble parsing 'name:'");
+ arg->set_name(tmp_name.data(), tmp_name.size());
+
+ // Detect "Ref(...)".
+ if (RE2Consume(&spec, "Ref\\s*\\(\\s*")) {
+ arg->set_is_ref(true);
+ }
+
+ { // Parse "<name|type>" or "<name>*<name|type>".
+ StringPiece first, second, type_or_attr;
+ VERIFY(RE2Consume(&spec, "([a-zA-Z][a-zA-Z0-9_]*)\\s*", &first),
+ "Trouble parsing either a type or an attr name at '", spec, "'");
+ if (RE2Consume(&spec, "[*]\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*", &second)) {
+ arg->set_number_attr(first.data(), first.size());
+ type_or_attr = second;
+ } else {
+ type_or_attr = first;
+ }
+ DataType dt;
+ if (DataTypeFromString(type_or_attr, &dt)) {
+ arg->set_type(dt);
+ } else {
+ const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def);
+ VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'");
+ if (attr->type() == "type") {
+ arg->set_type_attr(type_or_attr.data(), type_or_attr.size());
+ } else {
+ VERIFY(attr->type() == "list(type)", "Reference to attr '",
+ type_or_attr, "' with type ", attr->type(),
+ " that isn't type or list(type)");
+ arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size());
+ }
+ }
+ }
+
+ // Closing ) for Ref(.
+ if (arg->is_ref()) {
+ VERIFY(RE2Consume(&spec, "\\)\\s*"),
+ "Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
+ }
+
+ // Should not have anything else.
+ VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
+
+ // Int attrs that are the length of an input or output get a default
+ // minimum of 1.
+ if (!arg->number_attr().empty()) {
+ OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def);
+ if (attr != nullptr && !attr->has_minimum()) {
+ attr->set_has_minimum(true);
+ attr->set_minimum(1);
+ }
+ } else if (!arg->type_list_attr().empty()) {
+ // If an input or output has type specified by a list(type) attr,
+ // it gets a default minimum of 1 as well.
+ OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def);
+ if (attr != nullptr && attr->type() == "list(type)" &&
+ !attr->has_minimum()) {
+ attr->set_has_minimum(true);
+ attr->set_minimum(1);
+ }
+ }
+}
+
+#undef VERIFY
+
+int num_leading_spaces(StringPiece s) {
+ size_t i = 0;
+ while (i < s.size() && s[i] == ' ') {
+ ++i;
+ }
+ return i;
+}
+
+void FinalizeDoc(const string& text, OpDef* op_def,
+ std::vector<string>* errors) {
+ std::vector<string> lines = str_util::Split(text, '\n');
+
+ // Remove trailing spaces.
+ for (string& line : lines) {
+ str_util::StripTrailingWhitespace(&line);
+ }
+
+ // First non-blank line -> summary.
+ int l = 0;
+ while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
+ if (static_cast<size_t>(l) < lines.size()) {
+ op_def->set_summary(lines[l]);
+ ++l;
+ }
+ while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
+
+ // Lines until we see name: -> description.
+ int start_l = l;
+ while (static_cast<size_t>(l) < lines.size() &&
+ !RE2::PartialMatch(lines[l], "^[a-zA-Z][a-zA-Z0-9_]*\\s*:")) {
+ ++l;
+ }
+ int end_l = l;
+ // Trim trailing blank lines from the description.
+ while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
+ string desc = str_util::Join(
+ gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
+ if (!desc.empty()) op_def->set_description(desc);
+
+ // name: description
+ // possibly continued on the next line
+ // if so, we remove the minimum indent
+ StringPiece name;
+ std::vector<StringPiece> description;
+ while (static_cast<size_t>(l) < lines.size()) {
+ description.clear();
+ description.push_back(lines[l]);
+ RE2Consume(&description.back(), "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*", &name);
+ ++l;
+ while (static_cast<size_t>(l) < lines.size() &&
+ !RE2::PartialMatch(lines[l], "^[a-zA-Z][a-zA-Z0-9_]*\\s*:")) {
+ description.push_back(lines[l]);
+ ++l;
+ }
+ // Remove any trailing blank lines.
+ while (!description.empty() && description.back().empty()) {
+ description.pop_back();
+ }
+ // Compute the minimum indent of all lines after the first.
+ int min_indent = -1;
+ for (size_t i = 1; i < description.size(); ++i) {
+ if (!description[i].empty()) {
+ int indent = num_leading_spaces(description[i]);
+ if (min_indent < 0 || indent < min_indent) min_indent = indent;
+ }
+ }
+ // Remove min_indent spaces from all lines after the first.
+ for (size_t i = 1; i < description.size(); ++i) {
+ if (!description[i].empty()) description[i].remove_prefix(min_indent);
+ }
+ // Concatenate lines into a single string.
+ const string complete(str_util::Join(description, "\n"));
+
+ // Find name.
+ bool found = false;
+ for (int i = 0; !found && i < op_def->input_arg_size(); ++i) {
+ if (op_def->input_arg(i).name() == name) {
+ op_def->mutable_input_arg(i)->set_description(complete);
+ found = true;
+ }
+ }
+ for (int i = 0; !found && i < op_def->output_arg_size(); ++i) {
+ if (op_def->output_arg(i).name() == name) {
+ op_def->mutable_output_arg(i)->set_description(complete);
+ found = true;
+ }
+ }
+ for (int i = 0; !found && i < op_def->attr_size(); ++i) {
+ if (op_def->attr(i).name() == name) {
+ op_def->mutable_attr(i)->set_description(complete);
+ found = true;
+ }
+ }
+ if (!found) {
+ errors->push_back(
+ strings::StrCat("No matching input/output/attr for name '", name,
+ "' from Doc() for Op ", op_def->name()));
+ return;
+ }
+ }
+}
+
+} // namespace
+
+OpDefBuilder::OpDefBuilder(StringPiece op_name) {
+ op_def_.set_name(op_name.ToString()); // NOLINT
+}
+
+OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
+ attrs_.emplace_back(spec.data(), spec.size());
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::Input(StringPiece spec) {
+ inputs_.emplace_back(spec.data(), spec.size());
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::Output(StringPiece spec) {
+ outputs_.emplace_back(spec.data(), spec.size());
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
+ if (!doc_.empty()) {
+ errors_.push_back(
+ strings::StrCat("Extra call to Doc() for Op ", op_def_.name()));
+ } else {
+ doc_.assign(text.data(), text.size());
+ }
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::SetIsCommutative() {
+ op_def_.set_is_commutative(true);
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::SetIsAggregate() {
+ op_def_.set_is_aggregate(true);
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::SetIsStateful() {
+ op_def_.set_is_stateful(true);
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
+ op_def_.set_allows_uninitialized_input(true);
+ return *this;
+}
+
+Status OpDefBuilder::Finalize(OpDef* op_def) const {
+ std::vector<string> errors = errors_;
+ *op_def = op_def_;
+
+ for (StringPiece attr : attrs_) {
+ FinalizeAttr(attr, op_def, &errors);
+ }
+ for (StringPiece input : inputs_) {
+ FinalizeInputOrOutput(input, false, op_def, &errors);
+ }
+ for (StringPiece output : outputs_) {
+ FinalizeInputOrOutput(output, true, op_def, &errors);
+ }
+ FinalizeDoc(doc_, op_def, &errors);
+
+ if (errors.empty()) return Status::OK();
+ return errors::InvalidArgument(str_util::Join(errors, "\n"));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
new file mode 100644
index 0000000000..017338c508
--- /dev/null
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -0,0 +1,109 @@
+// Class and associated machinery for specifying an Op's OpDef for Op
+// registration.
+
+#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Builder class passed to the REGISTER_OP() macro.
+class OpDefBuilder {
+ public:
+ // Constructs an OpDef with just the name field set.
+ explicit OpDefBuilder(StringPiece op_name);
+
+ // Adds an attr to this OpDefBuilder (and returns *this). The spec has
+ // format "<name>:<type>" or "<name>:<type>=<default>"
+ // where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]*
+ // (by convention only using capital letters for attrs that can be inferred)
+ // <type> can be:
+ // "string", "int", "float", "bool", "type", "shape", or "tensor"
+ // "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}"
+ // (meaning "type" with a restriction on valid values)
+ // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
+ // (meaning "string" with a restriction on valid values)
+ // "list(string)", ..., "list(tensor)", "list(numbertype)", ...
+ // (meaning lists of the above types)
+ // "int >= 2" (meaning "int" with a restriction on valid values)
+ // "list(string) >= 2", "list(int) >= 2"
+ // (meaning "list(string)" / "list(int)" with length at least 2)
+ // <default>, if included, should use the Proto text format
+ // of <type>. For lists use [a, b, c] format.
+ //
+ // Note that any attr specifying the length of an input or output will
+ // get a default minimum of 1 unless the >= # syntax is used.
+ //
+ // TODO(josh11b): Perhaps support restrictions and defaults as optional
+ // extra arguments to Attr() instead of encoding them in the spec string.
+ // TODO(josh11b): Would like to have better dtype handling for tensor attrs:
+ // * Ability to say the type of an input/output matches the type of
+ // the tensor.
+ // * Ability to restrict the type of the tensor like the existing
+ // restrictions for type attrs.
+ // Perhaps by linking the type of the tensor to a type attr?
+ OpDefBuilder& Attr(StringPiece spec);
+
+ // Adds an input or ouput to this OpDefBuilder (and returns *this).
+ // The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
+ // where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
+ // * For a single tensor: <type>
+ // * For a sequence of tensors with the same type: <number>*<type>
+ // * For a sequence of tensors with different types: <type-list>
+ // Where:
+ // <type> is either one of "float", "int32", "string", ...
+ // or the name of an attr (see above) with type "type".
+ // <number> is the name of an attr with type "int".
+ // <type-list> is the name of an attr with type "list(type)".
+ // TODO(josh11b): Indicate Ref() via an optional argument instead of
+ // in the spec?
+ // TODO(josh11b): SparseInput() and SparseOutput() matching the Python
+ // handling?
+ OpDefBuilder& Input(StringPiece spec);
+ OpDefBuilder& Output(StringPiece spec);
+
+ // Turns on the indicated boolean flag in this OpDefBuilder (and
+ // returns *this).
+ OpDefBuilder& SetIsCommutative();
+ OpDefBuilder& SetIsAggregate();
+ OpDefBuilder& SetIsStateful();
+ OpDefBuilder& SetAllowsUninitializedInput();
+
+ // Adds docs to this OpDefBuilder (and returns *this).
+ // Docs have the format:
+ // <1-line summary>
+ // <rest of the description>
+ // <name>: <description of name>
+ // <name>: <description of name>
+ // <if long, indent the description on subsequent lines>
+ // Where <name> is the name of an attr, input, or output. Please
+ // wrap docs at 72 columns so that it may be indented in the
+ // generated output. For tensor inputs or outputs (not attrs), you
+ // may start the description with an "=" (like name:= <description>)
+ // to suppress the automatically-generated type documentation in
+ // generated output.
+ OpDefBuilder& Doc(StringPiece text);
+
+ // Sets *op_def to the requested OpDef, or returns an error.
+ // Must be called after all of the above methods.
+ // Note that OpDefBuilder only reports parsing errors. You should also
+ // call ValidateOpDef() to detect other problems.
+ Status Finalize(OpDef* op_def) const;
+
+ private:
+ OpDef op_def_;
+ std::vector<string> attrs_;
+ std::vector<string> inputs_;
+ std::vector<string> outputs_;
+ string doc_;
+ std::vector<string> errors_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc
new file mode 100644
index 0000000000..e53bad7075
--- /dev/null
+++ b/tensorflow/core/framework/op_def_builder_test.cc
@@ -0,0 +1,519 @@
+#include "tensorflow/core/framework/op_def_builder.h"
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+static void CanonicalizeAttrTypeListOrder(OpDef* def) {
+ for (int i = 0; i < def->attr_size(); i++) {
+ AttrValue* a = def->mutable_attr(i)->mutable_allowed_values();
+ std::sort(a->mutable_list()->mutable_type()->begin(),
+ a->mutable_list()->mutable_type()->end());
+ }
+}
+
+class OpDefBuilderTest : public ::testing::Test {
+ protected:
+ OpDefBuilder b() { return OpDefBuilder("Test"); }
+
+ void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto) {
+ OpDef op_def;
+ Status status = builder.Finalize(&op_def);
+ EXPECT_OK(status);
+ if (status.ok()) {
+ OpDef expected;
+ protobuf::TextFormat::ParseFromString(
+ strings::StrCat("name: 'Test' ", proto), &expected);
+ // Allow different orderings
+ CanonicalizeAttrTypeListOrder(&op_def);
+ CanonicalizeAttrTypeListOrder(&expected);
+ EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString());
+ }
+ }
+
+ void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) {
+ OpDef op_def;
+ Status status = builder.Finalize(&op_def);
+ EXPECT_OK(status);
+ if (status.ok()) {
+ OpDef expected;
+ protobuf::TextFormat::ParseFromString(
+ strings::StrCat("name: 'Test' ", proto), &expected);
+ EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString());
+ }
+ }
+
+ void ExpectFailure(const OpDefBuilder& builder, string error) {
+ OpDef op_def;
+ Status status = builder.Finalize(&op_def);
+ EXPECT_FALSE(status.ok());
+ if (!status.ok()) {
+ EXPECT_EQ(status.error_message(), error);
+ }
+ }
+};
+
+TEST_F(OpDefBuilderTest, Attr) {
+ ExpectSuccess(b().Attr("a:string"), "attr: { name: 'a' type: 'string' }");
+ ExpectSuccess(b().Attr("A: int"), "attr: { name: 'A' type: 'int' }");
+ ExpectSuccess(b().Attr("a1 :float"), "attr: { name: 'a1' type: 'float' }");
+ ExpectSuccess(b().Attr("a_a : bool"), "attr: { name: 'a_a' type: 'bool' }");
+ ExpectSuccess(b().Attr("aB : type"), "attr: { name: 'aB' type: 'type' }");
+ ExpectSuccess(b().Attr("aB_3\t: shape"),
+ "attr: { name: 'aB_3' type: 'shape' }");
+ ExpectSuccess(b().Attr("t: tensor"), "attr: { name: 't' type: 'tensor' }");
+ ExpectSuccess(b().Attr("XYZ\t:\tlist(type)"),
+ "attr: { name: 'XYZ' type: 'list(type)' }");
+ ExpectSuccess(b().Attr("f: func"), "attr { name: 'f' type: 'func'}");
+}
+
+TEST_F(OpDefBuilderTest, AttrFailure) {
+ ExpectFailure(
+ b().Attr("_:string"),
+ "Trouble parsing '<name>:' from Attr(\"_:string\") for Op Test");
+ ExpectFailure(
+ b().Attr("9:string"),
+ "Trouble parsing '<name>:' from Attr(\"9:string\") for Op Test");
+ ExpectFailure(b().Attr(":string"),
+ "Trouble parsing '<name>:' from Attr(\":string\") for Op Test");
+ ExpectFailure(b().Attr("string"),
+ "Trouble parsing '<name>:' from Attr(\"string\") for Op Test");
+ ExpectFailure(b().Attr("a:invalid"),
+ "Trouble parsing type string at 'invalid' from "
+ "Attr(\"a:invalid\") for Op Test");
+ ExpectFailure(
+ b().Attr("b:"),
+ "Trouble parsing type string at '' from Attr(\"b:\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
+ // Types with restrictions.
+ ExpectSuccess(b().Attr("a:numbertype"),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
+ "DT_INT8, DT_COMPLEX64, DT_QINT8, DT_QUINT8, DT_QINT32] } } }");
+ ExpectSuccess(b().Attr("a:realnumbertype"),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
+ "DT_INT8] } } }");
+ ExpectSuccess(b().Attr("a:quantizedtype"),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_QINT8, DT_QUINT8, DT_QINT32] } } }");
+ ExpectSuccess(b().Attr("a:{string,int32}"),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_STRING, DT_INT32] } } }");
+ ExpectSuccess(b().Attr("a: { float , complex64 } "),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_FLOAT, DT_COMPLEX64] } } }");
+ ExpectSuccess(b().Attr("a: {float, complex64,} "),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_FLOAT, DT_COMPLEX64] } }");
+ ExpectSuccess(b().Attr(R"(a: { "X", "yz" })"),
+ "attr: { name: 'a' type: 'string' allowed_values { list { s: "
+ "['X', 'yz'] } } }");
+ ExpectSuccess(b().Attr(R"(a: { "X", "yz", })"),
+ "attr: { name: 'a' type: 'string' allowed_values { list { s: "
+ "['X', 'yz'] } } }");
+ ExpectSuccess(
+ b().Attr("i: int >= -5"),
+ "attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }");
+}
+
+TEST_F(OpDefBuilderTest, AttrRestrictionFailure) {
+ ExpectFailure(
+ b().Attr("a:{}"),
+ "Trouble parsing type string at '}' from Attr(\"a:{}\") for Op Test");
+ ExpectFailure(
+ b().Attr("a:{,}"),
+ "Trouble parsing type string at ',}' from Attr(\"a:{,}\") for Op Test");
+ ExpectFailure(b().Attr("a:{invalid}"),
+ "Unrecognized type string 'invalid' from Attr(\"a:{invalid}\") "
+ "for Op Test");
+ ExpectFailure(b().Attr("a:{\"str\", float}"),
+ "Trouble parsing allowed string at 'float}' from "
+ "Attr(\"a:{\"str\", float}\") for Op Test");
+ ExpectFailure(b().Attr("a:{ float, \"str\" }"),
+ "Trouble parsing type string at '\"str\" }' from Attr(\"a:{ "
+ "float, \"str\" }\") for Op Test");
+ ExpectFailure(b().Attr("a:{float,,string}"),
+ "Trouble parsing type string at ',string}' from "
+ "Attr(\"a:{float,,string}\") for Op Test");
+ ExpectFailure(b().Attr("a:{float,,}"),
+ "Trouble parsing type string at ',}' from "
+ "Attr(\"a:{float,,}\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
+ ExpectSuccess(
+ b().Attr("a:list(realnumbertype)"),
+ "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
+ "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
+ "DT_INT8] } } }");
+ ExpectSuccess(
+ b().Attr("a:list(quantizedtype)"),
+ "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
+ "[DT_QINT8, DT_QUINT8, DT_QINT32] } } }");
+ ExpectSuccess(
+ b().Attr("a: list({float, string, bool})"),
+ "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
+ "[DT_FLOAT, DT_STRING, DT_BOOL] } } }");
+ ExpectSuccess(
+ b().Attr(R"(a: list({ "one fish", "two fish" }))"),
+ "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
+ "['one fish', 'two fish'] } } }");
+ ExpectSuccess(
+ b().Attr(R"(a: list({ 'red fish', 'blue fish' }))"),
+ "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
+ "['red fish', 'blue fish'] } } }");
+ ExpectSuccess(
+ b().Attr(R"(a: list({ "single' ", 'double"' }))"),
+ "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
+ "[\"single' \", 'double\"'] } } }");
+ ExpectSuccess(
+ b().Attr(R"(a: list({ 'escape\'\n', "from\\\"NY" }))"),
+ "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
+ "[\"escape'\\n\", 'from\\\\\"NY'] } } }");
+}
+
+TEST_F(OpDefBuilderTest, AttrListWithMinLength) {
+ ExpectSuccess(
+ b().Attr("i: list(bool) >= 4"),
+ "attr: { name: 'i' type: 'list(bool)' has_minimum: true minimum: 4 }");
+}
+
+TEST_F(OpDefBuilderTest, AttrWithDefaults) {
+ ExpectSuccess(b().Attr(R"(a:string="foo")"),
+ "attr: { name: 'a' type: 'string' default_value { s:'foo' } }");
+ ExpectSuccess(b().Attr(R"(a:string='foo')"),
+ "attr: { name: 'a' type: 'string' default_value { s:'foo' } }");
+ ExpectSuccess(b().Attr("a:float = 1.25"),
+ "attr: { name: 'a' type: 'float' default_value { f: 1.25 } }");
+ ExpectSuccess(b().Attr("a:tensor = { dtype: DT_INT32 int_val: 5 }"),
+ "attr: { name: 'a' type: 'tensor' default_value { tensor {"
+ " dtype: DT_INT32 int_val: 5 } } }");
+ ExpectSuccess(b().Attr("a:shape = { dim { size: 3 } dim { size: 4 } }"),
+ "attr: { name: 'a' type: 'shape' default_value { shape {"
+ " dim { size: 3 } dim { size: 4 } } } }");
+}
+
+TEST_F(OpDefBuilderTest, AttrFailedDefaults) {
+ ExpectFailure(b().Attr(R"(a:int="foo")"),
+ "Could not parse default value '\"foo\"' from "
+ "Attr(\"a:int=\"foo\"\") for Op Test");
+ ExpectFailure(b().Attr("a:float = [1.25]"),
+ "Could not parse default value '[1.25]' from Attr(\"a:float = "
+ "[1.25]\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, AttrListWithDefaults) {
+ ExpectSuccess(b().Attr(R"(a:list(string)=["foo", "bar"])"),
+ "attr: { name: 'a' type: 'list(string)' "
+ "default_value { list { s: ['foo', 'bar'] } } }");
+ ExpectSuccess(b().Attr("a:list(bool)=[true, false, true]"),
+ "attr: { name: 'a' type: 'list(bool)' "
+ "default_value { list { b: [true, false, true] } } }");
+ ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"),
+ "attr: { name: 'a' type: 'list(int)' "
+ "default_value { list { i: [0, -1, 2, -4, 8] } } }");
+}
+
+TEST_F(OpDefBuilderTest, AttrFailedListDefaults) {
+ ExpectFailure(b().Attr(R"(a:list(int)=["foo"])"),
+ "Could not parse default value '[\"foo\"]' from "
+ "Attr(\"a:list(int)=[\"foo\"]\") for Op Test");
+ ExpectFailure(b().Attr(R"(a:list(int)=[7, "foo"])"),
+ "Could not parse default value '[7, \"foo\"]' from "
+ "Attr(\"a:list(int)=[7, \"foo\"]\") for Op Test");
+ ExpectFailure(b().Attr("a:list(float) = [[1.25]]"),
+ "Could not parse default value '[[1.25]]' from "
+ "Attr(\"a:list(float) = [[1.25]]\") for Op Test");
+ ExpectFailure(b().Attr("a:list(float) = 1.25"),
+ "Could not parse default value '1.25' from "
+ "Attr(\"a:list(float) = 1.25\") for Op Test");
+ ExpectFailure(b().Attr(R"(a:list(string)='foo')"),
+ "Could not parse default value ''foo'' from "
+ "Attr(\"a:list(string)='foo'\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, InputOutput) {
+ ExpectSuccess(b().Input("a: int32"),
+ "input_arg: { name: 'a' type: DT_INT32 }");
+ ExpectSuccess(b().Output("b: string"),
+ "output_arg: { name: 'b' type: DT_STRING }");
+ ExpectSuccess(b().Input("c: float "),
+ "input_arg: { name: 'c' type: DT_FLOAT }");
+ ExpectSuccess(b().Output("d: Ref(bool)"),
+ "output_arg: { name: 'd' type: DT_BOOL is_ref: true }");
+ ExpectOrdered(b().Input("a: bool")
+ .Output("c: complex64")
+ .Input("b: int64")
+ .Output("d: string"),
+ "input_arg: { name: 'a' type: DT_BOOL } "
+ "input_arg: { name: 'b' type: DT_INT64 } "
+ "output_arg: { name: 'c' type: DT_COMPLEX64 } "
+ "output_arg: { name: 'd' type: DT_STRING }");
+}
+
+TEST_F(OpDefBuilderTest, PolymorphicInputOutput) {
+ ExpectSuccess(b().Input("a: foo").Attr("foo: type"),
+ "input_arg: { name: 'a' type_attr: 'foo' } "
+ "attr: { name: 'foo' type: 'type' }");
+ ExpectSuccess(b().Output("a: foo").Attr("foo: { bool, int32 }"),
+ "output_arg: { name: 'a' type_attr: 'foo' } "
+ "attr: { name: 'foo' type: 'type' "
+ "allowed_values: { list { type: [DT_BOOL, DT_INT32] } } }");
+}
+
+TEST_F(OpDefBuilderTest, InputOutputListSameType) {
+ ExpectSuccess(b().Input("a: n * int32").Attr("n: int"),
+ "input_arg: { name: 'a' number_attr: 'n' type: DT_INT32 } "
+ "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 }");
+ // Polymorphic case:
+ ExpectSuccess(b().Output("b: n * foo").Attr("n: int").Attr("foo: type"),
+ "output_arg: { name: 'b' number_attr: 'n' type_attr: 'foo' } "
+ "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 } "
+ "attr: { name: 'foo' type: 'type' }");
+}
+
+TEST_F(OpDefBuilderTest, InputOutputListAnyType) {
+ ExpectSuccess(
+ b().Input("c: foo").Attr("foo: list(type)"),
+ "input_arg: { name: 'c' type_list_attr: 'foo' } "
+ "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 }");
+ ExpectSuccess(
+ b().Output("c: foo").Attr("foo: list({string, float})"),
+ "output_arg: { name: 'c' type_list_attr: 'foo' } "
+ "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 "
+ "allowed_values: { list { type: [DT_STRING, DT_FLOAT] } } }");
+}
+
+TEST_F(OpDefBuilderTest, InputOutputFailure) {
+ ExpectFailure(b().Input("9: int32"),
+ "Trouble parsing 'name:' from Input(\"9: int32\") for Op Test");
+ ExpectFailure(
+ b().Output("_: int32"),
+ "Trouble parsing 'name:' from Output(\"_: int32\") for Op Test");
+ ExpectFailure(b().Input(": int32"),
+ "Trouble parsing 'name:' from Input(\": int32\") for Op Test");
+ ExpectFailure(b().Output("int32"),
+ "Trouble parsing 'name:' from Output(\"int32\") for Op Test");
+ ExpectFailure(
+ b().Input("CAPS: int32"),
+ "Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test");
+ ExpectFailure(b().Input("a: _"),
+ "Trouble parsing either a type or an attr name at '_' from "
+ "Input(\"a: _\") for Op Test");
+ ExpectFailure(b().Input("a: 9"),
+ "Trouble parsing either a type or an attr name at '9' from "
+ "Input(\"a: 9\") for Op Test");
+ ExpectFailure(b().Input("a: 9 * int32"),
+ "Trouble parsing either a type or an attr name at '9 * int32' "
+ "from Input(\"a: 9 * int32\") for Op Test");
+ ExpectFailure(
+ b().Input("a: x * _").Attr("x: type"),
+ "Extra '* _' unparsed at the end from Input(\"a: x * _\") for Op Test");
+ ExpectFailure(b().Input("a: x * y extra").Attr("x: int").Attr("y: type"),
+ "Extra 'extra' unparsed at the end from Input(\"a: x * y "
+ "extra\") for Op Test");
+ ExpectFailure(b().Input("a: Ref(int32"),
+ "Did not find closing ')' for 'Ref(', instead found: '' from "
+ "Input(\"a: Ref(int32\") for Op Test");
+ ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"),
+ "Did not find closing ')' for 'Ref(', instead found: 'y' from "
+ "Input(\"a: Ref(x y\") for Op Test");
+ ExpectFailure(
+ b().Input("a: x"),
+ "Reference to unknown attr 'x' from Input(\"a: x\") for Op Test");
+ ExpectFailure(
+ b().Input("a: x * y").Attr("x: int"),
+ "Reference to unknown attr 'y' from Input(\"a: x * y\") for Op Test");
+ ExpectFailure(b().Input("a: x").Attr("x: int"),
+ "Reference to attr 'x' with type int that isn't type or "
+ "list(type) from Input(\"a: x\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, Set) {
+ ExpectSuccess(b().SetIsStateful(), "is_stateful: true");
+ ExpectSuccess(b().SetIsCommutative().SetIsAggregate(),
+ "is_commutative: true is_aggregate: true");
+}
+
+TEST_F(OpDefBuilderTest, DocUnpackSparseFeatures) {
+ ExpectOrdered(b().Input("sf: string")
+ .Output("indices: int32")
+ .Output("ids: int64")
+ .Output("weights: float")
+ .Doc(R"doc(
+Converts a vector of strings with dist_belief::SparseFeatures to tensors.
+
+Note that indices, ids and weights are vectors of the same size and have
+one-to-one correspondence between their elements. ids and weights are each
+obtained by sequentially concatenating sf[i].id and sf[i].weight, for i in
+1...size(sf). Note that if sf[i].weight is not set, the default value for the
+weight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were
+extracted from sf[i], then index[j] is set to i.
+
+sf: vector of string, where each element is the string encoding of
+ SparseFeatures proto.
+indices: vector of indices inside sf
+ids: vector of id extracted from the SparseFeatures proto.
+weights: vector of weight extracted from the SparseFeatures proto.
+)doc"),
+ R"proto(
+input_arg {
+ name: "sf"
+ description: "vector of string, where each element is the string encoding of\nSparseFeatures proto."
+ type: DT_STRING
+}
+output_arg {
+ name: "indices"
+ description: "vector of indices inside sf"
+ type: DT_INT32
+}
+output_arg {
+ name: "ids"
+ description: "vector of id extracted from the SparseFeatures proto."
+ type: DT_INT64
+}
+output_arg {
+ name: "weights"
+ description: "vector of weight extracted from the SparseFeatures proto."
+ type: DT_FLOAT
+}
+summary: "Converts a vector of strings with dist_belief::SparseFeatures to tensors."
+description: "Note that indices, ids and weights are vectors of the same size and have\none-to-one correspondence between their elements. ids and weights are each\nobtained by sequentially concatenating sf[i].id and sf[i].weight, for i in\n1...size(sf). Note that if sf[i].weight is not set, the default value for the\nweight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were\nextracted from sf[i], then index[j] is set to i."
+)proto");
+}
+
+TEST_F(OpDefBuilderTest, DocConcat) {
+ ExpectOrdered(b().Input("concat_dim: int32")
+ .Input("values: num_values * dtype")
+ .Output("output: dtype")
+ .Attr("dtype: type")
+ .Attr("num_values: int >= 2")
+ .Doc(R"doc(
+Concatenate N Tensors along one dimension.
+
+concat_dim: The (scalar) dimension along which to concatenate. Must be
+ in the range [0, rank(values...)).
+values: The N Tensors to concatenate. Their ranks and types must match,
+ and their sizes must match in all dimensions except concat_dim.
+output: A Tensor with the concatenation of values stacked along the
+ concat_dim dimension. This Tensor's shape matches the Tensors in
+ values, except in concat_dim where it has the sum of the sizes.
+)doc"),
+ R"proto(
+input_arg {
+ name: "concat_dim"
+ description: "The (scalar) dimension along which to concatenate. Must be\nin the range [0, rank(values...))."
+ type: DT_INT32
+}
+input_arg {
+ name: "values"
+ description: "The N Tensors to concatenate. Their ranks and types must match,\nand their sizes must match in all dimensions except concat_dim."
+ type_attr: "dtype"
+ number_attr: "num_values"
+}
+output_arg {
+ name: "output"
+ description: "A Tensor with the concatenation of values stacked along the\nconcat_dim dimension. This Tensor\'s shape matches the Tensors in\nvalues, except in concat_dim where it has the sum of the sizes."
+ type_attr: "dtype"
+}
+summary: "Concatenate N Tensors along one dimension."
+attr {
+ name: "dtype"
+ type: "type"
+}
+attr {
+ name: "num_values"
+ type: "int"
+ has_minimum: true
+ minimum: 2
+}
+)proto");
+}
+
+TEST_F(OpDefBuilderTest, DocAttr) {
+ ExpectOrdered(b().Attr("i: int").Doc(R"doc(
+Summary
+
+i: How much to operate.
+)doc"),
+ R"proto(
+summary: "Summary"
+attr {
+ name: "i"
+ type: "int"
+ description: "How much to operate."
+}
+)proto");
+}
+
+TEST_F(OpDefBuilderTest, DocCalledTwiceFailure) {
+ ExpectFailure(b().Doc("What's").Doc("up, doc?"),
+ "Extra call to Doc() for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, DocFailureMissingName) {
+ ExpectFailure(
+ b().Input("a: int32").Doc(R"doc(
+Summary
+
+a: Something for a.
+b: b is not defined.
+)doc"),
+ "No matching input/output/attr for name 'b' from Doc() for Op Test");
+
+ ExpectFailure(
+ b().Input("a: int32").Doc(R"doc(
+Summary
+
+b: b is not defined and by itself.
+)doc"),
+ "No matching input/output/attr for name 'b' from Doc() for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, DefaultMinimum) {
+ ExpectSuccess(b().Input("values: num_values * dtype")
+ .Output("output: anything")
+ .Attr("anything: list(type)")
+ .Attr("dtype: type")
+ .Attr("num_values: int"),
+ R"proto(
+input_arg {
+ name: "values"
+ type_attr: "dtype"
+ number_attr: "num_values"
+}
+output_arg {
+ name: "output"
+ type_list_attr: "anything"
+}
+attr {
+ name: "anything"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+}
+attr {
+ name: "dtype"
+ type: "type"
+}
+attr {
+ name: "num_values"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+}
+)proto");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
new file mode 100644
index 0000000000..e3aef011de
--- /dev/null
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -0,0 +1,344 @@
+#include "tensorflow/core/framework/op_def_util.h"
+
+#include <set>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+namespace { // ------ Helper functions ------
+
+bool HasAttrStyleType(const OpDef::ArgDef& arg) {
+ return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
+ !arg.type_list_attr().empty();
+}
+
+Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
+ const AttrValue& allowed_values(attr.allowed_values());
+ for (auto allowed : allowed_values.list().type()) {
+ if (dt == allowed) {
+ return Status::OK();
+ }
+ }
+ string allowed_str;
+ for (int i = 0; i < allowed_values.list().type_size(); ++i) {
+ if (!allowed_str.empty()) {
+ strings::StrAppend(&allowed_str, ", ");
+ }
+ strings::StrAppend(&allowed_str,
+ DataTypeString(allowed_values.list().type(i)));
+ }
+ return errors::InvalidArgument(
+ "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
+ " is not in the list of allowed values: ", allowed_str);
+}
+
+Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
+ const AttrValue& allowed_values(attr.allowed_values());
+ for (auto allowed : allowed_values.list().s()) {
+ if (str == allowed) {
+ return Status::OK();
+ }
+ }
+ string allowed_str;
+ for (const string& allowed : allowed_values.list().s()) {
+ if (!allowed_str.empty()) {
+ strings::StrAppend(&allowed_str, ", ");
+ }
+ strings::StrAppend(&allowed_str, "\"", allowed, "\"");
+ }
+ return errors::InvalidArgument(
+ "Value for attr '", attr.name(), "' of \"", str,
+ "\" is not in the list of allowed values: ", allowed_str);
+}
+
+} // namespace
+
+// Requires: attr has already been validated.
+Status ValidateAttrValue(const AttrValue& attr_value,
+ const OpDef::AttrDef& attr) {
+ // Is it a valid value?
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
+ " for attr '", attr.name(), "'");
+
+ // Does the value satisfy the minimum constraint in the AttrDef?
+ if (attr.has_minimum()) {
+ if (attr.type() == "int") {
+ if (attr_value.i() < attr.minimum()) {
+ return errors::InvalidArgument(
+ "Value for attr '", attr.name(), "' of ", attr_value.i(),
+ " must be at least minimum ", attr.minimum());
+ }
+ } else {
+ int length = -1;
+ if (attr.type() == "list(string)") {
+ length = attr_value.list().s_size();
+ } else if (attr.type() == "list(int)") {
+ length = attr_value.list().i_size();
+ } else if (attr.type() == "list(float)") {
+ length = attr_value.list().f_size();
+ } else if (attr.type() == "list(bool)") {
+ length = attr_value.list().b_size();
+ } else if (attr.type() == "list(type)") {
+ length = attr_value.list().type_size();
+ } else if (attr.type() == "list(shape)") {
+ length = attr_value.list().shape_size();
+ } else if (attr.type() == "list(tensor)") {
+ length = attr_value.list().tensor_size();
+ }
+ if (length < attr.minimum()) {
+ return errors::InvalidArgument(
+ "Length for attr '", attr.name(), "' of ", length,
+ " must be at least minimum ", attr.minimum());
+ }
+ }
+ }
+
+ // Does the value satisfy the allowed_value constraint in the AttrDef?
+ if (attr.has_allowed_values()) {
+ if (attr.type() == "type") {
+ TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
+ } else if (attr.type() == "list(type)") {
+ for (int dt : attr_value.list().type()) {
+ TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
+ }
+ } else if (attr.type() == "string") {
+ TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
+ } else if (attr.type() == "list(string)") {
+ for (const string& str : attr_value.list().s()) {
+ TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
+ }
+ } else {
+ return errors::Unimplemented(
+ "Support for allowed_values not implemented for type ", attr.type());
+ }
+ }
+ return Status::OK();
+}
+
+const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ if (op_def.attr(i).name() == name) {
+ return &op_def.attr(i);
+ }
+ }
+ return nullptr;
+}
+
+OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
+ for (int i = 0; i < op_def->attr_size(); ++i) {
+ if (op_def->attr(i).name() == name) {
+ return op_def->mutable_attr(i);
+ }
+ }
+ return nullptr;
+}
+
+#define VALIDATE(EXPR, ...) \
+ do { \
+ if (!(EXPR)) { \
+ return errors::InvalidArgument(__VA_ARGS__, "; in OpDef: ", \
+ op_def.ShortDebugString()); \
+ } \
+ } while (false)
+
+static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
+ bool output, std::set<string>* names) {
+ const string suffix = strings::StrCat(
+ output ? " for output '" : " for input '", arg.name(), "'");
+ VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), "Duplicate name: ",
+ arg.name());
+ VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
+
+ if (!arg.number_attr().empty()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
+ VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
+ suffix);
+ VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
+ suffix, " has type ", attr->type(), " != int");
+ VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
+ suffix, " must have minimum");
+ VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
+ suffix, " must have minimum >= 0");
+ VALIDATE(arg.type_list_attr().empty(),
+ "Can't have both number_attr and type_list_attr", suffix);
+ VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
+ (!arg.type_attr().empty() ? 1 : 0) ==
+ 1,
+ "Exactly one of type, type_attr must be set", suffix);
+ } else {
+ const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
+ (!arg.type_attr().empty() ? 1 : 0) +
+ (!arg.type_list_attr().empty() ? 1 : 0);
+ VALIDATE(num_type_fields == 1,
+ "Exactly one of type, type_attr, type_list_attr must be set",
+ suffix);
+ }
+
+ if (!arg.type_attr().empty()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
+ VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
+ suffix);
+ VALIDATE(attr->type() == "type", "Attr '", attr->name(),
+ "' used as type_attr", suffix, " has type ", attr->type(),
+ " != type");
+ } else if (!arg.type_list_attr().empty()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
+ VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
+ suffix);
+ VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
+ "' used as type_list_attr", suffix, " has type ", attr->type(),
+ " != list(type)");
+ } else {
+ // All argument types should be non-reference types at this point.
+ // ArgDef.is_ref is set to true for reference arguments.
+ VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
+ DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
+ }
+
+ return Status::OK();
+}
+
+Status ValidateOpDef(const OpDef& op_def) {
+ VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"),
+ "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
+
+ std::set<string> names; // for detecting duplicate names
+ for (const auto& attr : op_def.attr()) {
+ // Validate name
+ VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), "Duplicate name: ",
+ attr.name());
+ DataType dt;
+ VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
+ attr.name(), " that matches a data type");
+
+ // Validate type
+ StringPiece type(attr.type());
+ bool is_list = type.Consume("list(");
+ bool found = false;
+ for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
+ "tensor", "func"}) {
+ if (type.Consume(valid)) {
+ found = true;
+ break;
+ }
+ }
+ VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
+ "'");
+ if (is_list) {
+ VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ",
+ attr.name(), "'s type ", attr.type());
+ }
+ VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
+ attr.name(), "'s type ", attr.type());
+
+ // Validate minimum
+ if (attr.has_minimum()) {
+ VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
+ "' has minimum for unsupported type ", attr.type());
+ if (is_list) {
+ VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
+ "' with list type must have a non-negative minimum, not ",
+ attr.minimum());
+ }
+ } else {
+ VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
+ "' with has_minimum = false but minimum ", attr.minimum(),
+ " not equal to default of 0");
+ }
+
+ // Validate allowed_values
+ if (attr.has_allowed_values()) {
+ const string list_type =
+ is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
+ attr.name(), "' in Op '", op_def.name(), "'");
+ }
+
+ // Validate default_value (after we have validated the rest of the attr,
+ // so we can use ValidateAttrValue()).
+ if (attr.has_default_value()) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ ValidateAttrValue(attr.default_value(), attr), " in Op '",
+ op_def.name(), "'");
+ }
+ }
+
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
+ }
+
+ for (const auto& arg : op_def.output_arg()) {
+ TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
+ }
+
+ return Status::OK();
+}
+
+#undef VALIDATE
+
+namespace {
+
+string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
+ string ret;
+ for (const OpDef::ArgDef& arg : args) {
+ if (!ret.empty()) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, arg.name(), ":");
+ if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
+ if (!arg.number_attr().empty()) {
+ strings::StrAppend(&ret, arg.number_attr(), "*");
+ }
+ if (arg.type() != DT_INVALID) {
+ strings::StrAppend(&ret, DataTypeString(arg.type()));
+ } else {
+ strings::StrAppend(&ret, arg.type_attr());
+ }
+ if (arg.is_ref()) strings::StrAppend(&ret, ")");
+ }
+ return ret;
+}
+
+} // namespace
+
+string SummarizeOpDef(const OpDef& op_def) {
+ string ret = strings::StrCat("Op<name=", op_def.name());
+ strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
+ " -> ", SummarizeArgs(op_def.output_arg()));
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
+ op_def.attr(i).type());
+ if (op_def.attr(i).has_default_value()) {
+ strings::StrAppend(&ret, ",default=",
+ SummarizeAttrValue(op_def.attr(i).default_value()));
+ }
+ if (op_def.attr(i).has_minimum()) {
+ strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
+ }
+ if (op_def.attr(i).has_allowed_values()) {
+ strings::StrAppend(&ret, ",allowed=",
+ SummarizeAttrValue(op_def.attr(i).allowed_values()));
+ }
+ }
+ if (op_def.is_commutative()) {
+ strings::StrAppend(&ret, "; is_commutative=true");
+ }
+ if (op_def.is_aggregate()) {
+ strings::StrAppend(&ret, "; is_aggregate=true");
+ }
+ if (op_def.is_stateful()) {
+ strings::StrAppend(&ret, "; is_stateful=true");
+ }
+ if (op_def.allows_uninitialized_input()) {
+ strings::StrAppend(&ret, "; allows_uninitialized_input=true");
+ }
+ strings::StrAppend(&ret, ">");
+ return ret;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h
new file mode 100644
index 0000000000..a9fecf3fa0
--- /dev/null
+++ b/tensorflow/core/framework/op_def_util.h
@@ -0,0 +1,32 @@
+// TODO(josh11b): Probably not needed for OpKernel authors, so doesn't
+// need to be as publicly accessible as other files in framework/.
+
+#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
+
+#include <string>
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Performs a consistency check across the fields of the op_def.
+Status ValidateOpDef(const OpDef& op_def);
+
+// Validates that attr_value satisfies the type and constraints from attr.
+// REQUIRES: attr has already been validated.
+Status ValidateAttrValue(const AttrValue& attr_value,
+ const OpDef::AttrDef& attr);
+
+// The following search through op_def for an attr with the indicated name.
+// Returns nullptr if no such attr is found.
+const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def);
+OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def);
+
+// Produce a human-readable version of an op_def that is more concise
+// than a text-format proto. Excludes descriptions.
+string SummarizeOpDef(const OpDef& op_def);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc
new file mode 100644
index 0000000000..515e8bb288
--- /dev/null
+++ b/tensorflow/core/framework/op_def_util_test.cc
@@ -0,0 +1,330 @@
+#include "tensorflow/core/framework/op_def_util.h"
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+OpDef FromText(const string& text) {
+ OpDef op_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &op_def));
+ return op_def;
+}
+
+class ValidateOpDefTest : public ::testing::Test {
+ protected:
+ Status TestProto(const string& text) {
+ return ValidateOpDef(FromText(text));
+ }
+
+ Status TestBuilder(const OpDefBuilder& builder) {
+ OpDef op_def;
+ Status status = builder.Finalize(&op_def);
+ EXPECT_OK(status);
+ if (!status.ok()) {
+ return status;
+ } else {
+ return ValidateOpDef(op_def);
+ }
+ }
+
+ void ExpectFailure(const Status& status, const string& message) {
+ EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
+ if (!status.ok()) {
+ LOG(INFO) << "message: " << status;
+ EXPECT_TRUE(StringPiece(status.ToString()).contains(message))
+ << "Actual: " << status << "\nExpected to contain: " << message;
+ }
+ }
+};
+
+TEST_F(ValidateOpDefTest, OpDefValid) {
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Input("a: int32")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Output("a: bool")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("t: type").Input("a: t")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int = 3")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5 = 3")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: numbertype")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("Uppercase")));
+}
+
+TEST_F(ValidateOpDefTest, InvalidName) {
+ ExpectFailure(TestBuilder(OpDefBuilder("lower").Attr("a: int")),
+ "Invalid name");
+ ExpectFailure(TestBuilder(OpDefBuilder("BadSuffix 7%")), "Invalid name");
+}
+
+TEST_F(ValidateOpDefTest, DuplicateName) {
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Input("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(
+ OpDefBuilder("DupeName").Input("a: int32").Output("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(
+ OpDefBuilder("DupeName").Output("a: int32").Output("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Attr("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("DupeName").Output("a: int32").Attr("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("DupeName").Attr("a: int").Attr("a: float")),
+ "Duplicate name: a");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrName) {
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("int32: int")),
+ "Attr can't have name int32 that matches a data type");
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("float: string")),
+ "Attr can't have name float that matches a data type");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrType) {
+ ExpectFailure(
+ TestProto("name: 'BadAttrType' attr { name: 'a' type: 'illegal' }"),
+ "Unrecognized type");
+ ExpectFailure(
+ TestProto("name: 'BadAttrType' attr { name: 'a' type: 'list(illegal)' }"),
+ "Unrecognized type");
+ ExpectFailure(
+ TestProto("name: 'BadAttrType' attr { name: 'a' type: 'int extra' }"),
+ "Extra ' extra' at the end");
+ ExpectFailure(
+ TestProto(
+ "name: 'BadAttrType' attr { name: 'a' type: 'list(int extra)' }"),
+ "'list(' is missing ')' in attr");
+ ExpectFailure(
+ TestProto(
+ "name: 'BadAttrType' attr { name: 'a' type: 'list(int) extra' }"),
+ "Extra ' extra' at the end");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrDefault) {
+ ExpectFailure(
+ TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'int' default_value { s: 'x' } }"),
+ "AttrValue had value with type string when int expected\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'int' default_value { f: 0.5 } }"),
+ "AttrValue had value with type float when int expected\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(
+ TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'int' "
+ "default_value { i: 5 list { i: [2] } } }"),
+ "AttrValue had value with type list(int) when int expected\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(
+ TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'list(int)' default_value { f: 0.5 } }"),
+ "AttrValue had value with type float when list(int) expected\n\t "
+ "for attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(
+ TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'list(int)' "
+ "default_value { list { i: [5] f: [0.5] } } }"),
+ "AttrValue had value with type list(float) when list(int) "
+ "expected\n\t for attr 'a'\n\t in Op 'BadAttrDef'");
+
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'type' default_value { } }"),
+ "AttrValue missing value with expected type type\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'shape' default_value { } }"),
+ "AttrValue missing value with expected type shape\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'tensor' default_value { } }"),
+ "AttrValue missing value with expected type tensor\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+
+ // default_value {} is indistinguishable from default_value{ list{} } (one
+ // with an empty list) in proto3 semantics.
+ EXPECT_OK(
+ TestProto("name: 'GoodAttrDef' attr { name: 'a' "
+ "type: 'list(int)' default_value { } }"));
+
+ // Empty lists are allowed:
+ EXPECT_OK(
+ TestProto("name: 'GoodAttrDef' attr { name: 'a' "
+ "type: 'list(int)' default_value { list { } } }"));
+ // Builder should make the same proto:
+ EXPECT_OK(TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(int) = []")));
+
+ // Unless there is a minimum length specified:
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'list(int)' has_minimum: true minimum: 2 "
+ "default_value { list { } } }"),
+ "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op "
+ "'BadAttrDef'");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(bool) >=2 = []")),
+ "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op "
+ "'GoodAttrDef'");
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' type: "
+ "'list(string)' has_minimum: true minimum: 2 "
+ "default_value { list { s: ['foo'] } } }"),
+ "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op "
+ "'BadAttrDef'");
+ ExpectFailure(TestBuilder(OpDefBuilder("GoodAttrDef")
+ .Attr("a: list(type) >=2 = [DT_STRING]")),
+ "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op "
+ "'GoodAttrDef'");
+}
+
+TEST_F(ValidateOpDefTest, NoRefTypes) {
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef").Input("i: float_ref")),
+ "Illegal use of ref type 'float_ref'. "
+ "Use 'Ref(type)' instead for input 'i'");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("BadAttrDef").Attr("T: type = DT_INT32_REF")),
+ "AttrValue must not have reference type value of int32_ref");
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef")
+ .Attr("T: list(type) = [DT_STRING_REF]")),
+ "AttrValue must not have reference type value of string_ref");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrMin) {
+ ExpectFailure(TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'string' "
+ "has_minimum: true minimum: 0 }"),
+ "minimum for unsupported type string");
+ ExpectFailure(
+ TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'int' default_value "
+ "{ i: 2 } has_minimum: true minimum: 7 }"),
+ "Value for attr 'a' of 2 must be at least minimum 7\n\t in Op "
+ "'BadAttrMin'");
+ ExpectFailure(
+ TestProto("name: 'BadAttrMin' attr { name: 'a' "
+ "type: 'list(string)' has_minimum: true minimum: -5 }"),
+ "list type must have a non-negative minimum, not -5");
+ EXPECT_OK(
+ TestProto("name: 'GoodAttrMin' attr { name: 'a' type: 'list(string)' "
+ "has_minimum: true minimum: 1 }"));
+ ExpectFailure(TestProto("name: 'NoHasMin' attr { name: 'a' "
+ "type: 'list(string)' minimum: 3 }"),
+ "Attr 'a' with has_minimum = false but minimum 3 not equal to "
+ "default of 0");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrAllowed) {
+ // Is in list of allowed types.
+ EXPECT_OK(TestBuilder(
+ OpDefBuilder("GoodAttrtude").Attr("x: numbertype = DT_INT32")));
+ // Not in list of allowed types.
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
+ .Attr("x: numbertype = DT_STRING")),
+ "attr 'x' of string is not in the list of allowed values");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("BadAttrtude")
+ .Attr("x: list(realnumbertype) = [DT_COMPLEX64]")),
+ "attr 'x' of complex64 is not in the list of allowed values");
+ // Is in list of allowed strings.
+ EXPECT_OK(TestBuilder(
+ OpDefBuilder("GoodAttrtude").Attr("x: {'foo', 'bar'} = 'bar'")));
+ // Not in list of allowed strings.
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
+ .Attr("x: {'foo', 'bar'} = 'baz'")),
+ "attr 'x' of \"baz\" is not in the list of allowed values");
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
+ .Attr("x: list({'foo', 'bar'}) = ['baz']")),
+ "attr 'x' of \"baz\" is not in the list of allowed values");
+ ExpectFailure(TestProto(
+ "name: 'BadAttrtude' attr { name: 'a' "
+ "type: 'string' allowed_values { s: 'not list' } }"),
+ "with type string when list(string) expected");
+ ExpectFailure(
+ TestProto("name: 'BadAttrtude' attr { name: 'a' "
+ "type: 'string' allowed_values { list { i: [6] } } }"),
+ "with type list(int) when list(string) expected");
+}
+
+TEST_F(ValidateOpDefTest, BadArgType) {
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type: DT_INT32 } input_arg { name: 'b' }"),
+ "Missing type for input 'b'");
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type: DT_INT32 } output_arg { name: 'b' }"),
+ "Missing type for output 'b'");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' type: "
+ "DT_INT32 type_attr: 'x' } attr { name: 'x' type: 'type' }"),
+ "Exactly one of type, type_attr, type_list_attr must be set for input "
+ "'a'");
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_attr: 'x' } attr { name: 'x' type: 'int' }"),
+ "Attr 'x' used as type_attr for input 'a' has type int");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_attr: 'x' } attr { name: 'x' type: 'list(type)' }"),
+ "Attr 'x' used as type_attr for input 'a' has type list(type)");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_list_attr: 'x' } attr { name: 'x' type: 'int' }"),
+ "Attr 'x' used as type_list_attr for input 'a' has type int");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_list_attr: 'x' } attr { name: 'x' type: 'type' }"),
+ "Attr 'x' used as type_list_attr for input 'a' has type type");
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_attr: 'x' }"),
+ "No attr with name 'x' for input 'a'");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: 'n' "
+ "type_attr: 'x' } attr { name: 'x' type: 'list(type)' } "
+ "attr { name: 'n' type: 'int' has_minimum: true minimum: 1 }"),
+ "Attr 'x' used as type_attr for input 'a' has type list(type)");
+ // But list(type) is fine as the type of an arg without a number_attr:
+ EXPECT_OK(TestProto(
+ "name: 'Arg' input_arg { name: 'a' type_list_attr: 'x' } "
+ "attr { name: 'x' type: 'list(type)' } attr { name: 'n' type: 'int' "
+ "has_minimum: true minimum: 1 }"));
+
+ // number_attr
+ EXPECT_OK(TestProto(
+ "name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: 'n' } "
+ "attr { name: 'n' type: 'int' has_minimum: true minimum: 0 }"));
+
+ ExpectFailure(TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 "
+ "number_attr: 'n' }"),
+ "No attr with name 'n'");
+ ExpectFailure(
+ TestProto(
+ "name: 'Arg' input_arg { name: 'a' type: "
+ "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'string' }"),
+ "Attr 'n' used as length for input 'a' has type string");
+ ExpectFailure(
+ TestProto("name: 'Arg' input_arg { name: 'a' type: "
+ "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'int' }"),
+ "Attr 'n' used as length for input 'a' must have minimum;");
+ ExpectFailure(
+ TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: "
+ "'n' } attr { name: 'n' type: 'int' has_minimum: true minimum: "
+ "-5 }"),
+ "Attr 'n' used as length for input 'a' must have minimum >= 0;");
+ ExpectFailure(
+ TestProto("name: 'Arg' input_arg { name: 'a' number_attr: 'n' } attr { "
+ "name: 'n' type: 'int' has_minimum: true minimum: 2 }"),
+ "Missing type for input 'a'; in OpDef:");
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: "
+ "'n' type_list_attr: 'x' } attr { name: 'n' type: "
+ "'int' has_minimum: true minimum: 1 } attr { name: "
+ "'x' type: 'list(type)' }"),
+ "Can't have both number_attr and type_list_attr for input 'a'");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
new file mode 100644
index 0000000000..04f4b7cacd
--- /dev/null
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -0,0 +1,55 @@
+#include "tensorflow/core/framework/op_gen_lib.h"
+
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+string WordWrap(StringPiece prefix, StringPiece str, int width) {
+ const string indent_next_line = "\n" + Spaces(prefix.size());
+ width -= prefix.size();
+ string result;
+ strings::StrAppend(&result, prefix);
+
+ while (!str.empty()) {
+ if (static_cast<int>(str.size()) <= width) {
+ // Remaining text fits on one line.
+ strings::StrAppend(&result, str);
+ break;
+ }
+ auto space = str.rfind(' ', width);
+ if (space == StringPiece::npos) {
+ // Rather make a too-long line and break at a space.
+ space = str.find(' ');
+ if (space == StringPiece::npos) {
+ strings::StrAppend(&result, str);
+ break;
+ }
+ }
+ // Breaking at character at position <space>.
+ StringPiece to_append = str.substr(0, space);
+ str.remove_prefix(space + 1);
+ // Remove spaces at break.
+ while (to_append.ends_with(" ")) {
+ to_append.remove_suffix(1);
+ }
+ while (str.Consume(" ")) {
+ }
+
+ // Go on to the next line.
+ strings::StrAppend(&result, to_append);
+ if (!str.empty()) strings::StrAppend(&result, indent_next_line);
+ }
+
+ return result;
+}
+
+bool ConsumeEquals(StringPiece* description) {
+ if (description->Consume("=")) {
+ while (description->Consume(" ")) { // Also remove spaces after "=".
+ }
+ return true;
+ }
+ return false;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h
new file mode 100644
index 0000000000..9890f1bcec
--- /dev/null
+++ b/tensorflow/core/framework/op_gen_lib.h
@@ -0,0 +1,24 @@
+#ifndef TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
+#define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
+
+#include <string>
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+inline string Spaces(int n) { return string(n, ' '); }
+
+// Wrap prefix + str to be at most width characters, indenting every line
+// after the first by prefix.size() spaces. Intended use case is something
+// like prefix = " Foo(" and str is a list of arguments (terminated by a ")").
+// TODO(josh11b): Option to wrap on ", " instead of " " when possible.
+string WordWrap(StringPiece prefix, StringPiece str, int width);
+
+// Looks for an "=" at the beginning of *description. If found, strips it off
+// (and any following spaces) from *description and return true. Otherwise
+// returns false.
+bool ConsumeEquals(StringPiece* description);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
new file mode 100644
index 0000000000..eb83d393f0
--- /dev/null
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -0,0 +1,749 @@
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
+ const DataTypeSlice expected_outputs,
+ const DataTypeSlice inputs,
+ const DataTypeSlice outputs) {
+ bool signature_mismatch = false;
+
+ if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
+ for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
+ if (!TypesCompatible(expected_inputs[i], inputs[i])) {
+ signature_mismatch = true;
+ }
+ }
+
+ if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
+ for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
+ if (!TypesCompatible(expected_outputs[i], outputs[i])) {
+ signature_mismatch = true;
+ }
+ }
+
+ if (signature_mismatch) {
+ return errors::InvalidArgument("Signature mismatch, have: ",
+ DataTypeSliceString(inputs), "->",
+ DataTypeSliceString(outputs), " expected: ",
+ DataTypeSliceString(expected_inputs), "->",
+ DataTypeSliceString(expected_outputs));
+ }
+ return Status::OK();
+}
+
+// Check HostMemory backward compatibility.
+bool CheckHostMemoryCompatibility(const DeviceType device_type,
+ const OpKernel* kernel) {
+ if (device_type == DEVICE_GPU) {
+ for (int i = 0; i < kernel->num_inputs(); ++i) {
+ if (kernel->input_type(i) == DT_INT32 &&
+ kernel->input_memory_types()[i] != HOST_MEMORY) {
+ return false;
+ }
+ }
+ for (int i = 0; i < kernel->num_outputs(); ++i) {
+ if (kernel->output_type(i) == DT_INT32 &&
+ kernel->output_memory_types()[i] != HOST_MEMORY) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+// OpKernel ------------------------------------------------------------------
+
+OpKernel::OpKernel(OpKernelConstruction* context)
+ : def_(context->def()),
+ input_types_(context->input_types().begin(),
+ context->input_types().end()),
+ output_types_(context->output_types().begin(),
+ context->output_types().end()),
+ input_name_map_(context->num_inputs()),
+ output_name_map_(context->num_outputs()) {
+ OP_REQUIRES_OK(context,
+ NameRangesForNode(def_, context->op_def(), &input_name_map_,
+ &output_name_map_));
+
+ // By default, the input and output memory types are always in device memory,
+ // but can be overridden by individual implementations of OpKernels in their
+ // constructor.
+ input_memory_types_ = MemoryTypeVector(input_types_.size(), DEVICE_MEMORY);
+ output_memory_types_ = MemoryTypeVector(output_types_.size(), DEVICE_MEMORY);
+ // TODO(yuanbyu): For now we assume the memory types of function
+ // inputs/outputs to be DEVICE_MEMORY.
+ auto lib = context->function_library();
+ if (lib == nullptr || !lib->IsDefined(def_.op())) {
+ OP_REQUIRES_OK(context, MemoryTypesForNode(
+ context->device_type(), def_, context->op_def(),
+ input_name_map_, output_name_map_,
+ &input_memory_types_, &output_memory_types_));
+ // Log all the uses of int32 on GPU.
+ // TODO(yunabyu): Remove once everyone transitions to HostMemory.
+ if (VLOG_IS_ON(2)) {
+ if (!CheckHostMemoryCompatibility(context->device_type(), this)) {
+ VLOG(2) << "Using int32 on GPU at node: " << SummarizeNodeDef(def());
+ }
+ }
+ }
+}
+
+Status OpKernel::InputRange(const string& input_name, int* start,
+ int* stop) const {
+ const auto result = input_name_map_.find(input_name);
+ if (result == input_name_map_.end()) {
+ return errors::InvalidArgument("Unknown input name: ", input_name);
+ } else {
+ *start = result->second.first;
+ *stop = result->second.second;
+ return Status::OK();
+ }
+}
+
+Status OpKernel::OutputRange(const string& output_name, int* start,
+ int* stop) const {
+ const auto result = output_name_map_.find(output_name);
+ if (result == output_name_map_.end()) {
+ return errors::InvalidArgument("Unknown output name: ", output_name);
+ } else {
+ *start = result->second.first;
+ *stop = result->second.second;
+ return Status::OK();
+ }
+}
+
+void AsyncOpKernel::Compute(OpKernelContext* context) {
+ Notification n;
+ ComputeAsync(context, [&n]() { n.Notify(); });
+ n.WaitForNotification();
+}
+
+// PersistentTensor ----------------------------------------------------------
+
+Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
+ // the caller has to have a valid context
+ CHECK(context);
+ return &tensor_;
+}
+
+Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
+ context->NotifyUseOfPersistentTensor(tensor_);
+ return &tensor_;
+}
+
+// OpKernelConstruction ------------------------------------------------------
+
+Status OpKernelConstruction::MatchSignature(
+ const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
+ return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_,
+ output_types_);
+}
+
+Status OpKernelConstruction::allocate_temp(DataType type,
+ const TensorShape& shape,
+ Tensor* out_temp) {
+ Tensor new_temp(allocator_, type, shape);
+
+ if (!new_temp.IsInitialized() && shape.num_elements() > 0) {
+ return errors::ResourceExhausted(
+ "OOM when allocating temporary tensor with shape", shape.DebugString());
+ }
+ *out_temp = new_temp;
+ return Status::OK();
+}
+
+Status OpKernelConstruction::allocate_persistent(
+ DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
+ Tensor** out_tensor) {
+ // for now just do the same thing as allocate_temp
+ // TODO(misard) add specific memory tracking for persistent tensors
+ Tensor persistent;
+ Status s = allocate_temp(type, shape, &persistent);
+ if (!s.ok()) {
+ return s;
+ }
+ *out_persistent = PersistentTensor(persistent);
+ Tensor* allocated = out_persistent->AccessTensor(this);
+ if (out_tensor) {
+ *out_tensor = allocated;
+ }
+ return s;
+}
+
+// OpKernelContext -----------------------------------------------------------
+
+OpKernelContext::OpKernelContext(const Params& params)
+ : params_(params),
+ outputs_(params.op_kernel->output_types().size()),
+ output_allocation_types_(params.op_kernel->output_types().size()) {
+ Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
+ eigen_gpu_device_ = params_.device->MakeGpuDevice(params_.op_device_context,
+ eigen_gpu_allocator);
+}
+
+OpKernelContext::~OpKernelContext() {
+ for (TensorValue& value : outputs_) {
+ if (!value.is_ref()) {
+ delete value.tensor;
+ }
+ }
+ for (Tensor* t : temp_tensors_) delete t;
+ delete eigen_gpu_device_;
+}
+
+Status OpKernelContext::input(const string& name, const Tensor** tensor) const {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was "
+ "expected");
+ }
+ if ((*params_.inputs)[start].is_ref()) {
+ return errors::InvalidArgument("OpKernel used ref input name '", name,
+ "' when immutable input was expected");
+ }
+ *tensor = (*params_.inputs)[start].tensor;
+ return Status::OK();
+}
+
+Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was expected");
+ }
+ *out_mutex = input_ref_mutex(start);
+ return Status::OK();
+}
+
+Status OpKernelContext::mutable_input(const string& name, Tensor* tensor,
+ bool lock_held) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was expected");
+ }
+ if (!(*params_.inputs)[start].is_ref()) {
+ return errors::InvalidArgument("OpKernel used immutable input name '", name,
+ "' when ref input was expected");
+ }
+ // return a copy of the Ref acquired while holding the mutex
+ if (lock_held) {
+ *tensor = *(*params_.inputs)[start].tensor;
+ } else {
+ mutex_lock l(*input_ref_mutex(start));
+ *tensor = *(*params_.inputs)[start].tensor;
+ }
+ return Status::OK();
+}
+
+Status OpKernelContext::replace_ref_input(const string& name,
+ const Tensor& tensor,
+ bool lock_held) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was expected");
+ }
+ if (!(*params_.inputs)[start].is_ref()) {
+ return errors::InvalidArgument("OpKernel used immutable input name '", name,
+ "' when ref input was expected");
+ }
+ replace_ref_input(start, tensor, lock_held);
+ return Status::OK();
+}
+
+Status OpKernelContext::input_list(const string& name,
+ OpInputList* list) const {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ *list = OpInputList(this, start, stop);
+ return Status::OK();
+}
+
+Status OpKernelContext::mutable_input_list(const string& name,
+ OpMutableInputList* list) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ *list = OpMutableInputList(this, start, stop);
+ return Status::OK();
+}
+
+Status OpKernelContext::output_list(const string& name, OpOutputList* list) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ *list = OpOutputList(this, start, stop);
+ return Status::OK();
+}
+
+Status OpKernelContext::allocate_output(const string& name,
+ const TensorShape& shape,
+ Tensor** tensor) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ return allocate_output(start, shape, tensor);
+}
+
+Status OpKernelContext::allocate_output(const string& name,
+ const TensorShape& shape,
+ Tensor** tensor,
+ AllocatorAttributes attr) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ return allocate_output(start, shape, tensor, attr);
+}
+
+Status OpKernelContext::set_output(const string& name, const Tensor& tensor) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ set_output(start, tensor);
+ return Status::OK();
+}
+
+Status OpKernelContext::set_output_ref(const string& name, mutex* mu,
+ Tensor* tensor_for_ref) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ set_output_ref(start, mu, tensor_for_ref);
+ return Status::OK();
+}
+
+Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ *tensor = mutable_output(start);
+ return Status::OK();
+}
+
+Status OpKernelContext::release_output(const string& name, TensorValue* value) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ *value = release_output(start);
+ return Status::OK();
+}
+
+bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
+ const auto& inputs = *params_.inputs;
+ for (size_t i = 1; i < inputs.size(); ++i) {
+ if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
+ SetStatus(errors::InvalidArgument(
+ "Inputs to operation ", op->name(), " of type ", op->type_string(),
+ " must have the same size and shape. Input 0: ",
+ inputs[0]->shape().DebugString(), " != input ", i, ": ",
+ inputs[i]->shape().DebugString()));
+ return false;
+ }
+ }
+ return true;
+}
+
+Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
+ const DataTypeSlice expected_outputs) {
+ DataTypeVector inputs;
+ for (const TensorValue& t : *params_.inputs) {
+ inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype());
+ }
+ DataTypeVector outputs = params_.op_kernel->output_types();
+ return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
+ outputs);
+}
+
+// OpKernel registration ------------------------------------------------------
+
+struct KernelRegistration {
+ KernelRegistration(const KernelDef& d,
+ kernel_factory::OpKernelRegistrar::Factory f)
+ : def(d), factory(f) {}
+ const KernelDef def;
+ const kernel_factory::OpKernelRegistrar::Factory factory;
+};
+
+// This maps from 'op_type' + DeviceType to the set of KernelDefs and
+// factory functions for instantiating the OpKernel that matches the
+// KernelDef.
+typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
+
+static KernelRegistry* GlobalKernelRegistry() {
+ static KernelRegistry* global_kernel_registry = new KernelRegistry;
+ return global_kernel_registry;
+}
+
+static string Key(const string& op_type, DeviceType device_type,
+ const string& label) {
+ return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
+ label);
+}
+
+namespace kernel_factory {
+
+OpKernelRegistrar::OpKernelRegistrar(const KernelDef* kernel_def,
+ Factory factory) {
+ const string key =
+ Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
+ kernel_def->label());
+ GlobalKernelRegistry()->insert(
+ std::make_pair(key, KernelRegistration(*kernel_def, factory)));
+ delete kernel_def;
+}
+
+} // namespace kernel_factory
+
+namespace {
+
+// Helper for AttrsMatch().
+bool InTypeList(DataType dt, const AttrValue& type_list) {
+ for (int in_list : type_list.list().type()) {
+ if (dt == in_list) return true;
+ }
+ return false;
+}
+
+// Returns whether the attrs in the NodeDef satisfy the constraints in
+// the kernel_def. Returns an error if attrs in kernel_def are not
+// found, or have a mismatching type.
+Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def,
+ bool* match) {
+ *match = false;
+ AttrSlice attrs(node_def);
+ for (const auto& constraint : kernel_def.constraint()) {
+ if (constraint.allowed_values().list().type_size() == 0) {
+ return errors::Unimplemented(
+ "KernelDef '", kernel_def.ShortDebugString(),
+ " has constraint on attr '", constraint.name(),
+ "' with unsupported type: ",
+ SummarizeAttrValue(constraint.allowed_values()));
+ }
+
+ const AttrValue* found = attrs.Find(constraint.name());
+ if (found) {
+ if (found->type() != DT_INVALID) {
+ if (!InTypeList(found->type(), constraint.allowed_values())) {
+ return Status::OK();
+ }
+ } else {
+ if (!AttrValueHasType(*found, "list(type)").ok()) {
+ return errors::InvalidArgument(
+ "KernelDef '", kernel_def.ShortDebugString(),
+ "' has constraint on attr '", constraint.name(),
+ "' that has value '", SummarizeAttrValue(*found),
+ "' that does not have type 'type' or 'list(type)' in NodeDef '",
+ SummarizeNodeDef(node_def), "'");
+ }
+
+ for (int t : found->list().type()) {
+ if (!InTypeList(static_cast<DataType>(t),
+ constraint.allowed_values())) {
+ return Status::OK();
+ }
+ }
+ }
+ } else {
+ return errors::InvalidArgument(
+ "OpKernel '", kernel_def.op(), "' has constraint on attr '",
+ constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def),
+ "', KernelDef: '", kernel_def.ShortDebugString(), "'");
+ }
+ }
+ *match = true;
+ return Status::OK();
+}
+
+Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def,
+ const KernelRegistration** reg) {
+ *reg = nullptr;
+ string label; // Label defaults to empty if not found in NodeDef.
+ GetNodeAttr(node_def, "_kernel", &label);
+ const string key = Key(node_def.op(), device_type, label);
+ auto regs = GlobalKernelRegistry()->equal_range(key);
+ for (auto iter = regs.first; iter != regs.second; ++iter) {
+ // If there is a kernel registered for the op and device_type,
+ // check that the attrs match.
+ bool match;
+ TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match));
+ if (match) {
+ if (*reg != nullptr) {
+ return errors::InvalidArgument(
+ "Multiple OpKernel registrations match NodeDef '",
+ SummarizeNodeDef(node_def), "': '", (*reg)->def.ShortDebugString(),
+ "' and '", iter->second.def.ShortDebugString(), "'");
+ }
+ *reg = &iter->second;
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status SupportedDeviceTypesForNode(
+ const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
+ DeviceTypeVector* device_types) {
+ // TODO(zhifengc): Changes the callers (SimplePlacer and
+ // DynamicPlacer) to consider the possibility that 'def' is call to
+ // a user-defined function and only calls this
+ // SupportedDeviceTypesForNode for primitive ops.
+ Status s;
+ const OpDef* op_def = OpRegistry::Global()->LookUp(def.op(), &s);
+ if (op_def) {
+ for (const DeviceType& device_type : prioritized_types) {
+ const KernelRegistration* reg = nullptr;
+ TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, def, &reg));
+ if (reg != nullptr) device_types->push_back(device_type);
+ }
+ } else {
+ // Assumes that all device types support this node.
+ for (const DeviceType& device_type : prioritized_types) {
+ device_types->push_back(device_type);
+ }
+ }
+ return Status::OK();
+}
+
+std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
+ DeviceBase* device,
+ Allocator* allocator,
+ const NodeDef& node_def,
+ Status* status) {
+ OpKernel* kernel = nullptr;
+ *status = CreateOpKernel(device_type, device, allocator, nullptr, node_def,
+ &kernel);
+ return std::unique_ptr<OpKernel>(kernel);
+}
+
+Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
+ Allocator* allocator, FunctionLibraryRuntime* flib,
+ const NodeDef& node_def, OpKernel** kernel) {
+ VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
+
+ // Look up the Op registered for this op name.
+ Status s;
+ const OpDef* op_def = OpRegistry::Global()->LookUp(node_def.op(), &s);
+ if (op_def == nullptr) return s;
+
+ // Validate node_def against OpDef.
+ s = ValidateNodeDef(node_def, *op_def);
+ if (!s.ok()) return s;
+
+ // Look up kernel registration.
+ const KernelRegistration* registration;
+ s = FindKernelRegistration(device_type, node_def, &registration);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " when instantiating ", node_def.op());
+ return s;
+ }
+ if (registration == nullptr) {
+ s.Update(errors::NotFound("No registered '", node_def.op(),
+ "' OpKernel for ", DeviceTypeString(device_type),
+ " devices compatible with node ",
+ SummarizeNodeDef(node_def)));
+ return s;
+ }
+
+ // Get signature from the OpDef & NodeDef
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " for node: ", SummarizeNodeDef(node_def));
+ return s;
+ }
+
+ // Everything needed for OpKernel construction.
+ OpKernelConstruction context(device_type, device, allocator, &node_def,
+ op_def, flib, inputs, outputs, &s);
+ *kernel = (*registration->factory)(&context);
+ if (!s.ok()) {
+ delete *kernel;
+ *kernel = nullptr;
+ }
+ return s;
+}
+
+namespace { // Helper for MemoryTypesForNode.
+// Fills memory_types for either input or output, setting everything
+// to DEVICE_MEMORY except those args in host_memory_args. Removes
+// elements of host_memory_args that were used.
+void MemoryTypesHelper(const NameRangeMap& name_map,
+ std::vector<string>* host_memory_args,
+ MemoryTypeVector* memory_types) {
+ // Set total to the largest endpoint of anything in the name_map.
+ int total = 0;
+ for (const auto& item : name_map) {
+ total = std::max(total, item.second.second);
+ }
+
+ // Now that we know the size, fill with the default 'DEVICE_MEMORY'.
+ memory_types->clear();
+ memory_types->resize(total, DEVICE_MEMORY);
+
+ // Update args that have been marked as in "HOST_MEMORY".
+ size_t keep = 0;
+ for (size_t i = 0; i < host_memory_args->size(); ++i) {
+ auto iter = name_map.find((*host_memory_args)[i]);
+ if (iter != name_map.end()) {
+ for (int j = iter->second.first; j < iter->second.second; ++j) {
+ (*memory_types)[j] = HOST_MEMORY;
+ }
+ } else {
+ // (*host_memory_args)[i] not found, save it for the next pass.
+ if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i];
+ ++keep;
+ }
+ }
+ host_memory_args->resize(keep);
+}
+} // namespace
+
+Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef,
+ const OpDef& op_def,
+ const NameRangeMap& input_name_map,
+ const NameRangeMap& output_name_map,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types) {
+ Status status;
+ const KernelRegistration* registration;
+ TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, ndef, &registration));
+
+ if (registration != nullptr) {
+ const auto& from_proto = registration->def.host_memory_arg();
+ std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
+ MemoryTypesHelper(input_name_map, &host_memory_args, input_memory_types);
+ MemoryTypesHelper(output_name_map, &host_memory_args, output_memory_types);
+ if (!host_memory_args.empty()) {
+ return errors::InvalidArgument(
+ "HostMemory args '", str_util::Join(host_memory_args, "', '"),
+ "' not found in OpDef: ", SummarizeOpDef(op_def));
+ }
+ }
+ return status;
+}
+
+Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
+ DeviceType device_type, const NodeDef& ndef,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types) {
+ // Look up the Op registered for this op name.
+ Status status;
+ const OpDef* op_def = op_registry->LookUp(ndef.op(), &status);
+ if (op_def == nullptr) return status;
+
+ NameRangeMap inputs, outputs;
+ status = NameRangesForNode(ndef, *op_def, &inputs, &outputs);
+ if (!status.ok()) return status;
+
+ return MemoryTypesForNode(device_type, ndef, *op_def, inputs, outputs,
+ input_memory_types, output_memory_types);
+}
+
+namespace {
+
+bool FindArgInOp(const string& arg_name,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
+ for (const auto& arg : args) {
+ if (arg_name == arg.name()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry) {
+ Status unused_status;
+ for (const auto& key_registration : *GlobalKernelRegistry()) {
+ const KernelDef& kernel_def(key_registration.second.def);
+ const OpDef* op_def = op_registry->LookUp(kernel_def.op(), &unused_status);
+ if (op_def == nullptr) {
+ // TODO(josh11b): Make this a hard error.
+ LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString()
+ << "') for unknown op: " << kernel_def.op();
+ continue;
+ }
+ for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
+ if (!FindArgInOp(host_memory_arg, op_def->input_arg()) &&
+ !FindArgInOp(host_memory_arg, op_def->output_arg())) {
+ return errors::InvalidArgument("HostMemory arg '", host_memory_arg,
+ "' not found in OpDef: ",
+ SummarizeOpDef(*op_def));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+template <>
+const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
+ return eigen_cpu_device();
+}
+
+template <>
+const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
+ return eigen_gpu_device();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
new file mode 100644
index 0000000000..34d588c6c9
--- /dev/null
+++ b/tensorflow/core/framework/op_kernel.h
@@ -0,0 +1,1250 @@
+#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+
+#include <functional>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace Eigen {
+class ThreadPoolDevice;
+class GpuDevice;
+} // end namespace Eigen
+
+namespace tensorflow {
+
+namespace checkpoint {
+class TensorSliceReaderCacheWrapper;
+} // namespace checkpoint
+
+class AsyncOpKernel;
+class OpKernelConstruction; // declared below
+class OpKernelContext; // declared below
+class ResourceMgr;
+
+// TODO(josh11b): Make reference-counted if needed.
+class OpKernel {
+ public:
+ // OpKernel won't be instantiated by the scheduler, so you may perform
+ // expensive initialization in the descendant's constructor.
+ explicit OpKernel(OpKernelConstruction* context);
+ virtual ~OpKernel() {}
+
+ // An OpKernel's computation can be either synchronous or
+ // asynchronous.
+ //
+ // Most OpKernels should compute synchronously. They should
+ // subclass OpKernel and override the Compute() method and have it
+ // return after completing the supplied work.
+ //
+ // A few special kernels might need to be asynchronous to bound the
+ // number of threads (e.g., network receive operations). These
+ // kernels must subclass AsyncOpKernel and override
+ // AsyncOpKernel::ComputeAsync().
+ //
+ // In both cases, implementations of Compute() and ComputeAsync()
+ // get inputs and write outputs through the given OpKernelContext
+ // and returns a status via context->SetStatus(). They must be
+ // thread-safe.
+
+ // Synchronous compute.
+ //
+ // "context" is guaranteed to be alive until Compute() returns.
+ virtual void Compute(OpKernelContext* context) = 0;
+
+ // Returns nullptr iff this op kernel is synchronous.
+ virtual AsyncOpKernel* AsAsync() { return nullptr; }
+
+ // Returns true iff this op kernel is considered "expensive". The
+ // runtime may use this flag to optimize graph execution for example
+ // to "inline" inexpensive kernels.
+ virtual bool IsExpensive() { return true; }
+
+ // Accessors.
+ const NodeDef& def() const { return def_; }
+ const string& name() const { return def_.name(); }
+ const string& type_string() const { return def_.op(); }
+
+ int num_inputs() const { return input_types_.size(); }
+ DataType input_type(int i) const { return input_types_[i]; }
+ const DataTypeVector& input_types() const { return input_types_; }
+ const MemoryTypeVector& input_memory_types() const {
+ return input_memory_types_;
+ }
+
+ int num_outputs() const { return output_types_.size(); }
+ DataType output_type(int o) const { return output_types_[o]; }
+ const DataTypeVector& output_types() const { return output_types_; }
+ const MemoryTypeVector& output_memory_types() const {
+ return output_memory_types_;
+ }
+
+ Status InputRange(const string& input_name, int* start, int* stop) const;
+ Status OutputRange(const string& output_name, int* start, int* stop) const;
+
+ private:
+ const NodeDef def_;
+ const DataTypeVector input_types_;
+ const DataTypeVector output_types_;
+ NameRangeMap input_name_map_;
+ NameRangeMap output_name_map_;
+ MemoryTypeVector input_memory_types_;
+ MemoryTypeVector output_memory_types_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
+};
+
+class AsyncOpKernel : public OpKernel {
+ public:
+ using OpKernel::OpKernel; // Lift OpKernel constructors.
+
+ // Asynchronous compute.
+ //
+ // Implementations of ComputeAsync() must run "done" to signal the
+ // completion of the computation. "context" is guaranteed to be
+ // alive until the "done" callback starts.
+ typedef std::function<void()> DoneCallback;
+ virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
+
+ AsyncOpKernel* AsAsync() final { return this; }
+
+ void Compute(OpKernelContext* context) final;
+};
+
+// Wraps a tensor that is held by an Op across calls to Compute(). For
+// memory safety when using asynchronous devices like GPUs, the system
+// must be notified when a Tensor is used inside an Op execution. The
+// wrapper ensures that all uses of the Tensor are tracked, because in
+// order to retrieve the Tensor the caller must use AccessTensor which
+// notifies the context.
+class PersistentTensor {
+ public:
+ PersistentTensor() {}
+ explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
+
+ // Caller does not own the returned Tensor*.
+ Tensor* AccessTensor(OpKernelConstruction* context);
+ // Caller does not own the returned Tensor*.
+ Tensor* AccessTensor(OpKernelContext* context);
+
+ // The check for initialization does not need to access the
+ // underlying tensor buffer.
+ bool IsInitialized() { return tensor_.IsInitialized(); }
+
+ private:
+ Tensor tensor_;
+};
+
+class OpKernelConstruction {
+ public:
+ // TODO(yuanbyu): Probably reduce the number of arguments.
+ OpKernelConstruction(DeviceType device_type, DeviceBase* device,
+ Allocator* allocator, const NodeDef* node_def,
+ const OpDef* op_def, FunctionLibraryRuntime* flib,
+ const DataTypeSlice& input_types,
+ const DataTypeSlice& output_types, Status* status)
+ : device_type_(device_type),
+ device_(device),
+ allocator_(allocator),
+ def_(node_def),
+ op_def_(op_def),
+ flib_(flib),
+ input_types_(input_types),
+ output_types_(output_types),
+ status_(status) {}
+
+ Env* env() const { return device_->env(); }
+
+ // Allocation of tensors during kernel construction:
+ //
+ // It is legal to temporarily allocate scratch tensor storage during
+ // Op kernel construction. Scratch tensors should be allocated using
+ // allocate_temp below. Some kernels need to keep tensors in between
+ // invocations. If such a Tensor is allocated during kernel
+ // construction this must be done using allocate_persistent, and the
+ // Op may only store the returned PersistentTensor object. When the
+ // Tensor is needed in a subsequent invocation, it can be retrieved
+ // from the PersistentTensor using the AccessTensor method. This
+ // ensures that the system is made aware of any use of the tensor's
+ // allocated memory, which is needed for correctness on asynchronous
+ // devices such as GPUs.
+
+ // Allocates a temporary Tensor of the specified type and shape. The
+ // Tensor must not be used after kernel construction is
+ // complete. See comment above.
+ Status allocate_temp(DataType type, const TensorShape& shape,
+ Tensor* out_temp);
+
+ // Allocates a Tensor of the specified type and shape which the Op
+ // plans to maintain as persistent state. out_persistent holds the
+ // PersistentTensor which is the object the caller should store. For
+ // convenience, if out_tensor is non-null then it will be filled in
+ // with a Tensor* pointing to the newly-allocated tensor which the
+ // caller can use instead of calling
+ // out_persistent->AccessTensor. The caller does not own out_tensor
+ // and should not keep a copy of it. See comment above.
+ Status allocate_persistent(DataType type, const TensorShape& shape,
+ PersistentTensor* out_persistent,
+ Tensor** out_tensor);
+
+ // User-supplied configuration of this operation.
+ const NodeDef& def() const { return *def_; }
+
+ // Op registered for this op type.
+ const OpDef& op_def() const { return *op_def_; }
+
+ // For inspecting the inputs to this operation.
+ int num_inputs() const { return input_types_.size(); }
+ DataType input_type(int i) const { return input_types_[i]; }
+ const DataTypeSlice& input_types() const { return input_types_; }
+
+ // For inspecting the outputs expected from this operation.
+ int num_outputs() const { return output_types_.size(); }
+ DataType output_type(int i) const { return output_types_[i]; }
+ const DataTypeSlice& output_types() const { return output_types_; }
+
+ // If expected_inputs == inputs() and expected_outputs == output_types(),
+ // returns OK, else returns INVALID_ARGUMENT with an error message.
+ // Recommended for Ops with dynamic signatures.
+ Status MatchSignature(const DataTypeSlice expected_inputs,
+ const DataTypeSlice expected_outputs);
+
+ // For recording configuration errors during construction.
+ void SetStatus(const Status& status) { status_->Update(status); }
+ const Status& status() const { return *status_; }
+
+ // Look up the attr with name attr_name and set *value to its value. If no
+ // attr with attr_name is found in def(), or the attr does not have
+ // a matching type, a non-ok status will be returned.
+ template <class T>
+ Status GetAttr(const string& attr_name, T* value) const {
+ return GetNodeAttr(def(), attr_name, value);
+ }
+
+ // May be used, e.g., to get GPU handles, etc.
+ // TODO(tucker): Add example usage.
+ DeviceBase* device() const { return device_; }
+
+ // Return the device type.
+ const DeviceType& device_type() const { return device_type_; }
+
+ // If not nullptr, the kernel can instantiate functions defined in
+ // the library. E.g.,
+ // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
+ FunctionLibraryRuntime* function_library() const { return flib_; }
+
+ private:
+ const DeviceType device_type_;
+ DeviceBase* const device_;
+ Allocator* allocator_;
+ const NodeDef* def_;
+ const OpDef* op_def_;
+ FunctionLibraryRuntime* flib_;
+ DataTypeSlice input_types_;
+ DataTypeSlice output_types_;
+ Status* status_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
+};
+
+// TODO(mrry): Consider converting to a random_access_iterator, and upgrading
+// tensorflow::gtl::iterator_range to make the below container classes
+// unnecessary.
+template <typename ListType, typename ElementType>
+class OpArgIterator {
+ public:
+ typedef OpArgIterator<ListType, ElementType> ME;
+ OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
+ bool operator==(const ME& rhs) {
+ DCHECK(list_ == rhs.list_);
+ return i_ == rhs.i_;
+ }
+ bool operator!=(const ME& rhs) {
+ DCHECK(list_ == rhs.list_);
+ return i_ != rhs.i_;
+ }
+ void operator++() { ++i_; }
+ ElementType& operator*() { return (*list_)[i_]; }
+
+ private:
+ const ListType* const list_;
+ int i_;
+};
+
+// Utility class for representing a list of immutable input tensors
+// that are passed to the op as a single named argument.
+class OpInputList {
+ public:
+ typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
+ OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
+ OpInputList(const OpKernelContext* ctx, int start, int stop)
+ : ctx_(ctx), start_(start), stop_(stop) {}
+ OpInputList& operator=(const OpInputList& other) = default;
+ const Tensor& operator[](int i) const;
+ int size() const { return stop_ - start_; }
+ Iterator begin() const { return Iterator(this, 0); }
+ Iterator end() const { return Iterator(this, size()); }
+
+ private:
+ const OpKernelContext* ctx_; // not owned
+ int start_;
+ int stop_;
+};
+
+// Utility class for representing a list of mutable ("ref") input tensors
+// that are passed to the op as a single named argument.
+class OpMutableInputList {
+ public:
+ typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
+ OpMutableInputList(OpKernelContext* ctx, int start, int stop)
+ : ctx_(ctx), start_(start), stop_(stop) {}
+ OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
+ OpMutableInputList& operator=(const OpMutableInputList& other) = default;
+ Tensor at(int i, bool lock_held);
+ mutex* ref_mutex(int i);
+ int size() const { return stop_ - start_; }
+ Iterator begin() const { return Iterator(this, 0); }
+ Iterator end() const { return Iterator(this, size()); }
+
+ private:
+ OpKernelContext* ctx_; // not owned
+ int start_;
+ int stop_;
+};
+
+// Utility class for representing a list of output tensors that are
+// grouped as a single named output.
+class OpOutputList {
+ public:
+ typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
+ OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
+ OpOutputList(OpKernelContext* ctx, int start, int stop)
+ : ctx_(ctx), start_(start), stop_(stop) {}
+ OpOutputList& operator=(const OpOutputList& other) = default;
+ Tensor* operator[](int i);
+ bool required(int i) const;
+ Status allocate(int i, const TensorShape& shape, Tensor** output);
+ void set(int i, const Tensor& tensor);
+ void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
+ int size() const { return stop_ - start_; }
+ Iterator begin() const { return Iterator(this, 0); }
+ Iterator end() const { return Iterator(this, size()); }
+
+ private:
+ OpKernelContext* ctx_; // not owned
+ int start_;
+ int stop_;
+};
+
+// Holds a tensor or tensor reference. For tensor references, we need
+// a mutex to prevent concurrent access to the tensor.
+struct TensorValue {
+ TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
+ TensorValue(Tensor* t) // NOLINT(runtime/explicit)
+ : mutex_if_ref(nullptr),
+ tensor(t) {}
+ TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
+ Tensor* operator->() const { return tensor; }
+ bool is_ref() const { return mutex_if_ref != nullptr; }
+
+ mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref
+ Tensor* tensor;
+};
+
+class OpKernelContext {
+ public:
+ // The first element of a WrappedAllocator is a "base" Allocator and
+ // the second element is that Allocator wrapped by a
+ // TrackingAllocator
+ typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
+
+ // TODO(zhifengc): Do some cleanup of Params.
+ struct Params {
+ // The op kernel being computed.
+ OpKernel* op_kernel = nullptr;
+
+ // The device on which the kernel is running.
+ DeviceBase* device = nullptr;
+
+ bool track_allocations = false;
+ std::function<AllocatorAttributes(int index)> output_alloc_attr = nullptr;
+
+ // Shared resources accessible by this op kernel invocation.
+ ResourceMgr* resource_manager = nullptr;
+
+ // Per-step resources accessible by this op kernel invocation.
+ ResourceMgr* step_resource_manager = nullptr;
+
+ // Mechanism used by this op kernel invocation to communicate with
+ // computations running on other devices.
+ Rendezvous* rendezvous = nullptr;
+
+ // Mechanism used by this op kernel invocation to register a callback
+ // for its cancellation.
+ CancellationManager* cancellation_manager = nullptr;
+
+ // Inputs to this op kernel.
+ const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
+ bool is_input_dead = false;
+
+ const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
+ nullptr;
+
+ // Device contexts.
+ const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts =
+ nullptr;
+ DeviceContext* op_device_context = nullptr;
+
+ // Control-flow op supports.
+ FrameAndIter frame_iter;
+
+ // Function call supports.
+ FunctionCallFrame* call_frame = nullptr;
+ FunctionLibraryRuntime* function_library = nullptr;
+
+ // TensorSliceReaderCache support.
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
+ };
+ explicit OpKernelContext(const Params& params);
+ ~OpKernelContext();
+
+ Env* env() const { return params_.device->env(); }
+
+ // Input/output signature.
+
+ int num_inputs() const { return params_.inputs->size(); }
+ DataType input_dtype(int index) const;
+ int num_outputs() const { return outputs_.size(); }
+ DataType expected_output_dtype(int index) const;
+
+ // Input
+
+ // Returns an immutable input tensor. May only be used for non-Ref
+ // inputs. For Ref inputs use mutable_input below.
+ // REQUIRES: !IsRefType(input_dtype(index))
+ // TODO(mrry): Convert this to return Status.
+ const Tensor& input(int index) const;
+
+ // Returns the named immutable input tensor in "tensor", as defined
+ // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
+ // use mutable_input below.
+ // REQUIRES: !IsRefType(input_dtype(index))
+ // REQUIRES: the named input must not be a list.
+ Status input(const string& name, const Tensor** tensor) const;
+
+ // Returns the named list-valued immutable input in "list", as
+ // defined in the OpDef. If the named output is not list-valued,
+ // returns a one-element list. May only be used for non-Ref
+ // inputs. For Ref inputs use mutable_input below.
+ // REQUIRES: !IsRefType(input_dtype(index))
+ Status input_list(const string& name, OpInputList* list) const;
+
+ // For mutable inputs, use the following together to make sure there
+ // is no concurrent access to mutable_input(), e.g.:
+ // {
+ // Tensor& t = context->mutable_input(index);
+ // mutex_lock lock(*context->input_ref_mutex(index));
+ // // modify the values in t
+ // }
+ // REQUIRES: IsRefType(input_dtype(index))
+ // TODO(mrry): Convert this to return Status.
+ mutex* input_ref_mutex(int index);
+ Status input_ref_mutex(const string& name, mutex** out_mutex);
+
+ // Returns a mutable input tensor. Must be used to access Ref
+ // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
+ // modify the values stored in the Tensor buffer, and modifications
+ // will be visible to other Ops reading the same ref tensor. If
+ // !lock_held the input mutex will be acquired before returning the
+ // Tensor.
+ // TODO(mrry):
+ // Convert this to return Status.
+ Tensor mutable_input(int index, bool lock_held);
+
+ // Returns the named mutable input tensor in "tensor", as defined in
+ // the OpDef. Must be used to access Ref inputs. The values stored
+ // in the Tensor buffer may be modified, and modifications will be
+ // visible to other Ops reading the same ref tensor. If !lock_held
+ // the input mutex will be acquired before returning the Tensor.
+ // REQUIRES: the named input must not be a list.
+ // REQUIRES: the named input must be a ref tensor.
+ Status mutable_input(const string& name, Tensor* tensor, bool lock_held);
+
+ // Returns the named list-valued mutable input in "list", as defined
+ // in the OpDef. If the named intput is not list-valued, returns a
+ // one-element list. Must be used to access Ref inputs. The values
+ // stored in the Tensor buffer may be modified, and modifications
+ // will be visible to other Ops reading the same ref tensor.
+ // REQUIRES: the named input must be a ref tensor.
+ Status mutable_input_list(const string& name, OpMutableInputList* list);
+
+ // Replace the corresponding Ref Input to use the storage buffer
+ // used by tensor. If !lock_held the input mutex will be acquired
+ // before returning the Tensor.
+ // REQUIRES: IsRefType(input_dtype(index)).
+ void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
+
+ // Replace the corresponding named Ref Input to use the storage
+ // buffer used by tensor. If !lock_held the input mutex will be
+ // acquired before returning the Tensor.
+ // REQUIRES: IsRefType(input_dtype(index)).
+ Status replace_ref_input(const string& name, const Tensor& tensor,
+ bool lock_held);
+
+ // Set the output Ref Tensor at output_index to be an alias of the
+ // input Ref Tensor at input_index.
+ // REQUIRES: IsRefType(input_dtype(input_index)).
+ // REQUIRES: IsRefType(output_dtype(output_index)).
+ void forward_ref_input_to_ref_output(int input_index, int output_index);
+
+ // Deletes the Tensor object used as the Ref Input at
+ // input_index. This is not usually necessary and should be used
+ // with caution. If !lock_held the input mutex will be acquired
+ // before returning the Tensor.
+ // REQUIRES: IsRefType(input_dtype(input_index)).
+ void delete_ref_input(int input_index, bool lock_held);
+
+ // Return true if there is input at the given index. An operator has no
+ // input at index if its tensor is null. This is primarily used by the
+ // merge operator.
+ // TODO(mrry): Convert this to return Status.
+ bool has_input(int index) const;
+
+ // Returns true if all inputs are the same shape, otherwise sets the
+ // status to a non-OK value and returns false.
+ // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
+ bool ValidateInputsAreSameShape(OpKernel* op);
+
+ // Output
+
+ // Returns the named list-valued output in "list", as defined in the OpDef.
+ // If the named output is not list-valued, returns a one-element list.
+ Status output_list(const string& name, OpOutputList* list);
+
+ // If output_required(index) returns true, the OpKernel's Compute() method
+ // should call allocate_output(index, ...), set_output(index, ...),
+ // set_output_ref(index, ...), or set the status to a non-ok value.
+ // If it returns false, it may output, but is not required to do so.
+ // TODO(mrry): Convert this to return Status, and implement a string
+ // name version.
+ bool output_required(int index) const {
+ return true; // TODO(josh11b): implement
+ }
+
+ // Allocation of tensors during kernel execution inside the Compute
+ // method:
+ //
+ // There are three methods to allocate Tensors when an Op kernel
+ // executes.
+ //
+ // 1) allocate_persistent. This is only needed for Tensors that will
+ // be stored by the Op between invocations, and it *must* be used
+ // for those Tensors. The call returns a PersistentTensor, and that
+ // is the only object the Op is allowed to hold on to between
+ // invocations. When the Tensor is needed in a subsequent
+ // invocation, it can be retrieved from the PersistentTensor using
+ // the AccessTensor method. This ensures that the system is made
+ // aware of any use of the tensor's allocated memory, which is
+ // needed for correctness on asynchronous devices such as GPUs.
+ //
+ // 2) allocate_output. This should be used to allocate any tensor
+ // that is going to be used as an output from the Op at the end of
+ // the current execution. The caller indicates which output the
+ // Tensor will be assigned to, and the call returns the
+ // newly-allocated Tensor. The Tensor can subsequently be assigned
+ // to during kernel execution, and will be used as the designated
+ // output when the kernel execution completes.
+ //
+ // 3) allocate_temp. This should be used to allocate any scratch
+ // storage that is needed while the kernel is executing, and will
+ // not be retained by the Op.
+ //
+ // In some cases a Tensor needs to be used as an output even though
+ // it was previously allocated elsewhere. The Tensor may have been
+ // passed as an input, or stored in a PersistentTensor during a
+ // previous kernel execution, or allocated earlier in the kernel
+ // execution at a time when it was not known which output it would
+ // be assigned to. In this case the kernel can use set_output or
+ // set_output_ref to indicate that the tensor should be used as the
+ // designated output. It is legal to use any previously-allocated
+ // Tensor as an argument to set_output or set_output_ref, including
+ // Tensors allocated via allocate_temp. There may be a performance
+ // penalty to using a Tensor that was not allocated using
+ // allocate_output. This is because allocate_output uses the
+ // AllocatorAttributes stored in output_alloc_attr for the
+ // designated output. In some cases, using the wrong attributes may
+ // cause an extra copy of the Tensor's buffer.
+
+ // Allocates output for the specified output index with shape.
+ // OpKernelContext retains ownership of the returned pointer. See
+ // comment above.
+ //
+ // If memory allocation fails, returns an error status.
+ //
+ // REQUIRES: !IsRefType(expected_output_dtype(index))
+ Status allocate_output(int index, const TensorShape& shape,
+ Tensor** tensor) TF_MUST_USE_RESULT;
+ Status allocate_output(const string& name, const TensorShape& shape,
+ Tensor** tensor) TF_MUST_USE_RESULT;
+ // The following methods use the supplied attributes instead of
+ // those in output_alloc_attr. The caller is responsible for
+ // ensuring that the attributes are "compatible" with the
+ // output_alloc_attr, e.g. the tensor is allocated on the correct
+ // device. See comment above.
+ Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
+ AllocatorAttributes attr) TF_MUST_USE_RESULT;
+ Status allocate_output(const string& name, const TensorShape& shape,
+ Tensor** tensor,
+ AllocatorAttributes attr) TF_MUST_USE_RESULT;
+
+ // Allocates a temporary Tensor of the specified type and
+ // shape. Devices such as GPUs that enqueue Ops for lazy execution
+ // may retain references to the temporary tensors after the Op's
+ // Compute method has run. See comment above.
+ Status allocate_temp(DataType type, const TensorShape& shape,
+ Tensor* out_temp, AllocatorAttributes attr);
+ Status allocate_temp(DataType type, const TensorShape& shape,
+ Tensor* out_temp) {
+ return allocate_temp(type, shape, out_temp, AllocatorAttributes());
+ }
+
+ // Allocates a Tensor of the specified type and shape which the Op
+ // plans to maintain as persistent state. out_persistent holds the
+ // PersistentTensor which is the object the caller should store. For
+ // convenience, if out_tensor is non-null then it will be filled in
+ // with a Tensor* pointing to the newly-allocated tensor which the
+ // caller can use instead of calling
+ // out_persistent->AccessTensor. The caller does not own out_tensor
+ // and should not keep a copy of it. See comment above.
+ Status allocate_persistent(DataType type, const TensorShape& shape,
+ PersistentTensor* out_persistent,
+ Tensor** out_tensor, AllocatorAttributes attr);
+ Status allocate_persistent(DataType type, const TensorShape& shape,
+ PersistentTensor* out_persistent,
+ Tensor** out_tensor) {
+ return allocate_persistent(type, shape, out_persistent, out_tensor,
+ AllocatorAttributes());
+ }
+
+ // Copies a tensor (allocated by the caller) to the specified output
+ // index. REQUIRES: !IsRefType(expected_output_dtype(index))
+ // REQUIRES: 'tensor' must have the same MemoryType as
+ // output_memory_types[index]. See comment above.
+ // TODO(mrry): Convert this to return Status.
+ void set_output(int index, const Tensor& tensor);
+ Status set_output(const string& name, const Tensor& tensor);
+
+ // To output a reference. Caller retains ownership of mu and tensor_for_ref,
+ // and they must outlive all uses within the step. See comment above.
+ // REQUIRES: IsRefType(expected_output_dtype(index))
+ // TODO(mrry): Convert this to return Status.
+ void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
+ Status set_output_ref(const string& name, mutex* mu, Tensor* tensor_for_ref);
+
+ // Returns nullptr if allocate_output() or set_output() have not been called.
+ // TODO(mrry): Convert this to return Status.
+ Tensor* mutable_output(int index);
+ Status mutable_output(const string& name, Tensor** tensor);
+
+ // Transfers ownership of an output tensor to the caller.
+ // NOTE: For non-reference outputs, the caller takes responsibility
+ // for deletion. For reference outputs, the caller does NOT take
+ // responsibility for deletion.
+ // TODO(mrry): Convert this to return Status.
+ TensorValue release_output(int index);
+ Status release_output(const string& name, TensorValue* value);
+
+ // Records device specific state about how the input tensors were
+ // computed.
+ //
+ // If using the templated function, the type must be a subclass
+ // of DeviceContext.
+ //
+ // Get the DeviceContext used for the index input. Returns nullptr
+ // if no DeviceContext was provided.
+ template <typename T>
+ T* input_device_context(int index);
+ DeviceContext* input_device_context(int index);
+
+ // Return the DeviceContext that should be used for this Op.
+ //
+ // If using the templated function, the type must be a subclass
+ // of DeviceContext.
+ //
+ // Returns nullptr if the device did not provide one.
+ template <typename T>
+ T* op_device_context();
+ DeviceContext* op_device_context() {
+ DeviceContext* ret = params_.op_device_context;
+ if (ret == nullptr) {
+ auto* dev_info = device()->tensorflow_gpu_device_info();
+ if (dev_info) ret = dev_info->default_context;
+ }
+ return ret;
+ }
+
+ AllocatorAttributes input_alloc_attr(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.input_alloc_attrs->size());
+ return (*params_.input_alloc_attrs)[index];
+ }
+
+ AllocatorAttributes output_alloc_attr(int index) const {
+ return params_.output_alloc_attr(index);
+ }
+
+ gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const {
+ mutex_lock lock(mu_);
+ gtl::InlinedVector<WrappedAllocator, 4> retrieved = wrapped_allocators_;
+ return retrieved;
+ }
+
+ // Communication.
+ //
+ // An op kernel communicates with outside environment through
+ // Rendezvous Send() and Recv().
+ Rendezvous* rendezvous() const { return params_.rendezvous; }
+
+ // Function call support.
+ //
+ // If this kernel invocation is within a function execution,
+ // call_frame() returns the call frame for the function call.
+ FunctionCallFrame* call_frame() const { return params_.call_frame; }
+
+ // If not nullptr, the kernel invoke functions defined in the
+ // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
+ FunctionLibraryRuntime* function_library() const {
+ return params_.function_library;
+ }
+
+ // Shared resources accessible to this kernel.
+ ResourceMgr* resource_manager() const { return params_.resource_manager; }
+
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
+ return params_.slice_reader_cache;
+ }
+
+ // Execution.
+ //
+ // OpKernels can use these eigen devices to carry out their
+ // numerical computation.
+ const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
+ return *device()->eigen_cpu_device();
+ }
+ const Eigen::GpuDevice& eigen_gpu_device() const {
+ return eigen_gpu_device_->device();
+ }
+ template <typename EigenDeviceType>
+ const EigenDeviceType& eigen_device() const;
+
+ // Error handling.
+
+ // If expected_inputs == inputs() and expected_outputs == output_types(),
+ // returns OK, else returns INVALID_ARGUMENT with an error message.
+ // Recommended for Ops with dynamic signatures, where validation can only
+ // be performed at runtime.
+ Status MatchSignature(const DataTypeSlice expected_inputs,
+ const DataTypeSlice expected_outputs);
+
+ // An OpKernel should call SetStatus() if Compute() encounters an
+ // error.
+ void SetStatus(const Status& status) { status_.Update(status); }
+ const Status& status() const { return status_; }
+
+ // Cancellation.
+ //
+ // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an
+ // example of how to use this API.
+ CancellationManager* cancellation_manager() const {
+ return params_.cancellation_manager;
+ }
+
+ // Other accessors.
+
+ // For control flow.
+ FrameAndIter frame_iter() const { return params_.frame_iter; }
+ bool is_input_dead() const { return params_.is_input_dead; }
+ bool* is_output_dead() { return &is_output_dead_; }
+
+ // May be used, e.g., to get GPU handles, etc.
+ // TODO(tucker): Add example usage.
+ DeviceBase* device() const { return params_.device; }
+
+ // Access to list of temporary tensors.
+ int num_temps();
+ Tensor* temp(int index);
+
+ // Access to information about whether each output was newly
+ // allocated or copied from an existing tensor
+ AllocationType output_allocation_type(int index) const {
+ return output_allocation_types_[index];
+ }
+
+ private:
+ Allocator* get_allocator(AllocatorAttributes attr) {
+ Allocator* allocator = params_.device->GetAllocator(attr);
+ if (params_.track_allocations) {
+ mutex_lock lock(mu_);
+ for (const auto& wrapped : wrapped_allocators_) {
+ if (wrapped.first == allocator) {
+ return wrapped.second;
+ }
+ }
+ TrackingAllocator* wrapped_allocator = new TrackingAllocator(allocator);
+ wrapped_allocators_.push_back(
+ std::make_pair(allocator, wrapped_allocator));
+ return wrapped_allocator;
+ } else {
+ return allocator;
+ }
+ }
+
+ // Per-step resource manager for use by white-listed internal ops.
+ friend class TemporaryVariableOp;
+ friend class DestroyTemporaryVariableOp;
+ ResourceMgr* step_resource_manager() const {
+ return params_.step_resource_manager;
+ }
+
+ // Internal common method used when allocating tensor memory
+ Status allocate_tensor(DataType type, const TensorShape& shape,
+ Tensor* out_tensor, AllocatorAttributes attr);
+
+ // This is called by PersistentTensor::AccessTensor whenever the
+ // wrapped tensor is retrieved, to ensure the runtime knows that the
+ // Tensor is being accessed within an Op. This is necessary for
+ // memory safety of devices like GPUs that queue Ops for
+ // asynchronous execution after the Compute() method completes.
+ friend class PersistentTensor;
+ void NotifyUseOfPersistentTensor(const Tensor& tensor);
+
+ Status status_;
+ Params params_; // immutable after construction.
+ const PerOpGpuDevice* eigen_gpu_device_; // owned, with a per-op
+ // wrapped allocator
+ mutable mutex mu_; // mutable so const accessors can acquire the lock
+ gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
+ gtl::InlinedVector<TensorValue, 4> outputs_;
+ gtl::InlinedVector<AllocationType, 4> output_allocation_types_;
+ gtl::InlinedVector<Tensor*, 4> temp_tensors_;
+ bool is_output_dead_ = false;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
+};
+
+// Register your OpKernel by specifying the Op's name, the device the
+// kernel runs on, any type attr constraints for this kernel, any
+// host-memory args, and the class to instantiate. Examples:
+//
+// // A kernel that supports all types.
+// REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
+//
+// // The following are equivalent ways of specifying that the kernel only
+// // works if the "T" type attr is set to DT_FLOAT.
+// REGISTER_KERNEL_BUILDER(
+// Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+// SubOp<float>);
+// // (You would then repeat this for every type supported by "Sub".)
+//
+// // This form allows you to specify a list of types as the constraint.
+// REGISTER_KERNEL_BUILDER(Name("Sub")
+// .Device(DEVICE_CPU)
+// .TypeConstraint("T", {DT_FLOAT}),
+// SubOp<float>);
+//
+// // A kernel that expects one of the input tensors in host memory.
+// REGISTER_KERNEL_BUILDER(
+// Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
+//
+// See kernel_def_builder for details.
+
+// Instantiate an OpKernel that has been registered. Returns nullptr
+// if no operation for that type of device / input signature combination
+// (and a NOT_FOUND *status), or there is an error in construction (and
+// an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
+// of the returned pointer.
+// EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
+// REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
+std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
+ DeviceBase* device,
+ Allocator* allocator,
+ const NodeDef& def, Status* status);
+Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
+ Allocator* allocator, FunctionLibraryRuntime* flib,
+ const NodeDef& def, OpKernel** kernel);
+
+// Returns into 'device_types' the subset of prioritized_types that this
+// binary has registered for the given NodeDef.
+//
+// REQUIRES: * 'device_types' is not nullptr.
+// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
+Status SupportedDeviceTypesForNode(
+ const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
+ DeviceTypeVector* device_types);
+
+// Returns into *{input,output}_memory_types the memory type of each
+// {input,output} tensor.
+//
+// REQUIRES: * '*_memory_types' is not nullptr.
+// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
+Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef,
+ const OpDef& op_def,
+ const NameRangeMap& input_name_map,
+ const NameRangeMap& output_name_map,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types);
+
+Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
+ DeviceType device_type, const NodeDef& ndef,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types);
+
+// Call once after Op registration has completed.
+Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry);
+
+// -----------------------------------------------------------------------------
+// OpKernel registration implementation follows, please ignore.
+
+// Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
+namespace register_kernel {
+typedef ::tensorflow::KernelDefBuilder Name;
+} // namespace register_kernel
+
+#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
+ REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
+
+#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
+ REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
+
+#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
+ static ::tensorflow::kernel_factory::OpKernelRegistrar \
+ registrar__body__##ctr##__object( \
+ ::tensorflow::register_kernel::kernel_builder.Build(), \
+ +[](::tensorflow::OpKernelConstruction* context) \
+ -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
+
+namespace kernel_factory {
+
+class OpKernelRegistrar {
+ public:
+ typedef OpKernel* (*Factory)(OpKernelConstruction*);
+ OpKernelRegistrar(const KernelDef* kernel_def, Factory factory);
+};
+
+} // namespace kernel_factory
+
+// -----------------------------------------------------------------------------
+// Template and inline method implementations, please ignore
+
+inline DataType OpKernelContext::input_dtype(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ const TensorValue& value((*params_.inputs)[index]);
+ if (value.is_ref()) {
+ return MakeRefType(value->dtype());
+ } else {
+ return value->dtype();
+ }
+}
+
+inline DataType OpKernelContext::expected_output_dtype(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.op_kernel->output_types().size());
+ return params_.op_kernel->output_type(index);
+}
+
+inline const Tensor& OpKernelContext::input(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK(!(*params_.inputs)[index].is_ref());
+ return *((*params_.inputs)[index].tensor);
+}
+
+inline Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK((*params_.inputs)[index].is_ref());
+ // return a copy of the Ref acquired while holding the mutex
+ if (lock_held) {
+ return *((*params_.inputs)[index].tensor);
+ } else {
+ mutex_lock l(*input_ref_mutex(index));
+ return *((*params_.inputs)[index].tensor);
+ }
+}
+
+inline void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
+ bool lock_held) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK((*params_.inputs)[index].is_ref());
+ // should only modify the tensor while holding the mutex
+ if (lock_held) {
+ *(*params_.inputs)[index].tensor = tensor;
+ } else {
+ mutex_lock l(*input_ref_mutex(index));
+ *(*params_.inputs)[index].tensor = tensor;
+ }
+}
+
+inline void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
+ int output_index) {
+ DCHECK_GE(input_index, 0);
+ DCHECK_LT(input_index, params_.inputs->size());
+ DCHECK((*params_.inputs)[input_index].is_ref());
+ set_output_ref(output_index, (*params_.inputs)[input_index].mutex_if_ref,
+ (*params_.inputs)[input_index].tensor);
+}
+
+inline void OpKernelContext::delete_ref_input(int index, bool lock_held) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK((*params_.inputs)[index].is_ref());
+ // should only modify the tensor while holding the mutex
+ if (lock_held) {
+ delete (*params_.inputs)[index].tensor;
+ } else {
+ mutex_lock l(*input_ref_mutex(index));
+ delete (*params_.inputs)[index].tensor;
+ }
+}
+
+// no input if tensor == nullptr.
+inline bool OpKernelContext::has_input(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ return (*params_.inputs)[index].tensor != nullptr;
+}
+
+inline mutex* OpKernelContext::input_ref_mutex(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK((*params_.inputs)[index].is_ref());
+ return (*params_.inputs)[index].mutex_if_ref;
+}
+
+inline Status OpKernelContext::allocate_output(int index,
+ const TensorShape& shape,
+ Tensor** output) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, num_outputs());
+ DCHECK(params_.output_alloc_attr);
+ AllocatorAttributes attr = params_.output_alloc_attr(index);
+ return allocate_output(index, shape, output, attr);
+}
+
+inline Status OpKernelContext::allocate_tensor(DataType type,
+ const TensorShape& shape,
+ Tensor* out_tensor,
+ AllocatorAttributes attr) {
+ Allocator* a = get_allocator(attr);
+ Tensor new_tensor(a, type, shape);
+
+ if (!new_tensor.IsInitialized() && shape.num_elements() > 0) {
+ return errors::ResourceExhausted("OOM when allocating tensor with shape",
+ shape.DebugString());
+ }
+ *out_tensor = new_tensor;
+ return Status::OK();
+}
+
+inline Status OpKernelContext::allocate_output(int index,
+ const TensorShape& shape,
+ Tensor** output,
+ AllocatorAttributes attr) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ // Record the fact that this output tensor was allocated by the Op
+ DCHECK_LT(index, output_allocation_types_.size());
+ output_allocation_types_[index] = AT_ALLOCATED;
+ const DataType type = params_.op_kernel->output_type(index);
+ DCHECK(!IsRefType(type));
+ DCHECK(mutable_output(index) == nullptr);
+ Tensor* output_tensor = new Tensor();
+ Status s = allocate_tensor(type, shape, output_tensor, attr);
+ if (s.ok()) {
+ outputs_[index] = TensorValue(output_tensor);
+ *output = outputs_[index].tensor;
+ }
+ return s;
+}
+
+inline Status OpKernelContext::allocate_temp(DataType type,
+ const TensorShape& shape,
+ Tensor* out_temp,
+ AllocatorAttributes attr) {
+ Status s = allocate_tensor(type, shape, out_temp, attr);
+ if (s.ok()) {
+ if (params_.device->SaveTemporaryTensors()) {
+ // keep a reference to the underlying memory around
+ temp_tensors_.push_back(new Tensor(*out_temp));
+ }
+ }
+ return s;
+}
+
+inline Status OpKernelContext::allocate_persistent(
+ DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
+ Tensor** out_tensor, AllocatorAttributes attr) {
+ // TODO(misard) add specific memory tracking for persistent tensors
+ Tensor persistent;
+ Status s = allocate_tensor(type, shape, &persistent, attr);
+ if (s.ok()) {
+ *out_persistent = PersistentTensor(persistent);
+ // This call saves a reference to the newly-allocated tensor if we
+ // are saving temporary tensors
+ Tensor* allocated = out_persistent->AccessTensor(this);
+ if (out_tensor) {
+ *out_tensor = allocated;
+ }
+ }
+ return s;
+}
+
+inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
+ if (t.IsInitialized() && params_.device->SaveTemporaryTensors()) {
+ // keep a reference to the underlying memory around
+ temp_tensors_.push_back(new Tensor(t));
+ }
+}
+
+inline int OpKernelContext::num_temps() { return temp_tensors_.size(); }
+
+inline Tensor* OpKernelContext::temp(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, temp_tensors_.size());
+ return temp_tensors_[index];
+}
+
+inline void OpKernelContext::set_output(int index, const Tensor& tensor) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ // Record the fact that this output tensor was set by the Op
+ DCHECK_LT(index, output_allocation_types_.size());
+ output_allocation_types_[index] = AT_EXISTING;
+ DCHECK(!IsRefType(params_.op_kernel->output_type(index)));
+ DCHECK_EQ(mutable_output(index), nullptr);
+ outputs_[index] = TensorValue(new Tensor(tensor));
+}
+
+inline void OpKernelContext::set_output_ref(int index, mutex* mu,
+ Tensor* tensor_for_ref) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ // Record the fact that this output tensor was set by reference the Op
+ DCHECK_LT(index, output_allocation_types_.size());
+ output_allocation_types_[index] = AT_REF;
+ DCHECK(IsRefType(params_.op_kernel->output_type(index)));
+ outputs_[index] = TensorValue(mu, tensor_for_ref);
+}
+
+inline Tensor* OpKernelContext::mutable_output(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ return outputs_[index].tensor;
+}
+
+inline TensorValue OpKernelContext::release_output(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ TensorValue value = outputs_[index];
+ outputs_[index] = TensorValue();
+ return value;
+}
+
+template <typename T>
+T* OpKernelContext::op_device_context() {
+ static_assert(std::is_base_of<DeviceContext, T>::value,
+ "T is not a subclass of DeviceContext");
+ return static_cast<T*>(op_device_context());
+}
+
+template <typename T>
+T* OpKernelContext::input_device_context(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.input_device_contexts->size());
+ static_assert(std::is_base_of<DeviceContext, T>::value,
+ "T is not a subclass of DeviceContext");
+ return static_cast<T*>((*params_.input_device_contexts)[index]);
+}
+
+inline DeviceContext* OpKernelContext::input_device_context(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.input_device_contexts->size());
+ return (*params_.input_device_contexts)[index];
+}
+
+inline const Tensor& OpInputList::operator[](int i) const {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->input(start_ + i);
+}
+
+inline mutex* OpMutableInputList::ref_mutex(int i) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->input_ref_mutex(start_ + i);
+}
+
+inline Tensor OpMutableInputList::at(int i, bool lock_held) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->mutable_input(start_ + i, lock_held);
+}
+
+inline Tensor* OpOutputList::operator[](int i) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->mutable_output(start_ + i);
+}
+
+inline bool OpOutputList::required(int i) const {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->output_required(start_ + i);
+}
+
+inline Status OpOutputList::allocate(int i, const TensorShape& shape,
+ Tensor** output) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->allocate_output(start_ + i, shape, output);
+}
+
+inline void OpOutputList::set(int i, const Tensor& tensor) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ ctx_->set_output(start_ + i, tensor);
+}
+
+inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ ctx_->set_output_ref(i, mu, tensor_for_ref);
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
new file mode 100644
index 0000000000..9400ef24f8
--- /dev/null
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -0,0 +1,803 @@
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include <memory>
+#include <vector>
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include <gtest/gtest.h>
+
+class DummyKernel : public tensorflow::OpKernel {
+ public:
+ explicit DummyKernel(tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+ void Compute(tensorflow::OpKernelContext* context) override {}
+};
+
+// Test that registration works outside a namespace.
+REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8");
+REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),
+ DummyKernel);
+
+namespace foo {
+bool match_signature_ = false;
+
+// Test that registration works inside a different namespace.
+class TestOp2 : public ::tensorflow::OpKernel {
+ public:
+ explicit TestOp2(::tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {
+ ::tensorflow::Status status = context->MatchSignature(
+ {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32});
+ match_signature_ = status.ok();
+ context->SetStatus(status);
+ }
+ void Compute(::tensorflow::OpKernelContext* context) override {}
+};
+
+REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type");
+REGISTER_KERNEL_BUILDER(Name("Test2")
+ .Device(::tensorflow::DEVICE_GPU)
+ .HostMemory("i")
+ .HostMemory("o"),
+ TestOp2);
+} // namespace foo
+
+namespace tensorflow {
+
+// Two operations with the same name but different devices.
+REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type");
+
+class TestOp3Cpu : public tensorflow::OpKernel {
+ public:
+ explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override {}
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu);
+
+namespace {
+
+class TestOp3Gpu : public tensorflow::OpKernel {
+ public:
+ explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override {}
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu);
+
+// An Op registered for both
+REGISTER_OP("Test4").Input("i: float").Output("o: float");
+REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
+
+static std::vector<DeviceType> DeviceTypes() {
+ return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
+}
+
+class OpKernelTest : public ::testing::Test {
+ public:
+ OpKernelTest() : device_(Env::Default()) {}
+
+ protected:
+ NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) {
+ NodeDefBuilder builder(op_type + "-op", op_type);
+ for (DataType dt : inputs) {
+ builder.Input(FakeInput(dt));
+ }
+ NodeDef node_def;
+ TF_CHECK_OK(builder.Finalize(&node_def));
+ return node_def;
+ }
+
+ void ExpectEqual(const string& what, const DataTypeVector& expected,
+ const DataTypeVector& observed) {
+ EXPECT_EQ(expected.size(), observed.size()) << what;
+ const int size = std::min(expected.size(), observed.size());
+ for (int i = 0; i < size; ++i) {
+ bool match = TypesCompatible(expected[i], observed[i]);
+ EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i]
+ << ", observed: " << observed[i];
+ }
+ }
+
+ void ExpectSuccess(const string& op_type, DeviceType device_type,
+ const DataTypeVector& inputs,
+ const DataTypeVector& outputs) {
+ Status status;
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(device_type, &device_, cpu_allocator(),
+ CreateNodeDef(op_type, inputs), &status));
+ EXPECT_TRUE(status.ok()) << status;
+ EXPECT_TRUE(op != nullptr);
+ if (op != nullptr) {
+ ExpectEqual("inputs", op->input_types(), inputs);
+ ExpectEqual("outputs", op->output_types(), outputs);
+ }
+ }
+
+ void ExpectFailure(const string& ascii_node_def, DeviceType device_type,
+ error::Code code) {
+ NodeDef node_def;
+ protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def);
+ Status status;
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ device_type, &device_, cpu_allocator(), node_def, &status));
+ EXPECT_TRUE(op == nullptr);
+ EXPECT_FALSE(status.ok());
+ if (!status.ok()) {
+ LOG(INFO) << "Status message: " << status.error_message();
+ EXPECT_EQ(code, status.code());
+ }
+ }
+
+ private:
+ DeviceBase device_;
+};
+
+TEST_F(OpKernelTest, SuccessCpu) {
+ ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8});
+ ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8});
+}
+
+TEST_F(OpKernelTest, SuccessGpu) {
+ foo::match_signature_ = false;
+ ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32});
+ EXPECT_TRUE(foo::match_signature_);
+}
+
+TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
+ ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {});
+ ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {});
+}
+
+TEST_F(OpKernelTest, CpuTypeRegistered) {
+ NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(1, devs.size());
+ EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
+}
+
+TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
+ {
+ // Try a node def of an op that is registered for a specific type
+ // only on CPU.
+ NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(1, devs.size());
+ EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
+ }
+ {
+ // Try a node def of an op that is registered for a specific type
+ // only on GPU.
+ NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(1, devs.size());
+ EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
+ }
+ {
+ // Try a node def of an op that is only registered for other types.
+ NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(0, devs.size());
+ }
+
+ {
+ // Try a node def of an op that is registered for both.
+ NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(2, devs.size());
+ EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
+ EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]);
+ }
+}
+
+TEST_F(OpKernelTest, NotFound) {
+ const auto not_found = error::NOT_FOUND;
+ // Something with that op type name exists, but only with a
+ // different DeviceType.
+ ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(),
+ DEVICE_GPU, not_found);
+ ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(),
+ DEVICE_GPU, not_found);
+ ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(),
+ DEVICE_CPU, not_found);
+
+ // No kernel with that signature registered.
+ ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(),
+ DEVICE_GPU, not_found);
+
+ // Nothing with that op type name exists.
+ ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found);
+ ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found);
+}
+
+TEST_F(OpKernelTest, TooFewInputs) {
+ const auto invalid = error::INVALID_ARGUMENT;
+ NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
+ node_def.clear_input();
+ ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
+ node_def.add_input("a");
+ ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
+}
+
+TEST_F(OpKernelTest, TooManyInputs) {
+ const auto invalid = error::INVALID_ARGUMENT;
+ NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
+ node_def.add_input("c");
+ ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
+}
+
+TEST_F(OpKernelTest, MatchSignatureFailes) {
+ const auto invalid = error::INVALID_ARGUMENT;
+ foo::match_signature_ = true;
+ ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU,
+ invalid);
+ EXPECT_FALSE(foo::match_signature_);
+}
+
+class DummyDevice : public DeviceBase {
+ public:
+ DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
+ bool SaveTemporaryTensors() const override { return save_; }
+ Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
+ return cpu_allocator();
+ }
+
+ private:
+ bool save_;
+};
+
+TEST_F(OpKernelTest, SaveTempFalse) {
+ Env* env = Env::Default();
+ OpKernelContext::Params params;
+ params.device = new DummyDevice(env, false);
+ Status status;
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
+ CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status));
+ EXPECT_TRUE(status.ok());
+ params.op_kernel = op.get();
+ OpKernelContext* ctx = new OpKernelContext(params);
+
+ Tensor t;
+ EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
+
+ EXPECT_EQ(0, ctx->num_temps());
+
+ delete ctx;
+ delete params.device;
+}
+
+TEST_F(OpKernelTest, SaveTempTrue) {
+ Env* env = Env::Default();
+ OpKernelContext::Params params;
+ params.device = new DummyDevice(env, true);
+ Status status;
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
+ CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status));
+ EXPECT_TRUE(status.ok());
+ params.op_kernel = op.get();
+ OpKernelContext* ctx = new OpKernelContext(params);
+
+ Tensor t;
+ EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
+
+ EXPECT_EQ(1, ctx->num_temps());
+
+ delete ctx;
+ delete params.device;
+}
+
+class OpKernelBuilderTest : public ::testing::Test {
+ protected:
+ // Each attr is described by a "name|type|value".
+ NodeDef CreateNodeDef(const string& op_type,
+ const std::vector<string>& attrs) {
+ NodeDef node_def;
+ node_def.set_name(op_type + "-op");
+ node_def.set_op(op_type);
+ for (const string& attr_desc : attrs) {
+ std::vector<string> parts = str_util::Split(attr_desc, '|');
+ CHECK_EQ(parts.size(), 3);
+ AttrValue attr_value;
+ CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc;
+ node_def.mutable_attr()->insert(
+ AttrValueMap::value_type(parts[0], attr_value));
+ }
+ return node_def;
+ }
+
+ std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type,
+ DeviceType device_type,
+ const std::vector<string>& attrs,
+ DataTypeSlice input_types = {}) {
+ Status status;
+ NodeDef def = CreateNodeDef(op_type, attrs);
+ for (size_t i = 0; i < input_types.size(); ++i) {
+ def.add_input("a:0");
+ }
+
+ Env* env = Env::Default();
+ DeviceBase device(env);
+
+ // Test CreateOpKernel()
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(device_type, &device, cpu_allocator(), def, &status));
+ EXPECT_TRUE(status.ok()) << status;
+ EXPECT_TRUE(op != nullptr);
+ if (op != nullptr) {
+ EXPECT_EQ(input_types.size(), op->num_inputs());
+ EXPECT_EQ(0, op->num_outputs());
+ }
+
+ // Test SupportedDeviceTypesForNode()
+ DeviceTypeVector devices;
+ EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
+ bool found = false;
+ for (DeviceType dt : devices) {
+ if (dt == device_type) {
+ found = true;
+ }
+ }
+ EXPECT_TRUE(found) << "Missing " << device_type << " from "
+ << devices.size() << " devices.";
+
+ // In case the caller wants to use the OpKernel
+ return op;
+ }
+
+ void ExpectFailure(const string& op_type, DeviceType device_type,
+ const std::vector<string>& attrs, error::Code code) {
+ Status status;
+ const NodeDef def = CreateNodeDef(op_type, attrs);
+ Env* env = Env::Default();
+ DeviceBase device(env);
+
+ // Test CreateOpKernel().
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(device_type, &device, cpu_allocator(), def, &status));
+ EXPECT_TRUE(op == nullptr);
+ EXPECT_FALSE(status.ok());
+ if (!status.ok()) {
+ LOG(INFO) << "Status message: " << status.error_message();
+ EXPECT_EQ(code, status.code());
+
+ // Test SupportedDeviceTypesForNode().
+ DeviceTypeVector devices;
+ if (errors::IsNotFound(status)) {
+ EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
+ for (DeviceType dt : devices) {
+ EXPECT_NE(dt, device_type);
+ }
+ } else {
+ Status status2 =
+ SupportedDeviceTypesForNode(DeviceTypes(), def, &devices);
+ EXPECT_EQ(status.code(), status2.code());
+ }
+ }
+ }
+};
+
+REGISTER_OP("BuildCPU");
+REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderCPU) {
+ ExpectSuccess("BuildCPU", DEVICE_CPU, {});
+ ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
+}
+
+REGISTER_OP("BuildGPU");
+REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderGPU) {
+ ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND);
+ ExpectSuccess("BuildGPU", DEVICE_GPU, {});
+}
+
+REGISTER_OP("BuildBoth");
+REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderBoth) {
+ ExpectSuccess("BuildBoth", DEVICE_CPU, {});
+ ExpectSuccess("BuildBoth", DEVICE_GPU, {});
+}
+
+REGISTER_OP("BuildTypeAttr").Attr("T: type");
+REGISTER_KERNEL_BUILDER(Name("BuildTypeAttr")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
+ ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"});
+ ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"},
+ error::NOT_FOUND);
+ ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
+ ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"},
+ error::INVALID_ARGUMENT);
+}
+
+REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)");
+REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<bool>("T"),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
+ ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
+ ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
+ ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
+ {"T|list(type)|[DT_BOOL, DT_BOOL]"});
+ ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
+ error::NOT_FOUND);
+ ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
+ ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
+ error::INVALID_ARGUMENT);
+}
+
+REGISTER_OP("DuplicateKernel");
+REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
+ DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, DuplicateKernel) {
+ const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
+ DeviceTypeVector devs;
+ Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Multiple OpKernel registrations match NodeDef"));
+
+ ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
+}
+
+REGISTER_OP("DuplicateKernelForT").Attr("T: type");
+REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
+ const NodeDef ndef =
+ CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
+ DeviceTypeVector devs;
+ Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Multiple OpKernel registrations match NodeDef"));
+
+ ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
+ error::INVALID_ARGUMENT);
+ ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"},
+ error::NOT_FOUND);
+}
+
+REGISTER_OP("BadConstraint").Attr("dtype: type");
+REGISTER_KERNEL_BUILDER(Name("BadConstraint")
+ .Device(DEVICE_CPU)
+ // Mistake: "T" should be "dtype".
+ .TypeConstraint<float>("T"),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BadConstraint) {
+ const NodeDef ndef = CreateNodeDef("BadConstraint", {});
+ DeviceTypeVector devs;
+ Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("OpKernel 'BadConstraint' has constraint on attr "
+ "'T' not in NodeDef"));
+
+ ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"},
+ error::INVALID_ARGUMENT);
+}
+
+class GetAttrKernel : public ::tensorflow::OpKernel {
+ public:
+ explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) {
+ string attr_name;
+ OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name));
+
+ status.emplace_back("s", context->GetAttr(attr_name, &s));
+ status.emplace_back("s_list", context->GetAttr(attr_name, &s_list));
+ status.emplace_back("i", context->GetAttr(attr_name, &i));
+ status.emplace_back("i_list", context->GetAttr(attr_name, &i_list));
+ status.emplace_back("i32", context->GetAttr(attr_name, &i32));
+ status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list));
+ status.emplace_back("f", context->GetAttr(attr_name, &f));
+ status.emplace_back("f_list", context->GetAttr(attr_name, &f_list));
+ status.emplace_back("b", context->GetAttr(attr_name, &b));
+ status.emplace_back("b_list", context->GetAttr(attr_name, &b_list));
+ status.emplace_back("type", context->GetAttr(attr_name, &type));
+ status.emplace_back("type_list", context->GetAttr(attr_name, &type_list));
+ status.emplace_back("type_vector",
+ context->GetAttr(attr_name, &type_vector));
+ status.emplace_back("shape_proto",
+ context->GetAttr(attr_name, &shape_proto));
+ status.emplace_back("shape_proto_list",
+ context->GetAttr(attr_name, &shape_proto_list));
+ status.emplace_back("shape", context->GetAttr(attr_name, &shape));
+ status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list));
+ }
+ void Compute(::tensorflow::OpKernelContext* context) override {}
+
+ void ExpectOk(std::initializer_list<string> keys) {
+ for (const auto& key_status : status) {
+ // Only the status for keys in "keys" should be ok().
+ bool in_keys = false;
+ for (const string& key : keys) {
+ if (key_status.first == key) {
+ in_keys = true;
+ }
+ }
+ EXPECT_EQ(in_keys, key_status.second.ok())
+ << "key_status: " << key_status.first << ", " << key_status.second;
+ }
+ }
+
+ string s;
+ std::vector<string> s_list;
+ int64 i;
+ std::vector<int64> i_list;
+ int32 i32;
+ std::vector<int32> i32_list;
+ float f;
+ std::vector<float> f_list;
+ bool b;
+ std::vector<bool> b_list;
+ DataType type;
+ std::vector<DataType> type_list;
+ DataTypeVector type_vector;
+ TensorShapeProto shape_proto;
+ std::vector<TensorShapeProto> shape_proto_list;
+ TensorShape shape;
+ std::vector<TensorShape> shape_list;
+ std::vector<std::pair<string, Status>> status;
+};
+
+class GetAttrTest : public OpKernelBuilderTest {};
+
+REGISTER_OP("GetAttrStringList")
+ .Attr("attr_name: string")
+ .Attr("a: list(string)");
+REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU),
+ GetAttrKernel);
+
+TEST_F(GetAttrTest, StringList) {
+ std::unique_ptr<OpKernel> op_kernel =
+ ExpectSuccess("GetAttrStringList", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"s_list"});
+ EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list);
+
+ op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU,
+ {"attr_name|string|'b'", "a|list(string)|['baz']"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({});
+ EXPECT_TRUE(get_attr_kernel->s_list.empty());
+}
+
+REGISTER_OP("GetAttrInt")
+ .Attr("attr_name: string")
+ .Attr("a: int")
+ .Attr("b: list(int)");
+REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel);
+
+TEST_F(GetAttrTest, Int) {
+ std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
+ "GetAttrInt", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"i", "i32"});
+ EXPECT_EQ(35, get_attr_kernel->i);
+ EXPECT_EQ(35, get_attr_kernel->i32);
+
+ op_kernel = ExpectSuccess(
+ "GetAttrInt", DEVICE_CPU,
+ {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"i_list", "i32_list"});
+ EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list);
+ EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list);
+
+ // 8589934592 == 2^33, too big to fit in an int32
+ op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|int|8589934592",
+ "b|list(int)|[-8589934592]"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"i"}); // no i32
+ EXPECT_EQ(8589934592ll, get_attr_kernel->i);
+ for (const auto& key_status : get_attr_kernel->status) {
+ if (key_status.first == "i32") {
+ EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
+ EXPECT_EQ("Attr a has value 8589934592 out of range for an int32",
+ key_status.second.error_message());
+ }
+ }
+
+ op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
+ {"attr_name|string|'b'", "a|int|8589934592",
+ "b|list(int)|[-8589934592]"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"i_list"}); // no i32_list
+ EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list);
+ for (const auto& key_status : get_attr_kernel->status) {
+ if (key_status.first == "i32_list") {
+ EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
+ EXPECT_EQ("Attr b has value -8589934592 out of range for an int32",
+ key_status.second.error_message());
+ }
+ }
+}
+
+REGISTER_OP("GetAttrShape")
+ .Attr("attr_name: string")
+ .Attr("a: shape")
+ .Attr("b: list(shape)");
+REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel);
+
+TEST_F(GetAttrTest, Shape) {
+ std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
+ "GetAttrShape", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }",
+ "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"shape", "shape_proto"});
+ EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }");
+ EXPECT_EQ("[3]", get_attr_kernel->shape.ShortDebugString());
+
+ op_kernel = ExpectSuccess(
+ "GetAttrShape", DEVICE_CPU,
+ {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }",
+ "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"});
+ ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size());
+ EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(),
+ "dim { size: 2 }");
+ EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(),
+ "dim { size: 4 }");
+ ASSERT_EQ(2, get_attr_kernel->shape_list.size());
+ EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].ShortDebugString());
+ EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].ShortDebugString());
+}
+
+REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type");
+REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel);
+
+TEST_F(GetAttrTest, Type) {
+ std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
+ "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"type"});
+ EXPECT_EQ(DT_FLOAT, get_attr_kernel->type);
+}
+
+REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)");
+REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU),
+ GetAttrKernel);
+
+TEST_F(GetAttrTest, TypeList) {
+ std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
+ "GetAttrTypeList", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+
+ get_attr_kernel->ExpectOk({"type_list", "type_vector"});
+ ASSERT_EQ(2, get_attr_kernel->type_list.size());
+ EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]);
+ EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]);
+ ASSERT_EQ(2, get_attr_kernel->type_vector.size());
+ EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]);
+ EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]);
+}
+
+REGISTER_OP("HostMemoryTest")
+ .Input("a: float")
+ .Input("b: T")
+ .Input("c: N * string")
+ .Output("o: N * T")
+ .Attr("T: type")
+ .Attr("N: int");
+REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
+ .Device(DEVICE_GPU)
+ .HostMemory("a")
+ .HostMemory("c")
+ .HostMemory("o"),
+ DummyKernel);
+
+TEST(MemoryTypesForNode, Simple) {
+ NodeDef node_def;
+ ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest")
+ .Input(FakeInput())
+ .Input(FakeInput(DT_BOOL))
+ .Input(FakeInput(3))
+ .Finalize(&node_def));
+ MemoryTypeVector input, output;
+
+ EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
+ &input, &output));
+ EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input);
+ EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output);
+
+ EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
+ &input, &output));
+ EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, HOST_MEMORY}),
+ input);
+ EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output);
+}
+
+class BaseKernel : public ::tensorflow::OpKernel {
+ public:
+ explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(::tensorflow::OpKernelContext* context) override {}
+ virtual int Which() const = 0;
+};
+
+template <int WHICH>
+class LabeledKernel : public BaseKernel {
+ public:
+ using BaseKernel::BaseKernel;
+ int Which() const override { return WHICH; }
+};
+
+class LabelTest : public OpKernelBuilderTest {};
+
+REGISTER_OP("LabeledKernel");
+REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU),
+ LabeledKernel<0>);
+REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"),
+ LabeledKernel<1>);
+REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
+ LabeledKernel<2>);
+REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
+ LabeledKernel<3>);
+
+TEST_F(LabelTest, Default) {
+ std::unique_ptr<OpKernel> op_kernel =
+ ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
+ auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
+ EXPECT_EQ(0, get_labeled_kernel->Which());
+}
+
+TEST_F(LabelTest, Specified) {
+ std::unique_ptr<OpKernel> op_kernel =
+ ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
+ auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
+ EXPECT_EQ(1, get_labeled_kernel->Which());
+}
+
+TEST_F(LabelTest, Duplicate) {
+ ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"},
+ error::INVALID_ARGUMENT);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc
new file mode 100644
index 0000000000..a39bebd854
--- /dev/null
+++ b/tensorflow/core/framework/op_segment.cc
@@ -0,0 +1,86 @@
+#include "tensorflow/core/framework/op_segment.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+OpSegment::Item::~Item() {
+ for (auto kv : name_kernel) delete kv.second;
+}
+
+OpSegment::OpSegment() {}
+
+OpSegment::~OpSegment() {
+ for (auto kv : sessions_) delete kv.second;
+}
+
+Status OpSegment::FindOrCreate(const string& session_handle,
+ const string& node_name, OpKernel** kernel,
+ CreateKernelFn create_fn) {
+ {
+ mutex_lock l(mu_);
+ auto item = gtl::FindPtrOrNull(sessions_, session_handle);
+ if (item == nullptr) {
+ return errors::NotFound("Session ", session_handle, " is not found.");
+ }
+ *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name);
+ if (*kernel != nullptr) {
+ return Status::OK();
+ }
+ }
+ Status s = create_fn(kernel);
+ if (!s.ok()) {
+ LOG(ERROR) << "Create kernel failed: " << s;
+ return s;
+ }
+ {
+ mutex_lock l(mu_);
+ auto item = gtl::FindPtrOrNull(sessions_, session_handle);
+ if (item == nullptr) {
+ return errors::NotFound("Session ", session_handle, " is not found.");
+ }
+ OpKernel** p_kernel = &(item->name_kernel[node_name]);
+ if (*p_kernel == nullptr) {
+ *p_kernel = *kernel; // Inserts 'kernel' in the map.
+ } else {
+ delete *kernel;
+ *kernel = *p_kernel;
+ }
+ }
+ return Status::OK();
+}
+
+void OpSegment::AddHold(const string& session_handle) {
+ mutex_lock l(mu_);
+ Item** item = &sessions_[session_handle];
+ if (*item == nullptr) {
+ *item = new Item; // num_holds == 1
+ } else {
+ ++((*item)->num_holds);
+ }
+}
+
+void OpSegment::RemoveHold(const string& session_handle) {
+ Item* item = nullptr;
+ {
+ mutex_lock l(mu_);
+ auto siter = sessions_.find(session_handle);
+ if (siter == sessions_.end()) {
+ VLOG(1) << "Session " << session_handle << " is not found.";
+ return;
+ }
+ item = siter->second;
+ if (--(item->num_holds) > 0) {
+ return;
+ } else {
+ sessions_.erase(siter);
+ }
+ }
+ delete item;
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h
new file mode 100644
index 0000000000..55249d2a38
--- /dev/null
+++ b/tensorflow/core/framework/op_segment.h
@@ -0,0 +1,67 @@
+#ifndef TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_
+#define TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// OpSegment keeps track of OpKernels registered for sessions running
+// on a device.
+//
+// The implementation maintains a two-level map. The 1st level maps
+// session handle to the map of registered OpKernels. The 2nd level
+// map maps node names to instantiated OpKernel objects.
+//
+// Each 2-nd level map is reference-counted and the caller can call
+// AddHold to obtain a reference on all kernels of a session and
+// ensure these kernels are alive until a corresponding RemoveHold is
+// called on the same session.
+class OpSegment {
+ public:
+ OpSegment();
+ ~OpSegment();
+
+ // A hold can be placed on a session, preventing all its kernels
+ // from being deleted.
+ void AddHold(const string& session_handle);
+ void RemoveHold(const string& session_handle);
+
+ // If the kernel for "node_name" has been created in the
+ // "session_handle", returns the existing op kernel in "*kernel".
+ // Otherwise, creates the kernel by calling create_fn(), cache it,
+ // and returns it in "*kernel". If create_fn() fails, returns the
+ // error.
+ //
+ // OpSegment keeps the ownership of the returned "*kernel".
+ typedef std::function<Status(OpKernel**)> CreateKernelFn;
+ Status FindOrCreate(const string& session_handle, const string& node_name,
+ OpKernel** kernel, CreateKernelFn create_fn);
+
+ private:
+ // op name -> OpKernel
+ typedef std::unordered_map<string, OpKernel*> KernelMap;
+ struct Item {
+ int num_holds = 1; // Num of holds put on the session.
+ KernelMap name_kernel; // op name -> kernel.
+ ~Item();
+ };
+
+ // session handle -> item.
+ // Session handles are produced by strings::FpToString()
+ typedef std::unordered_map<string, Item*> SessionMap;
+
+ mutable mutex mu_;
+ SessionMap sessions_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpSegment);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_
diff --git a/tensorflow/core/framework/op_segment_test.cc b/tensorflow/core/framework/op_segment_test.cc
new file mode 100644
index 0000000000..6297718df8
--- /dev/null
+++ b/tensorflow/core/framework/op_segment_test.cc
@@ -0,0 +1,142 @@
+#include "tensorflow/core/framework/op_segment.h"
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+
+class OpSegmentTest : public ::testing::Test {
+ protected:
+ DeviceBase device_;
+ std::vector<NodeDef> int32_nodedefs_;
+ std::vector<NodeDef> float_nodedefs_;
+
+ OpSegmentTest() : device_(Env::Default()) {
+ RequireDefaultOps();
+ for (int i = 0; i < 10; ++i) {
+ NodeDef def;
+ TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
+ .Input("x", 0, DT_INT32)
+ .Input("y", 0, DT_INT32)
+ .Finalize(&def));
+ int32_nodedefs_.push_back(def);
+ TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
+ .Input("x", 0, DT_FLOAT)
+ .Input("y", 0, DT_FLOAT)
+ .Finalize(&def));
+ float_nodedefs_.push_back(def);
+ }
+ }
+
+ void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) {
+ ASSERT_NE(op, nullptr);
+ EXPECT_EQ(expected.DebugString(), op->def().DebugString());
+ EXPECT_EQ(2, op->num_inputs());
+ EXPECT_EQ(dt, op->input_type(0));
+ EXPECT_EQ(dt, op->input_type(1));
+ EXPECT_EQ(1, op->num_outputs());
+ EXPECT_EQ(dt, op->output_type(0));
+ }
+
+ OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) {
+ return [this, ndef](OpKernel** kernel) {
+ Status s;
+ auto created =
+ CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(), *ndef, &s);
+ if (s.ok()) {
+ *kernel = created.release();
+ }
+ return s;
+ };
+ }
+};
+
+TEST_F(OpSegmentTest, Basic) {
+ OpSegment opseg;
+ OpKernel* op;
+
+ opseg.AddHold("A");
+ opseg.AddHold("B");
+ for (int i = 0; i < 10; ++i) {
+ // Register in session A.
+ auto* ndef = &float_nodedefs_[i];
+ EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef)));
+ ValidateOpAndTypes(op, *ndef, DT_FLOAT);
+
+ // Register in session B.
+ ndef = &int32_nodedefs_[i];
+ EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef)));
+ ValidateOpAndTypes(op, *ndef, DT_INT32);
+ }
+
+ auto reterr = [](OpKernel** kernel) {
+ return errors::Internal("Should not be called");
+ };
+ for (int i = 0; i < 10; ++i) {
+ // Lookup op in session A.
+ EXPECT_OK(opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr));
+ ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT);
+
+ // Lookup op in session B.
+ EXPECT_OK(opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr));
+ ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32);
+ }
+
+ opseg.RemoveHold("A");
+ opseg.RemoveHold("B");
+}
+
+TEST_F(OpSegmentTest, SessionNotFound) {
+ OpSegment opseg;
+ OpKernel* op;
+ NodeDef def = float_nodedefs_[0];
+ Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
+ EXPECT_TRUE(errors::IsNotFound(s)) << s;
+}
+
+TEST_F(OpSegmentTest, CreateFailure) {
+ OpSegment opseg;
+ OpKernel* op;
+ NodeDef def = float_nodedefs_[0];
+ def.set_op("nonexistop");
+ opseg.AddHold("A");
+ Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
+ EXPECT_TRUE(errors::IsNotFound(s)) << s;
+ opseg.RemoveHold("A");
+}
+
+TEST_F(OpSegmentTest, AddRemoveHolds) {
+ OpSegment opseg;
+ OpKernel* op;
+ const auto& ndef = int32_nodedefs_[0];
+
+ // No op.
+ opseg.RemoveHold("null");
+
+ // Thread1 register the op and wants to ensure it alive.
+ opseg.AddHold("foo");
+ EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef)));
+
+ // Thread2 starts some execution needs "op" to be alive.
+ opseg.AddHold("foo");
+
+ // Thread1 clears session "foo". E.g., a master sends CleanupGraph
+ // before an execution finishes.
+ opseg.RemoveHold("foo");
+
+ // Thread2 should still be able to access "op".
+ ValidateOpAndTypes(op, ndef, DT_INT32);
+
+ // Thread2 then remove its hold on "foo".
+ opseg.RemoveHold("foo");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/queue_interface.h b/tensorflow/core/framework/queue_interface.h
new file mode 100644
index 0000000000..a765c211cb
--- /dev/null
+++ b/tensorflow/core/framework/queue_interface.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
+#define TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+// All implementations must be thread-safe.
+class QueueInterface : public ResourceBase {
+ public:
+ typedef std::vector<Tensor> Tuple;
+ typedef AsyncOpKernel::DoneCallback DoneCallback;
+ typedef std::function<void(const Tuple&)> CallbackWithTuple;
+
+ virtual Status ValidateTuple(const Tuple& tuple) = 0;
+ virtual Status ValidateManyTuple(const Tuple& tuple) = 0;
+
+ // Stashes a function object for future execution, that will eventually
+ // enqueue the tuple of tensors into the queue, and returns immediately. The
+ // function object is guaranteed to call 'callback'.
+ virtual void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) = 0;
+
+ // Same as above, but the component tensors are sliced along the 0th dimension
+ // to make multiple queue-element components.
+ virtual void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) = 0;
+
+ // Stashes a function object for future execution, that will eventually
+ // dequeue an element from the queue and call 'callback' with that tuple
+ // element as argument.
+ virtual void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) = 0;
+
+ // Same as above, but the stashed function object will attempt to dequeue
+ // num_elements items.
+ virtual void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) = 0;
+
+ // Signals that no more elements will be enqueued, and optionally
+ // cancels pending Enqueue(Many) operations.
+ //
+ // After calling this function, subsequent calls to Enqueue(Many)
+ // will fail. If `cancel_pending_enqueues` is true, all pending
+ // calls to Enqueue(Many) will fail as well.
+ //
+ // After calling this function, all current and subsequent calls to
+ // Dequeue(Many) will fail instead of blocking (though they may
+ // succeed if they can be satisfied by the elements in the queue at
+ // the time it was closed).
+ virtual void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) = 0;
+
+ // Assuming *this represents a shared queue, verify that it matches
+ // another instantiation indicated by node_def.
+ virtual Status MatchesNodeDef(const NodeDef& node_def) = 0;
+
+ // Returns the number of elements in the queue.
+ virtual int32 size() = 0;
+
+ virtual const DataTypeVector& component_dtypes() const = 0;
+
+ string DebugString() override { return "A queue"; }
+
+ protected:
+ virtual ~QueueInterface() {}
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
diff --git a/tensorflow/core/framework/reader_interface.h b/tensorflow/core/framework/reader_interface.h
new file mode 100644
index 0000000000..b307c37f01
--- /dev/null
+++ b/tensorflow/core/framework/reader_interface.h
@@ -0,0 +1,66 @@
+#ifndef TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
+#define TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+class QueueInterface;
+class ReaderInterface;
+
+// Readers are the mechanism for reading records from files in
+// TensorFlow graphs. Each supported file format has a corresponding
+// ReaderInterface descendant and a corresponding Op & OpKernel
+// (implemented using ReaderOpKernel from reader_op_kernel.h).
+//
+// To use a Reader, you first encode "work" (some string, typically a
+// filename) in the Reader's "work queue". It then processes the
+// "work" (reading records from the file), to produce key/value
+// strings. The methods of this class are called by ReaderFoo ops,
+// so see ../ops/io_ops.cc for detailed descriptions.
+//
+// All descendants of this class must be thread-safe.
+//
+// See the design document here:
+// https://docs.google.com/document/d/1UAgZOoeehYr20TdzW2CoZ30V-aqQphU4SwKXsW7eJv4/edit#
+
+// TODO(josh11b): Switch this to Async.
+class ReaderInterface : public ResourceBase {
+ public:
+ // Read a single record into *key / *value. May get more work from
+ // *queue if the current work is complete. Sets the status on
+ // *context with an OutOfRange Status if the the current work is
+ // complete and the queue is done (closed and empty).
+ // This method may block.
+ virtual void Read(QueueInterface* queue, string* key, string* value,
+ OpKernelContext* context) = 0;
+
+ // Restore this reader to its newly-constructed state.
+ virtual Status Reset() = 0;
+
+ // Accessors
+ virtual int64 NumRecordsProduced() = 0;
+ virtual int64 NumWorkUnitsCompleted() = 0;
+
+ // -- Serialization/Restoration support --
+ // Not all readers will support saving and restoring state.
+ virtual Status SerializeState(string* state) = 0;
+ // Note: Must Reset on error.
+ virtual Status RestoreState(const string& state) = 0;
+
+ string DebugString() override { return "a reader"; }
+
+ protected:
+ virtual ~ReaderInterface() {}
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
diff --git a/tensorflow/core/framework/reader_op_kernel.cc b/tensorflow/core/framework/reader_op_kernel.cc
new file mode 100644
index 0000000000..719f27d94b
--- /dev/null
+++ b/tensorflow/core/framework/reader_op_kernel.cc
@@ -0,0 +1,39 @@
+#include "tensorflow/core/framework/reader_op_kernel.h"
+
+namespace tensorflow {
+
+ReaderOpKernel::ReaderOpKernel(OpKernelConstruction* context)
+ : OpKernel(context), have_handle_(false) {
+ OP_REQUIRES_OK(context, context->allocate_persistent(
+ tensorflow::DT_STRING,
+ tensorflow::TensorShape({2}), &handle_, nullptr));
+}
+
+ReaderOpKernel::~ReaderOpKernel() {
+ if (have_handle_ && cinfo_.resource_is_private_to_kernel()) {
+ TF_CHECK_OK(cinfo_.resource_manager()->Delete<ReaderInterface>(
+ cinfo_.container(), cinfo_.name()));
+ }
+}
+
+void ReaderOpKernel::Compute(OpKernelContext* ctx) {
+ mutex_lock l(mu_);
+ if (!have_handle_) {
+ OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), false));
+ ReaderInterface* reader;
+ OP_REQUIRES_OK(ctx,
+ cinfo_.resource_manager()->LookupOrCreate<ReaderInterface>(
+ cinfo_.container(), cinfo_.name(), &reader,
+ [this](ReaderInterface** ret) {
+ *ret = factory_();
+ return Status::OK();
+ }));
+ auto h = handle_.AccessTensor(ctx)->flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ have_handle_ = true;
+ }
+ ctx->set_output_ref(0, &mu_, handle_.AccessTensor(ctx));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h
new file mode 100644
index 0000000000..8e5cc50c9b
--- /dev/null
+++ b/tensorflow/core/framework/reader_op_kernel.h
@@ -0,0 +1,42 @@
+#ifndef TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
+#define TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
+
+#include <functional>
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/reader_interface.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Implementation for ops providing a Reader.
+class ReaderOpKernel : public OpKernel {
+ public:
+ explicit ReaderOpKernel(OpKernelConstruction* context);
+ ~ReaderOpKernel() override;
+
+ void Compute(OpKernelContext* context) override;
+
+ // Must be called by descendants before the first call to Compute()
+ // (typically called during construction). factory must return a
+ // ReaderInterface descendant allocated with new that ReaderOpKernel
+ // will take ownership of.
+ void SetReaderFactory(std::function<ReaderInterface*()> factory) {
+ mutex_lock l(mu_);
+ DCHECK(!have_handle_);
+ factory_ = factory;
+ }
+
+ private:
+ mutex mu_;
+ bool have_handle_ GUARDED_BY(mu_);
+ PersistentTensor handle_ GUARDED_BY(mu_);
+ ContainerInfo cinfo_;
+ std::function<ReaderInterface*()> factory_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
new file mode 100644
index 0000000000..18473aea2e
--- /dev/null
+++ b/tensorflow/core/framework/register_types.h
@@ -0,0 +1,90 @@
+#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
+// This file is used by cuda code and must remain compilable by nvcc.
+
+#include "tensorflow/core/platform/port.h"
+
+// Macros to apply another macro to lists of supported types. If you change
+// the lists of types, please also update the list in types.cc.
+//
+// See example uses of these macros in core/ops.
+//
+//
+// Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple
+// times by passing each invocation a data type supported by TensorFlow.
+//
+// The different variations pass different subsets of the types.
+// TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow.
+// The set of types depends on the compilation platform.
+//.
+// This can be used to register a different template instantiation of
+// an OpKernel for different signatures, e.g.:
+/*
+ #define REGISTER_PARTITION(type) \
+ REGISTER_TF_OP_KERNEL("partition", DEVICE_CPU, #type ", int32", \
+ PartitionOp<type>);
+ TF_CALL_ALL_TYPES(REGISTER_PARTITION)
+ #undef REGISTER_PARTITION
+*/
+
+#ifndef __ANDROID__
+
+// Call "m" for all number types that support the comparison operations "<" and
+// ">".
+#define TF_CALL_REAL_NUMBER_TYPES(m) \
+ m(float); \
+ m(double); \
+ m(int64); \
+ m(int32); \
+ m(uint8); \
+ m(int16); \
+ m(int8)
+
+#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
+ m(float); \
+ m(double); \
+ m(int64); \
+ m(uint8); \
+ m(int16); \
+ m(int8)
+
+// Call "m" for all number types, including complex64.
+#define TF_CALL_NUMBER_TYPES(m) \
+ TF_CALL_REAL_NUMBER_TYPES(m); \
+ m(complex64)
+
+#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \
+ TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m); \
+ m(complex64)
+
+// Call "m" on all types.
+#define TF_CALL_ALL_TYPES(m) \
+ TF_CALL_NUMBER_TYPES(m); \
+ m(bool); \
+ m(string)
+
+// Call "m" on all types supported on GPU.
+#define TF_CALL_GPU_NUMBER_TYPES(m) \
+ m(float); \
+ m(double)
+
+#else // __ANDROID__
+
+#define TF_CALL_REAL_NUMBER_TYPES(m) \
+ m(float); \
+ m(int32)
+
+#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+
+#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) m(float)
+
+#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)
+
+#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+
+// Maybe we could put an empty macro here for Android?
+#define TF_CALL_GPU_NUMBER_TYPES(m) m(float)
+
+#endif // __ANDROID__
+
+#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc
new file mode 100644
index 0000000000..7f551ea65f
--- /dev/null
+++ b/tensorflow/core/framework/rendezvous.cc
@@ -0,0 +1,263 @@
+#include "tensorflow/core/framework/rendezvous.h"
+
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+/* static */
+string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
+ const string& dst_device, const string& name,
+ const FrameAndIter& frame_iter) {
+ // NOTE: ';' is not used in the device name's job name.
+ //
+ // We include both sender and receiver in the key to facilitate
+ // debugging. For correctness, we only need to encode the receiver.
+ //
+ // "src_incarnation" is used to distinguish a worker when it
+ // restarts.
+ return strings::StrCat(src_device, ";", strings::FpToString(src_incarnation),
+ ";", dst_device, ";", name, ";", frame_iter.frame_id,
+ ":", frame_iter.iter_id);
+}
+
+/* static */
+Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
+ // TODO(zhifengc): This code is not fast enough.
+ std::vector<string> parts = str_util::Split(key, ';');
+ if (parts.size() == 5 &&
+ DeviceNameUtils::ParseFullName(parts[0], &out->src) &&
+ strings::StringToFp(parts[1], &out->src_incarnation) &&
+ DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
+ !parts[3].empty()) {
+ out->src_device = parts[0];
+ out->dst_device = parts[2];
+ out->edge_name = parts[3];
+ return Status::OK();
+ }
+ return errors::InvalidArgument("Invalid rendezvous key: ", key);
+}
+
+Rendezvous::~Rendezvous() {}
+
+Status Rendezvous::Recv(const string& key, const Args& recv_args, Tensor* val,
+ bool* is_dead) {
+ Status ret;
+ Notification n;
+ RecvAsync(key, recv_args,
+ [&ret, &n, val, is_dead](const Status& s, const Args& send_args,
+ const Args& recv_args, const Tensor& v,
+ const bool dead) {
+ ret = s;
+ *val = v;
+ *is_dead = dead;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+}
+
+class LocalRendezvousImpl : public Rendezvous {
+ public:
+ explicit LocalRendezvousImpl(bool tolerate_dup_recv)
+ : tolerate_dup_recv_(tolerate_dup_recv) {}
+
+ Status Send(const string& key, const Args& send_args, const Tensor& val,
+ const bool is_dead) override {
+ VLOG(2) << "Send " << this << " " << key;
+ DoneCallback waiter = nullptr;
+ Args recv_args;
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) {
+ return status_;
+ }
+ Item* item = nullptr;
+ Table::iterator iter = table_.find(key);
+ if (iter == table_.end()) {
+ // There is no waiter for this message. Insert the message
+ // into the waiters table. The waiter will pick it up when
+ // arrives.
+ item = new Item;
+ item->waiter = nullptr;
+ item->value = val;
+ item->is_dead = is_dead;
+ if (send_args.device_context) {
+ send_args.device_context->Ref();
+ item->send_dev_context = send_args.device_context;
+ }
+ item->recv_dev_context = nullptr;
+
+ // The allocator attributes of item->value.
+ item->send_alloc_attrs = send_args.alloc_attrs;
+
+ CHECK(table_.insert({key, item}).second);
+ return Status::OK();
+ } else {
+ item = iter->second;
+ if (item->waiter == nullptr) {
+ // There is already a message in the table under the key.
+ // Should not happen unless it has a waiter.
+ return errors::Aborted("Duplicated send: ", key);
+ }
+ // Mark item as complete.
+ item->has_been_recvd = true;
+ waiter = item->waiter;
+ item->waiter = nullptr;
+ // The ref on recv_dev_context transfers below.
+ recv_args.device_context = item->recv_dev_context;
+ recv_args.alloc_attrs = item->recv_alloc_attrs;
+ item->recv_dev_context = nullptr;
+ if (tolerate_dup_recv_) {
+ item->value = val;
+ item->is_dead = is_dead;
+ if (send_args.device_context) {
+ send_args.device_context->Ref();
+ item->send_dev_context = send_args.device_context;
+ }
+ item->send_alloc_attrs = send_args.alloc_attrs;
+ }
+ }
+ } // mutex
+ // Notify the waiter by invoking its done closure, outside scope
+ // of the table lock.
+ waiter(Status::OK(), send_args, recv_args, val, is_dead);
+ if (recv_args.device_context) recv_args.device_context->Unref();
+ return Status::OK();
+ }
+
+ void RecvAsync(const string& key, const Args& recv_args,
+ DoneCallback done) override {
+ VLOG(2) << "Recv " << this << " " << key;
+ mu_.lock();
+ if (!status_.ok()) {
+ // Rendezvous has been aborted.
+ Status s = status_;
+ mu_.unlock();
+ done(s, Args(), recv_args, Tensor(), false);
+ return;
+ }
+ Table::iterator iter = table_.find(key);
+ if (iter != table_.end()) {
+ Item* item = iter->second;
+ if (item->has_been_recvd && !tolerate_dup_recv_) {
+ mu_.unlock();
+ done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
+ Tensor(), false);
+ } else if (item->waiter == nullptr || tolerate_dup_recv_) {
+ // A message has already arrived and is stored in the table
+ // under this key. Consumes the message and invokes the done
+ // closure.
+ Tensor v = item->value;
+ if (!tolerate_dup_recv_) {
+ item->value = Tensor();
+ }
+ item->has_been_recvd = true;
+ // Before dropping the table lock, capture the item values.
+ // DeviceContext is only non-null for non-CPU devices.
+ // If we capture the send_dev_context, we need to hold a ref on
+ // it. Our caller will have a ref on the recv_dev_context,
+ // which is not in our table.
+ DeviceContext* send_dev_context = item->send_dev_context;
+ if (send_dev_context) send_dev_context->Ref();
+ bool is_dead = item->is_dead;
+ mu_.unlock();
+ Args send_args;
+ send_args.device_context = item->send_dev_context;
+ send_args.alloc_attrs = item->send_alloc_attrs;
+ done(Status::OK(), send_args, recv_args, v, is_dead);
+ if (send_dev_context) send_dev_context->Unref();
+ } else {
+ // Already have a waiter in the waiters table under this key,
+ // which should not happen.
+ mu_.unlock();
+ done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
+ Tensor(), false);
+ }
+ return;
+ }
+ // Waiting for a message that has not arrived yet. Insert into the
+ // waiting table. The done closure will be invoked when the
+ // message arrives.
+ Item* item = new Item;
+ item->waiter = done;
+ if (recv_args.device_context) {
+ item->recv_dev_context = recv_args.device_context;
+ item->recv_alloc_attrs = recv_args.alloc_attrs;
+ item->recv_dev_context->Ref();
+ }
+ CHECK(table_.insert({key, item}).second);
+ mu_.unlock();
+ return;
+ }
+
+ void StartAbort(const Status& status) override {
+ CHECK(!status.ok());
+ std::vector<Item*> items;
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) return;
+ status_ = status;
+ items.reserve(table_.size());
+ for (const auto& p : table_) items.push_back(p.second);
+ table_.clear();
+ }
+ for (Item* item : items) {
+ if (item->waiter != nullptr) {
+ item->waiter(status, Args(), Args(), Tensor(), false);
+ }
+ delete item;
+ }
+ }
+
+ private:
+ typedef LocalRendezvousImpl ME;
+ const bool tolerate_dup_recv_;
+
+ struct Item {
+ DoneCallback waiter = nullptr;
+ Tensor value;
+ bool is_dead = false;
+ bool has_been_recvd = false;
+ DeviceContext* send_dev_context = nullptr;
+ DeviceContext* recv_dev_context = nullptr;
+ AllocatorAttributes send_alloc_attrs;
+ AllocatorAttributes recv_alloc_attrs;
+
+ ~Item() {
+ if (send_dev_context) {
+ send_dev_context->Unref();
+ }
+ if (recv_dev_context) {
+ recv_dev_context->Unref();
+ }
+ }
+ };
+ typedef std::unordered_map<string, Item*> Table;
+
+ // TODO(zhifengc): shard table_.
+ mutex mu_;
+ Table table_ GUARDED_BY(mu_);
+ Status status_;
+
+ ~LocalRendezvousImpl() override {
+ for (auto i : table_) {
+ delete i.second;
+ }
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl);
+};
+
+Rendezvous* NewLocalRendezvous(bool tolerate_dup_recv) {
+ return new LocalRendezvousImpl(tolerate_dup_recv);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h
new file mode 100644
index 0000000000..94fbfb2523
--- /dev/null
+++ b/tensorflow/core/framework/rendezvous.h
@@ -0,0 +1,102 @@
+#ifndef TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
+#define TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
+
+#include <string>
+
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+// A Rendezvous is an abstraction for passing a Tensor
+// from a producer to a consumer, where the consumer may safely
+// request the Tensor before or after it has been produced. A
+// producer never blocks when using a Rendezvous. A consumer has the
+// choice of making a blocking call or providing a callback: in either
+// case, the consumer receives the Tensor as soon as it is available.
+//
+// A Rendezvous key encodes a single <producer, consumer> pair. It is
+// an error to call Send() or Recv*() more than once with the same
+// key.
+class Rendezvous : public core::RefCounted {
+ public:
+ struct Args {
+ DeviceContext* device_context = nullptr;
+ AllocatorAttributes alloc_attrs;
+ };
+
+ // Constructs a rendezvouz key for the tensor of "name" sent from
+ // "src_device" to "dst_device". The tensor is generated in the frame
+ // and iteration specified by "frame_iter".
+ static string CreateKey(const string& src_device, uint64 src_incarnation,
+ const string& dst_device, const string& name,
+ const FrameAndIter& frame_iter);
+
+ // Parses the key constructed by CreateKey and parse src/dst device
+ // names into structures respectively.
+ struct ParsedKey {
+ string src_device;
+ DeviceNameUtils::ParsedName src;
+ uint64 src_incarnation = 0;
+ string dst_device;
+ DeviceNameUtils::ParsedName dst;
+ string edge_name;
+ };
+ static Status ParseKey(const string& key, ParsedKey* out);
+
+ // The caller is a tensor producer and it sends a message (a tensor
+ // "val" and a bool "is_dead") under the given "key".
+ //
+ // {val, is_dead} is bundled as a message sent and received.
+ // Typically, is_dead is set by some control flow nodes
+ // (e.g., a not-take branch). args is passed by Send to the
+ // Recv function to communicate any information that the Recv
+ // function might need. This is typically only necessary for
+ // Send/Recv on the same worker.
+ //
+ // Send() never blocks.
+ virtual Status Send(const string& key, const Args& args, const Tensor& val,
+ const bool is_dead) = 0;
+
+ // Callback provided by a tensor consumer waiting on the rendezvous.
+ // It will be invoked when the tensor is available, or when a non-OK
+ // status arises in the production of that tensor. It also gets
+ // two Rendezvous::Args, one provided by the sender, the other by the
+ // receiver, which may be needed when a non-CPU device is in use
+ // by either side.
+ typedef std::function<void(const Status&, const Args&, const Args&,
+ const Tensor&, const bool)> DoneCallback;
+
+ virtual void RecvAsync(const string& key, const Args& args,
+ DoneCallback done) = 0;
+
+ // Synchronous wrapper for RecvAsync.
+ Status Recv(const string& key, const Args& args, Tensor* val, bool* is_dead);
+
+ // Aborts all pending and future Send/Recv with the given "status".
+ //
+ // StartAbort() does not wait for ongoing calls to finish.
+ // REQUIRES: !status.ok()
+ virtual void StartAbort(const Status& status) = 0;
+
+ protected:
+ ~Rendezvous() override;
+};
+
+// Returns a Rendezvous instance that is limited to use only by
+// producers and consumers in the local process. The caller assumes
+// ownership of one Ref() on the returned object.
+//
+// If "tolerate_dup_recv" is true, then the Rendezvous will retain
+// already Recv'd values and make them available to duplicate Recv
+// calls. This may be useful if the RPC layer is not reliable, but
+// comes at the cost of higher memory consumption.
+Rendezvous* NewLocalRendezvous(bool tolerate_dup_recv = false);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc
new file mode 100644
index 0000000000..32011a468f
--- /dev/null
+++ b/tensorflow/core/framework/rendezvous_test.cc
@@ -0,0 +1,314 @@
+#include "tensorflow/core/framework/rendezvous.h"
+
+#include <gtest/gtest.h>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+TEST(RendezvousTest, Key) {
+ const string key = Rendezvous::CreateKey(
+ "/job:mnist/replica:1/task:2/CPU:0", 7890,
+ "/job:mnist/replica:1/task:2/GPU:0", "var0", FrameAndIter(0, 0));
+ EXPECT_EQ(key,
+ "/job:mnist/replica:1/task:2/CPU:0;"
+ "0000000000001ed2;" // 7890 = 0x1ed2
+ "/job:mnist/replica:1/task:2/GPU:0;"
+ "var0;"
+ "0:0");
+ Rendezvous::ParsedKey parsed;
+ EXPECT_OK(Rendezvous::ParseKey(key, &parsed));
+ EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0");
+ EXPECT_EQ(parsed.src_incarnation, 7890);
+ EXPECT_EQ(parsed.src.type, "CPU");
+ EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/GPU:0");
+ EXPECT_EQ(parsed.dst.type, "GPU");
+
+ EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok());
+ EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;"
+ "/job:mnist/replica:1/task:2/GPU:0;",
+ &parsed)
+ .ok());
+ EXPECT_FALSE(
+ Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok());
+}
+
+class LocalRendezvousTest : public ::testing::Test {
+ public:
+ LocalRendezvousTest()
+ : threads_(new thread::ThreadPool(Env::Default(), "test", 16)) {
+ rendez_ = NewLocalRendezvous();
+ }
+
+ ~LocalRendezvousTest() override {
+ rendez_->Unref();
+ delete threads_;
+ }
+
+ void SchedClosure(std::function<void()> fn) { threads_->Schedule(fn); }
+
+ Rendezvous* rendez_;
+
+ private:
+ thread::ThreadPool* threads_;
+};
+
+// string -> Tensor<string>
+Tensor V(const string& content) {
+ Tensor tensor(DT_STRING, TensorShape({}));
+ tensor.scalar<string>()() = content;
+ return tensor;
+}
+
+// Tensor<string> -> string
+string V(const Tensor& tensor) {
+ CHECK_EQ(tensor.dtype(), DT_STRING);
+ CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
+ return tensor.scalar<string>()();
+}
+
+TEST_F(LocalRendezvousTest, SendRecv) {
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
+ EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, V("hello"), false)));
+ Tensor val(DT_STRING);
+ bool is_dead = false;
+ ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
+ EXPECT_EQ("hello", V(val));
+}
+
+TEST_F(LocalRendezvousTest, RecvSend) {
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(10000);
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
+ });
+ Tensor val(DT_STRING);
+ bool is_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
+ EXPECT_EQ("hello", V(val));
+}
+
+TEST_F(LocalRendezvousTest, DuplicateWaiterRecv) {
+ SchedClosure([this]() {
+ Tensor t(DT_STRING);
+ bool is_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
+ ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
+ });
+ Env::Default()->SleepForMicroseconds(1000000);
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
+ ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
+ ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
+ EXPECT_EQ("secret msg", V(val));
+}
+
+TEST_F(LocalRendezvousTest, DuplicateSerialRecv) {
+ SchedClosure([this]() {
+ Tensor t(DT_STRING);
+ bool is_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
+ ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
+ });
+ Env::Default()->SleepForMicroseconds(1000000);
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
+ ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
+ EXPECT_EQ("secret msg", V(val));
+ EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
+}
+
+// A simple structure that behaves a bit like a blocking counter. The
+// user that decrements counter to 0 does done.Notify(), and the main
+// thread waits for done to be notified.
+struct BlockingState {
+ mutex lock;
+ int counter;
+ Notification done;
+};
+
+TEST_F(LocalRendezvousTest, RandomSendRecv) {
+ static const int N = 1000;
+ BlockingState state;
+ state.counter = N;
+ for (int i = 0; i < N; ++i) {
+ SchedClosure([this, i]() {
+ random::PhiloxRandom philox(testing::RandomSeed() + i, 17);
+ random::SimplePhilox rnd(&philox);
+ Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000));
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Send(strings::StrCat(i), args, V(strings::StrCat(i)),
+ false));
+ });
+ SchedClosure([this, &state, i]() {
+ random::PhiloxRandom philox(testing::RandomSeed() + N + i, 17);
+ random::SimplePhilox rnd(&philox);
+ Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000));
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Recv(strings::StrCat(i), args, &val, &val_dead));
+ EXPECT_EQ(strings::StrCat(i), V(val));
+ bool done = false;
+ {
+ mutex_lock l(state.lock);
+ state.counter--;
+ if (state.counter == 0) {
+ done = true;
+ }
+ }
+ if (done) {
+ state.done.Notify();
+ }
+ });
+ }
+
+ state.done.WaitForNotification();
+}
+
+TEST_F(LocalRendezvousTest, RecvAbort) {
+ rendez_->Ref();
+ SchedClosure([this]() {
+ rendez_->StartAbort(errors::Aborted("")); // abort
+ rendez_->Unref();
+ });
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ Status status = rendez_->Recv("foo", args, &val, &val_dead);
+ EXPECT_TRUE(errors::IsAborted(status));
+}
+
+// Similar to RecvAbort. But this test case ensures the main thread
+// Recv() call happens after StartAbort().
+TEST_F(LocalRendezvousTest, RecvSleepAbort) {
+ rendez_->Ref();
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(1000000);
+ rendez_->StartAbort(errors::Aborted("")); // abort
+ rendez_->Unref();
+ });
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ Status status = rendez_->Recv("foo", args, &val, &val_dead);
+ EXPECT_TRUE(errors::IsAborted(status));
+}
+
+TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
+ rendez_->StartAbort(errors::Aborted(""));
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, val, val_dead)));
+ EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
+}
+
+class DummyDeviceContext : public DeviceContext {
+ public:
+ explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
+ ~DummyDeviceContext() override {}
+ int stream_id() const { return stream_id_; }
+
+ private:
+ const int stream_id_;
+};
+
+TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
+ Rendezvous::Args args;
+ args.device_context = new DummyDeviceContext(123);
+
+ ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
+
+ Notification n;
+ Rendezvous::Args args1;
+ args1.device_context = new DummyDeviceContext(1);
+ rendez_->RecvAsync("foo", args1, [&n](const Status& s,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& val, bool is_dead) {
+ CHECK_EQ(123,
+ dynamic_cast<const DummyDeviceContext*>(send_args.device_context)
+ ->stream_id());
+ n.Notify();
+ });
+
+ n.WaitForNotification();
+ args.device_context->Unref();
+ args1.device_context->Unref();
+}
+
+static void BM_SendRecv(int iters) {
+ Rendezvous* rendez = NewLocalRendezvous();
+ Tensor orig = V("val");
+ Tensor val(DT_STRING, TensorShape({}));
+ bool is_dead = false;
+ Rendezvous::Args args;
+ Status s;
+ if (iters > 0) {
+ while (iters--) {
+ s = rendez->Send("foo", args, orig, is_dead);
+ s = rendez->Recv("foo", args, &val, &is_dead);
+ }
+ CHECK_EQ(V(val), V(orig));
+ }
+ rendez->Unref();
+}
+BENCHMARK(BM_SendRecv);
+
+static void BM_RecvSend(int iters) {
+ thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
+
+ // The main thread sends "foo" for iters/2 times and receives "bar"
+ // for iters/2 times. The other thread sends "bar" for iters/2
+ // times and receives "foo" for iters/2 times.
+ Rendezvous* rendez = NewLocalRendezvous();
+ pool->Schedule([rendez, iters]() {
+ Tensor bar = V("bar");
+ Tensor foo(DT_STRING, TensorShape({}));
+ bool is_dead = false;
+ Rendezvous::Args args;
+ Status s;
+ for (int i = 0; i < iters / 2; ++i) {
+ s = rendez->Recv("foo", args, &foo, &is_dead);
+ s = rendez->Send("bar", args, bar, is_dead);
+ }
+ CHECK_EQ("foo", V(foo));
+ });
+ Tensor foo = V("foo");
+ Tensor bar(DT_STRING, TensorShape({}));
+ bool is_dead = false;
+ Rendezvous::Args args;
+ Status s;
+ for (int i = 0; i < iters / 2; ++i) {
+ s = rendez->Send("foo", args, foo, is_dead);
+ s = rendez->Recv("bar", args, &bar, &is_dead);
+ }
+ CHECK_EQ("bar", V(bar));
+ delete pool;
+}
+BENCHMARK(BM_RecvSend);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
new file mode 100644
index 0000000000..42326f068e
--- /dev/null
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -0,0 +1,146 @@
+#include "tensorflow/core/framework/resource_mgr.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+ResourceMgr::ResourceMgr() : default_container_("localhost") {}
+
+ResourceMgr::ResourceMgr(const string& default_container)
+ : default_container_(default_container) {}
+
+ResourceMgr::~ResourceMgr() { Clear(); }
+
+void ResourceMgr::Clear() {
+ mutex_lock l(mu_);
+ for (const auto& p : containers_) {
+ for (const auto& q : *p.second) {
+ q.second->Unref();
+ }
+ delete p.second;
+ }
+ containers_.clear();
+}
+
+Status ResourceMgr::DoCreate(const string& container, std::type_index type,
+ const string& name, ResourceBase* resource) {
+ {
+ mutex_lock l(mu_);
+ Container** b = &containers_[container];
+ if (*b == nullptr) {
+ *b = new Container;
+ }
+ if ((*b)->insert({{type, name}, resource}).second) {
+ return Status::OK();
+ }
+ }
+ resource->Unref();
+ return errors::AlreadyExists("Resource ", container, "/", name, "/",
+ type.name());
+}
+
+Status ResourceMgr::DoLookup(const string& container, std::type_index type,
+ const string& name,
+ ResourceBase** resource) const {
+ mutex_lock l(mu_);
+ const Container* b = gtl::FindPtrOrNull(containers_, container);
+ if (b == nullptr) {
+ return errors::NotFound("Container ", container, " does not exist.");
+ }
+ auto r = gtl::FindPtrOrNull(*b, {type, name});
+ if (r == nullptr) {
+ return errors::NotFound("Resource ", container, "/", name, "/", type.name(),
+ " does not exist.");
+ }
+ *resource = const_cast<ResourceBase*>(r);
+ (*resource)->Ref();
+ return Status::OK();
+}
+
+Status ResourceMgr::DoDelete(const string& container, std::type_index type,
+ const string& name) {
+ ResourceBase* base = nullptr;
+ {
+ mutex_lock l(mu_);
+ Container* b = gtl::FindPtrOrNull(containers_, container);
+ if (b == nullptr) {
+ return errors::NotFound("Container ", container, " does not exist.");
+ }
+ auto iter = b->find({type, name});
+ if (iter == b->end()) {
+ return errors::NotFound("Resource ", container, "/", name, "/",
+ type.name(), " does not exist.");
+ }
+ base = iter->second;
+ b->erase(iter);
+ }
+ CHECK(base != nullptr);
+ base->Unref();
+ return Status::OK();
+}
+
+Status ResourceMgr::Cleanup(const string& container) {
+ Container* b = nullptr;
+ {
+ mutex_lock l(mu_);
+ auto iter = containers_.find(container);
+ if (iter == containers_.end()) {
+ return errors::NotFound("Container ", container, " does not exist.");
+ }
+ b = iter->second;
+ containers_.erase(iter);
+ }
+ CHECK(b != nullptr);
+ for (const auto& p : *b) {
+ p.second->Unref();
+ }
+ delete b;
+ return Status::OK();
+}
+
+Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef,
+ bool use_node_name_as_default) {
+ CHECK(rmgr);
+ rmgr_ = rmgr;
+ string attr_container;
+ TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container));
+ static RE2 container_re("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+ if (!attr_container.empty() &&
+ !RE2::FullMatch(attr_container, container_re)) {
+ return errors::InvalidArgument("container contains invalid characters: ",
+ attr_container);
+ }
+ string attr_shared_name;
+ TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &attr_shared_name));
+ if (!attr_shared_name.empty() && (attr_shared_name[0] == '_')) {
+ return errors::InvalidArgument("shared_name cannot start with '_':",
+ attr_shared_name);
+ }
+ if (!attr_container.empty()) {
+ container_ = attr_container;
+ } else {
+ container_ = rmgr_->default_container();
+ }
+ if (!attr_shared_name.empty()) {
+ name_ = attr_shared_name;
+ } else if (use_node_name_as_default) {
+ name_ = ndef.name();
+ } else {
+ resource_is_private_to_kernel_ = true;
+ static std::atomic<int64> counter(0);
+ name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name());
+ }
+ return Status::OK();
+}
+
+string ContainerInfo::DebugString() const {
+ return strings::StrCat("[", container(), ",", name(), ",",
+ resource_is_private_to_kernel() ? "private" : "public",
+ "]");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
new file mode 100644
index 0000000000..65e859caf1
--- /dev/null
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -0,0 +1,280 @@
+#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
+#define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
+
+#include <string>
+#include <typeindex>
+#include <typeinfo>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// A ResourceMgr instance keeps track of named and typed resources
+// grouped into containers.
+//
+// Each resource must be represented as a sub-class of ResourceBase,
+// which is reference counted explicitly. Each named resource is
+// registered with ResourceMgr under a named "container" name. At any
+// time, there is at most one instance of a resource given the container
+// name, the resource type and the resource name.
+//
+// All resources for a given container can be dropped by one call of
+// Cleanup().
+//
+// E.g.,
+// struct MyVar : public ResourceBase {
+// mutex mu;
+// Tensor val;
+// }
+//
+// ResourceMgr rm;
+//
+// // Create a var.
+// MyVar* my_var = new MyVar;
+// my_var.val = Tensor(DT_FLOAT, my_shape);
+// my_val.val.flat<float>().setZeros(); // 0 initialized.
+// ctx->SetStatus(rm.Create("my_container", "my_name", my_val));
+//
+// // += a variable.
+// MyVar* my_var = nullptr;
+// Status s = rm.Lookup("my_container", "my_name", &my_var);
+// if (s.ok()) {
+// my_var->val.flat<float>() += grad;
+// }
+// my_var->Unref(); // Or use ScopedUnref().
+// ctx->SetStatus(s);
+class ResourceBase : public core::RefCounted {
+ public:
+ // Returns a debug string for *this.
+ virtual string DebugString() = 0;
+};
+
+class ResourceMgr {
+ public:
+ ResourceMgr();
+ explicit ResourceMgr(const string& default_container);
+ ~ResourceMgr();
+
+ // Returns the default container name for *this.
+ const string& default_container() const { return default_container_; }
+
+ // Creates a resource "name" in the "container". The caller transfers
+ // the ownership of one ref on "resource" to *this
+ //
+ // REQUIRES: std::is_base_of<ResourceBase, T>
+ // REQUIRES: resource != nullptr.
+ template <typename T>
+ Status Create(const string& container, const string& name,
+ T* resource) TF_MUST_USE_RESULT;
+
+ // If "container" has a resource "name", returns it in "*resource" and
+ // the caller takes the ownership of one ref on "*resource".
+ //
+ // REQUIRES: std::is_base_of<ResourceBase, T>
+ // REQUIRES: resource != nullptr
+ template <typename T>
+ Status Lookup(const string& container, const string& name,
+ T** resource) const TF_MUST_USE_RESULT;
+
+ // If "container" has a resource "name", returns it in
+ // "*resource". Otherwise, invokes creator() to create the resource.
+ // The caller takes the ownership of one ref on "*resource".
+ //
+ // REQUIRES: std::is_base_of<ResourceBase, T>
+ // REQUIRES: resource != nullptr
+ template <typename T>
+ Status LookupOrCreate(const string& container, const string& name,
+ T** resource,
+ std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
+
+ // Deletes the resource "name" from the "container".
+ //
+ // REQUIRES: std::is_base_of<ResourceBase, T>
+ template <typename T>
+ Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT;
+
+ // Deletes all resources from the "container" and removes the container.
+ Status Cleanup(const string& container) TF_MUST_USE_RESULT;
+
+ // Deletes all resources in all containers.
+ void Clear();
+
+ private:
+ typedef std::pair<std::type_index, string> Key;
+ struct KeyHash {
+ std::size_t operator()(const Key& k) const {
+ return Hash64(k.second.data(), k.second.size(), k.first.hash_code());
+ }
+ };
+ struct KeyEqual {
+ bool operator()(const Key& x, const Key& y) const {
+ return (x.second == y.second) && (x.first == y.first);
+ }
+ };
+ typedef std::unordered_map<Key, ResourceBase*, KeyHash, KeyEqual> Container;
+
+ const string default_container_;
+ mutable mutex mu_;
+ std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
+
+ Status DoCreate(const string& container, std::type_index type,
+ const string& name,
+ ResourceBase* resource) TF_MUST_USE_RESULT;
+ Status DoLookup(const string& container, std::type_index type,
+ const string& name,
+ ResourceBase** resource) const TF_MUST_USE_RESULT;
+ Status DoDelete(const string& container, std::type_index type,
+ const string& name) TF_MUST_USE_RESULT;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
+};
+
+// Policy helper to decide which container/shared_name to use for a
+// stateful kernel that accesses shared resource.
+class ContainerInfo {
+ public:
+ // Analyze the node attribute of 'ndef' and decides the container and
+ // resource name the kernel should use for accessing the shared
+ // resource.
+ //
+ // 'ndef' is expected to have node attribute "container" and
+ // "shared_name". Returns non-OK if they are not provided or they are
+ // invalid.
+ //
+ // The policy is as following:
+ // * If the attribute "container" is non-empty, it is used as is.
+ // Otherwise, uses the resource manager's default container.
+ // * If the attribute "shared_name" is non-empty, it is used as is.
+ // Otherwise, if "use_node_name_as_default" is true, the kernel's
+ // node name is used as the resource name. Otherwise, a string
+ // unique to this process is used.
+ Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
+ bool use_node_name_as_default);
+ Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
+ return Init(rmgr, ndef, false);
+ }
+
+ // The policy decides that the kernel should access the resource in
+ // resource_manager(), the resource is in the container() and its
+ // name is name(). If resource_is_private_to_kernel() is true, the
+ // kernel should delete the resource when the kernel is deleted.
+ ResourceMgr* resource_manager() const { return rmgr_; }
+ const string& container() const { return container_; }
+ const string& name() const { return name_; }
+ bool resource_is_private_to_kernel() const {
+ return resource_is_private_to_kernel_;
+ }
+
+ // Returns a readable string for *this.
+ string DebugString() const;
+
+ private:
+ ResourceMgr* rmgr_ = nullptr;
+ string container_;
+ string name_;
+ bool resource_is_private_to_kernel_ = false;
+};
+
+// Helper for kernels to obtain 'resource' from the
+// ctx->resource_manager().
+//
+// "input_name" specifies the kernel's ref input which gives a string
+// tensor with two elements, which specifies the container and
+// resource name.
+//
+// Returns OK if the resource is found and transfers one ref of
+// *resource to the caller. Otherwise, returns an error.
+template <typename T>
+Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
+ T** resource);
+
+// Implementation details below.
+
+template <typename T>
+void CheckDeriveFromResourceBase() {
+ static_assert(std::is_base_of<ResourceBase, T>::value,
+ "T must derive from ResourceBase");
+}
+
+template <typename T>
+Status ResourceMgr::Create(const string& container, const string& name,
+ T* resource) {
+ CheckDeriveFromResourceBase<T>();
+ CHECK(resource != nullptr);
+ return DoCreate(container, std::type_index(typeid(T)), name, resource);
+}
+
+template <typename T>
+Status ResourceMgr::Lookup(const string& container, const string& name,
+ T** resource) const {
+ CheckDeriveFromResourceBase<T>();
+ ResourceBase* found = nullptr;
+ Status s = DoLookup(container, std::type_index(typeid(T)), name, &found);
+ if (s.ok()) {
+ // It's safe to down cast 'found' to T* since
+ // typeid(T).hash_code() is part of the map key.
+ *resource = static_cast<T*>(found);
+ }
+ return s;
+}
+
+template <typename T>
+Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
+ T** resource,
+ std::function<Status(T**)> creator) {
+ Status s;
+ *resource = nullptr;
+ while (*resource == nullptr) {
+ s = Lookup(container, name, resource);
+ if (s.ok()) break;
+ s = creator(resource);
+ if (!s.ok()) break;
+ s = Create(container, name, *resource);
+ if (s.ok()) {
+ (*resource)->Ref();
+ break;
+ }
+ // Rare event. Concurrent racy creation. Redo the lookup.
+ *resource = nullptr;
+ }
+ return s;
+}
+
+template <typename T>
+Status ResourceMgr::Delete(const string& container, const string& name) {
+ CheckDeriveFromResourceBase<T>();
+ return DoDelete(container, std::type_index(typeid(T)), name);
+}
+
+template <typename T>
+Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
+ T** resource) {
+ string container;
+ string shared_name;
+ {
+ mutex* mu;
+ TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
+ mutex_lock l(*mu);
+ Tensor tensor;
+ TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
+ if (tensor.NumElements() != 2) {
+ return errors::InvalidArgument(
+ "Resource handle must have 2 elements, but had shape: ",
+ tensor.shape().DebugString());
+ }
+ container = tensor.flat<string>()(0);
+ shared_name = tensor.flat<string>()(1);
+ }
+ return ctx->resource_manager()->Lookup(container, shared_name, resource);
+}
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc
new file mode 100644
index 0000000000..9f7ce3dde3
--- /dev/null
+++ b/tensorflow/core/framework/resource_mgr_test.cc
@@ -0,0 +1,173 @@
+#include "tensorflow/core/framework/resource_mgr.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+class Resource : public ResourceBase {
+ public:
+ explicit Resource(const string& label) : label_(label) {}
+ ~Resource() override {}
+
+ string DebugString() { return strings::StrCat("R/", label_); }
+
+ private:
+ string label_;
+};
+
+class Other : public ResourceBase {
+ public:
+ explicit Other(const string& label) : label_(label) {}
+ ~Other() override {}
+
+ string DebugString() { return strings::StrCat("O/", label_); }
+
+ private:
+ string label_;
+};
+
+template <typename T>
+string Find(const ResourceMgr& rm, const string& container,
+ const string& name) {
+ T* r;
+ TF_CHECK_OK(rm.Lookup(container, name, &r));
+ const string ret = r->DebugString();
+ r->Unref();
+ return ret;
+}
+
+template <typename T>
+string LookupOrCreate(ResourceMgr* rm, const string& container,
+ const string& name, const string& label) {
+ T* r;
+ TF_CHECK_OK(rm->LookupOrCreate<T>(container, name, &r, [&label](T** ret) {
+ *ret = new T(label);
+ return Status::OK();
+ }));
+ const string ret = r->DebugString();
+ r->Unref();
+ return ret;
+}
+
+static void HasError(const Status& s, const string& substr) {
+ EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+ << s << ", expected substring " << substr;
+}
+
+template <typename T>
+Status FindErr(const ResourceMgr& rm, const string& container,
+ const string& name) {
+ T* r;
+ Status s = rm.Lookup(container, name, &r);
+ CHECK(!s.ok());
+ return s;
+}
+
+TEST(ResourceMgrTest, Basic) {
+ ResourceMgr rm;
+ TF_CHECK_OK(rm.Create("foo", "bar", new Resource("cat")));
+ TF_CHECK_OK(rm.Create("foo", "baz", new Resource("dog")));
+ TF_CHECK_OK(rm.Create("foo", "bar", new Other("tiger")));
+
+ // Expected to fail.
+ HasError(rm.Create("foo", "bar", new Resource("kitty")),
+ "Already exists: Resource foo/bar");
+
+ // Expected to be found.
+ EXPECT_EQ("R/cat", Find<Resource>(rm, "foo", "bar"));
+ EXPECT_EQ("R/dog", Find<Resource>(rm, "foo", "baz"));
+ EXPECT_EQ("O/tiger", Find<Other>(rm, "foo", "bar"));
+
+ // Expected to be not found.
+ HasError(FindErr<Resource>(rm, "bar", "foo"), "Not found: Container bar");
+ HasError(FindErr<Resource>(rm, "foo", "xxx"), "Not found: Resource foo/xxx");
+ HasError(FindErr<Other>(rm, "foo", "baz"), "Not found: Resource foo/baz");
+
+ // Delete foo/bar/Resource.
+ TF_CHECK_OK(rm.Delete<Resource>("foo", "bar"));
+ HasError(FindErr<Resource>(rm, "foo", "bar"), "Not found: Resource foo/bar");
+
+ TF_CHECK_OK(rm.Create("foo", "bar", new Resource("kitty")));
+ EXPECT_EQ("R/kitty", Find<Resource>(rm, "foo", "bar"));
+
+ // Drop the whole container foo.
+ TF_CHECK_OK(rm.Cleanup("foo"));
+ HasError(FindErr<Resource>(rm, "foo", "bar"), "Not found: Container foo");
+}
+
+TEST(ResourceMgr, CreateOrLookup) {
+ ResourceMgr rm;
+ EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "cat"));
+ EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "dog"));
+ EXPECT_EQ("R/cat", Find<Resource>(rm, "foo", "bar"));
+
+ EXPECT_EQ("O/tiger", LookupOrCreate<Other>(&rm, "foo", "bar", "tiger"));
+ EXPECT_EQ("O/tiger", LookupOrCreate<Other>(&rm, "foo", "bar", "lion"));
+ TF_CHECK_OK(rm.Delete<Other>("foo", "bar"));
+ HasError(FindErr<Other>(rm, "foo", "bar"), "Not found: Resource foo/bar");
+}
+
+Status ComputePolicy(const string& attr_container,
+ const string& attr_shared_name,
+ bool use_node_name_as_default, string* result) {
+ ContainerInfo cinfo;
+ ResourceMgr rmgr;
+ NodeDef ndef;
+ ndef.set_name("foo");
+ if (attr_container != "none") {
+ AddNodeAttr("container", attr_container, &ndef);
+ }
+ if (attr_shared_name != "none") {
+ AddNodeAttr("shared_name", attr_shared_name, &ndef);
+ }
+ TF_RETURN_IF_ERROR(cinfo.Init(&rmgr, ndef, use_node_name_as_default));
+ *result = cinfo.DebugString();
+ return Status::OK();
+}
+
+string Policy(const string& attr_container, const string& attr_shared_name,
+ bool use_node_name_as_default) {
+ string ret;
+ TF_CHECK_OK(ComputePolicy(attr_container, attr_shared_name,
+ use_node_name_as_default, &ret));
+ return ret;
+}
+
+TEST(ContainerInfo, Basic) {
+ // Correct cases.
+ EXPECT_EQ(Policy("", "", false), "[localhost,_0_foo,private]");
+ EXPECT_EQ(Policy("", "", true), "[localhost,foo,public]");
+ EXPECT_EQ(Policy("", "bar", false), "[localhost,bar,public]");
+ EXPECT_EQ(Policy("", "bar", true), "[localhost,bar,public]");
+ EXPECT_EQ(Policy("cat", "", false), "[cat,_1_foo,private]");
+ EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]");
+ EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]");
+ EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]");
+}
+
+Status WrongPolicy(const string& attr_container, const string& attr_shared_name,
+ bool use_node_name_as_default) {
+ string dbg;
+ auto s = ComputePolicy(attr_container, attr_shared_name,
+ use_node_name_as_default, &dbg);
+ CHECK(!s.ok());
+ return s;
+}
+
+TEST(ContainerInfo, Error) {
+ // Missing attribute.
+ HasError(WrongPolicy("none", "", false), "No attr");
+ HasError(WrongPolicy("", "none", false), "No attr");
+ HasError(WrongPolicy("none", "none", false), "No attr");
+
+ // Invalid container.
+ HasError(WrongPolicy("12$%", "", false), "container contains invalid char");
+
+ // Invalid shared name.
+ HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/step_stats.proto b/tensorflow/core/framework/step_stats.proto
new file mode 100644
index 0000000000..78610350ec
--- /dev/null
+++ b/tensorflow/core/framework/step_stats.proto
@@ -0,0 +1,58 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/tensor_description.proto";
+
+// TODO(tucker): The next 4 message defs are very similar to
+// the *LogEntry messages in profile.proto. They should be
+// unified in one place.
+
+message AllocatorMemoryUsed {
+ string allocator_name = 1;
+ int64 total_bytes = 2;
+ int64 peak_bytes = 3;
+}
+
+enum AllocationType {
+ AT_NOTUSED = 0; // tensor was not filled in
+ AT_ALLOCATED = 1; // tensor was allocated by the Op
+ AT_EXISTING = 2; // tensor was set to share the value of an existing tensor
+ AT_REF = 3; // tensor was set to be a reference to an existing tensor
+}
+
+// Output sizes recorded for a single execution of a graph node.
+message NodeOutput {
+ int32 slot = 1;
+ // Was the tensor allocated by this Op or a previous computation
+ AllocationType allocation_type = 2;
+ TensorDescription tensor_description = 3;
+};
+
+// Time/size stats recorded for a single execution of a graph node.
+message NodeExecStats {
+ // TODO(tucker): Use some more compact form of node identity than
+ // the full string name. Either all processes should agree on a
+ // global id (cost_id?) for each node, or we should use a hash of
+ // the name.
+ string node_name = 1;
+ int64 all_start_micros = 2;
+ int64 op_start_rel_micros = 3;
+ int64 op_end_rel_micros = 4;
+ int64 all_end_rel_micros = 5;
+ repeated AllocatorMemoryUsed memory = 6;
+ repeated NodeOutput output = 7;
+ string timeline_label = 8;
+ int64 scheduled_micros = 9;
+ uint32 thread_id = 10;
+};
+
+message DeviceStepStats {
+ string device = 1;
+ repeated NodeExecStats node_stats = 2;
+}
+
+message StepStats {
+ repeated DeviceStepStats dev_stats = 1;
+};
diff --git a/tensorflow/core/framework/summary.proto b/tensorflow/core/framework/summary.proto
new file mode 100644
index 0000000000..0e6e659f2f
--- /dev/null
+++ b/tensorflow/core/framework/summary.proto
@@ -0,0 +1,67 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+// Serialization format for histogram module in
+// core/lib/histogram/histogram.h
+message HistogramProto {
+ double min = 1;
+ double max = 2;
+ double num = 3;
+ double sum = 4;
+ double sum_squares = 5;
+
+ // Parallel arrays encoding the bucket boundaries and the bucket values.
+ // bucket(i) is the count for the bucket i. The range for
+ // a bucket is:
+ // i == 0: -DBL_MAX .. bucket_limit(0)
+ // i != 0: bucket_limit(i-1) .. bucket_limit(i)
+ repeated double bucket_limit = 6 [packed = true];
+ repeated double bucket = 7 [packed = true];
+};
+
+// A Summary is a set of named values to be displayed by the
+// visualizer.
+//
+// Summaries are produced regularly during training, as controlled by
+// the "summary_interval_secs" attribute of the training operation.
+// Summaries are also produced at the end of an evaluation.
+message Summary {
+ message Image {
+ // Dimensions of the image.
+ int32 height = 1;
+ int32 width = 2;
+ // Valid colorspace values are
+ // 1 - grayscale
+ // 2 - grayscale + alpha
+ // 3 - RGB
+ // 4 - RGBA
+ // 5 - DIGITAL_YUV
+ // 6 - BGRA
+ int32 colorspace = 3;
+ // Image data in encoded format. All image formats supported by
+ // image_codec::CoderUtil can be stored here.
+ bytes encoded_image_string = 4;
+ }
+
+ message Value {
+ // Tag name for the data. Will be used as the title of the graph
+ // in the visualizer.
+ //
+ // Tag is usually "op_name:value_name", where "op_name" itself can have
+ // structure to indicate grouping.
+ string tag = 1;
+
+ // Value associated with the tag.
+ oneof value {
+ float simple_value = 2;
+ bytes obsolete_old_style_histogram = 3;
+ Image image = 4;
+ HistogramProto histo = 5;
+ }
+ }
+
+ // Set of values for the summary.
+ repeated Value value = 1;
+}
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
new file mode 100644
index 0000000000..4a1b65db97
--- /dev/null
+++ b/tensorflow/core/framework/tensor.cc
@@ -0,0 +1,570 @@
+// Implementation notes:
+//
+// Tensor.cc uses a few templated classes and structs to facilitate
+// implementation of the Tensor class.
+//
+// * Buffer<T>: provides the implementation for a typed array T[n].
+// The array is allocated by the given allocator. It runs T's
+// default constructors and destructors when T is not a simple type
+// (e.g., string.), and skips them otherwise.
+//
+// * Helper<T>: provides various routines given type T. The routines
+// includes running the constructor and destructor of T[], encoding
+// an decoding T[] into/from a Cord, etc.
+
+#include "tensorflow/core/public/tensor.h"
+
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/tensor_coding.h"
+
+namespace tensorflow {
+namespace {
+
+// Typed ref-counted buffer: T[n].
+template <typename T>
+class Buffer : public TensorBuffer {
+ public:
+ Buffer(Allocator* a, int64 n);
+
+ void* data() const override { return data_; }
+ size_t size() const override { return sizeof(T) * elem_; }
+ TensorBuffer* root_buffer() override { return this; }
+ void FillAllocationDescription(AllocationDescription* proto) const override {
+ int64 rb = size();
+ proto->set_requested_bytes(rb);
+ proto->set_allocator_name(alloc_->Name());
+ if (alloc_->TracksAllocationSizes()) {
+ int64 ab = alloc_->AllocatedSize(data_);
+ proto->set_allocated_bytes(ab);
+ }
+ }
+
+ private:
+ Allocator* alloc_;
+ T* data_;
+ int64 elem_;
+
+ ~Buffer() override;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
+};
+
+// is_simple<T>::value if T[] can be safely constructed and destructed
+// without running T() and ~T(). We do not use std::is_trivial<T>
+// directly because std::complex<float> is not trival but its array
+// can be constructed and destructed without running its default ctor
+// and dtor.
+template <typename T>
+struct is_simple {
+ static const bool value = std::is_trivial<T>::value ||
+ std::is_same<T, complex64>::value ||
+ is_quantized<T>::value;
+};
+
+template <>
+struct is_simple<bfloat16> {
+ static const bool value = true;
+};
+
+// A set of helper functions depending on T.
+template <typename T>
+struct Helper {
+ // By default, we assume T is a simple type (float, int32, etc.)
+ static_assert(is_simple<T>::value, "T is not a simple type.");
+ typedef protobuf::RepeatedField<T> RepeatedFieldType;
+
+ // No constructor to run.
+ static void RunCtor(T* p, int n) {}
+
+ // No destructor to run.
+ static void RunDtor(T* p, int n) {}
+
+ // Encoder of simple type T to a string. We do a copy.
+ template <typename Destination>
+ static void Encode(TensorBuffer* in, int64 n, Destination* out) {
+ DCHECK_EQ(in->size(), sizeof(T) * n);
+ port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
+ out);
+ }
+
+ // Decoder of simple type T. Copy the bytes from "in" into the
+ // tensor buffer.
+ template <typename Source>
+ static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
+ if (in.size() != sizeof(T) * n) {
+ LOG(ERROR) << "Input size was " << in.size() << " and expected "
+ << sizeof(T) * n;
+ return nullptr;
+ }
+ Buffer<T>* buf = new Buffer<T>(a, n);
+ port::CopyToArray(in, buf->template base<char>());
+ return buf;
+ }
+
+ // Memory usage.
+ static int64 TotalBytes(TensorBuffer* in, int64 n) {
+ DCHECK_EQ(in->size(), sizeof(T) * n);
+ return in->size();
+ }
+};
+
+// Helper specialization for string (the only non-simple type we
+// support).
+template <>
+struct Helper<string> {
+ // Proto message uses RepeatedFieldType to hold repeated T.
+ typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
+
+ // Runs string's default constructor for p[0], p[1], ..., p[n-1].
+ static void RunCtor(string* p, int n) {
+ for (int i = 0; i < n; ++p, ++i) new (p) string();
+ }
+
+ // Runs T's default destructor for p[0], p[1], ..., p[n-1].
+ static void RunDtor(string* p, int n) {
+ for (int i = 0; i < n; ++p, ++i) p->~string();
+ }
+
+ // Encodes "n" elements of type string stored in "in" into Cord
+ // "out", which is usually the TensorProto::tensor_content.
+ template <typename Destination>
+ static void Encode(TensorBuffer* in, int64 n, Destination* out) {
+ port::EncodeStringList(in->base<const string>(), n, out);
+ }
+
+ // Decodes "n" elements of type string from "in" and constructs a
+ // buffer out of it. Returns nullptr if the decoding fails. "in" is
+ // usually the TensorProto::tensor_content.
+ template <typename Source>
+ static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
+ Buffer<string>* buf = new Buffer<string>(a, n);
+ string* strings = buf->template base<string>();
+ if (port::DecodeStringList(in, strings, n)) {
+ return buf;
+ } else {
+ buf->Unref();
+ return nullptr;
+ }
+ }
+
+ // Returns the estimated memory usage of "n" elements of type T
+ // stored in buffer "in".
+ static int64 TotalBytes(TensorBuffer* in, int n) {
+ int64 tot = in->size();
+ DCHECK_EQ(tot, sizeof(string) * n);
+ const string* p = in->base<const string>();
+ for (int i = 0; i < n; ++i, ++p) tot += p->size();
+ return tot;
+ }
+};
+
+template <typename T>
+struct ProtoHelper {};
+
+// For a C++ type "T" (float, double, int32, etc.), the repeated field
+// "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
+// int32, string, etc) in the TensorProto is used for serializing the
+// tensor of type "T".
+#define PROTO_TRAITS(T, F, N) \
+ template <> \
+ struct ProtoHelper<T> { \
+ typedef Helper<F>::RepeatedFieldType FieldType; \
+ static FieldType::const_iterator Begin(const TensorProto& proto) { \
+ return proto.N##_val().begin(); \
+ } \
+ static size_t NumElements(const TensorProto& proto) { \
+ return proto.N##_val().size(); \
+ } \
+ static void Fill(const T* data, size_t n, TensorProto* proto) { \
+ typename ProtoHelper<T>::FieldType copy(data, data + n); \
+ proto->mutable_##N##_val()->Swap(&copy); \
+ } \
+ };
+PROTO_TRAITS(float, float, float);
+PROTO_TRAITS(double, double, double);
+PROTO_TRAITS(int32, int32, int);
+PROTO_TRAITS(uint8, int32, int);
+PROTO_TRAITS(int16, int32, int);
+PROTO_TRAITS(int8, int32, int);
+PROTO_TRAITS(int64, int64, int64);
+PROTO_TRAITS(bool, bool, bool);
+PROTO_TRAITS(string, string, string);
+PROTO_TRAITS(qint8, int32, int);
+PROTO_TRAITS(quint8, int32, int);
+#undef PROTO_TRAITS
+
+template <>
+struct ProtoHelper<complex64> {
+ typedef Helper<float>::RepeatedFieldType FieldType;
+ static const complex64* Begin(const TensorProto& proto) {
+ return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
+ }
+ static size_t NumElements(const TensorProto& proto) {
+ return proto.scomplex_val().size() / 2;
+ }
+ static void Fill(const complex64* data, size_t n, TensorProto* proto) {
+ const float* p = reinterpret_cast<const float*>(data);
+ FieldType copy(p, p + n * 2);
+ proto->mutable_scomplex_val()->Swap(&copy);
+ }
+};
+
+template <>
+struct ProtoHelper<qint32> {
+ typedef Helper<int32>::RepeatedFieldType FieldType;
+ static const qint32* Begin(const TensorProto& proto) {
+ return reinterpret_cast<const qint32*>(proto.int_val().data());
+ }
+ static size_t NumElements(const TensorProto& proto) {
+ return proto.int_val().size();
+ }
+ static void Fill(const qint32* data, size_t n, TensorProto* proto) {
+ const int32* p = reinterpret_cast<const int32*>(data);
+ FieldType copy(p, p + n);
+ proto->mutable_int_val()->Swap(&copy);
+ }
+};
+
+template <>
+struct ProtoHelper<bfloat16> {
+ typedef Helper<float>::RepeatedFieldType FieldType;
+ static const bfloat16* Begin(const TensorProto& proto) {
+ return reinterpret_cast<const bfloat16*>(proto.int_val().data());
+ }
+ static size_t NumElements(const TensorProto& proto) {
+ return proto.int_val().size();
+ }
+ static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
+ proto->mutable_int_val()->Reserve(n);
+ for (size_t i = 0; i < n; ++i) {
+ proto->mutable_int_val()->AddAlreadyReserved(data[i].value);
+ }
+ }
+};
+
+template <typename T>
+Buffer<T>::Buffer(Allocator* a, int64 n)
+ : alloc_(a), data_(a->Allocate<T>(n)), elem_(n) {
+ if (data_) Helper<T>::RunCtor(data_, elem_);
+}
+
+template <typename T>
+Buffer<T>::~Buffer() {
+ if (data_) {
+ Helper<T>::RunDtor(data_, elem_);
+ alloc_->Deallocate<T>(data_);
+ }
+}
+
+// Allocates a T[n] buffer. Fills in the buffer with repeated values
+// in "in". If "in" has less values than "n", fills the rest of T[n]
+// with the last value. If "in" has no values, fills T[n] with the
+// default value for T.
+//
+// This routine is using the typed fields (float_val, etc.) in the
+// tenor proto as opposed to the untyped binary representation
+// (tensor_content). This is used when we expect the TensorProto is
+// used by a client program which may not know how to encode a tensor
+// in the compact binary representation.
+template <typename T>
+TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
+ CHECK_GT(n, 0);
+ Buffer<T>* buf = new Buffer<T>(a, n);
+ T* data = buf->template base<T>();
+ const int64 in_n = ProtoHelper<T>::NumElements(in);
+ auto begin = ProtoHelper<T>::Begin(in);
+ if (n <= in_n) {
+ std::copy_n(begin, n, data);
+ } else if (in_n > 0) {
+ std::copy_n(begin, in_n, data);
+ const T& last = *(data + in_n - 1);
+ std::fill_n(data + in_n, n - in_n, last);
+ } else {
+ std::fill_n(data, n, T());
+ }
+ return buf;
+}
+
+// Copies T[n] stored in the buffer "in" into the repeated field in
+// "out" corresponding to type T.
+template <typename T>
+void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
+ const T* data = in.base<const T>();
+ // NOTE: T may not the same as
+ // ProtoHelper<T>::FieldType::value_type. E.g., T==int16,
+ // ProtoHelper<T>::FieldType::value_type==int32. If performance is
+ // critical, we can specialize T=float and do memcpy directly.
+ ProtoHelper<T>::Fill(data, n, out);
+}
+
+void RefIfNonNull(core::RefCounted* buf) {
+ if (buf) buf->Ref();
+}
+
+void UnrefIfNonNull(core::RefCounted* buf) {
+ if (buf) buf->Unref();
+}
+
+} // end namespace
+
+Tensor::Tensor() : Tensor(DT_FLOAT) {}
+
+Tensor::Tensor(DataType type) : type_(type), shape_({0}), buf_(nullptr) {}
+
+Tensor::Tensor(const Tensor& other)
+ : type_(other.dtype()), shape_(other.shape()), buf_(other.buf_) {
+ RefIfNonNull(buf_);
+}
+
+Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
+ : type_(type), shape_(shape), buf_(buf) {
+ RefIfNonNull(buf);
+}
+
+bool Tensor::IsInitialized() const {
+ return buf_ != nullptr && buf_->data() != nullptr;
+}
+
+Tensor::~Tensor() { UnrefIfNonNull(buf_); }
+
+void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
+ CHECK_EQ(shape.num_elements(), other.NumElements());
+ type_ = other.dtype();
+ shape_ = shape;
+ if (buf_ != other.buf_) {
+ UnrefIfNonNull(buf_);
+ buf_ = other.buf_;
+ RefIfNonNull(buf_);
+ }
+}
+
+// The macro CASES() expands to a switch statement conditioned on
+// TYPE_ENUM. Each case expands the STMTS after a typedef for T.
+#define SINGLE_ARG(...) __VA_ARGS__
+#define CASE(TYPE, STMTS) \
+ case DataTypeToEnum<TYPE>::value: { \
+ typedef TYPE T; \
+ STMTS; \
+ break; \
+ }
+#define CASES(TYPE_ENUM, STMTS) \
+ switch (TYPE_ENUM) { \
+ CASE(float, SINGLE_ARG(STMTS)) \
+ CASE(double, SINGLE_ARG(STMTS)) \
+ CASE(int32, SINGLE_ARG(STMTS)) \
+ CASE(uint8, SINGLE_ARG(STMTS)) \
+ CASE(int16, SINGLE_ARG(STMTS)) \
+ CASE(int8, SINGLE_ARG(STMTS)) \
+ CASE(string, SINGLE_ARG(STMTS)) \
+ CASE(complex64, SINGLE_ARG(STMTS)) \
+ CASE(int64, SINGLE_ARG(STMTS)) \
+ CASE(bool, SINGLE_ARG(STMTS)) \
+ CASE(qint32, SINGLE_ARG(STMTS)) \
+ CASE(quint8, SINGLE_ARG(STMTS)) \
+ CASE(qint8, SINGLE_ARG(STMTS)) \
+ CASE(bfloat16, SINGLE_ARG(STMTS)) \
+ case DT_INVALID: \
+ LOG(FATAL) << "Type not set"; \
+ break; \
+ default: \
+ LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
+ break; \
+ }
+
+Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
+ : type_(type), shape_(shape), buf_(nullptr) {
+ CHECK_NOTNULL(a);
+ if (shape_.num_elements() > 0) {
+ CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
+ }
+}
+
+Tensor::Tensor(DataType type, const TensorShape& shape)
+ : Tensor(cpu_allocator(), type, shape) {}
+
+template <typename T>
+class SubBuffer : public TensorBuffer {
+ public:
+ // This buffer is an alias to buf[delta, delta + n).
+ SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
+ : root_(buf->root_buffer()), data_(buf->base<T>() + delta), elem_(n) {
+ // Sanity check. The caller should ensure the sub buffer is valid.
+ CHECK_LE(root_->base<T>(), this->base<T>());
+ T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
+ CHECK_LE(this->base<T>(), root_limit);
+ CHECK_LE(this->base<T>() + n, root_limit);
+ // Hold a ref of the underlying root buffer.
+ // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
+ root_->Ref();
+ }
+
+ void* data() const override { return data_; }
+ size_t size() const override { return sizeof(T) * elem_; }
+ TensorBuffer* root_buffer() override { return root_; }
+ void FillAllocationDescription(AllocationDescription* proto) const override {
+ root_->FillAllocationDescription(proto);
+ }
+
+ private:
+ TensorBuffer* root_;
+ T* data_;
+ int64 elem_;
+
+ ~SubBuffer() override { root_->Unref(); }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
+};
+
+Tensor Tensor::Slice(int64 start, int64 limit) const {
+ CHECK_GE(dims(), 1);
+ CHECK_LE(0, start);
+ CHECK_LE(start, limit);
+ int64 dim0_size = shape_.dim_size(0);
+ CHECK_LE(limit, dim0_size);
+ if ((start == 0) && (limit == dim0_size)) {
+ return *this;
+ }
+ Tensor ret;
+ ret.type_ = type_;
+ ret.shape_ = shape_;
+ ret.buf_ = nullptr;
+ if (dim0_size > 0) {
+ const int64 elems_per_dim0 = NumElements() / dim0_size;
+ const int64 delta = start * elems_per_dim0;
+ dim0_size = limit - start;
+ ret.shape_.set_dim(0, dim0_size);
+ const int64 num_elems = dim0_size * elems_per_dim0;
+ if (buf_) {
+ CASES(type_, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
+ }
+ }
+ return ret;
+}
+
+bool Tensor::FromProto(const TensorProto& proto) {
+ return FromProto(cpu_allocator(), proto);
+}
+
+bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
+ CHECK_NOTNULL(a);
+ TensorBuffer* p = nullptr;
+ if (!TensorShape::IsValid(proto.tensor_shape())) return false;
+ if (proto.dtype() == DT_INVALID) return false;
+ TensorShape shape(proto.tensor_shape());
+ const int64 N = shape.num_elements();
+ if (N > 0 && proto.dtype()) {
+ if (!proto.tensor_content().empty()) {
+ const auto& content = proto.tensor_content();
+ CASES(proto.dtype(), p = Helper<T>::Decode(a, content, N));
+ } else {
+ CASES(proto.dtype(), p = FromProtoField<T>(a, proto, N));
+ }
+ if (p == nullptr) return false;
+ }
+ type_ = proto.dtype();
+ shape_ = shape;
+ UnrefIfNonNull(buf_);
+ buf_ = p;
+ return true;
+}
+
+void Tensor::AsProtoField(TensorProto* proto) const {
+ proto->Clear();
+ proto->set_dtype(dtype());
+ shape_.AsProto(proto->mutable_tensor_shape());
+ if (buf_) {
+ CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
+ }
+}
+
+void Tensor::AsProtoTensorContent(TensorProto* proto) const {
+ proto->Clear();
+ proto->set_dtype(type_);
+ shape_.AsProto(proto->mutable_tensor_shape());
+ if (buf_) {
+ CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
+ proto->mutable_tensor_content()));
+ }
+}
+
+size_t Tensor::TotalBytes() const {
+ if (shape_.num_elements() == 0) return 0;
+ CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
+ CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
+ return 0; // Makes compiler happy.
+}
+
+bool Tensor::CanUseDMA() const {
+ CASES(dtype(), return is_simple<T>::value);
+ return false; // Makes compiler happy.
+}
+
+#undef CASES
+#undef CASE
+
+string Tensor::SummarizeValue(int64 max_entries) const {
+ string ret;
+ for (int64 i = 0; i < std::min(max_entries, NumElements()); ++i) {
+ if (i > 0) strings::StrAppend(&ret, " ");
+ switch (dtype()) {
+ case DT_STRING:
+ strings::StrAppend(&ret, str_util::CEscape(flat<string>()(i)));
+ break;
+ case DT_BOOL:
+ strings::StrAppend(&ret, flat<bool>()(i) ? "True" : "False");
+ break;
+
+#define CASE(DT_ENUM) \
+ case DT_ENUM: \
+ strings::StrAppend(&ret, flat<EnumToDataType<DT_ENUM>::Type>()(i)); \
+ break
+
+ CASE(DT_FLOAT);
+ CASE(DT_DOUBLE);
+ CASE(DT_INT32);
+ CASE(DT_UINT8);
+ CASE(DT_INT16);
+ CASE(DT_INT8);
+ CASE(DT_INT64);
+
+#undef CASE
+ default:
+ // TODO(zhifengc, josh11b): Pretty-print other types (bool,
+ // complex64, quantized, bfloat16).
+ strings::StrAppend(&ret, " ?");
+ }
+ }
+ if (max_entries < NumElements()) strings::StrAppend(&ret, "...");
+
+ return ret;
+}
+
+StringPiece Tensor::tensor_data() const {
+ if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors
+ return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
+}
+
+string Tensor::DebugString() const {
+ return strings::StrCat("Tensor<type: ", DataTypeString(dtype()), " shape: ",
+ shape().ShortDebugString(), " values: ",
+ SummarizeValue(3), ">");
+}
+
+void Tensor::FillDescription(TensorDescription* description) const {
+ description->set_dtype(dtype());
+ shape().AsProto(description->mutable_shape());
+ buf_->FillAllocationDescription(
+ description->mutable_allocation_description());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor.proto b/tensorflow/core/framework/tensor.proto
new file mode 100644
index 0000000000..b42694afde
--- /dev/null
+++ b/tensorflow/core/framework/tensor.proto
@@ -0,0 +1,57 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+
+// Protocol buffer representing a tensor.
+message TensorProto {
+ DataType dtype = 1;
+
+ // Shape of the tensor. TODO(mdevin): sort out the 0-rank issues.
+ TensorShapeProto tensor_shape = 2;
+
+ // Only one of the representations below is set, one of "tensor_contents" and
+ // the "xxx_val" attributes. We are not using oneof because as oneofs cannot
+ // contain repeated fields it would require another extra set of messages.
+
+ // Version number.
+ //
+ // In version 0, if the "repeated xxx" representations contain only one
+ // element, that element is repeated to fill the shape. This makes it easy
+ // to represent a constant Tensor with a single value.
+ int32 version_number = 3;
+
+ // Serialized content from TensorBase::Serialize() This representation can be
+ // used for all tensor types.
+ bytes tensor_content = 4;
+
+ // Type specific representations that make it easy to create tensor protos in
+ // all languages. Only the representation corresponding to "dtype" can
+ // be set. The values hold the flattened representation of the tensor in
+ // row major order.
+
+ // DT_FLOAT.
+ repeated float float_val = 5 [packed = true];
+
+ // DT_DOUBLE.
+ repeated double double_val = 6 [packed = true];
+
+ // DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
+ repeated int32 int_val = 7 [packed = true];
+
+ // DT_STRING
+ repeated bytes string_val = 8;
+
+ // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
+ // and imaginary parts of i-th single precision complex.
+ repeated float scomplex_val = 9 [packed = true];
+
+ // DT_INT64
+ repeated int64 int64_val = 10 [packed = true];
+
+ // DT_BOOL
+ repeated bool bool_val = 11 [packed = true];
+};
diff --git a/tensorflow/core/framework/tensor_description.proto b/tensorflow/core/framework/tensor_description.proto
new file mode 100644
index 0000000000..1fff3ee155
--- /dev/null
+++ b/tensorflow/core/framework/tensor_description.proto
@@ -0,0 +1,19 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/types.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/allocation_description.proto";
+
+message TensorDescription {
+ // Data type of tensor elements
+ DataType dtype = 1;
+
+ // Shape of the tensor.
+ TensorShapeProto shape = 2;
+
+ // Information about the size and allocator used for the data
+ AllocationDescription allocation_description = 4;
+};
diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc
new file mode 100644
index 0000000000..3db2ffaaca
--- /dev/null
+++ b/tensorflow/core/framework/tensor_shape.cc
@@ -0,0 +1,138 @@
+#include "tensorflow/core/public/tensor_shape.h"
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+// An upper limit of the total number of elements in a tensor.
+static const int64 kMaxElements = (1LL << 40);
+
+bool TensorShape::IsValid(const TensorShapeProto& proto) {
+ int64 num_elements = 1;
+ for (const auto& d : proto.dim()) {
+ if (d.size() < 0) return false;
+ num_elements *= d.size();
+ if (num_elements > kMaxElements) return false;
+ }
+ return true;
+}
+
+TensorShape::TensorShape(const TensorShapeProto& proto) {
+ dim_sizes_.reserve(proto.dim_size());
+ num_elements_ = 1;
+ for (const auto& d : proto.dim()) {
+ AddDim(d.size());
+ }
+}
+
+TensorShape::TensorShape(gtl::ArraySlice<int64> dim_sizes) {
+ dim_sizes_.reserve(dim_sizes.size());
+ num_elements_ = 1;
+ for (auto s : dim_sizes) {
+ AddDim(s);
+ }
+}
+
+TensorShape::TensorShape() : num_elements_(1) {}
+
+void TensorShape::Clear() {
+ dim_sizes_.clear();
+ num_elements_ = 1;
+}
+
+void TensorShape::AddDim(int64 size) {
+ CHECK_GE(size, 0);
+ dim_sizes_.push_back(size);
+ num_elements_ *= size;
+ CHECK_LE(0, num_elements_);
+ CHECK_LE(num_elements_, kMaxElements);
+}
+
+void TensorShape::AppendShape(const TensorShape& shape) {
+ for (auto d : shape) AddDim(d.size);
+}
+
+void TensorShape::InsertDim(int d, int64 size) {
+ CHECK_GE(d, 0);
+ CHECK_LE(d, dims());
+ CHECK_GE(size, 0);
+ dim_sizes_.insert(dim_sizes_.begin() + d, size);
+ num_elements_ *= size;
+ CHECK_LE(0, num_elements_);
+ CHECK_LE(num_elements_, kMaxElements);
+}
+
+void TensorShape::set_dim(int d, int64 size) {
+ CHECK_GE(d, 0);
+ CHECK_LT(d, dims());
+ CHECK_GE(size, 0);
+
+ // Update the number of elements. num_elements_ is int64.
+ dim_sizes_[d] = size;
+ recompute_dims();
+}
+
+void TensorShape::RemoveDim(int d) {
+ CHECK_GE(d, 0);
+ CHECK_LT(d, dims());
+
+ // Update the number of elements and remove the dimension from the
+ // sizes.
+ dim_sizes_.erase(dim_sizes_.begin() + d);
+ recompute_dims();
+}
+
+void TensorShape::recompute_dims() {
+ num_elements_ = 1;
+ for (auto s : dim_sizes_) {
+ num_elements_ *= s;
+ CHECK_LE(0, num_elements_);
+ CHECK_LE(num_elements_, kMaxElements);
+ }
+}
+
+bool TensorShape::IsSameSize(const TensorShape& b) const {
+ if (b.dims() != dims()) return false;
+ for (int d = 0; d < dims(); d++) {
+ if (dim_size(d) != b.dim_size(d)) return false;
+ }
+ return true;
+}
+
+void TensorShape::AsProto(TensorShapeProto* proto) const {
+ proto->Clear();
+ for (size_t d = 0; d < dim_sizes_.size(); ++d) {
+ auto* dim = proto->add_dim();
+ dim->set_size(dim_sizes_[d]);
+ }
+}
+
+TensorShapeIter TensorShape::begin() const { return TensorShapeIter(this, 0); }
+
+TensorShapeIter TensorShape::end() const {
+ return TensorShapeIter(this, dims());
+}
+
+string TensorShape::DebugString() const {
+ TensorShapeProto proto;
+ AsProto(&proto);
+ return proto.ShortDebugString();
+}
+
+string TensorShape::ShortDebugString() const {
+ return strings::StrCat(
+ "[", str_util::Join(gtl::ArraySlice<int64>(dim_sizes_), ","), "]");
+}
+
+bool TensorShapeUtils::StartsWith(const TensorShape& shape,
+ const TensorShape& prefix) {
+ if (shape.dims() < prefix.dims()) return false;
+ for (int i = 0; i < prefix.dims(); i++) {
+ if (shape.dim_size(i) != prefix.dim_size(i)) return false;
+ }
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_shape.proto b/tensorflow/core/framework/tensor_shape.proto
new file mode 100644
index 0000000000..8fe7cce13d
--- /dev/null
+++ b/tensorflow/core/framework/tensor_shape.proto
@@ -0,0 +1,29 @@
+// Protocol buffer representing the shape of tensors.
+
+syntax = "proto3";
+// option cc_enable_arenas = true;
+
+package tensorflow;
+
+// Dimensions of a tensor and the type of data it contains.
+message TensorShapeProto {
+ // One dimension of the tensor.
+ message Dim {
+ // Size of the tensor in that dimension.
+ int64 size = 1;
+
+ // Optional name of the tensor dimension.
+ string name = 2;
+ };
+
+ // Dimensions of the tensor, such as {"input", 30}, {"output", 40} for a 30 x
+ // 40 2D tensor. The names are optional.
+ //
+ // The order of entries in "dim" matters: It indicates the layout of the
+ // values in the tensor in-memory representation.
+ //
+ // The first entry in "dim" is the outermost dimension used to layout the
+ // values, the last entry is the innermost dimension. This matches the
+ // in-memory layout of RowMajor Eigen tensors.
+ repeated Dim dim = 2;
+};
diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc
new file mode 100644
index 0000000000..adac1a4787
--- /dev/null
+++ b/tensorflow/core/framework/tensor_shape_test.cc
@@ -0,0 +1,75 @@
+#include "tensorflow/core/public/tensor_shape.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(TensorShapeTest, Default) {
+ // The default TensorShape constructor constructs a shape of 0-dim
+ // and 1-element.
+ TensorShape s;
+ EXPECT_EQ(s.dims(), 0);
+ EXPECT_EQ(s.num_elements(), 1);
+}
+
+TEST(TensorShapeTest, set_dim) {
+ TensorShape s({10, 5});
+
+ s.set_dim(0, 20);
+ ASSERT_EQ(2, s.dims());
+ EXPECT_EQ(20, s.dim_size(0));
+ EXPECT_EQ(100, s.num_elements());
+
+ s.set_dim(1, 2);
+ ASSERT_EQ(2, s.dims());
+ EXPECT_EQ(2, s.dim_size(1));
+ EXPECT_EQ(40, s.num_elements());
+}
+
+TEST(TensorShapeTest, RemoveDim) {
+ TensorShape s({10, 5});
+ s.RemoveDim(0);
+ EXPECT_EQ(5, s.num_elements());
+ ASSERT_EQ(1, s.dims());
+}
+
+TEST(TensorShapeTest, RemoveAndAddDim) {
+ TensorShape s({10, 5, 20});
+ s.RemoveDim(1);
+ s.AddDim(100);
+
+ EXPECT_EQ(20000, s.num_elements());
+ ASSERT_EQ(3, s.dims());
+}
+
+TEST(TensorShapeTest, InvalidShapeProto) {
+ TensorShapeProto proto;
+ EXPECT_TRUE(TensorShape::IsValid(proto));
+
+ proto.add_dim()->set_size(357);
+ proto.add_dim()->set_size(982);
+ EXPECT_TRUE(TensorShape::IsValid(proto));
+
+ proto.Clear();
+ proto.add_dim()->set_size(-357);
+ proto.add_dim()->set_size(-982);
+ EXPECT_FALSE(TensorShape::IsValid(proto));
+
+ proto.Clear();
+ proto.add_dim()->set_size(1LL << 20);
+ proto.add_dim()->set_size((1LL << 20) + 1);
+ EXPECT_FALSE(TensorShape::IsValid(proto));
+}
+
+TEST(TensorShapeTest, SetDimForEmptyTensor) {
+ TensorShape s({10, 5, 20});
+ EXPECT_EQ(1000, s.num_elements());
+ s.set_dim(1, 0);
+ EXPECT_EQ(0, s.num_elements());
+ s.set_dim(1, 7);
+ EXPECT_EQ(1400, s.num_elements());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_slice.cc b/tensorflow/core/framework/tensor_slice.cc
new file mode 100644
index 0000000000..473d9463ee
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice.cc
@@ -0,0 +1,226 @@
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
+
+TensorSlice::TensorSlice(const TensorSliceProto& proto) {
+ starts_.reserve(proto.extent_size());
+ lengths_.reserve(proto.extent_size());
+ for (const auto& e : proto.extent()) {
+ starts_.push_back(e.start());
+ lengths_.push_back(GetExtentLength(e));
+ }
+}
+
+TensorSlice::TensorSlice(std::initializer_list<std::pair<int, int>> extents) {
+ starts_.reserve(extents.size());
+ lengths_.reserve(extents.size());
+ for (const auto& e : extents) {
+ starts_.push_back(e.first);
+ lengths_.push_back(e.second);
+ }
+}
+
+Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
+ std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
+ slice->starts_.reserve(items.size());
+ slice->lengths_.reserve(items.size());
+ for (const string& x : items) {
+ int s, l;
+ if (x == "-") {
+ // "everything"
+ s = 0;
+ l = kFullExtent;
+ } else {
+ char junk;
+ if (sscanf(x.c_str(), "%d,%d%c", &s, &l, &junk) != 2) {
+ return errors::InvalidArgument(
+ "Expected a pair of numbers or '-' "
+ "but got '",
+ x, "': string = ", str);
+ }
+ if (s < 0 || l <= 0) {
+ return errors::InvalidArgument(
+ "Expected non-negative start and "
+ "positive length but got start = ",
+ s, ", length = ", l, ": string = ", str);
+ }
+ }
+ slice->starts_.push_back(s);
+ slice->lengths_.push_back(l);
+ }
+
+ return Status::OK();
+}
+
+void TensorSlice::Clear() {
+ starts_.clear();
+ lengths_.clear();
+}
+
+void TensorSlice::SetFullSlice(int dim) {
+ Clear();
+ starts_.reserve(dim);
+ lengths_.reserve(dim);
+ for (int d = 0; d < dim; ++d) {
+ starts_.push_back(0);
+ lengths_.push_back(kFullExtent);
+ }
+}
+
+void TensorSlice::Extend(int dim) {
+ int old_dim = dims();
+ DCHECK_LE(old_dim, dim);
+ starts_.resize(dim);
+ lengths_.resize(dim);
+ for (int d = old_dim; d < dim; ++d) {
+ starts_[d] = 0;
+ lengths_[d] = kFullExtent;
+ }
+}
+
+void TensorSlice::AsProto(TensorSliceProto* proto) const {
+ for (int d = 0; d < dims(); ++d) {
+ TensorSliceProto::Extent* e = proto->add_extent();
+ // We only need to record the explicit slice for non-full slices
+ if (!IsFullAt(d)) {
+ e->set_start(starts_[d]);
+ e->set_length(lengths_[d]);
+ }
+ }
+}
+
+string TensorSlice::DebugString() const {
+ string buffer;
+ bool first = true;
+ for (int d = 0; d < dims(); ++d) {
+ if (!first) {
+ buffer.append(":");
+ }
+ string s;
+ if (IsFullAt(d)) {
+ buffer.append("-");
+ } else {
+ strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
+ }
+ first = false;
+ }
+ return buffer;
+}
+
+bool TensorSlice::Intersect(const TensorSlice& other,
+ TensorSlice* result) const {
+ // First, if two slices have different ranks, they obviously don't overlap
+ // -- in fact they are not compatible.
+ if (dims() != other.dims()) {
+ return false;
+ }
+
+ // Setting the result to the right dimension
+ if (result) {
+ result->SetFullSlice(dims());
+ }
+ // The two slices overlap if they overlap in all dimensions.
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ if (result) {
+ result->set_start(d, other.start(d));
+ result->set_length(d, other.length(d));
+ }
+ } else if (other.IsFullAt(d)) {
+ if (result) {
+ result->set_start(d, start(d));
+ result->set_length(d, length(d));
+ }
+ } else {
+ // If we have an intersection here, it should have a start that is the
+ // max of the two starts and an end that is the min of the two ends.
+ int s = std::max(start(d), other.start(d));
+ int l = std::min(end(d), other.end(d)) - s;
+ if (l > 0) {
+ // We have a real intersection
+ if (result) {
+ result->set_start(d, s);
+ result->set_length(d, l);
+ }
+ } else {
+ // We don't have an intersection for this dimension -- thus we don't
+ // have any intersection at all.
+ if (result) {
+ result->Clear();
+ }
+ return false;
+ }
+ }
+ }
+ // If we are here, we know there is overlap in every dimension.
+ return true;
+}
+
+void TensorSlice::ComputeRelative(const TensorSlice& sub,
+ TensorSlice* relative) const {
+ DCHECK_EQ(dims(), sub.dims());
+ relative->SetFullSlice(dims());
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ relative->set_start(d, sub.start(d));
+ relative->set_length(d, sub.length(d));
+ } else {
+ // Otherwise the relative start is the difference between the start of
+ // sub and the start of base
+ relative->set_start(d, sub.start(d) - start(d));
+ relative->set_length(d, sub.length(d));
+ }
+ }
+}
+
+// static
+bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
+ return extent.has_length_case() == TensorSliceProto::Extent::kLength;
+}
+
+// static
+int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
+ if (!HasExtentLength(extent)) return -1;
+ return extent.length();
+}
+
+Status TensorSlice::SliceTensorShape(const TensorShape& shape,
+ TensorShape* result_shape) const {
+ result_shape->Clear();
+ // Mismatching ranks: we can't apply the slice at all.
+ if (shape.dims() != dims()) {
+ return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
+ ", slice = ", DebugString());
+ }
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ result_shape->AddDim(shape.dim_size(d));
+ } else {
+ // Check if the extent applies to the dimension
+ if (end(d) <= shape.dim_size(d)) {
+ // Yes: the end is within the range of the dim -- we adjust the result
+ // shape so that its size along this dimension is the length of the
+ // slice.
+ result_shape->AddDim(length(d));
+ } else {
+ // The extent doesn't apply to the dimension
+ result_shape->Clear();
+ return errors::Internal("Extent in dimension ", d,
+ " out of bounds: shape = ", shape.DebugString(),
+ ", slice = ", DebugString());
+ }
+ }
+ }
+ // If we are here, we have successfully applied the shape.
+ return Status::OK();
+}
+
+const int TensorSlice::kFullExtent = -1;
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h
new file mode 100644
index 0000000000..8e2f108c3f
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice.h
@@ -0,0 +1,189 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+
+#include <string>
+#include "tensorflow/core/framework/tensor_slice.pb.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// A tensor slice represents a slice of a given tensor. It is represented by a
+// list of (start, length) pairs, where the size of the list is the rank of the
+// tensor.
+
+class TensorSlice {
+ public:
+ // Construct a tensor slice: you have a number of ways:
+ // -- creating an empty slice
+ // -- from just a dimension (in this case it will create a full slice)
+ // -- from an array of pairs of integers.
+ // -- from a TensorSliceProto protocol buffer
+ // -- from a string format of "start,lenth:start,length..." where each
+ // "start,length" pair represents the slice on one dimension. We allow a
+ // special "-" that means "everything for this dimension". One such example
+ // is: 0,10:-:14,1:-:-
+ TensorSlice() {}
+ explicit TensorSlice(int dim);
+ explicit TensorSlice(const TensorSliceProto& proto);
+ explicit TensorSlice(std::initializer_list<std::pair<int, int>> extents);
+
+ static Status Parse(const string& str, TensorSlice* output);
+ static TensorSlice ParseOrDie(const string& str) {
+ TensorSlice ret;
+ Status s = Parse(str, &ret);
+ if (!s.ok()) {
+ LOG(FATAL) << "Could not parse TensorSlice";
+ }
+ return ret;
+ }
+
+ void Clear();
+
+ // Accessors
+ int dims() const { return starts_.size(); }
+
+ int start(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return starts_[d];
+ }
+
+ int length(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return lengths_[d];
+ }
+
+ int end(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return start(d) + length(d);
+ }
+
+ void set_start(int d, int x) {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ DCHECK_GE(x, 0);
+ starts_[d] = x;
+ }
+
+ void set_length(int d, int x) {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ lengths_[d] = x;
+ }
+
+ // If we have a full slice along dimension "d".
+ bool IsFullAt(int d) const { return lengths_[d] < 0; }
+
+ // Set the slice to be a full slice of "dim" dimensions
+ void SetFullSlice(int dim);
+
+ // Extend a slice to "dim" dimensions: all the added dimensions are full.
+ // Requires: dim >= dims().
+ void Extend(int dim);
+
+ // Conversion of a TensorSlice to other formats
+ void AsProto(TensorSliceProto* proto) const;
+ string DebugString() const;
+
+ // Fill *indices and *sizes from *this (so that we can use the slice()
+ // function in eigen tensor). We need a tensor shape in case some of the
+ // slices are full slices.
+ // We allow NDIMS to be greater than dims(), in which case we will pad the
+ // higher dimensions with trivial dimensions.
+ template <int NDIMS>
+ void FillIndicesAndSizes(const TensorShape& shape,
+ Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
+ Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const;
+
+ // Interaction with other TensorSlices.
+
+ // Compute the intersection with another slice and if "result" is not
+ // nullptr, store the results in *result; returns true is there is any real
+ // intersection.
+ bool Intersect(const TensorSlice& other, TensorSlice* result) const;
+ // A short hand.
+ bool Overlaps(const TensorSlice& other) const {
+ return Intersect(other, nullptr);
+ }
+
+ // Interaction with TensorShape.
+
+ // Slices a shape and stores the result into *result_shape.
+ // Requires that the shape and *this have the same rank.
+ // For example, given a tensor shape of {3, 4, 5}, and a slice of
+ // 1,2:-:0,2, the result shape is {2, 4, 2}.
+ Status SliceTensorShape(const TensorShape& shape,
+ TensorShape* result_shape) const;
+
+ // Given slice "sub" where "sub" is fully contained in *this,
+ // (meaning that the intersection of "sub" and *this equals "sub"), computes
+ // the "relative" slice of "sub" with respect to *this.
+ //
+ // In other words, if we use A>S to denote slicing a shape S with a slice A,
+ // then the function is computing a slice X such that:
+ // X > (this > S) = sub > S
+ // for any shape S.
+ //
+ // In general, along every dimension, the start of the relative slice is the
+ // start of the "sub" slice minus the start of *this; the length of the
+ // relative slice is the length of the "sub" slice.
+ //
+ // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and
+ // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2.
+ //
+ // The caller needs to make sure that "sub" is indeed a sub-slice of *this;
+ // otherwise the result is undefined.
+ void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const;
+
+ // Returns true if the length field was specified in an Extent.
+ static bool HasExtentLength(const TensorSliceProto::Extent& extent);
+
+ // Returns the value of the length field in an Extent, or -1 if it
+ // is not present.
+ static int64 GetExtentLength(const TensorSliceProto::Extent& extent);
+
+ private:
+ // a length value of kFullExtent (-1) means we have a full slice at this
+ // dimension. It's defined in tensor_slice.cc.
+ static const int kFullExtent;
+
+ // TODO(yangke): switch to Eigen once it supports variable size arrays.
+ // A value of
+ gtl::InlinedVector<int, 4> starts_;
+ gtl::InlinedVector<int, 4> lengths_;
+};
+
+template <int NDIMS>
+void TensorSlice::FillIndicesAndSizes(
+ const TensorShape& shape, Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
+ Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const {
+ CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape "
+ << "slices: shape = " << shape.DebugString()
+ << ", slice = " << DebugString();
+ CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from "
+ << "a slice of dimension " << dims();
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ (*indices)[d] = 0;
+ (*sizes)[d] = shape.dim_size(d);
+ } else {
+ (*indices)[d] = starts_[d];
+ (*sizes)[d] = lengths_[d];
+ }
+ }
+ for (int d = dims(); d < NDIMS; ++d) {
+ (*indices)[d] = 0;
+ (*sizes)[d] = 1;
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
diff --git a/tensorflow/core/framework/tensor_slice.proto b/tensorflow/core/framework/tensor_slice.proto
new file mode 100644
index 0000000000..ca676bc766
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice.proto
@@ -0,0 +1,34 @@
+// Protocol buffer representing slices of a tensor
+
+syntax = "proto3";
+// option cc_enable_arenas = true;
+
+package tensorflow;
+
+// Can only be interpreted if you know the corresponding TensorShape.
+message TensorSliceProto {
+ // Extent of the slice in one dimension.
+ message Extent {
+ // Either both or no attributes must be set. When no attribute is set
+ // means: All data in that dimension.
+
+ // Start index of the slice, starting at 0.
+ int64 start = 1;
+
+ // Length of the slice: if the length is missing or -1 we will
+ // interpret this as "everything in this dimension". We use
+ // "oneof" to preserve information about whether the length is
+ // present without changing the serialization format from the
+ // prior proto2 version of this proto.
+ oneof has_length {
+ int64 length = 2;
+ }
+ };
+
+ // Extent of the slice in all tensor dimensions.
+ //
+ // Must have one entry for each of the dimension of the tensor that this
+ // slice belongs to. The order of sizes is the same as the order of
+ // dimensions in the TensorShape.
+ repeated Extent extent = 1;
+};
diff --git a/tensorflow/core/framework/tensor_slice_test.cc b/tensorflow/core/framework/tensor_slice_test.cc
new file mode 100644
index 0000000000..5f718a56b6
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice_test.cc
@@ -0,0 +1,246 @@
+#include "tensorflow/core/framework/tensor_slice.h"
+
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+// Basic tests
+TEST(TensorSliceTest, Basic) {
+ {
+ // Repeatedly setting FullSlice should work.
+ TensorSlice s(3);
+ EXPECT_EQ("-:-:-", s.DebugString());
+
+ s.SetFullSlice(4);
+ EXPECT_EQ("-:-:-:-", s.DebugString());
+ }
+}
+
+// Testing for serialization and parsing for the string format of slices.
+TEST(TensorSliceTest, Serialization) {
+ // Serialization
+ {
+ TensorSlice s({{0, -1}, {0, 10}, {14, 1}, {0, -1}});
+ EXPECT_EQ("-:0,10:14,1:-", s.DebugString());
+ }
+
+ {
+ TensorSliceProto proto;
+ // Define ptxt outside ASSERT_TRUE call to work around bug in some
+ // versions of gcc that breaks when you use raw string literals
+ // inside macro expansions.
+ // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971
+ const char* ptxt = R"PROTO(
+ extent { }
+ extent { start: 0 length: 10 }
+ extent { start: 14 length: 1 }
+ extent { }
+ )PROTO";
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto));
+ TensorSlice s(proto);
+ EXPECT_EQ("-:0,10:14,1:-", s.DebugString());
+ }
+
+ // Parsing
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5");
+ TensorSliceProto proto;
+ s.AsProto(&proto);
+ EXPECT_EQ(
+ "extent { } "
+ "extent { } "
+ "extent { start: 1 length: 3 } "
+ "extent { start: 4 length: 5 }",
+ proto.ShortDebugString());
+ }
+
+ // Failed parsing
+ {
+ TensorSlice slice;
+ Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice);
+ EXPECT_EQ(
+ "Invalid argument: "
+ "Expected a pair of numbers or '-' but got '4': "
+ "string = -:-:1,3:4:5",
+ s.ToString());
+ }
+ {
+ TensorSlice slice;
+ Status s = TensorSlice::Parse("-:-1,3", &slice);
+ EXPECT_EQ(
+ "Invalid argument: "
+ "Expected non-negative start and positive length but got "
+ "start = -1, length = 3: string = -:-1,3",
+ s.ToString());
+ }
+}
+
+// Testing the slice intersection
+TEST(TensorSliceTest, Intersection) {
+ // "EVERYTHING" intersects with everything
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("-:-");
+ TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4");
+ TensorSlice c;
+ EXPECT_TRUE(a.Intersect(b, &c));
+ EXPECT_EQ("1,2:3,4", c.DebugString());
+ }
+
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("-:-");
+ TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4");
+ TensorSlice c;
+ EXPECT_TRUE(b.Intersect(a, &c));
+ EXPECT_EQ("1,2:3,4", c.DebugString());
+ }
+
+ // Overlap at all dimensions
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,5:2,6:3,7:5,10");
+ TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4:9,10:12,1");
+ TensorSlice c;
+ EXPECT_TRUE(a.Intersect(b, &c));
+ EXPECT_EQ("1,2:3,4:9,1:12,1", c.DebugString());
+ }
+
+ // A mixture of everything and non-trivial slices
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("-:1,1");
+ TensorSlice b = TensorSlice::ParseOrDie("-:0,2");
+ TensorSlice c;
+ EXPECT_TRUE(a.Intersect(b, &c));
+ EXPECT_EQ("-:1,1", c.DebugString());
+ }
+
+ // No overlap on dimension 3: "3,1" and "4,5" don't intersect
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:5,6");
+ TensorSlice b = TensorSlice::ParseOrDie("1,3:4,5:1,6");
+ TensorSlice c;
+ EXPECT_FALSE(a.Intersect(b, &c));
+ EXPECT_EQ("", c.DebugString());
+ }
+ // No intersection when there are different dimensions
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:-");
+ TensorSlice b = TensorSlice::ParseOrDie("-:-");
+ TensorSlice c;
+ EXPECT_FALSE(a.Intersect(b, &c));
+ EXPECT_EQ("", c.DebugString());
+ }
+}
+
+// Testing applying a slice to a tensor shape
+TEST(TensorSliceTest, SliceTensorShape) {
+ // A proper application
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,1:-:4,1:2,6");
+ TensorShape x({2, 4, 5, 8});
+ TensorShape y;
+ EXPECT_OK(a.SliceTensorShape(x, &y));
+ EXPECT_EQ(
+ "dim { size: 1 } "
+ "dim { size: 4 } "
+ "dim { size: 1 } "
+ "dim { size: 6 }",
+ y.DebugString());
+ }
+
+ // An invalid application -- dimension 2 is out of bound
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,1:1,4:-:-");
+ TensorShape x({2, 4, 5, 8});
+ TensorShape y;
+ EXPECT_EQ(
+ "Internal: "
+ "Extent in dimension 1 out of bounds: "
+ "shape = dim { size: 2 } "
+ "dim { size: 4 } "
+ "dim { size: 5 } "
+ "dim { size: 8 }, "
+ "slice = 1,1:1,4:-:-",
+ a.SliceTensorShape(x, &y).ToString());
+ EXPECT_EQ("", y.DebugString());
+ }
+}
+
+// Testing the computation of relative slices.
+TEST(TensorSliceTest, ComputeRelative) {
+ // Easy case: base is "everything"
+ {
+ TensorSlice base = TensorSlice::ParseOrDie("-:-:-:-");
+ TensorSlice sub = TensorSlice::ParseOrDie("-:1,2:-:3,4");
+ TensorSlice relative;
+ base.ComputeRelative(sub, &relative);
+ EXPECT_EQ("-:1,2:-:3,4", relative.DebugString());
+ }
+
+ // A slightly more complicated case
+ {
+ TensorSlice base = TensorSlice::ParseOrDie("1,2:3,4:-:5,1");
+ TensorSlice sub = TensorSlice::ParseOrDie("1,1:4,2:3,3:5,1");
+ TensorSlice relative;
+ base.ComputeRelative(sub, &relative);
+ EXPECT_EQ("0,1:1,2:3,3:0,1", relative.DebugString());
+ }
+}
+
+TEST(TensorSliceTest, ExtentLength) {
+ TensorSliceProto proto;
+ // Define ptxt outside ASSERT_TRUE call to work around bug in some
+ // versions of gcc that breaks when you use raw string literals
+ // inside macro expansions.
+ // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971
+ const char* ptxt = R"PROTO(
+ extent { }
+ extent { start: 0 length: 10 }
+ extent { start: 14 length: 1 }
+ extent { }
+ )PROTO";
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto));
+ EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(0)));
+ EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(1)));
+ EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(2)));
+ EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(3)));
+ EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(0)));
+ EXPECT_EQ(10, TensorSlice::GetExtentLength(proto.extent(1)));
+ EXPECT_EQ(1, TensorSlice::GetExtentLength(proto.extent(2)));
+ EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(3)));
+}
+
+TEST(TensorSliceTest, Deserialization) {
+ // Serialization of
+ // extent { length: 5 }
+ // extent { start: 0 length: 10 }
+ // extent { start: 14 length: 1 }
+ // extent { start: 1 }
+ // extent { }
+ // in proto2 and proto3:
+ const char pb2[] =
+ "\x0A\x02\x10\x05\x0A\x04\x08\x00"
+ "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00";
+ const char pb3[] =
+ "\x0A\x02\x10\x05\x0A\x02"
+ "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00";
+ // (The difference is that in the proto3 version, "start: 0" isn't included
+ // since 0 is start's default value.)
+
+ TensorSliceProto proto2;
+ ASSERT_TRUE(proto2.ParseFromArray(pb2, sizeof(pb2) - 1));
+ TensorSlice ts2(proto2);
+
+ TensorSliceProto proto3;
+ ASSERT_TRUE(proto3.ParseFromArray(pb3, sizeof(pb3) - 1));
+ TensorSlice ts3(proto3);
+
+ // Both serializations should be interpreted the same.
+ EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString());
+ EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
new file mode 100644
index 0000000000..4963c2c219
--- /dev/null
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -0,0 +1,551 @@
+#include "tensorflow/core/public/tensor.h"
+
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(TensorTest, Default) {
+ Tensor t;
+ EXPECT_EQ(t.dtype(), DT_FLOAT);
+ EXPECT_EQ(t.dims(), 1);
+ EXPECT_EQ(t.NumElements(), 0);
+}
+
+TEST(TensorTest, DataType_Traits) {
+ EXPECT_TRUE(std::is_trivial<float>::value);
+ EXPECT_TRUE(std::is_trivial<double>::value);
+ EXPECT_TRUE(std::is_trivial<int32>::value);
+ EXPECT_TRUE(std::is_trivial<uint8>::value);
+ EXPECT_TRUE(std::is_trivial<int16>::value);
+ EXPECT_TRUE(std::is_trivial<int8>::value);
+ EXPECT_TRUE(std::is_trivial<int64>::value);
+ EXPECT_TRUE(std::is_trivial<bool>::value);
+ EXPECT_FALSE(std::is_trivial<string>::value);
+
+ EXPECT_EQ(sizeof(bool), 1);
+
+ // Unfortunately. std::complex::complex() initializes (0, 0).
+ EXPECT_FALSE(std::is_trivial<complex64>::value);
+ EXPECT_FALSE(std::is_trivial<std::complex<double>>::value);
+ EXPECT_TRUE(std::is_trivial<float[2]>::value);
+ struct MyComplex {
+ float re, im;
+ };
+ EXPECT_TRUE(std::is_trivial<MyComplex>::value);
+}
+
+template <typename T>
+void TestCopies(const Tensor& t) {
+ {
+ LOG(INFO) << "CopyFrom()";
+ Tensor t2(t.dtype());
+ EXPECT_TRUE(t2.CopyFrom(t, t.shape()));
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "operator=()";
+ Tensor t2(t.dtype());
+ t2 = t;
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "deep copy";
+ Tensor t2(t.dtype(), t.shape());
+ t2.flat<T>() = t.flat<T>();
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "AsProtoField()";
+ TensorProto proto;
+ t.AsProtoField(&proto);
+ Tensor t2(t.dtype());
+ EXPECT_TRUE(t2.FromProto(proto));
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "AsProtoTensorContent()";
+ TensorProto proto;
+ t.AsProtoTensorContent(&proto);
+ Tensor t2(t.dtype());
+ EXPECT_TRUE(t2.FromProto(proto));
+ test::ExpectTensorEqual<T>(t, t2);
+ // Make another copy via tensor_content field.
+ *proto.mutable_tensor_content() = proto.tensor_content();
+ Tensor t3(t.dtype());
+ EXPECT_TRUE(t3.FromProto(proto));
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "AsTensor";
+ gtl::ArraySlice<T> values(t.flat<T>().data(), t.NumElements());
+ Tensor t2 = test::AsTensor(values, t.shape());
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+}
+
+TEST(Tensor_Float, Simple) {
+ Tensor t(DT_FLOAT, TensorShape({10, 20}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 20})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<float>()(a, b) = static_cast<float>(a * b);
+ }
+ }
+ TestCopies<float>(t);
+}
+
+TEST(Tensor_QInt8, Simple) {
+ Tensor t(DT_QINT8, TensorShape({2, 2}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<qint8>()(a, b) = qint8(a * b);
+ }
+ }
+ TestCopies<qint8>(t);
+}
+
+TEST(Tensor_QUInt8, Simple) {
+ Tensor t(DT_QUINT8, TensorShape({2, 2}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<Eigen::QUInt8>()(a, b) = Eigen::QUInt8(a * b);
+ }
+ }
+ TestCopies<Eigen::QUInt8>(t);
+}
+
+TEST(Tensor_QInt32, Simple) {
+ Tensor t(DT_QINT32, TensorShape({2, 2}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<qint32>()(a, b) = qint32(static_cast<int32>(a * b));
+ }
+ }
+ TestCopies<qint32>(t);
+}
+
+TEST(Tensor_Float, Reshape) {
+ Tensor t(DT_FLOAT, TensorShape({2, 3, 4, 5}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 3, 4, 5})));
+
+ {
+ auto tensor = t.tensor<float, 4>();
+ EXPECT_EQ(2, tensor.dimension(0));
+ EXPECT_EQ(3, tensor.dimension(1));
+ EXPECT_EQ(4, tensor.dimension(2));
+ EXPECT_EQ(5, tensor.dimension(3));
+
+ // Set first and last elements.
+ tensor(0, 0, 0, 0) = 0.01f;
+ tensor(1, 2, 3, 4) = 0.02f;
+ }
+ {
+ auto shaped = t.shaped<float, 1>({120});
+ EXPECT_EQ(120, shaped.dimension(0));
+ EXPECT_EQ(shaped(0), 0.01f);
+ EXPECT_EQ(shaped(119), 0.02f);
+ }
+ {
+ auto shaped = t.shaped<float, 2>({6, 20});
+ EXPECT_EQ(6, shaped.dimension(0));
+ EXPECT_EQ(20, shaped.dimension(1));
+ EXPECT_EQ(shaped(0, 0), 0.01f);
+ EXPECT_EQ(shaped(5, 19), 0.02f);
+ }
+ {
+ auto shaped = t.shaped<float, 3>({6, 4, 5});
+ EXPECT_EQ(6, shaped.dimension(0));
+ EXPECT_EQ(4, shaped.dimension(1));
+ EXPECT_EQ(5, shaped.dimension(2));
+ EXPECT_EQ(shaped(0, 0, 0), 0.01f);
+ EXPECT_EQ(shaped(5, 3, 4), 0.02f);
+ }
+ {
+ auto shaped = t.shaped<float, 4>({2, 3, 4, 5});
+ EXPECT_EQ(2, shaped.dimension(0));
+ EXPECT_EQ(3, shaped.dimension(1));
+ EXPECT_EQ(4, shaped.dimension(2));
+ EXPECT_EQ(5, shaped.dimension(3));
+
+ EXPECT_EQ(shaped(0, 0, 0, 0), 0.01f);
+ EXPECT_EQ(shaped(1, 2, 3, 4), 0.02f);
+ }
+ {
+ auto flat = t.flat<float>();
+ EXPECT_EQ(flat(0), 0.01f);
+ EXPECT_EQ(120, flat.dimension(0));
+ EXPECT_EQ(flat(0), 0.01f);
+ EXPECT_EQ(flat(119), 0.02f);
+ }
+ {
+ auto flat_inner_dims = t.flat_inner_dims<float>();
+ EXPECT_EQ(24, flat_inner_dims.dimension(0));
+ EXPECT_EQ(5, flat_inner_dims.dimension(1));
+ EXPECT_EQ(flat_inner_dims(0, 0), 0.01f);
+ EXPECT_EQ(flat_inner_dims(23, 4), 0.02f);
+ }
+}
+
+TEST(Tensor_Scalar, Basics) {
+ {
+ Tensor t(DT_FLOAT, TensorShape({}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.scalar<float>();
+ EXPECT_EQ(1, Tt.size());
+ EXPECT_EQ(0, Tt.rank());
+ t.scalar<float>()() = 123.45f;
+ EXPECT_FLOAT_EQ(123.45f, Tt());
+ }
+ {
+ Tensor t(DT_FLOAT, TensorShape({1}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.vec<float>();
+ EXPECT_EQ(1, Tt.size());
+ t.vec<float>()(0) = 123.45f;
+ EXPECT_FLOAT_EQ(123.45f, Tt(0));
+ }
+ {
+ Tensor t(DT_FLOAT, TensorShape({1, 1, 1}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.scalar<float>();
+ EXPECT_EQ(1, Tt.size());
+ EXPECT_EQ(0, Tt.rank());
+ t.flat<float>()(0) = 123.45f;
+ EXPECT_FLOAT_EQ(123.45f, Tt());
+ }
+ {
+ Tensor t(DT_STRING, TensorShape({}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.scalar<string>();
+ EXPECT_EQ(1, Tt.size());
+ EXPECT_EQ(0, Tt.rank());
+ t.scalar<string>()() = "foo";
+ EXPECT_EQ("foo", Tt());
+ }
+ {
+ Tensor t(DT_STRING, TensorShape({1}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.vec<string>();
+ EXPECT_EQ(1, Tt.size());
+ t.flat<string>()(0) = "foo";
+ EXPECT_EQ("foo", Tt(0));
+ }
+ {
+ Tensor t(DT_STRING, TensorShape({1, 1, 1}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.scalar<string>();
+ EXPECT_EQ(1, Tt.size());
+ EXPECT_EQ(0, Tt.rank());
+ t.flat<string>()(0) = "bar";
+ EXPECT_EQ("bar", Tt());
+ }
+ {
+ Tensor t(DT_FLOAT, TensorShape({0, 1}));
+ EXPECT_EQ(0, t.NumElements());
+ auto Tt = t.flat<float>();
+ EXPECT_EQ(0, Tt.size());
+ auto Tm = t.matrix<float>();
+ EXPECT_EQ(0, Tm.size());
+ EXPECT_EQ(0, Tm.dimensions()[0]);
+ EXPECT_EQ(1, Tm.dimensions()[1]);
+ }
+}
+
+TEST(Tensor_Float, Reshape_And_Slice_Assignment) {
+ // A test to experiment with a way to assign to a subset of a tensor
+ Tensor t(DT_FLOAT, TensorShape({10, 4, 3, 2}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 4, 3, 2})));
+
+ // Get the N dimensional tensor (N==4 here)
+ auto e_t = t.tensor<float, 4>();
+ // Reshape to view it as a two-dimensional tensor
+ auto e_2d = t.shaped<float, 2>({10, 4 * 3 * 2});
+ for (int i = 0; i < 10; i++) {
+ // Assign a 1 x 4*3*2 matrix (really vector) to a slice of size
+ // 1 x 4*3*2 in e_t.
+ Eigen::Tensor<float, 2, Eigen::RowMajor> m(1, 4 * 3 * 2);
+ m.setConstant(i * 2.0);
+
+ Eigen::DSizes<Eigen::DenseIndex, 2> indices(i, 0);
+ Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, 4 * 3 * 2);
+ e_2d.slice(indices, sizes) = m;
+ }
+ for (int i = 0; i < 10; i++) {
+ for (int j = 0; j < 4; j++) {
+ for (int k = 0; k < 3; k++) {
+ for (int l = 0; l < 2; l++) {
+ EXPECT_EQ(e_t(i, j, k, l), i * 2.0f);
+ LOG(INFO) << i << "," << j << "," << k << "," << l
+ << " &e_t(i, j, k, l): " << &e_t(i, j, k, l) << " = "
+ << e_t(i, j, k, l);
+ }
+ }
+ }
+ }
+}
+
+TEST(Tensor_String, Simple) {
+ Tensor t = test::AsTensor<string>(
+ {"hello", "world", "machine", "learning", "new", "york"},
+ TensorShape({3, 2}));
+ auto s = t.shape();
+ ASSERT_EQ(s.dims(), 2);
+ ASSERT_EQ(s.dim_size(0), 3);
+ ASSERT_EQ(s.dim_size(1), 2);
+ auto m = t.matrix<string>();
+ EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(string) + 5 + 5 + 7 + 8 + 3 + 4);
+
+ EXPECT_EQ(m(0, 0), "hello");
+ EXPECT_EQ(m(0, 1), "world");
+ EXPECT_EQ(m(1, 0), "machine");
+ EXPECT_EQ(m(1, 1), "learning");
+ EXPECT_EQ(m(2, 0), "new");
+ EXPECT_EQ(m(2, 1), "york");
+
+ TestCopies<string>(t);
+}
+
+TEST(Tensor_Float, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<float>({0, 1, 2, 3, 4, 5}, {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<float>() = t1.flat<float>() * 2.0f;
+ Tensor t3 = test::AsTensor<float>({0, 2, 4, 6, 8, 10}, t1.shape());
+ test::ExpectTensorEqual<float>(t2, t3);
+}
+
+TEST(Tensor_Int32, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<int32>({0, 1, 2, 3, 4, 5}, {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<int32>() = t1.flat<int32>() * 2;
+ Tensor t3 = test::AsTensor<int32>({0, 2, 4, 6, 8, 10}, t1.shape());
+ test::ExpectTensorEqual<int32>(t2, t3);
+}
+
+TEST(Tensor_QInt8, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<qint8>({0, 1, 2, 3, 4, 5}, {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<qint8>() = t1.flat<qint8>() + qint8(-2);
+ Tensor t3 = test::AsTensor<qint8>({-2, -1, 0, 1, 2, 3}, {2, 3});
+ test::ExpectTensorEqual<qint8>(t2, t3);
+}
+
+TEST(Tensor_QUInt8, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<quint8>({0, 1, 2, 3, 4, 5}, {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<quint8>() = t1.flat<quint8>() + quint8(2);
+ Tensor t3 = test::AsTensor<quint8>({2, 3, 4, 5, 6, 7}, {2, 3});
+ test::ExpectTensorEqual<quint8>(t2, t3);
+}
+
+TEST(Tensor_Int64, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<int64>(
+ {0LL << 48, 1LL << 48, 2LL << 48, 3LL << 48, 4LL << 48, 5LL << 48},
+ {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<int64>() = t1.flat<int64>() * static_cast<int64>(2);
+ Tensor t3 = test::AsTensor<int64>(
+ {0LL << 48, 2LL << 48, 4LL << 48, 6LL << 48, 8LL << 48, 10LL << 48},
+ {2, 3});
+ test::ExpectTensorEqual<int64>(t2, t3);
+}
+
+TEST(Tensor_String, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<string>({"0", "1", "2", "3", "4", "5"}, {2, 3});
+ Tensor t2(DT_STRING, {2, 3});
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ t2.matrix<string>()(i, j) = strings::StrCat(i * 3 + j);
+ }
+ }
+
+ // Test with helper.
+ test::ExpectTensorEqual<string>(t1, t2);
+}
+
+TEST(Tensor_Bool, SimpleWithHelper) {
+ Tensor t1 =
+ test::AsTensor<bool>({false, true, false, true, false, true}, {2, 3});
+
+ Tensor t2(DT_BOOL, {2, 3});
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ t2.matrix<bool>()(i, j) = (((i + j) % 2) != 0);
+ }
+ }
+
+ // Test with helper.
+ test::ExpectTensorEqual<bool>(t1, t2);
+}
+
+TEST(Tensor_Complex, Simple) {
+ Tensor t(DT_COMPLEX64, {4, 5, 3, 7});
+ t.flat<complex64>().setRandom();
+ TestCopies<complex64>(t);
+}
+
+TEST(Tensor_Complex, SimpleWithHelper) {
+ {
+ Tensor t1 = test::AsTensor<complex64>({0,
+ {1, 1},
+ complex64(2),
+ complex64(3, 3),
+ complex64(0, 4),
+ complex64(2, 5)},
+ {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<complex64>() = t1.flat<complex64>() * complex64(0, 2);
+ Tensor t3 = test::AsTensor<complex64>(
+ {0, {-2, 2}, {0, 4}, {-6, 6}, {-8, 0}, {-10, 4}},
+ // shape
+ {2, 3});
+ test::ExpectTensorEqual<complex64>(t2, t3);
+ }
+
+ // Does some numeric operations for complex numbers.
+ {
+ const float PI = std::acos(-1);
+ const complex64 rotate_45 = std::polar(1.0f, PI / 4);
+
+ // x contains all the 8-th root of unity.
+ Tensor x(DT_COMPLEX64, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ x.vec<complex64>()(i) = std::pow(rotate_45, i);
+ }
+
+ // Shift the roots by 45 degree.
+ Tensor y(DT_COMPLEX64, TensorShape({8}));
+ y.vec<complex64>() = x.vec<complex64>() * rotate_45;
+ Tensor y_expected(DT_COMPLEX64, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ y_expected.vec<complex64>()(i) = std::pow(rotate_45, i + 1);
+ }
+ test::ExpectTensorNear<complex64>(y, y_expected, 1e-5);
+
+ // Raise roots to the power of 8.
+ Tensor z(DT_COMPLEX64, TensorShape({8}));
+ z.vec<complex64>() = x.vec<complex64>().pow(8);
+ Tensor z_expected(DT_COMPLEX64, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ z_expected.vec<complex64>()(i) = 1;
+ }
+ test::ExpectTensorNear<complex64>(z, z_expected, 1e-5);
+ }
+}
+
+// On the alignment.
+//
+// As of 2015/8, tensorflow::Tensor allocates its buffer with 32-byte
+// alignment. Tensor::tensor/flat/vec/matrix methods requires the the
+// buffer satisfies Eigen::Aligned (e.g., 16-bytes aligned usually,
+// and 32-bytes for AVX). Tensor::Slice requires the caller to ensure
+// its result is aligned if the caller intends to use those methods.
+// In this test case, we simply make sure each slice is 32-byte
+// aligned: sizeof(float) * 4 * 2 = 32.
+TEST(Tensor, Slice_Basic) {
+ Tensor saved;
+ { // General
+ Tensor x(DT_FLOAT, TensorShape({10, 4, 34}));
+ // Fills in known values.
+ for (int i = 0; i < 10; ++i) {
+ x.Slice(i, i + 1).flat<float>().setConstant(i * 1.f);
+ }
+ // A simple slice along dim0.
+ Tensor y = x.Slice(4, 8);
+ EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 4, 34})));
+ auto tx = x.tensor<float, 3>();
+ auto ty = y.tensor<float, 3>();
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 34; ++k) {
+ EXPECT_EQ(ty(i, j, k), 4.0 + i);
+ EXPECT_EQ(&tx(4 + i, j, k), &ty(i, j, k));
+ }
+ }
+ }
+ // A simple slice equivalent to identity.
+ TestCopies<float>(y);
+ y = x.Slice(0, 10);
+ test::ExpectTensorEqual<float>(x, y);
+ EXPECT_EQ(x.flat<float>().data(), y.flat<float>().data());
+
+ // A slice of a slice.
+ auto z = x.Slice(4, 8).Slice(2, 3);
+ auto tz = z.tensor<float, 3>();
+ EXPECT_EQ(1, z.dim_size(0));
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 34; ++k) {
+ EXPECT_EQ(tz(0, j, k), 6.0);
+ }
+ }
+
+ // x and y will be out of scope. But 'saved' should be alive.
+ saved = z;
+ }
+ {
+ EXPECT_EQ(1, saved.dim_size(0));
+ auto tsaved = saved.tensor<float, 3>();
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 34; ++k) {
+ EXPECT_EQ(tsaved(0, j, k), 6.0);
+ }
+ }
+ }
+ { // Empty
+ Tensor x(DT_FLOAT, TensorShape({10, 0, 34}));
+ x.flat<float>().setRandom();
+ Tensor y = x.Slice(4, 8);
+ EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 0, 34})));
+ }
+
+ {
+ // Test unaligned access via a Slice.
+ Tensor x(DT_FLOAT, TensorShape({30}));
+ x.flat<float>().setConstant(0.0);
+
+ // Take an unaligned slice.
+ Tensor y = x.Slice(1, 13);
+ y.unaligned_flat<float>().setConstant(1.0);
+ for (int64 i = 0; i < y.NumElements(); ++i) {
+ EXPECT_EQ(1.0, y.unaligned_flat<float>()(i));
+ }
+ }
+}
+
+static void BM_CreateAndDestroy(int iters) {
+ TensorShape shape({10, 20});
+ while (--iters) {
+ Tensor t(DT_FLOAT, shape);
+ }
+}
+BENCHMARK(BM_CreateAndDestroy);
+
+static void BM_Assign(int iters) {
+ Tensor a(DT_FLOAT, TensorShape({10, 20}));
+ Tensor b(DT_FLOAT, TensorShape({10, 20}));
+ bool a_to_b = true;
+ while (--iters) {
+ if (a_to_b) {
+ b = a;
+ } else {
+ a = b;
+ }
+ a_to_b = !a_to_b;
+ }
+}
+BENCHMARK(BM_Assign);
+
+// Ensure tensor_data() works on empty tensors
+TEST(Tensor, EmptyTensorData) {
+ Tensor empty;
+ EXPECT_EQ(empty.tensor_data().size(), 0);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc
new file mode 100644
index 0000000000..b6cd12a864
--- /dev/null
+++ b/tensorflow/core/framework/tensor_testutil.cc
@@ -0,0 +1,43 @@
+#include <cmath>
+#include "tensorflow/core/framework/tensor_testutil.h"
+
+namespace tensorflow {
+namespace test {
+
+template <typename T>
+bool IsClose(const T& x, const T& y, double atol, double rtol) {
+ return fabs(x - y) < atol + rtol * fabs(x);
+}
+
+template <typename T>
+void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
+ auto Tx = x.flat<T>();
+ auto Ty = y.flat<T>();
+ for (int i = 0; i < Tx.size(); ++i) {
+ if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
+ LOG(ERROR) << "x = " << x.DebugString();
+ LOG(ERROR) << "y = " << y.DebugString();
+ LOG(ERROR) << "atol = " << atol << " rtol = " << rtol
+ << " tol = " << atol + rtol * std::fabs(Tx(i));
+ EXPECT_TRUE(false) << i << "-th element is not close " << Tx(i) << " vs. "
+ << Ty(i);
+ }
+ }
+}
+
+void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
+ internal::AssertSameTypeDims(x, y);
+ switch (x.dtype()) {
+ case DT_FLOAT:
+ ExpectClose<float>(x, y, atol, rtol);
+ break;
+ case DT_DOUBLE:
+ ExpectClose<double>(x, y, atol, rtol);
+ break;
+ default:
+ LOG(FATAL) << "Unexpected type : " << DataTypeString(x.dtype());
+ }
+}
+
+} // end namespace test
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_testutil.h b/tensorflow/core/framework/tensor_testutil.h
new file mode 100644
index 0000000000..53d6da0fb2
--- /dev/null
+++ b/tensorflow/core/framework/tensor_testutil.h
@@ -0,0 +1,189 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace test {
+
+// Constructs a scalar tensor with 'val'.
+template <typename T>
+Tensor AsScalar(const T& val) {
+ Tensor ret(DataTypeToEnum<T>::value, {});
+ ret.scalar<T>()() = val;
+ return ret;
+}
+
+// Constructs a flat tensor with 'vals'.
+template <typename T>
+Tensor AsTensor(gtl::ArraySlice<T> vals) {
+ Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
+ std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
+ return ret;
+}
+
+// Constructs a tensor of "shape" with values "vals".
+template <typename T>
+Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
+ Tensor ret;
+ CHECK(ret.CopyFrom(AsTensor(vals), shape));
+ return ret;
+}
+
+// Fills in '*tensor' with 'vals'. E.g.,
+// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
+// test::FillValues<float>(&x, {11, 21, 21, 22});
+template <typename T>
+void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
+ auto flat = tensor->flat<T>();
+ CHECK_EQ(flat.size(), vals.size());
+ if (flat.size() > 0) {
+ std::copy_n(vals.data(), vals.size(), flat.data());
+ }
+}
+
+// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
+// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
+// test::FillIota<float>(&x, 1.0);
+template <typename T>
+void FillIota(Tensor* tensor, const T& val) {
+ auto flat = tensor->flat<T>();
+ std::iota(flat.data(), flat.data() + flat.size(), val);
+}
+
+// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
+// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
+// test::FillFn<float>(&x, [](int i)->float { return i*i; });
+template <typename T>
+void FillFn(Tensor* tensor, std::function<T(int)> fn) {
+ auto flat = tensor->flat<T>();
+ for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
+}
+
+// Expects "x" and "y" are tensors of the same type, same shape, and
+// identical values.
+template <typename T>
+void ExpectTensorEqual(const Tensor& x, const Tensor& y);
+
+// Expects "x" and "y" are tensors of the same type, same shape, and
+// approxmiate equal values, each within "abs_err".
+template <typename T>
+void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err);
+
+// Expects "x" and "y" are tensors of the same type (float or double),
+// same shape and element-wise difference between x and y is no more
+// than atol + rtol * abs(x).
+void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6,
+ double rtol = 1e-6);
+
+// Implementation details.
+
+namespace internal {
+
+template <typename T>
+struct is_floating_point_type {
+ static const bool value = std::is_same<T, float>::value ||
+ std::is_same<T, double>::value ||
+ std::is_same<T, std::complex<float> >::value ||
+ std::is_same<T, std::complex<double> >::value;
+};
+
+template <typename T>
+static void ExpectEqual(const T& a, const T& b) {
+ EXPECT_EQ(a, b);
+}
+
+template <>
+void ExpectEqual<float>(const float& a, const float& b) {
+ EXPECT_FLOAT_EQ(a, b);
+}
+
+template <>
+void ExpectEqual<double>(const double& a, const double& b) {
+ EXPECT_DOUBLE_EQ(a, b);
+}
+
+template <>
+void ExpectEqual<complex64>(const complex64& a, const complex64& b) {
+ EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b;
+ EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b;
+}
+
+inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
+ ASSERT_EQ(x.dtype(), y.dtype());
+ ASSERT_TRUE(x.IsSameSize(y))
+ << "x.shape [" << x.shape().DebugString() << "] vs "
+ << "y.shape [ " << y.shape().DebugString() << "]";
+}
+
+template <typename T, bool is_fp = is_floating_point_type<T>::value>
+struct Expector;
+
+template <typename T>
+struct Expector<T, false> {
+ static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
+
+ static void Equal(const Tensor& x, const Tensor& y) {
+ ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
+ AssertSameTypeDims(x, y);
+ auto a = x.flat<T>();
+ auto b = y.flat<T>();
+ for (int i = 0; i < a.size(); ++i) {
+ ExpectEqual(a(i), b(i));
+ }
+ }
+};
+
+// Partial specialization for float and double.
+template <typename T>
+struct Expector<T, true> {
+ static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
+
+ static void Equal(const Tensor& x, const Tensor& y) {
+ ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
+ AssertSameTypeDims(x, y);
+ auto a = x.flat<T>();
+ auto b = y.flat<T>();
+ for (int i = 0; i < a.size(); ++i) {
+ ExpectEqual(a(i), b(i));
+ }
+ }
+
+ static void Near(const T& a, const T& b, const double abs_err) {
+ if (a != b) { // Takes care of inf.
+ EXPECT_LE(std::abs(a - b), abs_err) << "a = " << a << " b = " << b;
+ }
+ }
+
+ static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
+ ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
+ AssertSameTypeDims(x, y);
+ auto a = x.flat<T>();
+ auto b = y.flat<T>();
+ for (int i = 0; i < a.size(); ++i) {
+ Near(a(i), b(i), abs_err);
+ }
+ }
+};
+
+} // namespace internal
+
+template <typename T>
+void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
+ internal::Expector<T>::Equal(x, y);
+}
+
+template <typename T>
+void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
+ static_assert(internal::is_floating_point_type<T>::value,
+ "T is not a floating point types.");
+ internal::Expector<T>::Near(x, y, abs_err);
+}
+
+} // namespace test
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h
new file mode 100644
index 0000000000..077d86d442
--- /dev/null
+++ b/tensorflow/core/framework/tensor_types.h
@@ -0,0 +1,92 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+// Helper to define Tensor types given that the scalar is of type T.
+template <typename T, int NDIMS = 1>
+struct TTypes {
+ // Rank-<NDIMS> tensor of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor>,
+ Eigen::Aligned> Tensor;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor>,
+ Eigen::Aligned> ConstTensor;
+
+ // Unaligned Rank-<NDIMS> tensor of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor> >
+ UnalignedTensor;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor> >
+ UnalignedConstTensor;
+
+ typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>,
+ Eigen::Aligned> Tensor32Bit;
+
+ // Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
+ typedef Eigen::TensorMap<
+ Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor>,
+ Eigen::Aligned> Scalar;
+ typedef Eigen::TensorMap<
+ Eigen::TensorFixedSize<const T, Eigen::Sizes<>, Eigen::RowMajor>,
+ Eigen::Aligned> ConstScalar;
+
+ // Unaligned Scalar tensor of scalar type T.
+ typedef Eigen::TensorMap<Eigen::TensorFixedSize<
+ T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedScalar;
+ typedef Eigen::TensorMap<Eigen::TensorFixedSize<
+ const T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedConstScalar;
+
+ // Rank-1 tensor (vector) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned>
+ Flat;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
+ Eigen::Aligned> ConstFlat;
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned>
+ Vec;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
+ Eigen::Aligned> ConstVec;
+
+ // Unaligned Rank-1 tensor (vector) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedFlat;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> >
+ UnalignedConstFlat;
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedVec;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> >
+ UnalignedConstVec;
+
+ // Rank-2 tensor (matrix) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>
+ Matrix;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Aligned> ConstMatrix;
+
+ // Unaligned Rank-2 tensor (matrix) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor> >
+ UnalignedMatrix;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor> >
+ UnalignedConstMatrix;
+};
+
+typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32;
+
+template <typename DSizes>
+Eigen::DSizes<Index32, DSizes::count> To32BitDims(const DSizes& in) {
+ Eigen::DSizes<Index32, DSizes::count> out;
+ for (int i = 0; i < DSizes::count; ++i) {
+ out[i] = in[i];
+ }
+ return out;
+}
+
+template <typename TensorType>
+typename TTypes<typename TensorType::Scalar,
+ TensorType::NumIndices>::Tensor32Bit
+To32Bit(TensorType in) {
+ typedef typename TTypes<typename TensorType::Scalar,
+ TensorType::NumIndices>::Tensor32Bit RetType;
+ return RetType(in.data(), To32BitDims(in.dimensions()));
+}
+
+} // namespace tensorflow
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc
new file mode 100644
index 0000000000..7353191c74
--- /dev/null
+++ b/tensorflow/core/framework/tensor_util.cc
@@ -0,0 +1,28 @@
+#include "tensorflow/core/framework/tensor_util.h"
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace tensor {
+
+Tensor DeepCopy(const Tensor& other) {
+ Tensor tmp = Tensor(other.dtype(), other.shape());
+ if (DataTypeCanUseMemcpy(other.dtype())) {
+ StringPiece other_data = other.tensor_data();
+
+ // We use StringPiece as a convenient map over the tensor buffer,
+ // but we cast the type to get to the underlying buffer to do the
+ // copy.
+ StringPiece tmp_data = tmp.tensor_data();
+ memcpy(const_cast<char*>(tmp_data.data()), other_data.data(),
+ other_data.size());
+ } else {
+ CHECK_EQ(DT_STRING, other.dtype());
+ tmp.flat<string>() = other.flat<string>();
+ }
+ return tmp;
+}
+
+} // namespace tensor
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
new file mode 100644
index 0000000000..a8dde1d0ca
--- /dev/null
+++ b/tensorflow/core/framework/tensor_util.h
@@ -0,0 +1,21 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
+
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace tensor {
+
+// DeepCopy returns a tensor whose contents are a deep copy of the
+// contents of 'other'. This function is intended only for
+// convenience, not speed.
+//
+// REQUIRES: 'other' must point to data stored in CPU memory.
+// REQUIRES: 'other' must be a Tensor of a copy-able type if
+// 'other' is not appropriately memory-aligned.
+Tensor DeepCopy(const Tensor& other);
+
+} // namespace tensor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc
new file mode 100644
index 0000000000..fef7468151
--- /dev/null
+++ b/tensorflow/core/framework/tensor_util_test.cc
@@ -0,0 +1,124 @@
+#include "tensorflow/core/framework/tensor_util.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(TensorUtil, DeepCopy0d) {
+ Tensor x(DT_FLOAT, TensorShape({}));
+ x.scalar<float>()() = 10.0;
+
+ // Make y a deep copy of x and then change it.
+ Tensor y = tensor::DeepCopy(x);
+ y.scalar<float>()() = 20.0;
+
+ // x doesn't change
+ EXPECT_EQ(10.0, x.scalar<float>()());
+
+ // Change x.
+ x.scalar<float>()() = 30.0;
+
+ // Y doesn't change.
+ EXPECT_EQ(20.0, y.scalar<float>()());
+
+ Tensor z = tensor::DeepCopy(y);
+
+ // Change y.
+ y.scalar<float>()() = 40.0;
+
+ // The final states should all be different.
+ EXPECT_EQ(20.0, z.scalar<float>()());
+ EXPECT_EQ(30.0, x.scalar<float>()());
+ EXPECT_EQ(40.0, y.scalar<float>()());
+
+ // Should have the same shape and type.
+ EXPECT_EQ(TensorShape({}), x.shape());
+ EXPECT_EQ(TensorShape({}), y.shape());
+ EXPECT_EQ(TensorShape({}), z.shape());
+
+ EXPECT_EQ(DT_FLOAT, x.dtype());
+ EXPECT_EQ(DT_FLOAT, y.dtype());
+ EXPECT_EQ(DT_FLOAT, z.dtype());
+}
+
+TEST(TensorUtil, DeepCopy) {
+ Tensor x(DT_FLOAT, TensorShape({1}));
+ x.flat<float>()(0) = 10.0;
+
+ // Make y a deep copy of x and then change it.
+ Tensor y = tensor::DeepCopy(x);
+ y.flat<float>()(0) = 20.0;
+
+ // x doesn't change
+ EXPECT_EQ(10.0, x.flat<float>()(0));
+
+ // Change x.
+ x.flat<float>()(0) = 30.0;
+
+ // Y doesn't change.
+ EXPECT_EQ(20.0, y.flat<float>()(0));
+
+ Tensor z = tensor::DeepCopy(y);
+
+ // Change y.
+ y.flat<float>()(0) = 40.0;
+
+ // The final states should all be different.
+ EXPECT_EQ(20.0, z.flat<float>()(0));
+ EXPECT_EQ(30.0, x.flat<float>()(0));
+ EXPECT_EQ(40.0, y.flat<float>()(0));
+
+ // Should have the same shape and type.
+ EXPECT_EQ(TensorShape({1}), x.shape());
+ EXPECT_EQ(TensorShape({1}), y.shape());
+ EXPECT_EQ(TensorShape({1}), z.shape());
+
+ EXPECT_EQ(DT_FLOAT, x.dtype());
+ EXPECT_EQ(DT_FLOAT, y.dtype());
+ EXPECT_EQ(DT_FLOAT, z.dtype());
+
+ // Test string deep copy
+ Tensor str1(DT_STRING, TensorShape({2}));
+ str1.flat<string>()(0) = "foo1";
+ str1.flat<string>()(1) = "foo2";
+ Tensor str2 = tensor::DeepCopy(str1);
+ str2.flat<string>()(0) = "bar1";
+ str2.flat<string>()(1) = "bar2";
+ EXPECT_NE(str2.flat<string>()(0), str1.flat<string>()(0));
+}
+
+TEST(TensorUtil, DeepCopySlice) {
+ Tensor x(DT_INT32, TensorShape({10}));
+ x.flat<int32>().setConstant(1);
+
+ // Slice 'x' -- y still refers to the same buffer.
+ Tensor y = x.Slice(2, 6);
+
+ // Do a deep copy of y, which is a slice.
+ Tensor z = tensor::DeepCopy(y);
+
+ // Set x to be different.
+ x.flat<int32>().setConstant(2);
+
+ EXPECT_EQ(TensorShape({10}), x.shape());
+ EXPECT_EQ(TensorShape({4}), y.shape());
+ EXPECT_EQ(TensorShape({4}), z.shape());
+ EXPECT_EQ(DT_INT32, x.dtype());
+ EXPECT_EQ(DT_INT32, y.dtype());
+ EXPECT_EQ(DT_INT32, z.dtype());
+
+ // x and y should now all be '2', but z should be '1'.
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(2, x.flat<int32>()(i));
+ }
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(2, y.unaligned_flat<int32>()(i));
+ EXPECT_EQ(1, z.flat<int32>()(i));
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tracking_allocator.cc b/tensorflow/core/framework/tracking_allocator.cc
new file mode 100644
index 0000000000..78311ded19
--- /dev/null
+++ b/tensorflow/core/framework/tracking_allocator.cc
@@ -0,0 +1,100 @@
+#include "tensorflow/core/framework/tracking_allocator.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+TrackingAllocator::TrackingAllocator(Allocator* allocator)
+ : allocator_(allocator),
+ ref_(1),
+ allocated_(0),
+ high_watermark_(0),
+ total_bytes_(0) {}
+
+void* TrackingAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ void* ptr = allocator_->AllocateRaw(alignment, num_bytes);
+ // If memory is exhausted AllocateRaw returns nullptr, and we should
+ // pass this through to the caller
+ if (nullptr == ptr) {
+ return ptr;
+ }
+ if (allocator_->TracksAllocationSizes()) {
+ size_t allocated_bytes = allocator_->AllocatedSize(ptr);
+ {
+ mutex_lock lock(mu_);
+ allocated_ += allocated_bytes;
+ high_watermark_ = std::max(high_watermark_, allocated_);
+ total_bytes_ += allocated_bytes;
+ ++ref_;
+ }
+ } else {
+ mutex_lock lock(mu_);
+ total_bytes_ += num_bytes;
+ ++ref_;
+ }
+ return ptr;
+}
+
+void TrackingAllocator::DeallocateRaw(void* ptr) {
+ // freeing a null ptr is a no-op
+ if (nullptr == ptr) {
+ return;
+ }
+ bool should_delete;
+ // fetch the following outside the lock in case the call to
+ // AllocatedSize is slow
+ bool tracks_allocation_sizes = allocator_->TracksAllocationSizes();
+ size_t allocated_bytes = 0;
+ if (tracks_allocation_sizes) {
+ allocated_bytes = allocator_->AllocatedSize(ptr);
+ }
+ Allocator* allocator = allocator_;
+ {
+ mutex_lock lock(mu_);
+ if (tracks_allocation_sizes) {
+ CHECK_GE(allocated_, allocated_bytes);
+ allocated_ -= allocated_bytes;
+ }
+ should_delete = UnRef();
+ }
+ allocator->DeallocateRaw(ptr);
+ if (should_delete) {
+ delete this;
+ }
+}
+
+bool TrackingAllocator::TracksAllocationSizes() {
+ return allocator_->TracksAllocationSizes();
+}
+
+size_t TrackingAllocator::RequestedSize(void* ptr) {
+ return allocator_->RequestedSize(ptr);
+}
+
+size_t TrackingAllocator::AllocatedSize(void* ptr) {
+ return allocator_->AllocatedSize(ptr);
+}
+
+std::pair<size_t, size_t> TrackingAllocator::GetSizesAndUnRef() {
+ size_t high_watermark;
+ size_t total_bytes;
+ bool should_delete;
+ {
+ mutex_lock lock(mu_);
+ high_watermark = high_watermark_;
+ total_bytes = total_bytes_;
+ should_delete = UnRef();
+ }
+ if (should_delete) {
+ delete this;
+ }
+ return std::make_pair(total_bytes, high_watermark);
+}
+
+bool TrackingAllocator::UnRef() {
+ CHECK_GE(ref_, 1);
+ --ref_;
+ return (ref_ == 0);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h
new file mode 100644
index 0000000000..f809e3822c
--- /dev/null
+++ b/tensorflow/core/framework/tracking_allocator.h
@@ -0,0 +1,80 @@
+#ifndef TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
+#define TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+// TrackingAllocator is a wrapper for an Allocator. It keeps a running
+// count of the number of bytes allocated through the wrapper. It is
+// used by the Executor to "charge" allocations to particular Op
+// executions. Each Op gets a separate TrackingAllocator wrapper
+// around the underlying allocator.
+//
+// The implementation assumes the invariant that all calls to
+// AllocateRaw by an Op (or work items spawned by the Op) will occur
+// before the Op's Compute method returns. Thus the high watermark is
+// established once Compute returns.
+//
+// DeallocateRaw can be called long after the Op has finished,
+// e.g. when an output tensor is deallocated, and the wrapper cannot
+// be deleted until the last of these calls has occurred. The
+// TrackingAllocator keeps track of outstanding calls using a
+// reference count, and deletes itself once the last call has been
+// received and the high watermark has been retrieved.
+class TrackingAllocator : public Allocator {
+ public:
+ explicit TrackingAllocator(Allocator* allocator);
+ string Name() override { return allocator_->Name(); }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+ bool TracksAllocationSizes() override;
+ size_t RequestedSize(void* ptr) override;
+ size_t AllocatedSize(void* ptr) override;
+
+ // If the underlying allocator tracks allocation sizes, this returns
+ // a pair where the first value is the total number of bytes
+ // allocated through this wrapper, and the second value is the high
+ // watermark of bytes allocated through this wrapper. If the
+ // underlying allocator does not track allocation sizes the first
+ // value is the total number of bytes requested through this wrapper
+ // and the second is 0.
+ //
+ // After GetSizesAndUnref is called, the only further calls allowed
+ // on this wrapper are calls to DeallocateRaw with pointers that
+ // were allocated by this wrapper and have not yet been
+ // deallocated. After this call completes and all allocated pointers
+ // have been deallocated the wrapper will delete itself.
+ std::pair<size_t, size_t> GetSizesAndUnRef();
+
+ private:
+ ~TrackingAllocator() override {}
+ bool UnRef() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ Allocator* allocator_; // not owned.
+ mutex mu_;
+ // the number of calls to AllocateRaw that have not yet been matched
+ // by a corresponding call to DeAllocateRaw, plus 1 if the Executor
+ // has not yet read out the high watermark.
+ int ref_ GUARDED_BY(mu_);
+ // the current number of outstanding bytes that have been allocated
+ // by this wrapper, or 0 if the underlying allocator does not track
+ // allocation sizes.
+ size_t allocated_ GUARDED_BY(mu_);
+ // the maximum number of outstanding bytes that have been allocated
+ // by this wrapper, or 0 if the underlying allocator does not track
+ // allocation sizes.
+ size_t high_watermark_ GUARDED_BY(mu_);
+ // the total number of bytes that have been allocated by this
+ // wrapper if the underlying allocator tracks allocation sizes,
+ // otherwise the total number of bytes that have been requested by
+ // this allocator.
+ size_t total_bytes_ GUARDED_BY(mu_);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
diff --git a/tensorflow/core/framework/tracking_allocator_test.cc b/tensorflow/core/framework/tracking_allocator_test.cc
new file mode 100644
index 0000000000..90ce851775
--- /dev/null
+++ b/tensorflow/core/framework/tracking_allocator_test.cc
@@ -0,0 +1,115 @@
+#include "tensorflow/core/framework/tracking_allocator.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+class TestableSizeTrackingAllocator : public Allocator {
+ public:
+ string Name() override { return "test"; }
+ void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override {
+ void* ptr = malloc(num_bytes);
+ size_map_[ptr] = num_bytes;
+ return ptr;
+ }
+ void DeallocateRaw(void* ptr) override {
+ const auto& iter = size_map_.find(ptr);
+ EXPECT_NE(size_map_.end(), iter);
+ size_map_.erase(iter);
+ free(ptr);
+ }
+ bool TracksAllocationSizes() override { return true; }
+ size_t RequestedSize(void* ptr) override {
+ const auto& iter = size_map_.find(ptr);
+ EXPECT_NE(size_map_.end(), iter);
+ return iter->second;
+ }
+
+ private:
+ std::unordered_map<void*, size_t> size_map_;
+};
+
+class NoMemoryAllocator : public Allocator {
+ public:
+ string Name() override { return "test"; }
+ void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override {
+ return nullptr;
+ }
+ void DeallocateRaw(void* ptr) override {}
+ bool TracksAllocationSizes() override { return true; }
+};
+
+TEST(TrackingAllocatorTest, SimpleNoTracking) {
+ Allocator* a = cpu_allocator();
+
+ EXPECT_FALSE(a->TracksAllocationSizes());
+
+ TrackingAllocator* ta = new TrackingAllocator(a);
+
+ void* p1 = ta->AllocateRaw(4, 4);
+ ta->Deallocate(p1);
+ void* p2 = ta->AllocateRaw(4, 12);
+
+ std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef();
+
+ EXPECT_EQ(16, sizes.first);
+ EXPECT_EQ(0, sizes.second);
+
+ ta->Deallocate(p2);
+}
+
+TEST(TrackingAllocatorTest, SimpleTracking) {
+ TestableSizeTrackingAllocator a = TestableSizeTrackingAllocator();
+
+ EXPECT_TRUE(a.TracksAllocationSizes());
+
+ TrackingAllocator* ta = new TrackingAllocator(&a);
+
+ void* p1 = ta->AllocateRaw(4, 12);
+ ta->Deallocate(p1);
+ void* p2 = ta->AllocateRaw(4, 4);
+
+ std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef();
+
+ EXPECT_EQ(16, sizes.first);
+ EXPECT_EQ(12, sizes.second);
+
+ ta->Deallocate(p2);
+}
+
+TEST(TrackingAllocatorTest, OutOfMemory) {
+ NoMemoryAllocator a;
+
+ EXPECT_TRUE(a.TracksAllocationSizes());
+
+ TrackingAllocator* ta = new TrackingAllocator(&a);
+
+ void* p1 = ta->AllocateRaw(4, 12);
+ EXPECT_EQ(nullptr, p1);
+
+ std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef();
+
+ EXPECT_EQ(0, sizes.first);
+ EXPECT_EQ(0, sizes.second);
+}
+
+TEST(TrackingAllocatorTest, FreeNullPtr) {
+ NoMemoryAllocator a;
+
+ EXPECT_TRUE(a.TracksAllocationSizes());
+
+ TrackingAllocator* ta = new TrackingAllocator(&a);
+
+ ta->DeallocateRaw(nullptr);
+
+ std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef();
+
+ EXPECT_EQ(0, sizes.first);
+ EXPECT_EQ(0, sizes.second);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/type_traits.h b/tensorflow/core/framework/type_traits.h
new file mode 100644
index 0000000000..d87b6ff49b
--- /dev/null
+++ b/tensorflow/core/framework/type_traits.h
@@ -0,0 +1,69 @@
+#ifndef TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
+#define TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
+
+#include <limits>
+#include <utility>
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Functions to define quantization attribute of types.
+struct true_type {
+ static const bool value = true;
+};
+struct false_type {
+ static const bool value = false;
+};
+
+// Default is_quantized is false.
+template <typename T>
+struct is_quantized : false_type {};
+
+// Specialize the quantized types.
+template <>
+struct is_quantized<qint8> : true_type {};
+template <>
+struct is_quantized<quint8> : true_type {};
+template <>
+struct is_quantized<qint32> : true_type {};
+
+// All types not specialized are marked invalid.
+template <class T>
+struct IsValidDataType {
+ static constexpr bool value = false;
+};
+
+// Extra validity checking; not part of public API.
+struct TestIsValidDataType {
+ static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64");
+ static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32");
+};
+
+} // namespace tensorflow
+
+// Define numeric limits for our quantized as subclasses of the
+// standard types.
+namespace std {
+template <>
+class numeric_limits<tensorflow::qint8>
+ : public numeric_limits<tensorflow::int8> {};
+template <>
+class numeric_limits<tensorflow::quint8>
+ : public numeric_limits<tensorflow::uint8> {};
+template <>
+class numeric_limits<tensorflow::qint32>
+ : public numeric_limits<tensorflow::int32> {};
+
+// Specialize is_signed for quantized types.
+template <>
+struct is_signed<tensorflow::qint8> : public is_signed<tensorflow::int8> {};
+template <>
+struct is_signed<tensorflow::quint8> : public is_signed<tensorflow::uint8> {};
+template <>
+struct is_signed<tensorflow::qint32> : public is_signed<tensorflow::int32> {};
+
+} // namespace std
+
+#endif // TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
new file mode 100644
index 0000000000..01b9fca3b6
--- /dev/null
+++ b/tensorflow/core/framework/types.cc
@@ -0,0 +1,210 @@
+#include "tensorflow/core/framework/types.h"
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+bool DeviceType::operator<(const DeviceType& other) const {
+ return type_ < other.type_;
+}
+
+bool DeviceType::operator==(const DeviceType& other) const {
+ return type_ == other.type_;
+}
+
+std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
+ os << d.type();
+ return os;
+}
+
+const char* const DEVICE_CPU = "CPU";
+const char* const DEVICE_GPU = "GPU";
+
+string DataTypeString(DataType dtype) {
+ if (IsRefType(dtype)) {
+ DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
+ return strings::StrCat(DataTypeString(non_ref), "_ref");
+ }
+ switch (dtype) {
+ case DT_INVALID:
+ return "INVALID";
+ case DT_FLOAT:
+ return "float";
+ case DT_DOUBLE:
+ return "double";
+ case DT_INT32:
+ return "int32";
+ case DT_UINT8:
+ return "uint8";
+ case DT_INT16:
+ return "int16";
+ case DT_INT8:
+ return "int8";
+ case DT_STRING:
+ return "string";
+ case DT_COMPLEX64:
+ return "complex64";
+ case DT_INT64:
+ return "int64";
+ case DT_BOOL:
+ return "bool";
+ case DT_QINT8:
+ return "qint8";
+ case DT_QUINT8:
+ return "quint8";
+ case DT_QINT32:
+ return "qint32";
+ case DT_BFLOAT16:
+ return "bfloat16";
+ default:
+ LOG(FATAL) << "Unrecognized DataType enum value " << dtype;
+ return "";
+ }
+}
+
+bool DataTypeFromString(StringPiece sp, DataType* dt) {
+ if (sp.ends_with("_ref")) {
+ sp.remove_suffix(4);
+ DataType non_ref;
+ if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
+ *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ if (sp == "float" || sp == "float32") {
+ *dt = DT_FLOAT;
+ return true;
+ } else if (sp == "double" || sp == "float64") {
+ *dt = DT_DOUBLE;
+ return true;
+ } else if (sp == "int32") {
+ *dt = DT_INT32;
+ return true;
+ } else if (sp == "uint8") {
+ *dt = DT_UINT8;
+ return true;
+ } else if (sp == "int16") {
+ *dt = DT_INT16;
+ return true;
+ } else if (sp == "int8") {
+ *dt = DT_INT8;
+ return true;
+ } else if (sp == "string") {
+ *dt = DT_STRING;
+ return true;
+ } else if (sp == "complex64") {
+ *dt = DT_COMPLEX64;
+ return true;
+ } else if (sp == "int64") {
+ *dt = DT_INT64;
+ return true;
+ } else if (sp == "bool") {
+ *dt = DT_BOOL;
+ return true;
+ } else if (sp == "qint8") {
+ *dt = DT_QINT8;
+ return true;
+ } else if (sp == "quint8") {
+ *dt = DT_QUINT8;
+ return true;
+ } else if (sp == "qint32") {
+ *dt = DT_QINT32;
+ return true;
+ } else if (sp == "bfloat16") {
+ *dt = DT_BFLOAT16;
+ return true;
+ }
+ return false;
+}
+
+string DeviceTypeString(DeviceType device_type) { return device_type.type(); }
+
+string DataTypeSliceString(const DataTypeSlice types) {
+ string out;
+ for (auto it = types.begin(); it != types.end(); ++it) {
+ strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
+ DataTypeString(*it));
+ }
+ return out;
+}
+
+DataTypeVector AllTypes() {
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
+ DT_INT8, DT_STRING, DT_COMPLEX64, DT_INT64, DT_BOOL,
+ DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+#ifndef __ANDROID__
+
+DataTypeVector RealNumberTypes() {
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8};
+}
+
+DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; }
+
+DataTypeVector RealAndQuantizedTypes() {
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8,
+ DT_INT16, DT_INT8, DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+DataTypeVector NumberTypes() {
+ return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16,
+ DT_INT8, DT_COMPLEX64, DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+#else // __ANDROID__
+
+DataTypeVector RealNumberTypes() { return {DT_FLOAT, DT_INT32}; }
+
+DataTypeVector NumberTypes() {
+ return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; }
+
+DataTypeVector RealAndQuantizedTypes() {
+ return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+#endif // __ANDROID__
+
+// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying
+// is_simple<T> in tensor.cc (and possible choose a more general name?)
+bool DataTypeCanUseMemcpy(DataType dt) {
+ switch (dt) {
+ case DT_FLOAT:
+ case DT_DOUBLE:
+ case DT_INT32:
+ case DT_UINT8:
+ case DT_INT16:
+ case DT_INT8:
+ case DT_COMPLEX64:
+ case DT_INT64:
+ case DT_BOOL:
+ case DT_QINT8:
+ case DT_QUINT8:
+ case DT_QINT32:
+ case DT_BFLOAT16:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool DataTypeIsQuantized(DataType dt) {
+ switch (dt) {
+ case DT_QINT8:
+ case DT_QUINT8:
+ case DT_QINT32:
+ return true;
+ default:
+ return false;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
new file mode 100644
index 0000000000..2d417cf076
--- /dev/null
+++ b/tensorflow/core/framework/types.h
@@ -0,0 +1,168 @@
+#ifndef TENSORFLOW_FRAMEWORK_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_TYPES_H_
+
+#include <map>
+#include <set>
+#include <string>
+
+#include "tensorflow/core/framework/bfloat16.h"
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
+
+namespace tensorflow {
+
+// MemoryType is used to describe whether input or output Tensors of
+// an OpKernel should reside in "Host memory" (e.g., CPU memory) or
+// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
+// devices).
+enum MemoryType {
+ DEVICE_MEMORY = 0,
+ HOST_MEMORY = 1,
+};
+
+// A DeviceType is just a string, but we wrap it up in a class to give
+// some type checking as we're passing these around
+class DeviceType {
+ public:
+ DeviceType(const char* type) // NOLINT(runtime/explicit)
+ : type_(type) {}
+
+ explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {}
+
+ const char* type() const { return type_.c_str(); }
+
+ bool operator<(const DeviceType& other) const;
+ bool operator==(const DeviceType& other) const;
+ bool operator!=(const DeviceType& other) const { return !(*this == other); }
+
+ private:
+ string type_;
+};
+std::ostream& operator<<(std::ostream& os, const DeviceType& d);
+
+// Convenient constants that can be passed to a DeviceType constructor
+extern const char* const DEVICE_CPU; // "CPU"
+extern const char* const DEVICE_GPU; // "GPU"
+
+typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
+
+typedef gtl::InlinedVector<DataType, 4> DataTypeVector;
+typedef gtl::ArraySlice<DataType> DataTypeSlice;
+
+typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector;
+
+// Convert the enums to strings for errors:
+string DataTypeString(DataType dtype);
+string DeviceTypeString(DeviceType device_type);
+string DataTypeSliceString(const DataTypeSlice dtypes);
+inline string DataTypeVectorString(const DataTypeVector& dtypes) {
+ return DataTypeSliceString(dtypes);
+}
+
+// If "sp" names a valid type, store it in "*dt" and return true. Otherwise,
+// return false.
+bool DataTypeFromString(StringPiece sp, DataType* dt);
+
+// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc.
+enum { kDataTypeRefOffset = 100 };
+inline bool IsRefType(DataType dtype) {
+ return dtype > static_cast<DataType>(kDataTypeRefOffset);
+}
+inline DataType MakeRefType(DataType dtype) {
+ DCHECK(!IsRefType(dtype));
+ return static_cast<DataType>(dtype + kDataTypeRefOffset);
+}
+inline DataType RemoveRefType(DataType dtype) {
+ DCHECK(IsRefType(dtype));
+ return static_cast<DataType>(dtype - kDataTypeRefOffset);
+}
+inline DataType BaseType(DataType dtype) {
+ return IsRefType(dtype) ? RemoveRefType(dtype) : dtype;
+}
+
+// Returns true if the actual type is the same as or ref of the expected type.
+inline bool TypesCompatible(DataType expected, DataType actual) {
+ return expected == actual || expected == BaseType(actual);
+}
+
+// Does not include _ref types.
+DataTypeVector AllTypes();
+
+// Return the list of all numeric types.
+// NOTE: On Android, we only include the float and int32 types for now.
+DataTypeVector RealNumberTypes(); // Types that support '<' and '>'.
+DataTypeVector NumberTypes(); // Includes complex and quantized types.
+
+DataTypeVector QuantizedTypes();
+DataTypeVector RealAndQuantizedTypes(); // Types that support '<' and
+ // '>', including quantized
+ // types
+
+// Validates type T for whether it is a supported DataType.
+template <class T>
+struct IsValidDataType;
+
+// DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType
+// constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT.
+template <class T>
+struct DataTypeToEnum {
+ static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
+}; // Specializations below
+
+// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
+// EnumToDataType<DT_FLOAT>::Type is float.
+template <DataType VALUE>
+struct EnumToDataType {}; // Specializations below
+
+// Template specialization for both DataTypeToEnum and EnumToDataType.
+#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \
+ template <> \
+ struct DataTypeToEnum<TYPE> { \
+ static DataType v() { return ENUM; } \
+ static DataType ref() { return MakeRefType(ENUM); } \
+ static constexpr DataType value = ENUM; \
+ }; \
+ template <> \
+ struct IsValidDataType<TYPE> { \
+ static constexpr bool value = true; \
+ }; \
+ template <> \
+ struct EnumToDataType<ENUM> { \
+ typedef TYPE Type; \
+ }
+
+// We use Eigen's QInt implementations for our quantized int types.
+typedef Eigen::QInt8 qint8;
+typedef Eigen::QUInt8 quint8;
+typedef Eigen::QInt32 qint32;
+
+MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
+MATCH_TYPE_AND_ENUM(double, DT_DOUBLE);
+MATCH_TYPE_AND_ENUM(int32, DT_INT32);
+MATCH_TYPE_AND_ENUM(uint8, DT_UINT8);
+MATCH_TYPE_AND_ENUM(int16, DT_INT16);
+MATCH_TYPE_AND_ENUM(int8, DT_INT8);
+MATCH_TYPE_AND_ENUM(string, DT_STRING);
+MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64);
+MATCH_TYPE_AND_ENUM(int64, DT_INT64);
+MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
+MATCH_TYPE_AND_ENUM(qint8, DT_QINT8);
+MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8);
+MATCH_TYPE_AND_ENUM(qint32, DT_QINT32);
+MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16);
+
+#undef MATCH_TYPE_AND_ENUM
+
+bool DataTypeCanUseMemcpy(DataType dt);
+
+bool DataTypeIsQuantized(DataType dt);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TYPES_H_
diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto
new file mode 100644
index 0000000000..e5dc9c45a0
--- /dev/null
+++ b/tensorflow/core/framework/types.proto
@@ -0,0 +1,48 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+enum DataType {
+ // Not a legal value for DataType. Used to indicate a DataType field
+ // has not been set.
+ DT_INVALID = 0;
+
+ // Data types that all computation devices are expected to be
+ // capable to support.
+ DT_FLOAT = 1;
+ DT_DOUBLE = 2;
+ DT_INT32 = 3;
+ DT_UINT8 = 4;
+ DT_INT16 = 5;
+ DT_INT8 = 6;
+ DT_STRING = 7;
+ DT_COMPLEX64 = 8; // Single-precision complex
+ DT_INT64 = 9;
+ DT_BOOL = 10;
+ DT_QINT8 = 11; // Quantized int8
+ DT_QUINT8 = 12; // Quantized uint8
+ DT_QINT32 = 13; // Quantized int32
+ DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
+
+ // TODO(josh11b): DT_GENERIC_PROTO = ??;
+ // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? DT_UINT16?
+ // TODO(zhifengc): DT_COMPLEX128 (double-precision complex)?
+
+ // Do not use! These are only for parameters. Every enum above
+ // should have a corresponding value below (verified by types_test).
+ DT_FLOAT_REF = 101;
+ DT_DOUBLE_REF = 102;
+ DT_INT32_REF = 103;
+ DT_UINT8_REF = 104;
+ DT_INT16_REF = 105;
+ DT_INT8_REF = 106;
+ DT_STRING_REF = 107;
+ DT_COMPLEX64_REF = 108;
+ DT_INT64_REF = 109;
+ DT_BOOL_REF = 110;
+ DT_QINT8_REF = 111;
+ DT_QUINT8_REF = 112;
+ DT_QINT32_REF = 113;
+ DT_BFLOAT16_REF = 114;
+}
diff --git a/tensorflow/core/framework/types_test.cc b/tensorflow/core/framework/types_test.cc
new file mode 100644
index 0000000000..eb92600397
--- /dev/null
+++ b/tensorflow/core/framework/types_test.cc
@@ -0,0 +1,117 @@
+#include "tensorflow/core/framework/types.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(TypesTest, DeviceTypeName) {
+ EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU)));
+ EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU)));
+}
+
+TEST(TypesTest, kDataTypeRefOffset) {
+ // Basic sanity check
+ EXPECT_EQ(DT_FLOAT + kDataTypeRefOffset, DT_FLOAT_REF);
+
+ // Use the meta-data provided by proto2 to iterate through the basic
+ // types and validate that adding kDataTypeRefOffset gives the
+ // corresponding reference type.
+ const auto* enum_descriptor = DataType_descriptor();
+ int e = DataType_MIN;
+ if (e == DT_INVALID) ++e;
+ int e_ref = e + kDataTypeRefOffset;
+ EXPECT_FALSE(DataType_IsValid(e_ref - 1))
+ << "Reference enum "
+ << enum_descriptor->FindValueByNumber(e_ref - 1)->name()
+ << " without corresponding base enum with value " << e - 1;
+ for (;
+ DataType_IsValid(e) && DataType_IsValid(e_ref) && e_ref <= DataType_MAX;
+ ++e, ++e_ref) {
+ string enum_name = enum_descriptor->FindValueByNumber(e)->name();
+ string enum_ref_name = enum_descriptor->FindValueByNumber(e_ref)->name();
+ EXPECT_EQ(enum_name + "_REF", enum_ref_name)
+ << enum_name << "_REF should have value " << e_ref << " not "
+ << enum_ref_name;
+ // Validate DataTypeString() as well.
+ DataType dt_e = static_cast<DataType>(e);
+ DataType dt_e_ref = static_cast<DataType>(e_ref);
+ EXPECT_EQ(DataTypeString(dt_e) + "_ref", DataTypeString(dt_e_ref));
+
+ // Test DataTypeFromString reverse conversion
+ DataType dt_e2, dt_e2_ref;
+ EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e), &dt_e2));
+ EXPECT_EQ(dt_e, dt_e2);
+ EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e_ref), &dt_e2_ref));
+ EXPECT_EQ(dt_e_ref, dt_e2_ref);
+ }
+ ASSERT_FALSE(DataType_IsValid(e))
+ << "Should define " << enum_descriptor->FindValueByNumber(e)->name()
+ << "_REF to be " << e_ref;
+ ASSERT_FALSE(DataType_IsValid(e_ref))
+ << "Extra reference enum "
+ << enum_descriptor->FindValueByNumber(e_ref)->name()
+ << " without corresponding base enum with value " << e;
+ ASSERT_LT(DataType_MAX, e_ref) << "Gap in reference types, missing value for "
+ << e_ref;
+
+ // Make sure there are no enums defined after the last regular type before
+ // the first reference type.
+ for (; e < DataType_MIN + kDataTypeRefOffset; ++e) {
+ EXPECT_FALSE(DataType_IsValid(e))
+ << "Discontinuous enum value "
+ << enum_descriptor->FindValueByNumber(e)->name() << " = " << e;
+ }
+}
+
+TEST(TypesTest, DataTypeFromString) {
+ DataType dt;
+ ASSERT_TRUE(DataTypeFromString("int32", &dt));
+ EXPECT_EQ(DT_INT32, dt);
+ ASSERT_TRUE(DataTypeFromString("int32_ref", &dt));
+ EXPECT_EQ(DT_INT32_REF, dt);
+ EXPECT_FALSE(DataTypeFromString("int32_ref_ref", &dt));
+ EXPECT_FALSE(DataTypeFromString("foo", &dt));
+ EXPECT_FALSE(DataTypeFromString("foo_ref", &dt));
+ ASSERT_TRUE(DataTypeFromString("int64", &dt));
+ EXPECT_EQ(DT_INT64, dt);
+ ASSERT_TRUE(DataTypeFromString("int64_ref", &dt));
+ EXPECT_EQ(DT_INT64_REF, dt);
+ ASSERT_TRUE(DataTypeFromString("quint8_ref", &dt));
+ EXPECT_EQ(DT_QUINT8_REF, dt);
+ ASSERT_TRUE(DataTypeFromString("bfloat16", &dt));
+ EXPECT_EQ(DT_BFLOAT16, dt);
+}
+
+template <typename T>
+static bool GetQuantized() {
+ return is_quantized<T>::value;
+}
+
+TEST(TypesTest, QuantizedTypes) {
+ // NOTE: GUnit cannot parse is::quantized<TYPE>::value() within the
+ // EXPECT_TRUE() clause, so we delegate through a template function.
+ EXPECT_TRUE(GetQuantized<qint8>());
+ EXPECT_TRUE(GetQuantized<quint8>());
+ EXPECT_TRUE(GetQuantized<qint32>());
+
+ EXPECT_FALSE(GetQuantized<int8>());
+ EXPECT_FALSE(GetQuantized<uint8>());
+ EXPECT_FALSE(GetQuantized<int16>());
+ EXPECT_FALSE(GetQuantized<int32>());
+
+ EXPECT_TRUE(DataTypeIsQuantized(DT_QINT8));
+ EXPECT_TRUE(DataTypeIsQuantized(DT_QUINT8));
+ EXPECT_TRUE(DataTypeIsQuantized(DT_QINT32));
+
+ EXPECT_FALSE(DataTypeIsQuantized(DT_INT8));
+ EXPECT_FALSE(DataTypeIsQuantized(DT_UINT8));
+ EXPECT_FALSE(DataTypeIsQuantized(DT_INT16));
+ EXPECT_FALSE(DataTypeIsQuantized(DT_INT32));
+ EXPECT_FALSE(DataTypeIsQuantized(DT_BFLOAT16));
+}
+
+} // namespace
+} // namespace tensorflow