diff options
Diffstat (limited to 'tensorflow/core/framework')
99 files changed, 17650 insertions, 0 deletions
diff --git a/tensorflow/core/framework/allocation_description.proto b/tensorflow/core/framework/allocation_description.proto new file mode 100644 index 0000000000..f6f4bc0126 --- /dev/null +++ b/tensorflow/core/framework/allocation_description.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +message AllocationDescription { + // Total number of bytes requested + int64 requested_bytes = 1; + + // Total number of bytes allocated if known + int64 allocated_bytes = 2; + + // Name of the allocator used + string allocator_name = 3; +}; diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc new file mode 100644 index 0000000000..93f68dcccb --- /dev/null +++ b/tensorflow/core/framework/allocator.cc @@ -0,0 +1,25 @@ +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +Allocator::~Allocator() {} + +class CPUAllocator : public Allocator { + public: + ~CPUAllocator() override {} + + string Name() override { return "cpu"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + return port::aligned_malloc(num_bytes, alignment); + } + + void DeallocateRaw(void* ptr) override { port::aligned_free(ptr); } +}; + +Allocator* cpu_allocator() { + static CPUAllocator* cpu_alloc = new CPUAllocator; + return cpu_alloc; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h new file mode 100644 index 0000000000..6f162a608c --- /dev/null +++ b/tensorflow/core/framework/allocator.h @@ -0,0 +1,132 @@ +#ifndef TENSORFLOW_FRAMEWORK_ALLOCATOR_H_ +#define TENSORFLOW_FRAMEWORK_ALLOCATOR_H_ + +#include <stdlib.h> +#include <unistd.h> + +#include <limits> + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Allocator is an abstract interface for allocating and deallocating +// device memory. +class Allocator { + public: + virtual ~Allocator(); + + // Return a string identifying this allocator + virtual string Name() = 0; + + // Return an uninitialized block of memory that is "num_bytes" bytes + // in size. The returned pointer is guaranteed to be aligned to a + // multiple of "alignment" bytes. + // REQUIRES: "alignment" is a power of 2. + virtual void* AllocateRaw(size_t alignment, size_t num_bytes) = 0; + + // Deallocate a block of memory pointer to by "ptr" + // REQUIRES: "ptr" was previously returned by a call to AllocateRaw + virtual void DeallocateRaw(void* ptr) = 0; + + // Convenience functions to do typed allocation. Note that these functions + // do not invoke C++ constructors or destructors. May return NULL if the + // tensor has too many elements to represent in a single allocation. + template <typename T> + T* Allocate(size_t num_elements) { + // TODO(jeff): Do we need to allow clients to pass in alignment + // requirements? + + if (num_elements > (std::numeric_limits<size_t>::max() / sizeof(T))) { + return NULL; + } + + void* p = AllocateRaw(32 /* align to 32 byte boundary */, + sizeof(T) * num_elements); + return reinterpret_cast<T*>(p); + } + + template <typename T> + void Deallocate(T* ptr) { + DeallocateRaw(ptr); + } + + // Returns true if this allocator tracks the sizes of allocations. + // RequestedSize and AllocatedSize must be overridden if + // TracksAlloctionSizes is overridden to return true. + virtual bool TracksAllocationSizes() { return false; } + + // Returns the user-requested size of the data allocated at + // 'ptr'. Note that the actual buffer allocated might be larger + // than requested, but this function returns the size requested by + // the user. + // + // REQUIRES: TracksAllocationSizes() is true. + // + // REQUIRES: 'ptr!=nullptr' and points to a buffer previously + // allocated by this allocator. + virtual size_t RequestedSize(void* ptr) { + CHECK(false) << "allocator doesn't track sizes"; + } + + // Returns the allocated size of the buffer at 'ptr' if known, + // otherwise returns RequestedSize(ptr). AllocatedSize(ptr) is + // guaranteed to be >= RequestedSize(ptr). + // + // REQUIRES: TracksAllocationSizes() is true. + // + // REQUIRES: 'ptr!=nullptr' and points to a buffer previously + // allocated by this allocator. + virtual size_t AllocatedSize(void* ptr) { return RequestedSize(ptr); } + + // TODO(jeff): Maybe provide some interface to give info about + // current allocation state (total number of bytes available for + // allocation, number of bytes free on device, etc.) +}; + +// A tensorflow Op may need access to different kinds of memory that +// are not simply a function of the device to which the Op has been +// assigned. For example, an Op executing on a GPU may still need +// to allocate CPU RAM for some purpose. Internal to the tensorflow +// runtime we may choose to allocate CPU ram from special regions +// that have been prepared for higher performance in some use +// contexts, e.g. doing DMA with particular devices. For these +// reasons, the Device interface does not expose just one memory +// Allocator, but instead provides an accessor that takes a +// specification of the desired memory attributes in order to select +// an Allocator. +// +// NOTE: The upper 8 bits of the value are reserved for +// device-specific uses. Implementors of a device can interpret these +// upper 8 bits in device-specific ways, and ops implemented for those +// devices are responsible for setting those 8 bits appropriately. +// +// Example use: +// // Allocator for ordinary device memory: +// Allocator* a = allocator(AllocatorAttributes()); +// ... +// // Allocator for CPU RAM, regardless of where Op is executing: +// AllocatorAttributes attr; +// attr.set_on_host(true); +// Allocator* a = allocator(attr); +struct AllocatorAttributes { + void set_on_host(bool v) { value |= (static_cast<int>(v)); } + bool on_host() const { return value & 0x1; } + void set_nic_compatible(bool v) { value |= (static_cast<int>(v) << 1); } + bool nic_compatible() const { return value & (0x1 << 1); } + void set_gpu_compatible(bool v) { value |= (static_cast<int>(v) << 2); } + bool gpu_compatible() const { return value & (0x1 << 2); } + + void Merge(AllocatorAttributes other) { value |= other.value; } + + uint32 value = 0; +}; + +// Returns a trivial implementation of Allocator which uses the system +// default malloc. +Allocator* cpu_allocator(); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_ALLOCATOR_H_ diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc new file mode 100644 index 0000000000..6b1e52cfc4 --- /dev/null +++ b/tensorflow/core/framework/allocator_test.cc @@ -0,0 +1,61 @@ +#include "tensorflow/core/framework/allocator.h" +#include <algorithm> +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> +namespace tensorflow { + +TEST(CPUAllocatorTest, Simple) { + Allocator* a = cpu_allocator(); + std::vector<void*> ptrs; + for (int s = 1; s < 1024; s++) { + void* raw = a->AllocateRaw(1, s); + ptrs.push_back(raw); + } + std::sort(ptrs.begin(), ptrs.end()); + for (size_t i = 0; i < ptrs.size(); i++) { + if (i > 0) { + CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups + } + a->DeallocateRaw(ptrs[i]); + } + float* t1 = a->Allocate<float>(1024); + double* t2 = a->Allocate<double>(1048576); + a->Deallocate(t1); + a->Deallocate(t2); +} + +// Define a struct that we will use to observe behavior in the unit tests +struct TestStruct { + int x; // not used just want to make sure sizeof(TestStruct) > 1 +}; + +TEST(CPUAllocatorTest, CheckStructSize) { CHECK_GT(sizeof(TestStruct), 1); } + +TEST(CPUAllocatorTest, AllocateOverflowMaxSizeT) { + Allocator* a = cpu_allocator(); + + // The maximum size_t value will definitely overflow. + size_t count_to_allocate = std::numeric_limits<size_t>::max(); + TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate); + + CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL)); +} + +TEST(CPUAllocatorTest, AllocateOverflowSmallest) { + Allocator* a = cpu_allocator(); + + // count_to_allocate is the smallest count that will cause overflow. + const size_t count_to_allocate = + (std::numeric_limits<size_t>::max() / sizeof(TestStruct)) + 1; + TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate); + + CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL)); +} + +TEST(CPUAllocatorTest, Sizes) { + Allocator* a = cpu_allocator(); + + EXPECT_EQ(false, a->TracksAllocationSizes()); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/attr_value.proto b/tensorflow/core/framework/attr_value.proto new file mode 100644 index 0000000000..c6a9940815 --- /dev/null +++ b/tensorflow/core/framework/attr_value.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + // TODO(zhifengc/josh11b): implements list(func) if needed. + } + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map<string, AttrValue> attr = 2; +} diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc new file mode 100644 index 0000000000..400ef118b8 --- /dev/null +++ b/tensorflow/core/framework/attr_value_util.cc @@ -0,0 +1,382 @@ +#include "tensorflow/core/framework/attr_value_util.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +namespace { + +string SummarizeString(const string& str) { + return strings::StrCat("\"", str_util::CEscape(str), "\""); +} + +string SummarizeShape(const TensorShapeProto& proto) { + TensorShape shape(proto); + return shape.ShortDebugString(); +} + +string SummarizeTensor(const TensorProto& tensor_proto) { + Tensor t; + if (!t.FromProto(tensor_proto)) { + return strings::StrCat("<Invalid TensorProto: ", + tensor_proto.ShortDebugString(), ">"); + } + return t.DebugString(); +} + +} // namespace + +string SummarizeAttrValue(const AttrValue& attr_value) { + switch (attr_value.value_case()) { + case AttrValue::kS: + return SummarizeString(attr_value.s()); + case AttrValue::kI: + return strings::StrCat(attr_value.i()); + case AttrValue::kF: + return strings::StrCat(attr_value.f()); + case AttrValue::kB: + return attr_value.b() ? "true" : "false"; + case AttrValue::kType: + return DataType_Name(attr_value.type()); + case AttrValue::kShape: + return SummarizeShape(attr_value.shape()); + case AttrValue::kTensor: + return SummarizeTensor(attr_value.tensor()); + case AttrValue::kList: { + string ret = "["; + if (attr_value.list().s_size() > 0) { + for (int i = 0; i < attr_value.list().s_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, SummarizeString(attr_value.list().s(i))); + } + } else if (attr_value.list().i_size() > 0) { + for (int i = 0; i < attr_value.list().i_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().i(i)); + } + } else if (attr_value.list().f_size() > 0) { + for (int i = 0; i < attr_value.list().f_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().f(i)); + } + } else if (attr_value.list().b_size() > 0) { + for (int i = 0; i < attr_value.list().b_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false"); + } + } else if (attr_value.list().type_size() > 0) { + for (int i = 0; i < attr_value.list().type_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, DataType_Name(attr_value.list().type(i))); + } + } else if (attr_value.list().shape_size() > 0) { + for (int i = 0; i < attr_value.list().shape_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, SummarizeShape(attr_value.list().shape(i))); + } + } else if (attr_value.list().tensor_size() > 0) { + for (int i = 0; i < attr_value.list().tensor_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, + SummarizeTensor(attr_value.list().tensor(i))); + } + } + strings::StrAppend(&ret, "]"); + return ret; + } + case AttrValue::kFunc: { + std::vector<string> entries; + for (auto p : attr_value.func().attr()) { + entries.push_back( + strings::StrCat(p.first, "=", SummarizeAttrValue(p.second))); + } + sort(entries.begin(), entries.end()); + return strings::StrCat(attr_value.func().name(), "[", + str_util::Join(entries, ", "), "]"); + } + case AttrValue::kPlaceholder: + return strings::StrCat("$", attr_value.placeholder()); + case AttrValue::VALUE_NOT_SET: + return "<Unknown AttrValue type>"; + } + return "<Unknown AttrValue type>"; // Prevent missing return warning +} + +Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { + int num_set = 0; + +#define VALIDATE_FIELD(name, type_string, oneof_case) \ + do { \ + if (attr_value.has_list()) { \ + if (attr_value.list().name##_size() > 0) { \ + if (type != "list(" type_string ")") { \ + return errors::InvalidArgument( \ + "AttrValue had value with type list(" type_string ") when ", \ + type, " expected"); \ + } \ + ++num_set; \ + } \ + } else if (attr_value.value_case() == AttrValue::oneof_case) { \ + if (type != type_string) { \ + return errors::InvalidArgument( \ + "AttrValue had value with type " type_string " when ", type, \ + " expected"); \ + } \ + ++num_set; \ + } \ + } while (false) + + VALIDATE_FIELD(s, "string", kS); + VALIDATE_FIELD(i, "int", kI); + VALIDATE_FIELD(f, "float", kF); + VALIDATE_FIELD(b, "bool", kB); + VALIDATE_FIELD(type, "type", kType); + VALIDATE_FIELD(shape, "shape", kShape); + VALIDATE_FIELD(tensor, "tensor", kTensor); + +#undef VALIDATE_FIELD + + if (attr_value.value_case() == AttrValue::kFunc) { + if (type != "func") { + return errors::InvalidArgument( + "AttrValue had value with type 'func' when ", type, " expected"); + } + ++num_set; + } + + if (attr_value.value_case() == AttrValue::kPlaceholder) { + return errors::InvalidArgument( + "AttrValue had value with unexpected type 'placeholder"); + } + + // If the attr type is 'list', we expect attr_value.has_list() to be true. + // However, proto3's attr_value.has_list() can be false when set to an empty + // list. So we simply check if has_list is false and some other field in + // attr_value is set to flag the error. + if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) { + if (num_set) { + return errors::InvalidArgument( + "AttrValue missing value with expected type ", type); + } else { + // Indicate that we have a list, but an empty one. + ++num_set; + } + } + + // Okay to have an empty list, but not to be missing a non-list value. + if (num_set == 0 && !StringPiece(type).starts_with("list(")) { + return errors::InvalidArgument( + "AttrValue missing value with expected type ", type); + } + + // Ref types and DT_INVALID are illegal. + if (type == "type") { + if (IsRefType(attr_value.type())) { + return errors::InvalidArgument( + "AttrValue must not have reference type value of ", + DataTypeString(attr_value.type())); + } + if (attr_value.type() == DT_INVALID) { + return errors::InvalidArgument("AttrValue has invalid DataType"); + } + } else if (type == "list(type)") { + for (auto as_int : attr_value.list().type()) { + const DataType dtype = static_cast<DataType>(as_int); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "AttrValue must not have reference type value of ", + DataTypeString(dtype)); + } + if (dtype == DT_INVALID) { + return errors::InvalidArgument("AttrValue contains invalid DataType"); + } + } + } + + return Status::OK(); +} + +bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { + // Parse type. + string field_name; + bool is_list = type.Consume("list("); + if (type.Consume("string")) { + field_name = "s"; + } else if (type.Consume("int")) { + field_name = "i"; + } else if (type.Consume("float")) { + field_name = "f"; + } else if (type.Consume("bool")) { + field_name = "b"; + } else if (type.Consume("type")) { + field_name = "type"; + } else if (type.Consume("shape")) { + field_name = "shape"; + } else if (type.Consume("tensor")) { + field_name = "tensor"; + } else if (type.Consume("func")) { + field_name = "func"; + } else if (type.Consume("placeholder")) { + field_name = "placeholder"; + } else { + return false; + } + if (is_list && !type.Consume(")")) { + return false; + } + + // Construct a valid text proto message to parse. + string to_parse; + if (is_list) { + // TextFormat parser considers "i: 7" to be the same as "i: [7]", + // but we only want to allow list values with []. + if (!RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[.*\\]\\s*")) { + return false; + } + if (RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[\\s*\\]\\s*")) { + // User wrote "[]", so return empty list without invoking the TextFormat + // parse which returns an error for "i: []". + out->Clear(); + out->mutable_list(); + return true; + } + to_parse = strings::StrCat("list { ", field_name, ": ", text, " }"); + } else { + to_parse = strings::StrCat(field_name, ": ", text); + } + + // Parse if we can. + return protobuf::TextFormat::ParseFromString(to_parse, out); +} + +#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ + void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); } + +#define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \ + void SetAttrValue(ARG_TYPE value, AttrValue* out) { \ + out->mutable_list(); /* create list() even if value empty */ \ + for (const auto& v : value) { \ + out->mutable_list()->add_##FIELD(v); \ + } \ + } + +#define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \ + DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ + DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD) + +DEFINE_SET_ATTR_VALUE_ONE(const string&, s) +DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s) +DEFINE_SET_ATTR_VALUE_BOTH(const char*, s) +DEFINE_SET_ATTR_VALUE_BOTH(int64, i) +DEFINE_SET_ATTR_VALUE_BOTH(int32, i) +DEFINE_SET_ATTR_VALUE_BOTH(float, f) +DEFINE_SET_ATTR_VALUE_BOTH(double, f) +DEFINE_SET_ATTR_VALUE_BOTH(bool, b) +DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b) +DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b) +DEFINE_SET_ATTR_VALUE_BOTH(DataType, type) + +void SetAttrValue(StringPiece value, AttrValue* out) { + out->set_s(value.data(), value.size()); +} + +void SetAttrValue(const TensorShape& value, AttrValue* out) { + value.AsProto(out->mutable_shape()); +} + +void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) { + out->mutable_list(); // Create list() even if value empty. + for (const auto& v : value) { + v.AsProto(out->mutable_list()->add_shape()); + } +} + +void SetAttrValue(const Tensor& value, AttrValue* out) { + if (value.NumElements() > 1) { + value.AsProtoTensorContent(out->mutable_tensor()); + } else { + value.AsProtoField(out->mutable_tensor()); + } +} + +void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) { + out->mutable_list(); // Create list() even if value empty. + for (const auto& v : value) { + if (v.NumElements() > 1) { + v.AsProtoTensorContent(out->mutable_list()->add_tensor()); + } else { + v.AsProtoField(out->mutable_list()->add_tensor()); + } + } +} + +void SetAttrValue(const TensorProto& value, AttrValue* out) { + *out->mutable_tensor() = value; +} + +void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) { + out->mutable_list(); // Create list() even if value empty. + for (const auto& v : value) { + *out->mutable_list()->add_tensor() = v; + } +} + +void SetAttrValue(const NameAttrList& value, AttrValue* out) { + *out->mutable_func() = value; +} + +bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) { + string a_str, b_str; + a.SerializeToString(&a_str); + b.SerializeToString(&b_str); + // Note: it should be safe to compare proto serializations of the attr + // values since at most one field should be set in each (indeed, it + // must be the same field if they are to compare equal). + // Exception: there are multiple equivalent representations of + // TensorProtos. So a return value of true implies a == b, but not the + // converse. + return a_str == b_str; +} + +bool HasPlaceHolder(const AttrValue& val) { + switch (val.value_case()) { + case AttrValue::kFunc: + for (const auto& p : val.func().attr()) { + if (HasPlaceHolder(p.second)) { + return true; + } + } + break; + case AttrValue::kPlaceholder: + return true; + default: + break; + } + return false; +} + +bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value) { + switch (value->value_case()) { + case AttrValue::kFunc: + for (auto& p : *(value->mutable_func()->mutable_attr())) { + if (!SubstitutePlaceholders(substitute, &p.second)) { + return false; + } + } + break; + case AttrValue::kPlaceholder: + return substitute(value->placeholder(), value); + case AttrValue::VALUE_NOT_SET: + return false; + default: + break; + } + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h new file mode 100644 index 0000000000..1faf74a327 --- /dev/null +++ b/tensorflow/core/framework/attr_value_util.h @@ -0,0 +1,83 @@ +#ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_ + +#include <string> +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// A human-readable rendering of attr_value, that is more concise than a +// text-format proto. +string SummarizeAttrValue(const AttrValue& attr_value); + +// Generates an error if attr_value doesn't have the indicated attr type. +Status AttrValueHasType(const AttrValue& attr_value, StringPiece type); + +// Converts a text proto value from "text" into the the field of *out +// indicated by "type" (e.g. from the type field of an AttrDef). +// Examples: +// * If type:"int" and text:"-14", then *out is set to "i: -14" +// * If type:"list(string)" and text:"['foo', 'bar']", +// then *out is set to "list { s: ['foo', 'bar'] }" +// Returns true on success. +bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out); + +// Sets *out based on the type of value. +void SetAttrValue(const string& value, AttrValue* out); +void SetAttrValue(const char* value, AttrValue* out); +void SetAttrValue(StringPiece value, AttrValue* out); +void SetAttrValue(int64 value, AttrValue* out); +void SetAttrValue(int32 value, AttrValue* out); +void SetAttrValue(float value, AttrValue* out); +void SetAttrValue(double value, AttrValue* out); +void SetAttrValue(bool value, AttrValue* out); +void SetAttrValue(DataType value, AttrValue* out); +void SetAttrValue(const TensorShape& value, AttrValue* out); +void SetAttrValue(const Tensor& value, AttrValue* out); +void SetAttrValue(const TensorProto& value, AttrValue* out); +void SetAttrValue(const NameAttrList& value, AttrValue* out); + +void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice<int64> value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice<int32> value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice<float> value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice<double> value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice<bool> value, AttrValue* out); +void SetAttrValue(const std::vector<bool>& value, AttrValue* out); +void SetAttrValue(std::initializer_list<bool> value, AttrValue* out); +void SetAttrValue(DataTypeSlice value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice<TensorShape> value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out); +void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out); + +inline void SetAttrValue(const AttrValue& value, AttrValue* out) { + *out = value; +} + +// Returns true if a and b have the same value. +// NOTE: May return false negatives for tensor values. +bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b); + +// Returns true if "val" has a placeholder. +bool HasPlaceHolder(const AttrValue& val); + +// SubstitutePlaceholders recursively replaces placeholders in 'value' +// with an attr value by calling SubstituteFunc. Returns true iff all +// placeholders in "value" are replaced with a value. +// +// SubstituteFunc is given a placeholder string. If the placeholder is +// unknown, SubstituteFunc returns false. Otherwise, overwrites the +// attr value and returns true. +typedef std::function<bool(const string&, AttrValue*)> SubstituteFunc; +bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_ diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc new file mode 100644 index 0000000000..bdfbf1707a --- /dev/null +++ b/tensorflow/core/framework/attr_value_util_test.cc @@ -0,0 +1,91 @@ +#include "tensorflow/core/framework/attr_value_util.h" + +#include <gtest/gtest.h> + +namespace tensorflow { + +// A few helpers to construct AttrValue protos. +template <typename T> +AttrValue V(T value) { + AttrValue ret; + SetAttrValue(value, &ret); + return ret; +} + +AttrValue P(const string& p) { + AttrValue ret; + ret.set_placeholder(p); + return ret; +} + +AttrValue F(const string& name, + std::vector<std::pair<string, AttrValue> > pairs) { + AttrValue ret; + ret.mutable_func()->set_name(name); + ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end()); + return ret; +} + +TEST(AttrValueUtil, HasType) { + // OK + EXPECT_TRUE(AttrValueHasType(V(123), "int").ok()); + EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok()); + EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok()); + EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok()); + + // not OK. + EXPECT_FALSE(AttrValueHasType(V(123), "func").ok()); + EXPECT_FALSE(AttrValueHasType(V(1.2), "int").ok()); + EXPECT_FALSE(AttrValueHasType(V(DT_FLOAT), "shape").ok()); + EXPECT_FALSE(AttrValueHasType(F("f", {}), "string").ok()); + EXPECT_FALSE(AttrValueHasType(P("T"), "float").ok()); +} + +SubstituteFunc ReplaceTWith(const AttrValue& val) { + return [val](const string& placeholder, AttrValue* target) { + if (placeholder == "T") { + *target = val; + return true; + } else { + return false; + } + }; +} + +TEST(AttrValueUtil, Basic) { + auto v = F("MatMul", {{"dtype", P("T")}, + {"transpose_a", V(false)}, + {"transpose_b", V(true)}, + {"use_cublas", V(true)}}); + TF_CHECK_OK(AttrValueHasType(v, "func")); + EXPECT_TRUE(HasPlaceHolder(v)); + + EXPECT_EQ( + SummarizeAttrValue(v), + "MatMul[dtype=$T, transpose_a=false, transpose_b=true, use_cublas=true]"); + + SubstitutePlaceholders(ReplaceTWith(V(DT_FLOAT)), &v); + EXPECT_TRUE(!HasPlaceHolder(v)); + EXPECT_EQ(SummarizeAttrValue(v), + "MatMul[dtype=DT_FLOAT, transpose_a=false, transpose_b=true, " + "use_cublas=true]"); +} + +TEST(AttrValueUtil, DeepAttr) { + auto v = F("f", {{"T", P("T")}}); + TF_CHECK_OK(AttrValueHasType(v, "func")); + EXPECT_TRUE(HasPlaceHolder(v)); + + for (int i = 0; i < 3; ++i) { + v = F("f", {{"T", P("T")}, {"F", v}}); + EXPECT_TRUE(HasPlaceHolder(v)); + } + EXPECT_EQ(SummarizeAttrValue(v), "f[F=f[F=f[F=f[T=$T], T=$T], T=$T], T=$T]"); + + SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v); + EXPECT_TRUE(!HasPlaceHolder(v)); + EXPECT_EQ(SummarizeAttrValue(v), + "f[F=f[F=f[F=f[T=x[]], T=x[]], T=x[]], T=x[]]"); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/bfloat16.cc b/tensorflow/core/framework/bfloat16.cc new file mode 100644 index 0000000000..0068283367 --- /dev/null +++ b/tensorflow/core/framework/bfloat16.cc @@ -0,0 +1,22 @@ +#include "tensorflow/core/framework/bfloat16.h" + +namespace tensorflow { + +void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) { + const uint16_t* p = reinterpret_cast<const uint16_t*>(src); + uint16_t* q = reinterpret_cast<uint16_t*>(dst); + for (; size; p += 2, q++, size--) { + *q = p[1]; + } +} + +void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) { + const uint16_t* p = reinterpret_cast<const uint16_t*>(src); + uint16_t* q = reinterpret_cast<uint16_t*>(dst); + for (; size; p++, q += 2, size--) { + q[0] = 0; + q[1] = *p; + } +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/bfloat16.h b/tensorflow/core/framework/bfloat16.h new file mode 100644 index 0000000000..9cd260ee13 --- /dev/null +++ b/tensorflow/core/framework/bfloat16.h @@ -0,0 +1,58 @@ +#ifndef TENSORFLOW_FRAMEWORK_BFLOAT16_H_ +#define TENSORFLOW_FRAMEWORK_BFLOAT16_H_ + +#include "tensorflow/core/platform/port.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +// Compact 16-bit encoding of floating point numbers. This representation uses +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. It +// is assumed that floats are in IEEE 754 format so the representation is just +// bits 16-31 of a single precision float. +// +// NOTE: The IEEE floating point standard defines a float16 format that +// is different than this format (it has fewer bits of exponent and more +// bits of mantissa). We don't use that format here because conversion +// to/from 32-bit floats is more complex for that format, and the +// conversion for this format is very simple. +// +// Because of the existing IEEE float16 type, we do not name our representation +// "float16" but just use "uint16". +// +// <-----our 16bits float-------> +// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f +// <------------------------------float--------------------------> +// 3 3 2 2 1 1 0 +// 1 0 3 2 5 4 0 +// +// +// This type only supports conversion back and forth with float. +// +// This file must be compilable by nvcc. + +namespace tensorflow { +struct bfloat16 { + EIGEN_DEVICE_FUNC bfloat16() {} + EIGEN_DEVICE_FUNC explicit bfloat16(const uint16_t v) : value(v) {} + + uint16_t value; +}; + +// Conversion routines between an array of float and bfloat16 of +// "size". +void FloatToBFloat16(const float* src, bfloat16* dst, int64 size); +void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size); + +} // namespace tensorflow + +namespace Eigen { +template <> +struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {}; + +EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a, + const tensorflow::bfloat16 b) { + return a.value == b.value; +} + +} // namespace Eigen + +#endif // TENSORFLOW_FRAMEWORK_BFLOAT16_H_ diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc new file mode 100644 index 0000000000..4fe791fdeb --- /dev/null +++ b/tensorflow/core/framework/bfloat16_test.cc @@ -0,0 +1,69 @@ +#include "tensorflow/core/framework/bfloat16.h" + +#include "tensorflow/core/platform/test_benchmark.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +TEST(Bfloat16Test, Simple) { + bfloat16 a(12); + EXPECT_EQ(12, a.value); +} + +TEST(Bfloat16Test, Conversion) { + float a[100]; + for (int i = 0; i < 100; ++i) { + a[i] = i + 1.25; + } + bfloat16 b[100]; + float c[100]; + FloatToBFloat16(a, b, 100); + BFloat16ToFloat(b, c, 100); + for (int i = 0; i < 100; ++i) { + // The relative error should be less than 1/(2^7) since bfloat16 + // has 7 bits mantissa. + EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128); + } +} + +static void BM_FloatToBFloat16(int iters) { + testing::StopTiming(); + static const int N = 32 << 20; + const int64 tot = static_cast<int64>(iters) * N; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); + + float* inp = new float[N]; + bfloat16* out = new bfloat16[N]; + + testing::StartTiming(); + while (iters--) { + FloatToBFloat16(inp, out, N); + } + delete[] inp; + delete[] out; +} +BENCHMARK(BM_FloatToBFloat16); + +static void BM_BFloat16ToFloat(int iters) { + testing::StopTiming(); + static const int N = 32 << 20; + const int64 tot = static_cast<int64>(iters) * N; + testing::ItemsProcessed(tot); + testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16))); + + bfloat16* inp = new bfloat16[N]; + float* out = new float[N]; + + testing::StartTiming(); + while (iters--) { + BFloat16ToFloat(inp, out, N); + } + delete[] inp; + delete[] out; +} +BENCHMARK(BM_BFloat16ToFloat); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc new file mode 100644 index 0000000000..51423792a8 --- /dev/null +++ b/tensorflow/core/framework/cancellation.cc @@ -0,0 +1,79 @@ +#include "tensorflow/core/framework/cancellation.h" + +#include <vector> + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +const CancellationToken CancellationManager::kInvalidToken = -1; + +CancellationManager::CancellationManager() + : is_cancelling_(false), is_cancelled_(0), next_cancellation_token_(0) {} + +void CancellationManager::StartCancel() { + std::unordered_map<CancellationToken, CancelCallback> callbacks_to_run; + { + mutex_lock l(mu_); + if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) { + return; + } + is_cancelling_ = true; + std::swap(callbacks_, callbacks_to_run); + } + // We call these callbacks without holding mu_, so that concurrent + // calls to DeregisterCallback, which can happen asynchronously, do + // not block. The callbacks remain valid because any concurrent call + // to DeregisterCallback will block until the + // cancelled_notification_ is notified. + for (auto key_and_value : callbacks_to_run) { + key_and_value.second(); + } + { + mutex_lock l(mu_); + is_cancelling_ = false; + is_cancelled_.store(true, std::memory_order_release); + } + cancelled_notification_.Notify(); +} + +CancellationToken CancellationManager::get_cancellation_token() { + mutex_lock l(mu_); + return next_cancellation_token_++; +} + +bool CancellationManager::RegisterCallback(CancellationToken token, + CancelCallback callback) { + mutex_lock l(mu_); + CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token"; + bool should_register = !is_cancelled_ && !is_cancelling_; + if (should_register) { + std::swap(callbacks_[token], callback); + } + return should_register; +} + +bool CancellationManager::DeregisterCallback(CancellationToken token) { + mu_.lock(); + if (is_cancelled_) { + mu_.unlock(); + return false; + } else if (is_cancelling_) { + mu_.unlock(); + // Wait for all of the cancellation callbacks to be called. This + // wait ensures that the caller of DeregisterCallback does not + // return immediately and free objects that may be used in the + // execution of any currently pending callbacks in StartCancel. + cancelled_notification_.WaitForNotification(); + return false; + } else { + callbacks_.erase(token); + mu_.unlock(); + return true; + } +} + +CancellationManager::~CancellationManager() { StartCancel(); } + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h new file mode 100644 index 0000000000..feda548e97 --- /dev/null +++ b/tensorflow/core/framework/cancellation.h @@ -0,0 +1,121 @@ +#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_ +#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_ + +#include <atomic> +#include <functional> +#include <unordered_map> + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// A token that can be used to register and deregister a +// CancelCallback with a CancellationManager. +// +// CancellationToken values must be created by a call to +// CancellationManager::get_cancellation_token. +typedef int64 CancellationToken; + +// A callback that is invoked when a step is cancelled. +// +// NOTE(mrry): See caveats about CancelCallback implementations in the +// comment for CancellationManager::RegisterCallback. +typedef std::function<void()> CancelCallback; + +class CancellationManager { + public: + // A value that won't be returned by get_cancellation_token(). + static const CancellationToken kInvalidToken; + + CancellationManager(); + ~CancellationManager(); + + // Run all callbacks associated with this manager. + void StartCancel(); + + // Returns true iff StartCancel() has been called. + bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); } + + // Returns a token that must be used in calls to RegisterCallback + // and DeregisterCallback. + CancellationToken get_cancellation_token(); + + // Attempts to register the given callback to be invoked when this + // manager is cancelled. Returns true if the callback was + // registered; returns false if this manager was already cancelled, + // and the callback was not registered. + // + // If this method returns false, it is the caller's responsibility + // to perform any cancellation cleanup. + // + // This method is tricky to use correctly. The following usage pattern + // is recommended: + // + // class ObjectWithCancellableOperation { + // mutex mu_; + // void CancellableOperation(CancellationManager* cm, + // std::function<void(Status)> callback) { + // bool already_cancelled; + // CancellationToken token = cm->get_cancellation_token(); + // { + // mutex_lock(mu_); + // already_cancelled = cm->RegisterCallback( + // [this, token]() { Cancel(token); }); + // if (!already_cancelled) { + // // Issue asynchronous operation. Associate the pending operation + // // with `token` in some object state, or provide another way for + // // the Cancel method to look up the operation for cancellation. + // // Ensure that `cm->DeregisterCallback(token)` is called without + // // holding `mu_`, before `callback` is invoked. + // // ... + // } + // } + // if (already_cancelled) { + // callback(errors::Cancelled("Operation was cancelled")); + // } + // } + // + // void Cancel(CancellationToken token) { + // mutex_lock(mu_); + // // Take action to cancel the operation with the given cancellation + // // token. + // } + // + // NOTE(mrry): The caller should take care that (i) the calling code + // is robust to `callback` being invoked asynchronously (e.g. from + // another thread), (ii) `callback` is deregistered by a call to + // this->DeregisterCallback(token) when the operation completes + // successfully, and (iii) `callback` does not invoke any method + // on this cancellation manager. Furthermore, it is important that + // the eventual caller of the complementary DeregisterCallback does not + // hold any mutexes that are required by `callback`. + bool RegisterCallback(CancellationToken token, CancelCallback callback); + + // Deregister the callback that, when registered, was associated + // with the given cancellation token. Returns true iff the callback + // was deregistered and will not be invoked; otherwise returns false + // after the callback has been invoked, blocking if necessary. + // + // NOTE(mrry): This method may block if cancellation is in progress. + // The caller of this method must not hold any mutexes that are required + // to invoke any cancellation callback that has been registered with this + // cancellation manager. + bool DeregisterCallback(CancellationToken token); + + private: + bool is_cancelling_; + std::atomic_bool is_cancelled_; + + mutex mu_; + Notification cancelled_notification_; + CancellationToken next_cancellation_token_ GUARDED_BY(mu_); + std::unordered_map<CancellationToken, CancelCallback> callbacks_ + GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_ diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc new file mode 100644 index 0000000000..1925dd20cc --- /dev/null +++ b/tensorflow/core/framework/cancellation_test.cc @@ -0,0 +1,102 @@ +#include "tensorflow/core/framework/cancellation.h" + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +TEST(Cancellation, SimpleNoCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + bool deregistered = manager->DeregisterCallback(token); + EXPECT_TRUE(deregistered); + delete manager; + EXPECT_FALSE(is_cancelled); +} + +TEST(Cancellation, SimpleCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled); + delete manager; +} + +TEST(Cancellation, CancelBeforeRegister) { + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + manager->StartCancel(); + bool registered = manager->RegisterCallback(token, nullptr); + EXPECT_FALSE(registered); + delete manager; +} + +TEST(Cancellation, DeregisterAfterCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled); + bool deregistered = manager->DeregisterCallback(token); + EXPECT_FALSE(deregistered); + delete manager; +} + +TEST(Cancellation, CancelMultiple) { + bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false; + CancellationManager* manager = new CancellationManager(); + auto token_1 = manager->get_cancellation_token(); + bool registered_1 = manager->RegisterCallback( + token_1, [&is_cancelled_1]() { is_cancelled_1 = true; }); + EXPECT_TRUE(registered_1); + auto token_2 = manager->get_cancellation_token(); + bool registered_2 = manager->RegisterCallback( + token_2, [&is_cancelled_2]() { is_cancelled_2 = true; }); + EXPECT_TRUE(registered_2); + EXPECT_FALSE(is_cancelled_1); + EXPECT_FALSE(is_cancelled_2); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled_1); + EXPECT_TRUE(is_cancelled_2); + EXPECT_FALSE(is_cancelled_3); + auto token_3 = manager->get_cancellation_token(); + bool registered_3 = manager->RegisterCallback( + token_3, [&is_cancelled_3]() { is_cancelled_3 = true; }); + EXPECT_FALSE(registered_3); + EXPECT_FALSE(is_cancelled_3); + delete manager; +} + +TEST(Cancellation, IsCancelled) { + CancellationManager* cm = new CancellationManager(); + thread::ThreadPool w(Env::Default(), "test", 4); + std::vector<Notification> done(8); + for (size_t i = 0; i < done.size(); ++i) { + Notification* n = &done[i]; + w.Schedule([n, cm]() { + while (!cm->IsCancelled()) { + } + n->Notify(); + }); + } + Env::Default()->SleepForMicroseconds(1000000 /* 1 second */); + cm->StartCancel(); + for (size_t i = 0; i < done.size(); ++i) { + done[i].WaitForNotification(); + } + delete cm; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto new file mode 100644 index 0000000000..f0def3d6d7 --- /dev/null +++ b/tensorflow/core/framework/config.proto @@ -0,0 +1,61 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +message GPUOptions { + // A value between 0 and 1 that indicates what fraction of the + // available GPU memory to pre-allocate for each process. 1 means + // to pre-allocate all of the GPU memory, 0.5 means the process + // allocates ~50% of the available GPU memory. + double per_process_gpu_memory_fraction = 1; +}; + +// Session configuration parameters. +// The system picks an appropriate values for fields that are not set. +message ConfigProto { + // Map from device type name (e.g., "CPU" or "GPU" ) to maximum + // number of devices of that type to use. If a particular device + // type is not found in the map, the system picks an appropriate + // number. + map<string, int32> device_count = 1; + + // The execution of an individual op (for some op types) can be + // parallelized on a pool of intra_op_parallelism_threads. + // 0 means the system picks an appropriate number. + int32 intra_op_parallelism_threads = 2; + + // Nodes that perform blocking operations are enqueued on a pool of + // inter_op_parallelism_threads available in each process. + // + // 0 means the system picks an appropriate number. + // + // Note that the first Session created in the process sets the + // number of threads for all future sessions. + int32 inter_op_parallelism_threads = 5; + + // Assignment of Nodes to Devices is recomputed every placement_period + // steps until the system warms up (at which point the recomputation + // typically slows down automatically). + int32 placement_period = 3; + + // When any filters are present sessions will ignore all devices which do not + // match the filters. Each filter can be partially specified, e.g. "/job:ps" + // "/job:worker/replica:3", etc. + repeated string device_filters = 4; + + // Options that apply to all GPUs. + GPUOptions gpu_options = 6; + + // Whether soft placement is allowed. If allow_soft_placement is true, + // an op will be placed on CPU if + // 1. there's no GPU implementation for the OP + // or + // 2. no GPU devices are known or registered + // or + // 3. need to co-locate with reftype input(s) which are from CPU. + bool allow_soft_placement = 7; + + // Whether device placements should be logged. + bool log_device_placement = 8; +}; diff --git a/tensorflow/core/framework/control_flow.h b/tensorflow/core/framework/control_flow.h new file mode 100644 index 0000000000..f59e0f5310 --- /dev/null +++ b/tensorflow/core/framework/control_flow.h @@ -0,0 +1,43 @@ +#ifndef TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_ +#define TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_ + +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +const uint64 kIllegalFrameId = ~0uLL; +const int64 kIllegalIterId = -1; + +// For the purpose of control flow, every tensor produced by TensorFlow is +// conceptually tagged by a 'FrameAndIter'. FrameAndIter consists of a +// 'frame_id' and an 'iter_id'. The tensor value it represents is produced +// in the frame with frame_id at the iteration of iter_id. +struct FrameAndIter { + uint64 frame_id = kIllegalFrameId; + int64 iter_id = kIllegalIterId; + + FrameAndIter() {} + + FrameAndIter(uint64 frame, int64 iter) { + frame_id = frame; + iter_id = iter; + } + + bool operator==(const FrameAndIter& other) const { + return (frame_id == other.frame_id && iter_id == other.iter_id); + } +}; + +struct FrameAndIterHash { + size_t operator()(const FrameAndIter& key) const { + // Make sure there are no padding bytes that we don't want + CHECK_EQ(sizeof(uint64) + sizeof(int64), sizeof(FrameAndIter)); + return Hash64(reinterpret_cast<const char*>(&key), sizeof(FrameAndIter)); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_ diff --git a/tensorflow/core/framework/device_attributes.proto b/tensorflow/core/framework/device_attributes.proto new file mode 100644 index 0000000000..7592215d1e --- /dev/null +++ b/tensorflow/core/framework/device_attributes.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +// BusAdjacency identifies the ability of a device to participate in +// maximally efficient DMA operations within the local context of a +// process. +// +// This is currently ignored. +enum BusAdjacency { + BUS_0 = 0; + BUS_1 = 1; + BUS_ANY = 2; + BUS_NUM_ADJACENCIES = 3; +}; + +message DeviceAttributes { + string name = 1; + + // String representation of device_type. + string device_type = 2; + + // Memory capacity of device in bytes. + int64 memory_limit = 4; + + BusAdjacency bus_adjacency = 5; + + // A device is assigned a global unique number each time it is + // initialized. "incarnation" should never be 0. + fixed64 incarnation = 6; + + // String representation of the physical device that this device maps to. + string physical_device_desc = 7; +} diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc new file mode 100644 index 0000000000..83ad199062 --- /dev/null +++ b/tensorflow/core/framework/device_base.cc @@ -0,0 +1,7 @@ +#include "tensorflow/core/framework/device_base.h" + +namespace tensorflow { + +DeviceBase::~DeviceBase() {} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h new file mode 100644 index 0000000000..ed4ffc5d94 --- /dev/null +++ b/tensorflow/core/framework/device_base.h @@ -0,0 +1,172 @@ +#ifndef TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_ +#define TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_ + +#include <memory> +#include <unordered_map> + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/public/status.h" + +namespace Eigen { +class ThreadPoolDevice; +} // end namespace Eigen + +namespace perftools { +namespace gputools { +class Stream; +} // namespace gputools +} // namespace perftools + +namespace tensorflow { + +class Device; +class Env; +class EventMgr; + +namespace thread { +class ThreadPool; +} + +// A wrapper for an Eigen Gpu Device that includes per-op state +class PerOpGpuDevice { + public: + virtual ~PerOpGpuDevice() {} + virtual const Eigen::GpuDevice& device() const = 0; +}; + +// A class that devices can subclass to pass around +// Device-specific context to OpKernels. +class DeviceContext : public core::RefCounted { + public: + ~DeviceContext() override {} + virtual perftools::gputools::Stream* stream() const { return nullptr; } + virtual void MaintainLifetimeOnStream( + const Tensor* t, perftools::gputools::Stream* stream) const {} + + // "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into + // "device_tensor" which is on a GPU device "device". "device_tensor" + // must be allocated to be of the same size as "cpu_tensor". + virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, + StatusCallback done) const { + done(errors::Internal("Unrecognized device type in CPU-to-device Copy")); + } + + // "device_tensor" is a tensor on a non-CPU device. Copies + // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated + // to be of the same size as "device_tensor". + virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor, + const string& tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) { + done(errors::Internal("Unrecognized device type in device-to-CPU Copy")); + } +}; + +typedef std::unordered_map<int, DeviceContext*> DeviceContextMap; + +class DeviceBase { + public: + explicit DeviceBase(Env* env) : env_(env) {} + virtual ~DeviceBase(); + + Env* env() const { return env_; } + + // Override this to return true for devices that require an Op's + // compute method to save references to the temporary tensors it + // allocates until the Op execution completes + virtual bool SaveTemporaryTensors() const { return false; } + + struct CpuWorkerThreads { + int num_threads = 0; + thread::ThreadPool* workers = nullptr; + }; + + // Does not take ownership. + void set_tensorflow_cpu_worker_threads(CpuWorkerThreads* t) { + cpu_worker_threads_ = t; + } + + const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { + CHECK(cpu_worker_threads_ != nullptr); + return cpu_worker_threads_; + } + + // "stream" is used in special circumstances (such as the + // constructors of Ops) where there is no available OpKernelContext. + // "default_context" is used by OpKernelContext whenever a device does not + // supply a DeviceContext for an op in FillContextMap (e.g. when only + // using a single stream.) + // "event_mgr" is used to delay deallocation of temporary GPU buffers. + // TODO(pbar) Work out how to move this out of DeviceBase. + struct GpuDeviceInfo { + perftools::gputools::Stream* stream; + DeviceContext* default_context; + EventMgr* event_mgr; + }; + + // Does not take ownership. + void set_tensorflow_gpu_device_info(GpuDeviceInfo* g) { + gpu_device_info_ = g; + } + + const GpuDeviceInfo* tensorflow_gpu_device_info() const { + return gpu_device_info_; + } + + // Does not take ownership. + void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) { + eigen_cpu_device_ = d; + } + + // Return the Allocator implementation to use based on the allocator + // attributes requested. See allocator.h for more details. + virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) { + LOG(FATAL) << "GetAllocator() is not implemented."; + } + + const Eigen::ThreadPoolDevice* eigen_cpu_device() { + CHECK(eigen_cpu_device_ != nullptr); + return eigen_cpu_device_; + } + + // The caller owns the returned device and must free it by calling + // DisposeGpuDevice below + virtual const PerOpGpuDevice* MakeGpuDevice(DeviceContext* /*dc*/, + Allocator* /*allocator*/) { + // The OpKernelContext calls this even for devices that do not + // implement an eigen_gpu_device + return nullptr; + } + + virtual const DeviceAttributes& attributes() const { + LOG(FATAL) << "Device does not implement attributes()"; + } + + // Materializes the given TensorProto into 'tensor' stored in Device + // memory. Most devices will want to override this. + // + // TODO(vrv): We should be able to put this function into + // OpKernelContext and handle the copies from device memory via send + // and receive nodes, instead of requiring that each device handle + // the copies here as well as in copy ops. + virtual Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) { + return errors::Internal("Device does not implement MakeTensorFromProto()"); + } + + private: + Env* const env_; + CpuWorkerThreads* cpu_worker_threads_ = nullptr; + GpuDeviceInfo* gpu_device_info_ = nullptr; + Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_ diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc new file mode 100644 index 0000000000..493c35e05f --- /dev/null +++ b/tensorflow/core/framework/fake_input.cc @@ -0,0 +1,214 @@ +#include "tensorflow/core/framework/fake_input.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { +namespace { + +class FakeInputImpl { + public: + FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def, + NodeDefBuilder* builder); + void SetN(int n); + void SetDataType(DataType dt); + void SetTypeList(DataTypeSlice dts); + Status AddInputToBuilder(); + + private: + static string FakeNodeName(int in_index); + Status GetN(int* n) const; + Status GetDataType(DataType* dt) const; + void NSources(int n, DataType dt) const; + void SourceList(DataTypeSlice dts) const; + + const OpDef* const op_def_; + const OpDef::ArgDef* const arg_; + const string in_node_; + const NodeDef* const node_def_; + NodeDefBuilder* const builder_; + + bool n_specified_; + int n_; + bool dt_specified_; + DataType dt_; + bool dts_specified_; + DataTypeSlice dts_; +}; + +FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index, + const NodeDef* node_def, NodeDefBuilder* builder) + : op_def_(op_def), + arg_(&op_def->input_arg(in_index)), + in_node_(FakeNodeName(in_index)), + node_def_(node_def), + builder_(builder), + n_specified_(false), + dt_specified_(false), + dts_specified_(false) {} + +void FakeInputImpl::SetN(int n) { + n_specified_ = true; + n_ = n; +} + +void FakeInputImpl::SetDataType(DataType dt) { + dt_specified_ = true; + dt_ = dt; +} + +void FakeInputImpl::SetTypeList(DataTypeSlice dts) { + dts_specified_ = true; + dts_ = dts; +} + +Status FakeInputImpl::AddInputToBuilder() { + if (dts_specified_) { + SourceList(dts_); + + } else if (n_specified_ || !arg_->number_attr().empty()) { + int n; + TF_RETURN_IF_ERROR(GetN(&n)); + + DataType dt; + if (n > 0) { + TF_RETURN_IF_ERROR(GetDataType(&dt)); + } else { + dt = DT_FLOAT; + } + + NSources(n, dt); + } else { + if (!dt_specified_ && !arg_->type_list_attr().empty()) { + DataTypeVector dts; + Status status = + GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts); + if (!status.ok()) { + return errors::InvalidArgument( + "Could not infer list of types for input '", arg_->name(), "': ", + status.error_message()); + } + SourceList(dts); + return Status::OK(); + } + + DataType dt; + TF_RETURN_IF_ERROR(GetDataType(&dt)); + builder_->Input(in_node_, 0, dt); + } + return Status::OK(); +} + +// static +string FakeInputImpl::FakeNodeName(int in_index) { + char c = 'a' + (in_index % 26); + return string(&c, 1); +} + +Status FakeInputImpl::GetN(int* n) const { + if (n_specified_) { + *n = n_; + } else { + Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n); + if (!status.ok()) { + return errors::InvalidArgument("Could not infer length of input '", + arg_->name(), "': ", + status.error_message()); + } + } + return Status::OK(); +} + +Status FakeInputImpl::GetDataType(DataType* dt) const { + if (dt_specified_) { + *dt = dt_; + } else if (arg_->type() != DT_INVALID) { + *dt = arg_->type(); + } else if (!arg_->type_attr().empty()) { + Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt); + if (!status.ok()) { + return errors::InvalidArgument("Could not infer type for input '", + arg_->name(), "': ", + status.error_message()); + } + } else { + return errors::InvalidArgument("No type or type_attr field in arg '", + arg_->name(), "'"); + } + return Status::OK(); +} + +void FakeInputImpl::NSources(int n, DataType dt) const { + std::vector<NodeDefBuilder::NodeOut> srcs; + srcs.reserve(n); + for (int i = 0; i < n; ++i) { + srcs.emplace_back(in_node_, i, dt); + } + builder_->Input(srcs); +} + +void FakeInputImpl::SourceList(DataTypeSlice dts) const { + std::vector<NodeDefBuilder::NodeOut> srcs; + srcs.reserve(dts.size()); + for (size_t i = 0; i < dts.size(); ++i) { + srcs.emplace_back(in_node_, i, dts[i]); + } + builder_->Input(srcs); +} + +} // namespace + +// Public interface ------------------------------------------------------------ + +FakeInputFunctor FakeInput() { + return [](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetDataType(dt); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(int n) { + return [n](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetN(n); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(int n, DataType dt) { + return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetN(n); + impl.SetDataType(dt); + return impl.AddInputToBuilder(); + }; +} + +FakeInputFunctor FakeInput(DataTypeSlice dts) { + // Make a copy to ensure the data will still be around when the lambda is + // called. + DataTypeVector dtv(dts.begin(), dts.end()); + return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + FakeInputImpl impl(&op_def, in_index, &node_def, builder); + impl.SetTypeList(dtv); + return impl.AddInputToBuilder(); + }; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/fake_input.h b/tensorflow/core/framework/fake_input.h new file mode 100644 index 0000000000..39b38e9a59 --- /dev/null +++ b/tensorflow/core/framework/fake_input.h @@ -0,0 +1,25 @@ +#ifndef TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_ +#define TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// These functions return values that may be passed to +// NodeDefBuilder::Input() to add an input for a test. Use them when +// you don't care about the node names/output indices providing the +// input. They also allow you to omit the input types and/or +// list length when they may be inferred. +FakeInputFunctor FakeInput(); // Infer everything +FakeInputFunctor FakeInput(DataType dt); +FakeInputFunctor FakeInput(int n); // List of length n +FakeInputFunctor FakeInput(int n, DataType dt); +FakeInputFunctor FakeInput(DataTypeSlice dts); +inline FakeInputFunctor FakeInput(std::initializer_list<DataType> dts) { + return FakeInput(DataTypeSlice(dts)); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_ diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc new file mode 100644 index 0000000000..b73e1ab8a9 --- /dev/null +++ b/tensorflow/core/framework/function.cc @@ -0,0 +1,878 @@ +#include "tensorflow/core/framework/function.h" + +#include <unordered_set> + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +REGISTER_OP("_Arg") + .Output("output: T") + .Attr("T: type") + .Attr("index: int >= 0") + .Doc(R"doc( +A graph node which represents an argument to a function. + +output: The argument. +index: This argument is the index-th argument of the function. +)doc"); + +REGISTER_OP("_Retval") + .Input("input: T") + .Attr("T: type") + .Attr("index: int >= 0") + .Doc(R"doc( +A graph node which represents a return value of a function. + +input: The return value. +index: This return value is the index-th return value of the function. +)doc"); + +REGISTER_OP("_ListToArray") + .Input("input: Tin") + .Output("output: N * T") + .Attr("Tin: list(type)") + .Attr("T: type") + .Attr("N: int >= 1") + .Doc(R"doc( +Converts a list of tensors to an array of tensors. +)doc"); + +REGISTER_OP("_ArrayToList") + .Input("input: N * T") + .Output("output: out_types") + .Attr("T: type") + .Attr("N: int >= 1") + .Attr("out_types: list(type)") + .Doc(R"doc( +Converts an array of tensors to a list of tensors. +)doc"); + +namespace { + +// Extracts the actual type from "attr_values" based on its definition +// "arg_def". +Status ArgNumType(const InstantiateAttrValueMap& attrs, + const OpDef::ArgDef& arg_def, int* num, DataType* dtype) { + if (!arg_def.type_list_attr().empty()) { + return errors::Unimplemented("type_list is not supported."); + } + + if (arg_def.number_attr().empty()) { + *num = 1; + } else { + const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr()); + if (v == nullptr) { + return errors::NotFound("type attr not found: ", arg_def.type_attr()); + } + *num = v->i(); + } + + if (arg_def.type() != DT_INVALID) { + *dtype = arg_def.type(); + } else if (arg_def.type_attr().empty()) { + *dtype = DT_INVALID; + } else { + const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr()); + if (v == nullptr) { + return errors::NotFound("type attr not found: ", arg_def.type_attr()); + } + *dtype = v->type(); + } + return Status::OK(); +} + +string Name(int node_index) { return strings::StrCat("n", node_index); } + +string Name(int node_index, int output_index) { + if (output_index == 0) { + return Name(node_index); + } else { + return strings::StrCat("n", node_index, ":", output_index); + } +} + +string Dep(int node_index) { return strings::StrCat("^", Name(node_index)); } + +template <typename T> +void AddAttr(const string& name, const T& val, NodeDef* ndef) { + SetAttrValue(val, &((*ndef->mutable_attr())[name])); +} + +Status ValidateSignatureWithAttrs(const OpDef& sig, + const InstantiateAttrValueMap& attr_values) { + // attr_values should specify all attrs defined in fdef. + for (const auto& a : sig.attr()) { + if (attr_values.find(a.name()) == attr_values.end()) { + return errors::NotFound("Attr ", a.name(), " is not found."); + } + } + + for (const auto& p : attr_values) { + if (HasPlaceHolder(p.second)) { + return errors::InvalidArgument(p.first, + " in attr_values is still a placeholder."); + } + } + + return Status::OK(); +} + +// We build a small index for all names that can be used as a node's +// input arguments. +// +// If is_func_arg is true, the name is a function's argument. In +// this case, the produced graph def has gdef.node[nid ... nid + +// num). +// +// Otherwise, the name is a function body's node return value. In +// this case, the produced graph def has one node gdef.node[nid] and +// the node's output index [idx ... idx + num) corresponds to the +// named outputs. +// +// In all cases, "dtype" specifies the data type. +struct NameInfoItem { + bool is_func_arg; + int nid; + int idx; + int num; + DataType dtype; +}; +typedef std::unordered_map<string, NameInfoItem> NameInfoIndex; + +Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, + const InstantiateAttrValueMap& attr_values, + NameInfoIndex* name_info, + InstantiationResult* result) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR(ArgNumType(attr_values, arg_def, &num, &dtype)); + CHECK_GE(num, 1); + GraphDef* gdef = &result->gdef; + int arg_index = gdef->node_size(); + if (!name_info->insert({arg_def.name(), {true, arg_index, 0, num, dtype}}) + .second) { + return errors::InvalidArgument("Duplicated arg name."); + } + // Creates "num" nodes in the gdef. + for (int i = 0; i < num; ++i) { + DCHECK_EQ(arg_index, gdef->node_size()); + NodeDef* gnode = gdef->add_node(); + gnode->set_name(Name(arg_index)); + gnode->set_op("_Arg"); + AddAttr("T", dtype, gnode); + AddAttr("index", arg_index, gnode); + result->arg_types.push_back(dtype); + ++arg_index; + } + return Status::OK(); +} + +Status BuildNodeOutputIndex(const FunctionDef::Node& node, + const InstantiateAttrValueMap& attrs, + GetFunctionSignature get_function, + const int arg_index, NameInfoIndex* name_info) { + const OpDef* node_sig = nullptr; + TF_RETURN_IF_ERROR(get_function(node.op(), &node_sig)); + if (node_sig->output_arg_size() == 0) { + // This node produces no output. + if (node.ret_size() != 1) { + return errors::InvalidArgument("Expect one ret name."); + } + if (!name_info->insert({node.ret(0), {false, arg_index, 0, 0, DT_INVALID}}) + .second) { + return errors::InvalidArgument("Duplicated ret name."); + } + return Status::OK(); + } + + // When the signature says the last return value is of list(type), + // i.e., it's variadic, we need to consult + // attrs[last_retval.type_list_attr] to determine for the last arg + // * the actual number of outputs; + // * the actual data type of outputs. + const int num_retval = node_sig->output_arg_size(); + const OpDef::ArgDef& last_retval = node_sig->output_arg(num_retval - 1); + const bool last_retval_is_typelist = !last_retval.type_list_attr().empty(); + if (!last_retval_is_typelist && (node.ret_size() != num_retval)) { + return errors::InvalidArgument("Malformed function node (#ret)."); + } + int start = 0; + const int num_fixed_size_retval = + last_retval_is_typelist ? num_retval - 1 : num_retval; + for (int i = 0; i < num_fixed_size_retval; ++i) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR( + ArgNumType(attrs, node_sig->output_arg(i), &num, &dtype)); + if (!name_info->insert({node.ret(i), {false, arg_index, start, num, dtype}}) + .second) { + return errors::InvalidArgument("Duplicated ret name."); + } + start += num; + } + if (last_retval_is_typelist) { + const AttrValue* typelist = + gtl::FindOrNull(attrs, last_retval.type_list_attr()); + if (typelist == nullptr) { + return errors::InvalidArgument("Missing attr ", + last_retval.type_list_attr(), "."); + } + if (num_fixed_size_retval + typelist->list().type_size() != + node.ret_size()) { + return errors::InvalidArgument("Wrong #ret: ", num_fixed_size_retval, " ", + typelist->list().type_size(), " ", + node.ret_size(), "."); + } + for (int i = 0; i < typelist->list().type_size(); ++i) { + if (!name_info->insert({node.ret(i), + {false, arg_index, start, 1, + typelist->list().type(i)}}) + .second) { + return errors::InvalidArgument("Duplicated ret name."); + } + ++start; + } + } + return Status::OK(); +} + +Status InstantiateNode(const FunctionDef::Node& fnode, + const InstantiateAttrValueMap& attrs, + GetFunctionSignature get_function, + const NameInfoIndex& name_info, GraphDef* gdef) { + const OpDef* fnode_sig = nullptr; + TF_CHECK_OK(get_function(fnode.op(), &fnode_sig)); + NodeDef* gnode = gdef->add_node(); + gnode->set_name(Name(gdef->node_size() - 1)); + gnode->set_op(fnode.op()); + + // Input + // + // When the signature says the last argument is of list(type), + // i.e., it's variadic, we need to consult + // attrs[last_arg.type_list_attr] to determine for the last arg + // * the number of arguments; + // * the data types of arguments. + const int num_arg = fnode_sig->input_arg_size(); + bool last_arg_is_typelist = false; + if (num_arg > 0 && + !fnode_sig->input_arg(num_arg - 1).type_list_attr().empty()) { + last_arg_is_typelist = true; + } + if (!last_arg_is_typelist && (fnode.arg_size() != num_arg)) { + return errors::InvalidArgument("arg.size != sig.arg.size."); + } + const int num_fixed_size_args = last_arg_is_typelist ? num_arg - 1 : num_arg; + for (int i = 0; i < num_fixed_size_args; ++i) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR( + ArgNumType(attrs, fnode_sig->input_arg(i), &num, &dtype)); + const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i)); + if (item == nullptr) { + return errors::InvalidArgument("arg[", i, "] is not found: ", + fnode.ShortDebugString()); + } + if (num != item->num || dtype != item->dtype) { + return errors::InvalidArgument("Invalid arg(", i, ") for function arg: ", + " ", num, "/", dtype, " vs. ", item->num, + "/", item->dtype, "."); + } + for (int j = 0; j < num; ++j) { + if (item->is_func_arg) { + gnode->add_input(Name(item->nid + j)); + } else { + gnode->add_input(Name(item->nid, item->idx + j)); + } + } + } + if (last_arg_is_typelist) { + AttrValue typelist; + for (int i = num_fixed_size_args; i < fnode.arg_size(); ++i) { + const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i)); + if (item == nullptr) { + return errors::InvalidArgument("arg[", i, "] is not found."); + } + for (int j = 0; j < item->num; ++j) { + if (item->is_func_arg) { + gnode->add_input(Name(item->nid + j)); + } else { + gnode->add_input(Name(item->nid, item->idx + j)); + } + typelist.mutable_list()->add_type(item->dtype); + } + } + + // 'typelist' is inferred from the inputs' data types. + const auto& last_arg = fnode_sig->input_arg(num_arg - 1); + gnode->mutable_attr()->insert({last_arg.type_list_attr(), typelist}); + } + + // Control deps. + for (int i = 0; i < fnode.dep_size(); ++i) { + const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.dep(i)); + if (item == nullptr) { + return errors::InvalidArgument("dep[", i, "] is not found."); + } + gnode->add_input(Dep(item->nid)); + } + + // Attrs. + for (const auto& p : attrs) { + (*gnode->mutable_attr())[p.first] = p.second; + } + + return Status::OK(); +} + +Status AddReturnNode(const OpDef::ArgDef& ret_def, + const InstantiateAttrValueMap& attrs, + const NameInfoIndex& name_info, int* ret_index, + InstantiationResult* result) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &num, &dtype)); + CHECK_GE(num, 1); + const NameInfoItem* item = gtl::FindOrNull(name_info, ret_def.name()); + if (item == nullptr) { + return errors::InvalidArgument("ret is not found."); + } + if (num != item->num || dtype != item->dtype) { + return errors::InvalidArgument("Invalid ret name."); + } + GraphDef* gdef = &result->gdef; + for (int i = 0; i < num; ++i) { + NodeDef* gnode = gdef->add_node(); + gnode->set_name(Name(gdef->node_size() - 1)); + gnode->set_op("_Retval"); + gnode->add_input(Name(item->nid, item->idx + i)); + AddAttr("T", dtype, gnode); + AddAttr("index", (*ret_index)++, gnode); + result->ret_types.push_back(dtype); + } + return Status::OK(); +} + +// Various helpers Print(proto) to print relevant protos to ascii. +string Print(const OpDef::ArgDef& arg) { + string out; + strings::StrAppend(&out, arg.name(), ":"); + if (arg.is_ref()) strings::StrAppend(&out, "Ref("); + if (!arg.number_attr().empty()) { + strings::StrAppend(&out, arg.number_attr(), "*"); + } + if (arg.type() != DT_INVALID) { + strings::StrAppend(&out, DataTypeString(arg.type())); + } else { + strings::StrAppend(&out, arg.type_attr()); + } + if (arg.is_ref()) strings::StrAppend(&out, ")"); + return out; +} + +string Print(const AttrValue& attr_value) { + if (attr_value.value_case() == AttrValue::kType) { + return DataTypeString(attr_value.type()); + } else if ((attr_value.value_case() == AttrValue::kList) && + (attr_value.list().type_size() > 0)) { + string ret = "{"; + for (int i = 0; i < attr_value.list().type_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i))); + } + strings::StrAppend(&ret, "}"); + return ret; + } else if (attr_value.value_case() == AttrValue::kFunc) { + if (attr_value.func().attr_size() == 0) { + return attr_value.func().name(); + } + std::vector<string> entries; + for (auto p : attr_value.func().attr()) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + sort(entries.begin(), entries.end()); + return strings::StrCat(attr_value.func().name(), "[", + str_util::Join(entries, ", "), "]"); + } + return SummarizeAttrValue(attr_value); +} + +string Print(const FunctionDef::Node& node) { + string out; + for (int i = 0; i < node.ret_size(); ++i) { + const auto& name = node.ret(i); + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, name); + } + strings::StrAppend(&out, " = ", node.op()); + if (node.attr_size() > 0) { + std::vector<string> entries; + for (auto p : node.attr()) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + sort(entries.begin(), entries.end()); + strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]"); + } + strings::StrAppend(&out, "("); + for (int i = 0; i < node.arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, node.arg(i)); + } + strings::StrAppend(&out, ")"); + if (node.dep_size() > 0) { + strings::StrAppend(&out, " @ "); + for (int i = 0; i < node.dep_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, node.dep(i)); + } + } + return out; +} + +string Print(const FunctionDef& fdef) { + string out; + const OpDef& sig = fdef.signature(); + strings::StrAppend(&out, "\n", sig.name()); + if (sig.attr_size() > 0) { + strings::StrAppend(&out, "["); + for (int i = 0; i < sig.attr_size(); ++i) { + const auto& a = sig.attr(i); + if (i > 0) strings::StrAppend(&out, ", "); + if (a.type() == "type") { + strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values())); + } else { + strings::StrAppend(&out, a.name(), ":", a.type()); + } + } + strings::StrAppend(&out, "]"); + } + strings::StrAppend(&out, "("); + for (int i = 0; i < sig.input_arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, Print(sig.input_arg(i))); + } + strings::StrAppend(&out, ") -> ("); + for (int i = 0; i < sig.output_arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, Print(sig.output_arg(i))); + } + strings::StrAppend(&out, ") {\n"); + for (const auto& n : fdef.node()) { + strings::StrAppend(&out, " ", Print(n), "\n"); + } + strings::StrAppend(&out, "}\n"); + return out; +} + +string Print(const NodeDef& n) { + string out; + strings::StrAppend(&out, n.name(), " = ", n.op()); + if (n.attr_size() > 0) { + std::vector<string> entries; + for (auto& a : n.attr()) { + entries.push_back(strings::StrCat(a.first, "=", Print(a.second))); + } + sort(entries.begin(), entries.end()); + strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]"); + } + strings::StrAppend(&out, "("); + std::vector<StringPiece> dat; + std::vector<string> dep; + for (StringPiece s : n.input()) { + if (s.Consume("^")) { + dep.push_back(s.ToString()); + } else { + dat.push_back(s); + } + } + strings::StrAppend(&out, str_util::Join(dat, ", "), ")"); + if (!dep.empty()) { + strings::StrAppend(&out, " @ ", str_util::Join(dep, ", ")); + } + return out; +} + +string Print(const GraphDef& gdef) { + std::vector<const NodeDef*> arg; + std::vector<const NodeDef*> ret; + std::vector<const NodeDef*> body; + for (const NodeDef& n : gdef.node()) { + if (n.op() == "_Arg") { + arg.push_back(&n); + } else if (n.op() == "_Retval") { + ret.push_back(&n); + } else { + body.push_back(&n); + } + } + auto comp = [](const NodeDef* x, const NodeDef* y) { + int xi; + TF_CHECK_OK(GetNodeAttr(*x, "index", &xi)); + int yi; + TF_CHECK_OK(GetNodeAttr(*y, "index", &yi)); + return xi < yi; + }; + sort(arg.begin(), arg.end(), comp); + sort(ret.begin(), ret.end(), comp); + string out; + strings::StrAppend(&out, "\n("); + auto get_type = [](const NodeDef& n) { + for (auto a : n.attr()) { + if (a.first == "T") { + return DataTypeString(a.second.type()); + } + } + return DataTypeString(DT_INVALID); + }; + for (size_t i = 0; i < arg.size(); ++i) { + const NodeDef* n = arg[i]; + if (i > 0) strings::StrAppend(&out, ", "); + CHECK_EQ(2, n->attr_size()); + strings::StrAppend(&out, n->name(), ":", get_type(*n)); + } + strings::StrAppend(&out, ") -> ("); + for (size_t i = 0; i < ret.size(); ++i) { + const NodeDef* n = ret[i]; + if (i > 0) strings::StrAppend(&out, ", "); + CHECK_EQ(2, n->attr_size()); + CHECK_EQ(1, n->input_size()); + strings::StrAppend(&out, n->input(0), ":", get_type(*n)); + } + strings::StrAppend(&out, ") {\n"); + for (size_t i = 0; i < body.size(); ++i) { + strings::StrAppend(&out, " ", Print(*body[i]), "\n"); + } + strings::StrAppend(&out, "}\n"); + return out; +} + +} // end namespace + +Status InstantiateFunction(const FunctionDef& fdef, + const InstantiateAttrValueMap& attr_values, + GetFunctionSignature get_function, + InstantiationResult* result) { + const OpDef& sig = fdef.signature(); + GraphDef* gdef = &result->gdef; + gdef->Clear(); + + TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values)); + + auto substitute = [&attr_values](const string& name, AttrValue* val) { + auto iter = attr_values.find(name); + if (iter == attr_values.end()) { + return false; + } else { + *val = iter->second; + return true; + } + }; + + // Makes a copy of all attrs in fdef and substitutes placeholders. + // After this step, every attr is bound to a concrete value. + std::vector<InstantiateAttrValueMap> node_attrs; + node_attrs.resize(fdef.node_size()); + for (int i = 0; i < fdef.node_size(); ++i) { + for (auto attr : fdef.node(i).attr()) { + if (!SubstitutePlaceholders(substitute, &attr.second)) { + return errors::InvalidArgument("Failed to bind all placeholders in ", + SummarizeAttrValue(attr.second)); + } + CHECK(node_attrs[i].insert(attr).second); + } + } + + NameInfoIndex name_info; + Status s; + for (const OpDef::ArgDef& arg_def : sig.input_arg()) { + s = BuildInputArgIndex(arg_def, attr_values, &name_info, result); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(arg_def)); + return s; + } + } + for (int i = 0; i < fdef.node_size(); ++i) { + s = BuildNodeOutputIndex(fdef.node(i), node_attrs[i], get_function, + gdef->node_size() + i, &name_info); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(fdef.node(i))); + return s; + } + } + + // Emits one gdef.node for each fdef.node. + for (int i = 0; i < fdef.node_size(); ++i) { + s = InstantiateNode(fdef.node(i), node_attrs[i], get_function, name_info, + gdef); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(fdef.node(i))); + return s; + } + } + + // Emits nodes for the function's return values. + int ret_index = 0; + for (const OpDef::ArgDef& ret_def : sig.output_arg()) { + s = AddReturnNode(ret_def, attr_values, name_info, &ret_index, result); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(ret_def)); + return s; + } + } + + return Status::OK(); +} + +string DebugString(const FunctionDef& func_def) { return Print(func_def); } + +string DebugString(const GraphDef& instantiated_func_def) { + return Print(instantiated_func_def); +} + +string DebugStringWhole(const GraphDef& gdef) { + string ret; + for (auto fdef : gdef.library().function()) { + strings::StrAppend(&ret, Print(fdef)); + } + strings::StrAppend(&ret, "\n"); + for (auto ndef : gdef.node()) { + strings::StrAppend(&ret, Print(ndef), "\n"); + } + return ret; +} + +string Canonicalize(const string& funcname, + const InstantiateAttrValueMap& attrs) { + std::vector<string> entries; + entries.reserve(attrs.size()); + for (auto p : attrs) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + sort(entries.begin(), entries.end()); + return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]"); +} + +FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types, + DataTypeSlice ret_types) + : arg_types_(arg_types.begin(), arg_types.end()), + ret_types_(ret_types.begin(), ret_types.end()) { + args_.resize(arg_types_.size()); + rets_.resize(ret_types_.size()); +} + +FunctionCallFrame::~FunctionCallFrame() {} + +Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) { + // Input type checks. + if (args.size() != arg_types_.size()) { + return errors::InvalidArgument("Expects ", arg_types_.size(), + " arguments, but ", args.size(), + " is provided"); + } + for (size_t i = 0; i < args.size(); ++i) { + if (arg_types_[i] != args[i].dtype()) { + return errors::InvalidArgument( + "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ", + DataTypeString(args[i].dtype()), " is provided"); + } + args_[i] = args[i]; + } + return Status::OK(); +} + +Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const { + rets->clear(); + rets->reserve(rets_.size()); + for (size_t i = 0; i < rets_.size(); ++i) { + auto item = rets_[i]; + if (item.has_val) { + rets->push_back(item.val); + } else { + return errors::Internal("Retval[", i, "] does not have value"); + } + } + return Status::OK(); +} + +Status FunctionCallFrame::GetArg(int index, Tensor* val) const { + if (index < 0 || static_cast<size_t>(index) >= args_.size()) { + return errors::OutOfRange("GetArg ", index, " is not within [0, ", + args_.size(), ")"); + } + *val = args_[index]; + return Status::OK(); +} + +Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { + if (index < 0 || static_cast<size_t>(index) >= rets_.size()) { + return errors::OutOfRange("SetRetval ", index, " is not within [0, ", + rets_.size(), ")"); + } + if (val.dtype() != ret_types_[index]) { + return errors::InvalidArgument( + "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]), + ", but ", DataTypeString(val.dtype()), " is provided."); + } + Retval* item = &rets_[index]; + if (!item->has_val) { + item->has_val = true; + item->val = val; + } else { + return errors::Internal("Retval[", index, "] has already been set."); + } + return Status::OK(); +} + +FunctionLibraryDefinition::FunctionLibraryDefinition( + const FunctionDefLibrary& def_lib) + : function_defs_(def_lib.function_size()) { + for (auto fdef : def_lib.function()) { + // The latter function definition wins. + function_defs_[fdef.signature().name()] = fdef; + } +} + +FunctionLibraryDefinition::~FunctionLibraryDefinition() {} + +const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { + auto iter = function_defs_.find(name); + if (iter == function_defs_.end()) { + return nullptr; + } else { + return &iter->second; + } +} + +const OpDef* FunctionLibraryDefinition::LookUp(const string& op, + Status* status) const { + auto fdef = Find(op); + if (fdef != nullptr) { + return &(fdef->signature()); + } + return OpRegistry::Global()->LookUp(op, status); +} + +Status InstantiateFunction(const FunctionDef& fdef, + InstantiateAttrValueSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result) { + InstantiateAttrValueMap m; + for (const auto& aval : attr_values) { + m.insert({aval.first, aval.second.proto}); + } + return InstantiateFunction(fdef, m, get_function, result); +} + +string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) { + InstantiateAttrValueMap m; + for (const auto& aval : attrs) { + m.insert({aval.first, aval.second.proto}); + } + return Canonicalize(funcname, m); +} + +Status FunctionLibraryRuntime::Instantiate(const string& function_name, + InstantiateAttrValueSlice attrs, + Handle* handle) { + InstantiateAttrValueMap m; + for (const auto& aval : attrs) { + m.insert({aval.first, aval.second.proto}); + } + return Instantiate(function_name, m, handle); +} + +void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { + if (val.size() >= 2 && val[0] == '$') { + proto.set_placeholder(val.data() + 1, val.size() - 1); + } else { + SetAttrValue(val, &proto); + } +} + +FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef( + const string& name, + gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) { + AttrValueWrapper ret; + ret.proto.mutable_func()->set_name(name); + for (const auto& a : attrs) { + ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto}); + } + return ret; +} + +FunctionDef::Node FunctionDefHelper::Node::ToProto() const { + FunctionDef::Node n; + for (const string& r : this->ret) { + n.add_ret(r); + } + n.set_op(this->op); + for (const string& a : arg) { + n.add_arg(a); + } + for (const auto& a : this->attr) { + n.mutable_attr()->insert({a.first, a.second.proto}); + } + for (const string& d : dep) { + n.add_dep(d); + } + return n; +} + +/* static */ +FunctionDef FunctionDefHelper::Define(const string& name, + gtl::ArraySlice<string> arg_def, + gtl::ArraySlice<string> ret_def, + gtl::ArraySlice<string> attr_def, + gtl::ArraySlice<Node> node_def) { + FunctionDef fdef; + OpDefBuilder b(name); + for (const auto& a : arg_def) b.Input(a); + for (const auto& r : ret_def) b.Output(r); + for (const auto& a : attr_def) b.Attr(a); + TF_CHECK_OK(b.Finalize(fdef.mutable_signature())); + for (const auto& n : node_def) { + *(fdef.add_node()) = n.ToProto(); + } + return fdef; +} + +FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def, + gtl::ArraySlice<string> ret_def, + gtl::ArraySlice<string> attr_def, + gtl::ArraySlice<Node> node_def) { + return Define("_", arg_def, ret_def, attr_def, node_def); +} + +namespace gradient { + +typedef std::unordered_map<string, Creator> OpGradFactory; + +OpGradFactory* GetOpGradFactory() { + static OpGradFactory* factory = new OpGradFactory; + return factory; +} + +bool RegisterOp(const string& op, Creator func) { + CHECK(GetOpGradFactory()->insert({op, func}).second) + << "Duplicated gradient for " << op; + return true; +} + +Status GetOpGradientCreator(const string& op, Creator* creator) { + auto fac = GetOpGradFactory(); + auto iter = fac->find(op); + if (iter == fac->end()) { + return errors::NotFound("No gradient defined for op: ", op); + } + *creator = iter->second; + return Status::OK(); +} + +} // end namespace gradient + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h new file mode 100644 index 0000000000..1ef93a0533 --- /dev/null +++ b/tensorflow/core/framework/function.h @@ -0,0 +1,376 @@ +#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_ +#define TENSORFLOW_FRAMEWORK_FUNCTION_H_ + +#include <unordered_map> + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +class CancellationManager; +class Node; +class OpKernel; + +// FunctionDefHelper::Define is a convenient helper to construct a +// FunctionDef proto. +// +// E.g., +// FunctionDef my_func = FunctionDefHelper::Define( +// "my_func_name", +// {"x:T", "y:T" /* one string per argument */}, +// {"z:T" /* one string per return value */}, +// {"T: {float, double}" /* one string per attribute */}, +// { +// {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}} +// /* one entry per function node */ +// }) +// +// NOTE: When we have a TFLang parser, we can add another helper: +// FunctionDef FunctionDefHelper::Define(const string& tf_func); +class FunctionDefHelper { + public: + // AttrValueWrapper has copy constructors for the type T so that + // it's easy to construct a simple AttrValue proto. + // + // If T is a string type (const char*, string, or StringPiece), and + // it starts with "$", we construct a AttrValue of "placeholder". + // + // E.g., + // std::<string, AttrValueWrapper> x = {"T", "$T"} + // is a named attr value placeholder. + struct AttrValueWrapper { + AttrValue proto; + + AttrValueWrapper() {} + + template <typename T> + AttrValueWrapper(T val) { // NOLINT(runtime/explicit) + SetAttrValue(val, &proto); + } + + private: + void InitFromString(StringPiece val); + }; + + // Constructs an AttrValue.func given the "name" and "attrs". + static AttrValueWrapper FunctionRef( + const string& name, + gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs); + static AttrValueWrapper FunctionRef(const string& name) { + return FunctionRef(name, {}); + } + + // Node is used to consturct FunctionDef.Node using initialization + // lists. E.g., + // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y + struct Node { + std::vector<string> ret; + string op; + std::vector<string> arg; + std::vector<std::pair<string, AttrValueWrapper>> attr; + std::vector<string> dep; + + FunctionDef::Node ToProto() const; + }; + + static FunctionDef Define(const string& function_name, + gtl::ArraySlice<string> arg_def, + gtl::ArraySlice<string> ret_def, + gtl::ArraySlice<string> attr_def, + gtl::ArraySlice<Node> node_def); + + // Defines an anonymous function. I.e., its name is not relevant. + static FunctionDef Define(gtl::ArraySlice<string> arg_def, + gtl::ArraySlice<string> ret_def, + gtl::ArraySlice<string> attr_def, + gtl::ArraySlice<Node> node_def); + + // Helpers to construct a constant scalar. + template <typename T> + static Node Const(const string& name, const T& val) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum<T>::value; + n.attr.push_back({"dtype", dtype}); + Tensor t(dtype, TensorShape({})); + t.scalar<T>()() = val; + n.attr.push_back({"value", t}); + return n; + } + + template <typename T> + static Node Const(const string& name, gtl::ArraySlice<T> vals) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum<T>::value; + n.attr.push_back({"dtype", dtype}); + int64 num = vals.size(); + Tensor t(dtype, TensorShape({num})); + for (int i = 0; i < vals.size(); ++i) { + t.flat<T>()(i) = vals[i]; + } + n.attr.push_back({"value", t}); + return n; + } +}; + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( + const string& val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { + InitFromString(val); +} + +// Instantiate a function. +// +// "fdef" encodes a TF function with some attrs in fdef.signature.attr +// containing placeholders. InstantiateFunction binds these +// placeholders and produces an instantiated function encoded in +// "result.gdef". The value to substitute a placeholder is given by +// "attr_values", which is a map from a placeholder name to an attr +// value. +// +// InstatiateFunction calls "get_function" to find signatures of other +// functions and primitive ops. + +// Placeholders in "fdef" is substitued based on "attr_values" here. +typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap; +typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>> + InstantiateAttrValueSlice; + +// GetFunctionSignature(func name, opdef) returns OK if the func name is found +// and opdef is filled with a pointer to the corresponding signature +// (a OpDef proto). Otherwise, returns an error. +typedef std::function<Status(const string&, const OpDef**)> + GetFunctionSignature; + +struct InstantiationResult { + DataTypeVector arg_types; + DataTypeVector ret_types; + GraphDef gdef; +}; +Status InstantiateFunction(const FunctionDef& fdef, + const InstantiateAttrValueMap& attr_values, + GetFunctionSignature get_function, + InstantiationResult* result); +Status InstantiateFunction(const FunctionDef& fdef, + InstantiateAttrValueSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result); + +// Returns a debug string for a function definition. +// +// The returned text is multiple-line. It is intended to be +// human-readable rather than being friendly to parsers. It is _NOT_ +// intended to be the canonical string representation of "func_def". +// Particularly, it may not include all information presented in +// "func_def" (e.g., comments, description of the function arguments, +// etc.) +string DebugString(const FunctionDef& func_def); +string DebugString(const GraphDef& instantiated_func_def); + +// Returns a debug string for a top level graph (the main program and +// its supporting functions defined in its library). +string DebugStringWhole(const GraphDef& gdef); + +// Returns a canonicalized string for the instantiation of the +// function of the given "name" and attributes "attrs". +// +// The returned string is guaranteed to be stable within one address +// space. But it may be change as the implementation +// evolves. Therefore, it should not be persisted or compared across +// address spaces. +string Canonicalize(const string& funcname, + const InstantiateAttrValueMap& attrs); +string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs); + +// Represents a function call frame. I.e., the data structure used to +// pass arguments to a function and retrieve its results. +// +// Runtime must arrange accesses to one FunctionCallFrame s.t. +// 1. SetArgs() happens before any GetArg(); +// 2. GetRetvals happens after all SetRetval(); +class FunctionCallFrame { + public: + FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types); + ~FunctionCallFrame(); + + // Caller methods. + Status SetArgs(gtl::ArraySlice<Tensor> args); + Status GetRetvals(std::vector<Tensor>* rets) const; + + // Callee methods. + Status GetArg(int index, Tensor* val) const; + Status SetRetval(int index, const Tensor& val); + + private: + DataTypeVector arg_types_; + DataTypeVector ret_types_; + gtl::InlinedVector<Tensor, 4> args_; + struct Retval { + bool has_val = false; + Tensor val; + }; + gtl::InlinedVector<Retval, 4> rets_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame); +}; + +// Helper to maintain a map between function names in a given +// FunctionDefLibrary and function definitions. +class FunctionLibraryDefinition : public OpRegistryInterface { + public: + explicit FunctionLibraryDefinition(const FunctionDefLibrary& lib_def); + ~FunctionLibraryDefinition() override; + + // Returns nullptr if "func" is not defined in "lib_def". Otherwise, + // returns its definition proto. + const FunctionDef* Find(const string& func) const; + + // OpRegistryInterface method. Useful for constructing a Graph. + // + // If "op" is defined in the library, returns its signature. + // Otherwise, assume "op" is a primitive op and returns its op + // signature. + const OpDef* LookUp(const string& op, Status* status) const override; + + private: + std::unordered_map<string, FunctionDef> function_defs_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryDefinition); +}; + +// Forward declare. Defined in common_runtime/function.h +struct FunctionBody; + +class FunctionLibraryRuntime { + public: + virtual ~FunctionLibraryRuntime() {} + + // Instantiate a function with the given "attrs". + // + // Returns OK and fills in "handle" if the instantiation succeeds. + // Otherwise returns an error and "handle" is undefined. + typedef uint64 Handle; + virtual Status Instantiate(const string& function_name, + const InstantiateAttrValueMap& attrs, + Handle* handle) = 0; + Status Instantiate(const string& function_name, + InstantiateAttrValueSlice attrs, Handle* handle); + + // Returns the function body for the instantiated function given its + // handle 'h'. Returns nullptr if "h" is not found. + // + // *this keeps the ownership of the returned object, which remains alive + // as long as *this. + virtual const FunctionBody* GetFunctionBody(Handle h) = 0; + + // Asynchronously invokes the instantiated function identified by + // "handle". + // + // If function execution succeeds, "done" is called with OK and + // "*rets" is filled with the function's return values. Otheriwse, + // "done" is called with an error status. + // + // Does not take ownership of "rets". + struct Options { + CancellationManager* cancellation_manager = nullptr; + }; + typedef std::function<void(const Status&)> DoneCallback; + virtual void Run(const Options& opts, Handle handle, + gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, + DoneCallback done) = 0; + + // Creates a "kernel" for the given node def "ndef". + // + // If succeeds, returns OK and the caller takes the ownership of the + // returned "*kernel". Otherwise, returns an error. + virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0; + + // Return true iff 'function_name' is the name of a defined function. + virtual bool IsDefined(const string& function_name) = 0; +}; + +// To register a gradient function for a builtin op, one should use +// REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>); +// +// Typically, the c++ grad factory is a plan function that can be +// converted into ::tensorflow::gradient::Creator, which is +// std::function<Status(const AttrSlice&, FunctionDef*)>. +// +// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a +// definition of a brain function which computate the gradient for the +// <op_name> when the <op_name> is instantiated with the given attrs. +// +// E.g., +// +// Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { +// bool transpose_a; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a)); +// bool transpose_b; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b)); +// DataType dtype; +// TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype)); +// if (!transpose_a && !transpose_b) { +// *g = FunctionDefHelper::Define( +// "MatMulGrad", +// {"x:T ", "y:T", "dz:T"}, // Inputs to this function +// {"dx:T", "dy:T"}, // Outputs from this function +// {"T: {float, double}"}, // Attributes needed by this function +// { +// {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}}, +// {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}}, +// {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}}, +// {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}}, +// }); +// } else { +// ... ... +// } +// return Status::OK(); +// } +// +// NOTE: $T is substituted with the type variable "T" when the +// gradient function MatMul is instantiated. +// +// TODO(zhifengc): Better documentation somewhere. + +// Macros to define a gradient function factory for a primitive +// operation. +#define REGISTER_OP_GRADIENT(name, fn) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn) + +#define REGISTER_OP_NO_GRADIENT(name) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr) + +#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \ + REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) + +#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \ + static bool unused_grad_##ctr = ::tensorflow::gradient::RegisterOp(name, fn) + +namespace gradient { +// Register a gradient creator for the "op". +typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator; +bool RegisterOp(const string& op, Creator func); + +// Returns OK the gradient creator for the "op" is found (may be +// nullptr if REGISTER_OP_NO_GRADIENT is used. +Status GetOpGradientCreator(const string& op, Creator* creator); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_ diff --git a/tensorflow/core/framework/function.proto b/tensorflow/core/framework/function.proto new file mode 100644 index 0000000000..4b8a26947c --- /dev/null +++ b/tensorflow/core/framework/function.proto @@ -0,0 +1,68 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// +// TODO(zhifengc): +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // The body of the function. + repeated Node node = 2; // function.node.ret[*] are unique. + + // A node is a multi-value assignment: + // (ret[0], ret[1], ...) = func(arg[0], arg[1], ...) + // + // By convention, "func" is resolved by consulting with a user-defined + // library first. If not resolved, "func" is assumed to be a builtin op. + message Node { + // This node produces multiple outputs. They are named ret[0], + // ret[1], ..., etc. + // + // REQUIRES: function.node.ret[*] are unique across all nodes. + // REQUIRES: ret.size == func/op def's number of output args. + repeated string ret = 1; + + // The op/function name. + string op = 2; + + // Arguments passed to this func/op. + // + // arg[i] must be either one of + // function.signature.input_args[*].name or one of + // function.node[*].ret[*]. + // + // REQUIRES: arg.size == func/op def's number of input args. + repeated string arg = 3; + + // Control dependencies. + // + // dep[i] must be one of function.node[*].ret[*] or one of + // function.signature.input_args[*].name. + repeated string dep = 4; + + // Attrs. + // + // 'attr' maps names defined by 'func's attr defs to attr values. + // attr values may have placeholders which are substituted + // recursively by concrete values when this node is instantiated. + // These placeholdes must name an attr listed in the FunctionDef's + // signature. + map<string, AttrValue> attr = 5; + } +} diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc new file mode 100644 index 0000000000..c9483fad18 --- /dev/null +++ b/tensorflow/core/framework/function_test.cc @@ -0,0 +1,634 @@ +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/port.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +typedef FunctionDefHelper FDH; + +Status GetOpSig(const string& op, const OpDef** sig) { + Status s; + *sig = OpRegistry::Global()->LookUp(op, &s); + return s; +} + +REGISTER_OP("One") + .Output("y: T") + .Attr("T: {float, double, int32, int64}") + .Doc(R"doc( +Returns a tensor with a single element (1) of type T. + +y: A scalar in type T. + +)doc"); + +static InstantiateAttrValueMap kNoAttrs; + +TEST(TFunc, SquarePlusOne) { + RequireDefaultOps(); + auto fdef = FDH::Define( + // Name + "SquarePlusOne", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attrs + {"T: {float, double, int32, int64}"}, + // Nodes + {// a = Square<T>(x) + {{"a"}, "Square", {"x"}, {{"T", "$T"}}}, + // o = One<T>() + // NOTE: We can also have a Cast<Tin, Tout>(x) instead. + {{"o"}, "One", {}, {{"T", "$T"}}}, + // y = Add<T>(a, o) + {{"y"}, "Add", {"a", "o"}, {{"T", "$T"}}}}); + + const char* e = R"P( +SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { + a = Square[T=$T](x) + o = One[T=$T]() + y = Add[T=$T](a, o) +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + // Instantiate one with T=float + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result)); + const char* e2 = R"P( +(n0:float) -> (n3:float) { + n1 = Square[T=float](n0) + n2 = One[T=float]() + n3 = Add[T=float](n1, n2) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +// NOTE: This is the simplest Map op. It takes a f:T->U. +REGISTER_OP("Map") + .Input("x: N * T") + .Output("y: N * U") + .Attr("T: type") + .Attr("U: type") + .Attr("N: int >= 1") + // .Attr("func: func_name_with_attr") + .Doc(R"doc( +Applies the 'func' on every input. I.e., + +y[i] = func<...>(x[i]) + +x: N tensors, each of type T; +y: N tensors, each of type U; + +)doc"); + +TEST(TFunc, AddSquared) { + auto fdef = FDH::Define( + // Name + "AddSquared", + // Args + {"x: N*T"}, + // Return values + {"y: T"}, + // Attrs + {"N:int", "T:{float, double, int32, int64}"}, + // Nodes + {// a = Map<func=Square<$T>,T=$T,U=$T,N=$N>(x) + {{"a"}, + "Map", + {"x"}, + {{"func", FDH::FunctionRef("Square", {{"T", "$T"}})}, + {"T", "$T"}, + {"U", "$T"}, + {"N", "$N"}}}, + // y = AddN<N=$N,T=$T>(a) + {{"y"}, "AddN", {"a"}, {{"N", "$N"}, {"T", "$T"}}}}); + + const char* e = R"P( +AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { + a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x) + y = AddN[N=$N, T=$T](a) +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + // Instantiate one with T=float + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, {{"N", 3}, {"T", DT_FLOAT}}, GetOpSig, + &result)); + const char* e2 = R"P( +(n0:float, n1:float, n2:float) -> (n4:float) { + n3 = Map[N=3, T=float, U=float, func=Square[T=float]](n0, n1, n2) + n4 = AddN[N=3, T=float](n3, n3:1, n3:2) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +TEST(TFunc, ControlDeps) { + auto fdef = FDH::Define( + // Name + "ControlDeps", + // Args + {"x: float"}, + // Return values + {}, + // Attrs + {}, + // Nodes + { + {{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}}, + {{"u"}, "NoOp", {}, {}, {"a"}}, + {{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}}, + {{"v"}, "NoOp", {}, {}, {"b"}}, + {{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}}, + }); + const char* e = R"P( +ControlDeps(x:float) -> () { + a = One[T=float]() @ x + u = NoOp() @ a + b = One[T=float]() @ u + v = NoOp() @ b + c = One[T=float]() @ a, v +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + const char* e2 = R"P( +(n0:float) -> () { + n1 = One[T=float]() @ n0 + n2 = NoOp() @ n1 + n3 = One[T=float]() @ n2 + n4 = NoOp() @ n3 + n5 = One[T=float]() @ n1, n4 +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +TEST(TFunc, XTimesTwo) { + auto expect = R"P( +XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { + two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() + scale = Cast[DstT=$T, SrcT=int64](two) + y = Mul[T=$T](x, scale) +} +)P"; + EXPECT_EQ(expect, DebugString(test::function::XTimesTwo())); +} + +TEST(TFunc, WXPlusB) { + auto expect = R"P( +WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) { + mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x) + y = Add[T=$T](mm, b) +} +)P"; + EXPECT_EQ(expect, DebugString(test::function::WXPlusB())); +} + +TEST(TFunc, Body_TypeList) { + const Tensor kZero = test::AsScalar<int32>(0); + auto fdef = FDH::Define( + // Name + "Test", + // Args + {"i:float"}, + // Return values + {"o:float"}, + // Attrs + {}, + // Nodes + {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}}, + {{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}}, + {{"a", "b", "c", "d"}, + "_ArrayToList", + {"s"}, + {{"N", 4}, + {"T", DT_FLOAT}, + {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}}, + {{"l"}, "Mul", {"a", "b"}, {{"T", DT_FLOAT}}}, + {{"r"}, "Mul", {"c", "d"}, {{"T", DT_FLOAT}}}, + {{"x"}, "_ListToArray", {"l", "r"}, {{"N", 2}, {"T", DT_FLOAT}}}, + {{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}}); + + const char* e = R"P( +Test(i:float) -> (o:float) { + zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() + s = Split[T=float, num_split=4](zero, i) + a, b, c, d = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s) + l = Mul[T=float](a, b) + r = Mul[T=float](c, d) + x = _ListToArray[N=2, T=float](l, r) + o = AddN[N=2, T=float](x) +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + const char* e2 = R"P( +(n0:float) -> (n7:float) { + n1 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() + n2 = Split[T=float, num_split=4](n1, n0) + n3 = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](n2, n2:1, n2:2, n2:3) + n4 = Mul[T=float](n3, n3:1) + n5 = Mul[T=float](n3:2, n3:3) + n6 = _ListToArray[N=2, T=float, Tin={float, float}](n4, n5) + n7 = AddN[N=2, T=float](n6, n6:1) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +REGISTER_OP("Cond") + .Input("input: Tin") + .Output("output: out_types") + .Attr("Tin: list(type)") + .Attr("out_types: list(type)") + .Attr("cond: func") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Doc(R"doc( +output = Cond(input) ? then_branch(input) : else_branch(input) + +cond: A function takes 'input' and returns a scalar. +then_branch: A funcion takes 'input' and returns 'output'. +else_branch: A funcion takes 'input' and returns 'output'. +)doc"); + +TEST(TFunc, Body_Array_List_Converter) { + auto fdef = FDH::Define( + // Name + "MySelect", + // Args + {"x:float"}, + // Return values + {"z:float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x"}, + {{"Tin", DataTypeSlice{DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond")}, + {"then_branch", FDH::FunctionRef("MyThen")}, + {"else_branch", FDH::FunctionRef("MyElse")}}}, + {{"z"}, + "Cond", + {"y", "y"}, + {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }); + + const char* e = R"P( +MySelect(x:float) -> (z:float) { + y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) + z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y) +} +)P"; + EXPECT_EQ(DebugString(fdef), e); + + InstantiationResult result; + TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); + const char* e2 = R"P( +(n0:float) -> (n2:float) { + n1 = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](n0) + n2 = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](n1, n1) +} +)P"; + EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT})); + EXPECT_EQ(DebugString(result.gdef), e2); +} + +static void HasError(const Status& s, const string& substr) { + EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + << s << ", expected substring " << substr; +} + +TEST(InstantiateErrors, Not_Sufficient_Attrs) { + auto fdef = + FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, {{"U", DT_FLOAT}}, GetOpSig, &result), + "T is not found"); +} + +TEST(InstantiateErrors, AttrValue_Value_Placeholder) { + auto fdef = + FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, {{"T", "$bad"}}, GetOpSig, &result), + "T in attr_values is still a placeholder"); +} + +TEST(InstantiateErrors, Unbounded_Attr) { + auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"}, + { + {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result), + "Failed to bind all placeholders"); +} + +TEST(InstantiateErrors, DupArgs) { + auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {}); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Duplicated arg name"); +} + +TEST(InstantiateErrors, Dup_Arg_Node_Name) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Duplicated ret name"); +} + +TEST(InstantiateErrors, Dup_Node_Names) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, + {{"y"}, "One", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Duplicated ret name"); +} + +TEST(InstantiateErrors, Node_Signature_Mismatch_NoOp) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y", "z"}, "NoOp", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Expect one ret name"); +} + +TEST(InstantiateErrors, Node_Signature_Mismatch) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y", "z"}, "One", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Malformed function node (#ret)"); +} + +TEST(InstantiateErrors, Node_Arg_Notfound) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "arg[1] is not found"); +} + +TEST(InstantiateErrors, Node_Arg_Mismatch) { + auto fdef = FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Invalid arg(0) for function arg"); +} + +TEST(InstantiateErrors, Node_Arg_ControlMissing) { + auto fdef = + FDH::Define("test", {"x:float"}, {}, {}, + { + {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "dep[0] is not found"); +} + +TEST(InstantiateErrors, FuncRet_Missing) { + auto fdef = FDH::Define("test", {}, {"y: float"}, {}, + { + {{"x"}, "One", {}, {{"T", DT_FLOAT}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "ret is not found"); +} + +TEST(InstantiateErrors, FuncRet_Mismatch) { + auto fdef = FDH::Define("test", {}, {"y: float"}, {}, + { + {{"y"}, "One", {}, {{"T", DT_DOUBLE}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Invalid ret name.\n\t In y"); +} + +TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) { + auto fdef = FDH::Define( + // Name + "MySelect", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x", "x"}, + {{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Missing attr out_types"); +} + +TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) { + auto fdef = FDH::Define( + // Name + "MySelect", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x", "x"}, + {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "Wrong #ret: 0 2 1"); +} + +TEST(InstantiateErrors, TypeList_Missing_Arg) { + auto fdef = FDH::Define( + // Name + "MySelect", + // Args + {"x: float"}, + // Return values + {"y: float"}, + // Attrs + {}, + // Nodes + { + {{"y"}, + "Cond", + {"x", "unknown"}, + {{"out_types", DataTypeSlice{DT_FLOAT}}, + {"cond", FDH::FunctionRef("MyCond2")}, + {"then_branch", FDH::FunctionRef("MyThen2")}, + {"else_branch", FDH::FunctionRef("MyElse2")}}}, + }); + InstantiationResult result; + HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), + "arg[1] is not found"); +} + +TEST(FunctionCallFrame, Void_Void) { + FunctionCallFrame frame({}, {}); + EXPECT_OK(frame.SetArgs({})); + auto a = test::AsTensor<float>({100}); + HasError(frame.SetArgs({a}), "Invalid argument"); + Tensor v; + HasError(frame.GetArg(0, &v), "Out of range"); + HasError(frame.SetRetval(0, v), "Out of range"); + std::vector<Tensor> rets; + EXPECT_OK(frame.GetRetvals(&rets)); + EXPECT_EQ(rets.size(), 0); +} + +TEST(FunctionCallFrame, Float_Float_Float) { + FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}); + HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments"); + auto a = test::AsTensor<float>({100}); + auto b = test::AsTensor<float>({200}); + auto c = test::AsTensor<int64>({300}); + HasError(frame.SetArgs({a, c}), + "Invalid argument: Expects arg[1] to be float"); + EXPECT_OK(frame.SetArgs({a, b})); + + Tensor v; + HasError(frame.GetArg(-1, &v), "Out of range"); + HasError(frame.GetArg(2, &v), "Out of range"); + EXPECT_OK(frame.GetArg(0, &v)); + test::ExpectTensorEqual<float>(a, v); + EXPECT_OK(frame.GetArg(1, &v)); + test::ExpectTensorEqual<float>(b, v); + + v = test::AsTensor<float>({-100}); + HasError(frame.SetRetval(-1, v), "Out of range"); + HasError(frame.SetRetval(1, v), "Out of range"); + HasError(frame.SetRetval(0, test::AsTensor<int64>({-100})), + "Invalid argument: Expects ret[0] to be float"); + + std::vector<Tensor> rets; + HasError(frame.GetRetvals(&rets), "does not have value"); + EXPECT_OK(frame.SetRetval(0, v)); + HasError(frame.SetRetval(0, v), "has already been set"); + + EXPECT_OK(frame.GetRetvals(&rets)); + EXPECT_EQ(rets.size(), 1); + test::ExpectTensorEqual<float>(rets[0], v); +} + +TEST(Canonicalize, Basic) { + EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT}, + {"transpose_a", false}, + {"transpose_b", false}}), + "MatMul[T=float,transpose_a=false,transpose_b=false]"); + EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT}, + {"transpose_b", false}, + {"transpose_a", false}}), + "MatMul[T=float,transpose_a=false,transpose_b=false]"); + EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_DOUBLE}, + {"transpose_b", true}, + {"transpose_a", false}}), + "MatMul[T=double,transpose_a=false,transpose_b=true]"); +} + +TEST(FunctionLibraryDefinitionTest, Find) { + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + FunctionLibraryDefinition lib_def(proto); + + EXPECT_EQ(lib_def.Find("XTimes16"), nullptr); + + auto expect = R"P( +XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { + two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() + scale = Cast[DstT=$T, SrcT=int64](two) + y = Mul[T=$T](x, scale) +} +)P"; + auto found = lib_def.Find("XTimesTwo"); + ASSERT_NE(found, nullptr); + EXPECT_EQ(expect, DebugString(*found)); +} + +TEST(FunctionLibraryDefinitionTest, LookUp) { + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + FunctionLibraryDefinition lib_def(proto); + + Status s; + EXPECT_EQ(lib_def.LookUp("XTimes16", &s), nullptr); + + auto found = lib_def.LookUp("XTimesTwo", &s); + ASSERT_NE(found, nullptr); + EXPECT_EQ(found->DebugString(), + test::function::XTimesTwo().signature().DebugString()); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc new file mode 100644 index 0000000000..5ead947076 --- /dev/null +++ b/tensorflow/core/framework/function_testlib.cc @@ -0,0 +1,146 @@ +#include "tensorflow/core/framework/function_testlib.h" + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/tensor_testutil.h" + +namespace tensorflow { +namespace test { +namespace function { + +typedef FunctionDefHelper FDH; + +GraphDef GDef(gtl::ArraySlice<NodeDef> nodes, + gtl::ArraySlice<FunctionDef> funcs) { + GraphDef g; + for (auto n : nodes) { + *(g.add_node()) = n; + } + auto lib = g.mutable_library(); + for (auto f : funcs) { + *(lib->add_function()) = f; + } + return g; +} + +// Helper to construct a NodeDef. +NodeDef NDef(const string& name, const string& op, + gtl::ArraySlice<string> inputs, + gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs, + const string& device) { + NodeDef n; + n.set_name(name); + n.set_op(op); + for (auto in : inputs) n.add_input(in); + n.set_device(device); + for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto}); + return n; +} + +FunctionDef NonZero() { + return FDH::Define( + // Name + "NonZero", + // Args + {"x:T"}, + // Return values + {"y:T"}, + // Attr def + {"T:{float, double, int32, int64, string}"}, + // Nodes + { + {{"y"}, "Identity", {"x"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XTimesTwo() { + const Tensor kTwo = test::AsScalar<int64>(2); + return FDH::Define( + // Name + "XTimesTwo", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, + {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XTimesFour() { + return FDH::Define( + // Name + "XTimesFour", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}}, + {{"y"}, "XTimesTwo", {"x2"}, {{"T", "$T"}}}, + }); +} + +FunctionDef XTimes16() { + return FDH::Define( + // Name + "XTimes16", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}}, + {{"y"}, "XTimesFour", {"x4"}, {{"T", "$T"}}}, + }); +} + +FunctionDef WXPlusB() { + return FDH::Define( + // Name + "WXPlusB", + // Args + {"w: T", "x: T", "b: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double}"}, + // Nodes + {{{"mm"}, + "MatMul", + {"w", "x"}, + {{"T", "$T"}, + {"transpose_a", false}, + {"transpose_b", false}, + {"_kernel", "eigen"}}}, + {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}}); +} + +FunctionDef Swap() { + return FDH::Define( + // Name + "Swap", + // Args + {"i0: T", "i1: T"}, + // Return values + {"o0: T", "o1: T"}, + // Attr def + {"T: {float, double}"}, + // Nodes + {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}}, + {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}}); +} + +} // end namespace function +} // end namespace test +} // end namespace tensorflow diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h new file mode 100644 index 0000000000..ed0446ea85 --- /dev/null +++ b/tensorflow/core/framework/function_testlib.h @@ -0,0 +1,53 @@ +#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_ +#define TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_ + +#include <string> + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { +namespace test { +namespace function { + +// Helper to construct a NodeDef. +NodeDef NDef( + const string& name, const string& op, gtl::ArraySlice<string> inputs, + gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>> + attrs = {}, + const string& device = ""); + +// Helper to construct a GraphDef proto. +GraphDef GDef(gtl::ArraySlice<NodeDef> nodes, + gtl::ArraySlice<FunctionDef> funcs = {}); + +// For testing convenience, we provide a few simple functions that can +// be easily executed and tested. + +// x:T -> x * 2. +FunctionDef XTimesTwo(); + +// x:T -> (x * 2) * 2. +FunctionDef XTimesFour(); + +// x:T -> ((x * 2) * 2) * 2. +FunctionDef XTimes16(); + +// w:T, x:T, b:T -> MatMul(w, x) + b +FunctionDef WXPlusB(); + +// x:T -> x:T, T is a type which we automatically converts to a bool. +FunctionDef NonZero(); + +// x:T, y:T -> y:T, x:T +FunctionDef Swap(); + +} // end namespace function +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_ diff --git a/tensorflow/core/framework/graph.proto b/tensorflow/core/framework/graph.proto new file mode 100644 index 0000000000..a9bc07e88c --- /dev/null +++ b/tensorflow/core/framework/graph.proto @@ -0,0 +1,103 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/function.proto"; + +// Represents the graph of operations +// TODO(sanjay): Also want to put the following somewhere: +// * random_seed +// * replicas: Do we stamp them out in python itself? +// * where to load parameters +// * optimizer info? does it go with the parameter layers/ops? +message GraphDef { + repeated NodeDef node = 1; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", {...}} + // map<string, FunctionDefLib> named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= COLOCATED_NODE | PARTIAL_SPEC + // + // COLOCATED_NODE ::= "@" NODE_NAME // See NodeDef.name above. + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "@other/node" (colocate with "other/node") + // * "/job:worker/replica:0/task:1/gpu:3" (full specification) + // * "/job:worker/gpu:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // TODO(josh11b): Add some examples here showing best practices. + map<string, AttrValue> attr = 5; +}; diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc new file mode 100644 index 0000000000..1e0d280126 --- /dev/null +++ b/tensorflow/core/framework/graph_def_util.cc @@ -0,0 +1,25 @@ +#include "tensorflow/core/framework/graph_def_util.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +string SummarizeGraphDef(const GraphDef& graph_def) { + string ret; + for (const NodeDef& node : graph_def.node()) { + strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n"); + } + return ret; +} + +Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { + for (const NodeDef& node : graph_def.node()) { + TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h new file mode 100644 index 0000000000..7a2ec9c7a7 --- /dev/null +++ b/tensorflow/core/framework/graph_def_util.h @@ -0,0 +1,29 @@ +#ifndef TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Produce a human-readable version of a GraphDef that is more concise +// than a text-format proto. +string SummarizeGraphDef(const GraphDef& graph_def); + +// Validates the syntax of a GraphDef provided externally. +// +// The following is an EBNF-style syntax for GraphDef objects. Note that +// Node objects are actually specified as tensorflow::NodeDef protocol buffers, +// which contain many other fields that are not (currently) validated. +// +// Graph = Node * +// Node = NodeName, Inputs +// Inputs = ( DataInput * ), ( ControlInput * ) +// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? +// ControlInput = "^", NodeName +// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * +Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_ diff --git a/tensorflow/core/framework/kernel_def.proto b/tensorflow/core/framework/kernel_def.proto new file mode 100644 index 0000000000..db7856a156 --- /dev/null +++ b/tensorflow/core/framework/kernel_def.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/attr_value.proto"; + +message KernelDef { + // Must match the name of an Op. + string op = 1; + + // Type of device this kernel runs on. + string device_type = 2; + + message AttrConstraint { + // Name of an attr from the Op. + string name = 1; + + // A list of values that this kernel supports for this attr. + // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops. + AttrValue allowed_values = 2; + } + repeated AttrConstraint constraint = 3; + + // Names of the Op's input_/output_args that reside in host memory + // instead of device memory. + repeated string host_memory_arg = 4; + + // This allows experimental kernels to be registered for an op that + // won't be used unless the user specifies a "_kernel" attr with + // value matching this. + string label = 5; +} diff --git a/tensorflow/core/framework/kernel_def_builder.cc b/tensorflow/core/framework/kernel_def_builder.cc new file mode 100644 index 0000000000..8fba883a16 --- /dev/null +++ b/tensorflow/core/framework/kernel_def_builder.cc @@ -0,0 +1,47 @@ +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { + +KernelDefBuilder::KernelDefBuilder(const char* op_name) { + kernel_def_ = new KernelDef; + kernel_def_->set_op(op_name); +} + +KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) { + kernel_def_->set_device_type(device_type); + return *this; +} + +KernelDefBuilder& KernelDefBuilder::TypeConstraint( + const char* attr_name, gtl::ArraySlice<DataType> allowed) { + auto* constraint = kernel_def_->add_constraint(); + constraint->set_name(attr_name); + auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); + for (DataType dt : allowed) { + allowed_values->add_type(dt); + } + return *this; +} + +KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name, + DataType allowed) { + auto* constraint = kernel_def_->add_constraint(); + constraint->set_name(attr_name); + constraint->mutable_allowed_values()->mutable_list()->add_type(allowed); + return *this; +} + +KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) { + kernel_def_->add_host_memory_arg(arg_name); + return *this; +} + +KernelDefBuilder& KernelDefBuilder::Label(const char* label) { + CHECK_EQ(kernel_def_->label(), "") + << "Trying to set a kernel's label a second time: '" << label + << "' in: " << kernel_def_->ShortDebugString(); + kernel_def_->set_label(label); + return *this; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h new file mode 100644 index 0000000000..0c14d1e006 --- /dev/null +++ b/tensorflow/core/framework/kernel_def_builder.h @@ -0,0 +1,77 @@ +#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Builder class passed to the REGISTER_KERNEL_BUILDER() macro. +class KernelDefBuilder { + public: + // Starts with just the name field set. + // Caller MUST call Build() and take ownership of the result. + explicit KernelDefBuilder(const char* op_name); + + ~KernelDefBuilder() { + DCHECK(kernel_def_ == nullptr) << "Did not call Build()"; + } + + // Required: specify the type of device this kernel supports. + // Returns *this. + KernelDefBuilder& Device(const char* device_type); + // KernelDefBuilder& Device(DeviceType device_type); + + // Specify that this kernel supports a limited set of values for a + // particular type or list(type) attr (a further restriction than + // what the Op allows). + // Returns *this. + KernelDefBuilder& TypeConstraint(const char* attr_name, + gtl::ArraySlice<DataType> allowed); + + // Like TypeConstraint but supports just a single type. + KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed); + + // Like TypeConstraint, but (a) gets the type from a template parameter + // and (b) only supports a constraint to a single type. + template <class T> + KernelDefBuilder& TypeConstraint(const char* attr_name); + // TODO(josh11b): Support other types of attr constraints as needed. + + // Specify that this kernel requires/provides an input/output arg + // in host memory (instead of the default, device memory). + // Returns *this. + KernelDefBuilder& HostMemory(const char* arg_name); + + // Specify that this kernel requires a particular value for the + // "_kernel" attr. May only be specified once. Returns *this. + KernelDefBuilder& Label(const char* label); + + // Returns a pointer to a KernelDef with fields set based on the + // above calls to this instance. + // Caller takes ownership of the result. + const KernelDef* Build() { + KernelDef* r = kernel_def_; + kernel_def_ = nullptr; + return r; + } + + private: + KernelDef* kernel_def_; + + TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder); +}; + +// IMPLEMENTATION + +template <class T> +inline KernelDefBuilder& KernelDefBuilder::TypeConstraint( + const char* attr_name) { + return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_ diff --git a/tensorflow/core/framework/kernel_def_builder_test.cc b/tensorflow/core/framework/kernel_def_builder_test.cc new file mode 100644 index 0000000000..eba7144b59 --- /dev/null +++ b/tensorflow/core/framework/kernel_def_builder_test.cc @@ -0,0 +1,76 @@ +#include "tensorflow/core/framework/kernel_def_builder.h" + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +TEST(KernelDefBuilderTest, Basic) { + const KernelDef* def = KernelDefBuilder("A").Device(DEVICE_CPU).Build(); + KernelDef expected; + protobuf::TextFormat::ParseFromString("op: 'A' device_type: 'CPU'", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; +} + +TEST(KernelDefBuilderTest, TypeConstraint) { + const KernelDef* def = KernelDefBuilder("B") + .Device(DEVICE_GPU) + .TypeConstraint<float>("T") + .Build(); + KernelDef expected; + protobuf::TextFormat::ParseFromString(R"proto( + op: 'B' device_type: 'GPU' + constraint { name: 'T' allowed_values { list { type: DT_FLOAT } } } )proto", + &expected); + + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; + + def = KernelDefBuilder("C") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("U") + .TypeConstraint<bool>("V") + .Build(); + + protobuf::TextFormat::ParseFromString(R"proto( + op: 'C' device_type: 'GPU' + constraint { name: 'U' allowed_values { list { type: DT_INT32 } } } + constraint { name: 'V' allowed_values { list { type: DT_BOOL } } } )proto", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; + + def = KernelDefBuilder("D") + .Device(DEVICE_CPU) + .TypeConstraint("W", {DT_DOUBLE, DT_STRING}) + .Build(); + protobuf::TextFormat::ParseFromString(R"proto( + op: 'D' device_type: 'CPU' + constraint { name: 'W' + allowed_values { list { type: [DT_DOUBLE, DT_STRING] } } } )proto", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; +} + +TEST(KernelDefBuilderTest, HostMemory) { + const KernelDef* def = KernelDefBuilder("E") + .Device(DEVICE_GPU) + .HostMemory("in") + .HostMemory("out") + .Build(); + KernelDef expected; + protobuf::TextFormat::ParseFromString( + "op: 'E' device_type: 'GPU' " + "host_memory_arg: ['in', 'out']", + &expected); + EXPECT_EQ(def->DebugString(), expected.DebugString()); + delete def; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc new file mode 100644 index 0000000000..c660b84aa0 --- /dev/null +++ b/tensorflow/core/framework/lookup_interface.cc @@ -0,0 +1,45 @@ +#include "tensorflow/core/framework/lookup_interface.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace lookup { + +Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key, + const Tensor& value) { + if (key.dtype() != key_dtype()) { + return errors::InvalidArgument("Key must be type ", key_dtype(), + " but got ", key.dtype()); + } + if (value.dtype() != value_dtype()) { + return errors::InvalidArgument("Value must be type ", value_dtype(), + " but got ", value.dtype()); + } + if (key.NumElements() != value.NumElements()) { + return errors::InvalidArgument("Number of elements of key(", + key.NumElements(), ") and value(", + value.NumElements(), ") are different."); + } + if (!key.shape().IsSameSize(value.shape())) { + return errors::InvalidArgument("key and value have different shapes."); + } + return Status::OK(); +} + +Status LookupInterface::CheckFindArguments(const Tensor& key, + const Tensor& value, + const Tensor& default_value) { + TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value)); + + if (default_value.dtype() != value_dtype()) { + return errors::InvalidArgument("Default value must be type ", value_dtype(), + " but got ", default_value.dtype()); + } + if (!TensorShapeUtils::IsScalar(default_value.shape())) { + return errors::InvalidArgument("Default values must be scalar."); + } + return Status::OK(); +} + +} // namespace lookup +} // namespace tensorflow diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h new file mode 100644 index 0000000000..d4036d2019 --- /dev/null +++ b/tensorflow/core/framework/lookup_interface.h @@ -0,0 +1,65 @@ +#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace lookup { + +// Lookup interface for batch lookups used by table lookup ops. +class LookupInterface : public ResourceBase { + public: + // Performs batch lookups, for every element in the key tensor, Find returns + // the corresponding value into the values tensor. + // If an element is not present in the table, the given default value is used. + + // For tables that require initialization, Find is available once the table + // is marked as initialized. + + // Returns the following statuses: + // - OK: when the find finishes successfully. + // - FailedPrecondition: if the table is not initialized. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - In addition, other implementations may provide another non-OK status + // specific to their failure modes. + virtual Status Find(const Tensor& keys, Tensor* values, + const Tensor& default_value) = 0; + + // Returns the number of elements in the table. + virtual size_t size() const = 0; + + // Returns the data type of the key. + virtual DataType key_dtype() const = 0; + + // Returns the data type of the value. + virtual DataType value_dtype() const = 0; + + string DebugString() override { return "A lookup table"; } + + protected: + virtual ~LookupInterface() = default; + + // Check format of the key and value tensors. + // Returns OK if all the following requirements are satisfied, otherwise it + // returns InvalidArgument: + // - DataType of the tensor key equals to the table key_dtype + // - DataType of the test value equals to the table value_dtype + // - key and value have the same size and shape + Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values); + + // Check the arguments of a find operation. Returns OK if all the following + // requirements are satisfied, otherwise it returns InvalidArgument: + // - All requirements of CheckKeyAndValueTensors + // - default_value type equals to the table value_dtype + // - default_value is scalar + Status CheckFindArguments(const Tensor& keys, const Tensor& values, + const Tensor& default_value); +}; + +} // namespace lookup +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_ diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc new file mode 100644 index 0000000000..12757f153a --- /dev/null +++ b/tensorflow/core/framework/node_def_builder.cc @@ -0,0 +1,194 @@ +#include "tensorflow/core/framework/node_def_builder.h" + +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name, + const OpRegistryInterface* op_registry) { + node_def_.set_name(name); + Status status; + op_def_ = op_registry->LookUp(op_name, &status); + if (op_def_ == nullptr) { + errors_.push_back(status.error_message()); + inputs_specified_ = 0; + } else { + Initialize(); + } +} + +NodeDefBuilder::NodeDefBuilder(const string& name, const OpDef* op_def) + : op_def_(op_def) { + node_def_.set_name(name); + Initialize(); +} + +void NodeDefBuilder::Initialize() { + inputs_specified_ = 0; + node_def_.set_op(op_def_->name()); +} + +const OpDef::ArgDef* NodeDefBuilder::NextArgDef() { + if (!NextArgAvailable()) return nullptr; + return &op_def_->input_arg(inputs_specified_++); +} + +bool NodeDefBuilder::NextArgAvailable() { + if (op_def_ == nullptr) { + return false; + } else if (inputs_specified_ >= op_def_->input_arg_size()) { + errors_.push_back(strings::StrCat("More Input() calls than the ", + op_def_->input_arg_size(), + " input_args")); + return false; + } + return true; +} + +NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) { + if (NextArgAvailable()) { + Status status = + fake_input(*op_def_, inputs_specified_, node_def_, this); + if (!status.ok()) errors_.push_back(status.error_message()); + } + return *this; +} + +void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, + const string& src_node, int src_index, + DataType dt) { + AddInput(src_node, src_index); + + if (!input_arg->number_attr().empty() || + !input_arg->type_list_attr().empty()) { + errors_.push_back(strings::StrCat("Single tensor passed to '", + input_arg->name(), "', expected list")); + return; + } + + if (input_arg->type() != DT_INVALID) { + const DataType expected = MaybeAddRef(input_arg, input_arg->type()); + VerifyInputType(input_arg, expected, dt); + } else { + VerifyInputRef(input_arg, dt); + Attr(input_arg->type_attr(), BaseType(dt)); + } +} + +void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg, + gtl::ArraySlice<NodeOut> src_list) { + for (const auto& node_out : src_list) { + AddInput(node_out.node, node_out.index); + } + + if (!input_arg->number_attr().empty()) { + Attr(input_arg->number_attr(), static_cast<int64>(src_list.size())); + if (input_arg->type() != DT_INVALID) { + const DataType expected = MaybeAddRef(input_arg, input_arg->type()); + for (const auto& node_out : src_list) { + VerifyInputType(input_arg, expected, node_out.data_type); + } + } else if (!src_list.empty()) { + const DataType base = BaseType(src_list[0].data_type); + Attr(input_arg->type_attr(), base); + const DataType expected = MaybeAddRef(input_arg, base); + for (const auto& node_out : src_list) { + VerifyInputType(input_arg, expected, node_out.data_type); + } + } + } else if (!input_arg->type_list_attr().empty()) { + DataTypeVector type_vec; + type_vec.reserve(src_list.size()); + for (const auto& node_out : src_list) { + const DataType dt = node_out.data_type; + VerifyInputRef(input_arg, dt); + type_vec.push_back(BaseType(dt)); + } + Attr(input_arg->type_list_attr(), type_vec); + } else { + errors_.push_back(strings::StrCat("List provided to input '", + input_arg->name(), + "' when single Tensor expected")); + } +} + +void NodeDefBuilder::AddInput(const string& src_node, int src_index) { + if (src_node.empty()) { + errors_.push_back("Empty input node name"); + } else if (src_node[0] == '^') { + errors_.push_back( + strings::StrCat("Non-control input starting with ^: ", src_node)); + } else if (src_index > 0) { + node_def_.add_input(strings::StrCat(src_node, ":", src_index)); + } else { + node_def_.add_input(src_node); + } +} + +void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg, + DataType expected, DataType dt) { + if (!TypesCompatible(expected, dt)) { + errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ", + DataTypeString(dt), " expected ", + DataTypeString(expected))); + } +} + +void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg, + DataType dt) { + if (input_arg->is_ref() && !IsRefType(dt)) { + errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ", + DataTypeString(dt), + " expected ref type")); + } +} + +Status NodeDefBuilder::Finalize(NodeDef* node_def) const { + const std::vector<string>* errors_ptr = &errors_; + std::vector<string> errors_storage; + if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) { + // Since this is a const method, to add an error, we have to make + // a copy of the existing errors. + errors_storage = errors_; + errors_storage.push_back( + strings::StrCat(inputs_specified_, " inputs specified of ", + op_def_->input_arg_size(), " inputs in Op")); + errors_ptr = &errors_storage; + } + + if (!errors_ptr->empty()) { + if (errors_ptr->size() == 1) { + if (op_def_ == nullptr) { + return errors::InvalidArgument((*errors_ptr)[0], + " while building NodeDef '", + node_def_.name(), "'"); + } + return errors::InvalidArgument( + (*errors_ptr)[0], " while building NodeDef '", node_def_.name(), + "' using ", SummarizeOpDef(*op_def_)); + } else { + return errors::InvalidArgument( + errors_ptr->size(), " errors while building NodeDef '", + node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n", + str_util::Join(*errors_ptr, "\n")); + } + } else { + NodeDef node_def_backup; + if (node_def == nullptr) node_def = &node_def_backup; + *node_def = node_def_; + + // Add control inputs after the regular inputs. + for (const auto& control_input : control_inputs_) { + node_def->add_input(strings::StrCat("^", control_input)); + } + + // Add default values for unspecified attrs. + AddDefaultsToNodeDef(*op_def_, node_def); + + return Status::OK(); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h new file mode 100644 index 0000000000..706f072608 --- /dev/null +++ b/tensorflow/core/framework/node_def_builder.h @@ -0,0 +1,176 @@ +#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ + +#include <functional> +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +class NodeDefBuilder; +typedef std::function<Status(const OpDef&, int, const NodeDef&, + NodeDefBuilder*)> FakeInputFunctor; + +// This is a helper for creating a NodeDef. Automatically sets attrs +// that can be inferred from the inputs, and uses default values +// (where they exist) for unspecified attrs. Example usage: +// +// NodeDef node_def; +// Status status = NodeDefBuilder(node_name, op_name) +// .Input(...) +// .Attr(...) +// .Finalize(&node_def); +// if (!status.ok()) return status; +// // Use node_def here. +class NodeDefBuilder { + public: + // To specify an output to be consumed by one of the Input() methods below. + struct NodeOut { + NodeOut(const string& n, int i, DataType dt) + : node(n), index(i), data_type(dt) {} + NodeOut() {} // uninitialized, call Reset() before use. + void Reset(const string& n, int i, DataType dt) { + node = n; + index = i; + data_type = dt; + } + string node; + int index; + DataType data_type; + }; + + // Specify the name and the Op (either via an OpDef or the name of + // the Op plus a registry) for the NodeDef. Other fields are + // specified by calling the methods below. + // REQUIRES: The OpDef must satisfy ValidateOpDef(). + NodeDefBuilder(const string& name, const string& op_name, + const OpRegistryInterface* op_registry = OpRegistry::Global()); + // REQUIRES: in addition, *op_def must outlive *this. + NodeDefBuilder(const string& name, const OpDef* op_def); + + // You must call one Input() function per input_arg in the Op, + // *and in the same order as the input_args appear in the OpDef.* + + // For inputs that take a single tensor. + NodeDefBuilder& Input(const string& src_node, int src_index, DataType dt) { + const OpDef::ArgDef* arg = NextArgDef(); + if (arg != nullptr) SingleInput(arg, src_node, src_index, dt); + return *this; + } + NodeDefBuilder& Input(const NodeOut& src) { + Input(src.node, src.index, src.data_type); + return *this; + } + + // For inputs that take a list of tensors. + NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list) { + const OpDef::ArgDef* arg = NextArgDef(); + if (arg != nullptr) ListInput(arg, src_list); + return *this; + } + + // To create inputs in tests, see fake_input.h. + NodeDefBuilder& Input(FakeInputFunctor fake_input); + + // Specify that this node must only run after src_node. + NodeDefBuilder& ControlInput(const string& src_node) { + control_inputs_.push_back(src_node); + return *this; + } + + // Constrains what devices this node may be scheduled on. + NodeDefBuilder& Device(const string& device_spec) { + node_def_.set_device(device_spec); + return *this; + } + + // Sets the attr, if not already set. If already set with a different + // value, an error will be returned from Finalize(). + template <class T> + NodeDefBuilder& Attr(const string& attr_name, T&& value); + // Note: overload needed to allow {...} expressions for value. + template <class T> + NodeDefBuilder& Attr(const string& attr_name, + std::initializer_list<T> value) { + Attr<std::initializer_list<T>>(attr_name, std::move(value)); + return *this; + } + + // Finish building the NodeDef, returning any errors or setting + // *node_def if none. + // WARNING: Not all problems are detected! The resulting NodeDef may + // not be valid! Call ValidateNodeDef() from node_def_utils to be sure. + Status Finalize(NodeDef* node_def) const; + + // Accessor for the OpDef set in the constructor. + const OpDef& op_def() const { return *op_def_; } + + private: + // Called in the constructors. + void Initialize(); + + // Get the current ArgDef and advance to the next one. Returns nullptr + // if no more inputs are available. + const OpDef::ArgDef* NextArgDef(); + + // Returns true if there is still an input_arg available in *op_def_, + // otherwise adds to error_ and returns false. + bool NextArgAvailable(); + + // These do the main work of the Input() methods. + void SingleInput(const OpDef::ArgDef* input_arg, const string& src_node, + int src_index, DataType dt); + void ListInput(const OpDef::ArgDef* input_arg, + gtl::ArraySlice<NodeOut> src_list); + + // Add "src_node:src_index" to the list of inputs in the node_def_. + void AddInput(const string& src_node, int src_index); + + // Generate an error if you can't pass dt when expected is expected. + void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected, + DataType dt); + + // If input_arg->is_ref() is true, generate an error if dt is not a ref. + void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt); + + // Makes dt a ref type if that is what the input_arg specifies. + DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) { + return input_arg->is_ref() ? MakeRefType(dt) : dt; + } + + const OpDef* op_def_; + NodeDef node_def_; + int inputs_specified_; + std::vector<string> control_inputs_; + std::vector<string> errors_; +}; + +// IMPLEMENTATION ------------------------------------------------------------- + +template <class T> +NodeDefBuilder& NodeDefBuilder::Attr(const string& attr_name, T&& value) { + const AttrValue* found = AttrSlice(node_def_).Find(attr_name); + if (found == nullptr) { + AddNodeAttr(attr_name, std::forward<T>(value), &node_def_); + } else { + AttrValue attr_value; + SetAttrValue(std::forward<T>(value), &attr_value); + if (!AreAttrValuesEqual(*found, attr_value)) { + errors_.push_back(strings::StrCat( + "Inconsistent values for attr '", attr_name, "' ", + SummarizeAttrValue(*found), " vs. ", SummarizeAttrValue(attr_value))); + } + } + return *this; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_ diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc new file mode 100644 index 0000000000..6fd4a8d1ed --- /dev/null +++ b/tensorflow/core/framework/node_def_builder_test.cc @@ -0,0 +1,1036 @@ +#include "tensorflow/core/framework/node_def_builder.h" + +#include <memory> +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +class NodeDefBuilderTest : public ::testing::Test { + protected: + // Specify an OpDef via an OpDefBuilder. + void Op(const OpDefBuilder& op_def_builder) { + EXPECT_OK(op_def_builder.Finalize(&op_def_)); + } + + // Resets builder_ with a new NodeDefBuilder using the Op from the last call + // to Op() above. + NodeDefBuilder& Builder() { + EXPECT_FALSE(op_def_.name().empty()) << "Must call Op() before Builder()"; + builder_.reset(new NodeDefBuilder("n", &op_def_)); + return *builder_; + } + + // Calls Finalize() and verifies it returns success and the result matches + // expectations. + void ExpectSuccess(const NodeDefBuilder& builder, + DataTypeSlice expected_in_types, + DataTypeSlice expected_out_types, StringPiece proto) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + EXPECT_OK(status); + if (!status.ok()) return; + NodeDef expected; + protobuf::TextFormat::ParseFromString(strings::StrCat("name: 'n' ", proto), + &expected); + EXPECT_EQ(node_def.DebugString(), expected.DebugString()); + + DataTypeVector in_types, out_types; + status = + InOutTypesForNode(node_def, builder.op_def(), &in_types, &out_types); + EXPECT_OK(status); + if (!status.ok()) return; + EXPECT_EQ(DataTypeSliceString(expected_in_types), + DataTypeVectorString(in_types)); + EXPECT_EQ(DataTypeSliceString(expected_out_types), + DataTypeVectorString(out_types)); + + status = ValidateNodeDef(node_def, op_def_); + EXPECT_OK(status); + } + + // Calls Finalize() and verifies it returns an error. + // Each message must appear as a substring of the error. + void ExpectFailures(const NodeDefBuilder& builder, + const std::vector<string>& messages) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); + if (status.ok()) return; + for (const string& message : messages) { + EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + << status << ", " << message; + } + } + + // Calls Finalize() and verifies it returns an error. + // Message must appear as a substring of the error. + void ExpectFailure(const NodeDefBuilder& builder, const string& message) { + ExpectFailures(builder, {message}); + } + + // Like ExpectFailure(), except that the error can come from + // ValidateNodeDef(). + void ExpectInvalid(const NodeDefBuilder& builder, const string& message) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + if (status.ok()) { + status = ValidateNodeDef(node_def, op_def_); + } + EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); + if (status.ok()) return; + EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + << "Actual error: " << status.error_message() + << "\nDoes not contain: " << message; + } + + OpDef op_def_; + std::unique_ptr<NodeDefBuilder> builder_; +}; + +TEST_F(NodeDefBuilderTest, Simple) { + Op(OpDefBuilder("Simple").Input("a: int32").Output("out: float")); + + ExpectSuccess(Builder().Input("x", 0, DT_INT32), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "x" )proto"); + + // Port != 0 + ExpectSuccess(Builder().Input("y", 2, DT_INT32), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "y:2" )proto"); + + // FakeInput + ExpectSuccess(Builder().Input(FakeInput()), {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "a" )proto"); + + // Ref input + ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32}, + {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); + + // ControlInput + ExpectSuccess( + Builder().ControlInput("x").Input(FakeInput()).ControlInput("y"), + {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: ["a", "^x", "^y"] )proto"); + + // Device + ExpectSuccess(Builder().Input(FakeInput()).Device("ddd"), {DT_INT32}, + {DT_FLOAT}, R"proto( + op: "Simple" input: "a" device: "ddd" )proto"); + + // Extra input + ExpectFailure(Builder().Input("x", 0, DT_INT32).Input("y", 0, DT_INT32), + "More Input() calls than the 1 input_args while building " + "NodeDef 'n' using Op<name=Simple; signature=a:int32 -> " + "out:float>"); + + // Missing input + ExpectFailure(Builder(), "0 inputs specified of 1 inputs in Op while"); + + { // Finalize() twice. + NodeDefBuilder& builder = Builder(); + builder.Input(FakeInput()).Finalize(nullptr); // First call to Finalize() + // ExpectSuccess() also calls Finalize(). + ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + } + + { // Input() after Finalize() + NodeDefBuilder& builder = Builder(); + // Calling Finalize() before enough inputs -> error. + ExpectFailure(builder, "0 inputs specified of 1 inputs in Op while"); + builder.Input(FakeInput()); + // Calling Finalize() with enough inputs -> success + ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + // Calling Finalize() with too many inputs -> error. + builder.Input(FakeInput(DT_INT32)); + ExpectFailure(builder, "More Input() calls than the 1 input_args while"); + } + + // Wrong input type + ExpectFailure(Builder().Input("x", 0, DT_FLOAT), + "Input 'a' passed float expected int32 "); + + ExpectFailure(Builder().Input("x", 0, DT_FLOAT_REF), + "Input 'a' passed float_ref expected int32 "); + + // List input + ExpectFailure(Builder().Input(FakeInput(3, DT_FLOAT)), + "List provided to input 'a' when single Tensor expected while"); + + ExpectFailure(Builder().Input(FakeInput(3)), + "List provided to input 'a' when single Tensor expected while"); + + // Bad ControlInput + ExpectInvalid(Builder().Input(FakeInput()).ControlInput("z:2"), + "Control input '^z:2' must not have ':' in NodeDef:"); + + // Bad input name + ExpectFailure(Builder().Input("", 0, DT_INT32), + "Empty input node name while"); + + ExpectFailure(Builder().Input("^x", 0, DT_INT32), + "Non-control input starting with ^: ^x while"); +} + +TEST_F(NodeDefBuilderTest, OpDoesNotExist) { + NodeDefBuilder builder("n", "Op Does Not Exist"); + builder.Input(FakeInput()) + .Input(FakeInput(12)) + .ControlInput("y") + .Attr("foo", 12) + .Device("device"); + ExpectFailure( + builder, + "Op type not registered 'Op Does Not Exist' while building NodeDef 'n'"); +} + +TEST_F(NodeDefBuilderTest, Polymorphic) { + Op(OpDefBuilder("Polymorphic") + .Input("v: T") + .Output("out: T") + .Attr("T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_INT32}, + R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_FLOAT)), {DT_FLOAT}, {DT_FLOAT}, + R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Redundant Attr() + ExpectSuccess(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_BOOL), + {DT_BOOL}, {DT_BOOL}, R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_BOOL } } )proto"); + + // Conficting Attr() + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_STRING), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); + + ExpectFailure(Builder().Attr("T", DT_STRING).Input(FakeInput(DT_BOOL)), + "Inconsistent values for attr 'T' DT_STRING vs. DT_BOOL while"); + + ExpectFailure(Builder().Attr("T", 12).Input(FakeInput(DT_BOOL)), + "Inconsistent values for attr 'T' 12 vs. DT_BOOL while"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicOut) { + Op(OpDefBuilder("PolymorphicOut").Output("out: T").Attr("T: type")); + + ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Redundant attr + ExpectSuccess(Builder().Attr("T", DT_FLOAT).Attr("T", DT_FLOAT), {}, + {DT_FLOAT}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Conflicting attr + ExpectFailure(Builder().Attr("T", DT_BOOL).Attr("T", DT_FLOAT), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_FLOAT while"); + + // Missing attr + ExpectInvalid(Builder(), "NodeDef missing attr 'T' from"); + + // Attr has the wrong type + ExpectInvalid(Builder().Attr("T", {DT_INT32, DT_BOOL}), + "AttrValue had value with type list(type) when type expected"); + + ExpectInvalid(Builder().Attr("T", 12), + "AttrValue had value with type int when type expected"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) { + Op(OpDefBuilder("PolymorphicDefaultOut") + .Output("out: T") + .Attr("T: type = DT_STRING")); + + ExpectSuccess(Builder(), {}, {DT_STRING}, R"proto( + op: "PolymorphicDefaultOut" + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_BOOL), {}, {DT_BOOL}, R"proto( + op: "PolymorphicDefaultOut" + attr { key: "T" value { type: DT_BOOL } } )proto"); +} + +TEST_F(NodeDefBuilderTest, Binary) { + Op(OpDefBuilder("Binary").Input("a: T").Input("b: T").Output("out: T").Attr( + "T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)).Input(FakeInput(DT_INT32)), + {DT_INT32, DT_INT32}, {DT_INT32}, R"proto( + op: "Binary" input: "a" input: "b" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_STRING)).Input(FakeInput()), + {DT_STRING, DT_STRING}, {DT_STRING}, R"proto( + op: "Binary" input: "a" input: "b" + attr { key: "T" value { type: DT_STRING } } )proto"); + + // Type mismatch + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Input(FakeInput(DT_STRING)), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, Restrict) { + Op(OpDefBuilder("Restrict") + .Input("a: T") + .Output("out: T") + .Attr("T: {string, bool}")); + ExpectSuccess(Builder().Input(FakeInput(DT_STRING)), {DT_STRING}, {DT_STRING}, + R"proto( + op: "Restrict" input: "a" + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput(DT_INT32)), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, TypeList) { + Op(OpDefBuilder("TypeList").Input("a: T").Attr("T: list(type)")); + + ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_INT32})), + {DT_STRING, DT_INT32}, {}, R"proto( + op: "TypeList" input: ["a", "a:1"] + attr { key: "T" value { list { type: [DT_STRING, DT_INT32] } } } + )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_BOOL)), + {DT_BOOL, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "TypeList" input: ["a", "a:1", "a:2"] + attr { key: "T" value { list { type: [DT_BOOL, DT_BOOL, DT_BOOL] } } } + )proto"); + + ExpectInvalid(Builder().Input(FakeInput(0)), + "Length for attr 'T' of 0 must be at least minimum 1"); + + ExpectInvalid(Builder().Input(FakeInput({})), + "Length for attr 'T' of 0 must be at least minimum 1"); + + ExpectInvalid(Builder().Input(FakeInput(DT_BOOL)), + "Single tensor passed to 'a', expected list while"); + + ExpectFailures(Builder().Input(FakeInput()), + {"2 errors while building NodeDef", + "Could not infer list of types for input 'a': " + "No attr named 'T' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); +} + +TEST_F(NodeDefBuilderTest, TypeListNoMin) { + Op(OpDefBuilder("TypeListNoMin").Input("a: T").Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(0)), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DataTypeVector())), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput({})), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput({DT_BOOL})), {DT_BOOL}, {}, R"proto( + op: "TypeListNoMin" input: "a" + attr { key: "T" value { list { type: DT_BOOL } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, TypeListTwice) { + Op(OpDefBuilder("TypeListTwice") + .Input("a: T") + .Input("b: T") + .Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder() + .Input(FakeInput({DT_INT32, DT_BOOL})) + .Input(FakeInput({DT_INT32, DT_BOOL})), + {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( + op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] + attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); + + ExpectSuccess( + Builder().Input(FakeInput({DT_INT32, DT_BOOL})).Input(FakeInput()), + {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( + op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] + attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(0)), {}, {}, + R"proto( + op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, + R"proto( + op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); + + ExpectFailure(Builder() + .Input(FakeInput({DT_INT32, DT_BOOL})) + .Input(FakeInput({DT_INT32, DT_STRING})), + "Inconsistent values for attr 'T' [DT_INT32, DT_BOOL] vs. " + "[DT_INT32, DT_STRING] while"); +} + +TEST_F(NodeDefBuilderTest, OutTypeList) { + Op(OpDefBuilder("OutTypeList").Output("out: T").Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder().Attr("T", {DT_FLOAT}), {}, {DT_FLOAT}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { type: DT_FLOAT } } } )proto"); + + ExpectSuccess(Builder().Attr("T", {DT_STRING, DT_BOOL}), {}, + {DT_STRING, DT_BOOL}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); + + ExpectSuccess(Builder().Attr("T", DataTypeVector()), {}, {}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { } } } )proto"); + + ExpectInvalid(Builder().Attr("T", DT_FLOAT), + "AttrValue had value with type type when list(type) expected"); +} + +TEST_F(NodeDefBuilderTest, TypeListRestrict) { + Op(OpDefBuilder("TypeListRestrict") + .Input("a: T") + .Attr("T: list({string, bool}) >= 0")); + + ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_BOOL})), + {DT_STRING, DT_BOOL}, {}, R"proto( + op: "TypeListRestrict" input: ["a", "a:1"] + attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput({DT_STRING, DT_INT32})), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, OutTypeListRestrict) { + Op(OpDefBuilder("OutTypeListRestrict") + .Output("out: t") + .Attr("t: list({string, bool}) >= 0")); + + ExpectSuccess(Builder().Attr("t", {DT_BOOL, DT_STRING}), {}, + {DT_BOOL, DT_STRING}, R"proto( + op: "OutTypeListRestrict" + attr { key: "t" value { list { type: [DT_BOOL, DT_STRING] } } } )proto"); + + ExpectInvalid(Builder().Attr("t", {DT_STRING, DT_INT32}), + "Value for attr 't' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, Attr) { + Op(OpDefBuilder("Attr").Attr("a: int")); + + ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( + op: "Attr" attr { key: "a" value { i: 12 } } )proto"); + + // Attr has wrong type + ExpectInvalid(Builder().Attr("a", "bad"), + "AttrValue had value with type string when int expected"); + + ExpectInvalid(Builder().Attr("a", {12}), + "AttrValue had value with type list(int) when int expected"); + + // Missing attr + ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<"); + + // Wrong attr + ExpectInvalid(Builder().Attr("b", 12), + "NodeDef mentions attr 'b' not in Op<"); + + // Extra attr + ExpectInvalid(Builder().Attr("a", 12).Attr("extra", 12), + "NodeDef mentions attr 'extra' not in Op<"); +} + +TEST_F(NodeDefBuilderTest, AttrFloat) { + Op(OpDefBuilder("AttrFloat").Attr("a: float")); + + ExpectSuccess(Builder().Attr("a", 1.2f /* float */), {}, {}, R"proto( + op: "AttrFloat" attr { key: "a" value { f: 1.2 } } + )proto"); + + ExpectSuccess(Builder().Attr("a", 1.2 /* double */), {}, {}, R"proto( + op: "AttrFloat" attr { key: "a" value { f: 1.2 } } + )proto"); + + // Won't automatically cast int to float + ExpectInvalid(Builder().Attr("a", 12), + "AttrValue had value with type int when float expected"); +} + +TEST_F(NodeDefBuilderTest, AttrBoolList) { + Op(OpDefBuilder("AttrBoolList").Attr("a: list(bool)")); + + ExpectSuccess(Builder().Attr("a", {true, false, true}), {}, {}, R"proto( + op: "AttrBoolList" + attr { key: "a" value { list { b: [true, false, true] } } } + )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector<bool>()), {}, {}, R"proto( + op: "AttrBoolList" attr { key: "a" value { list { } } } + )proto"); + + // Won't cast int -> bool. + ExpectInvalid(Builder().Attr("a", {0}), + "AttrValue had value with type list(int) when list(bool) " + "expected"); +} + +TEST_F(NodeDefBuilderTest, AttrMin) { + Op(OpDefBuilder("AttrMin").Attr("a: int >= 5")); + + ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( + op: "AttrMin" attr { key: "a" value { i: 12 } } )proto"); + + ExpectInvalid(Builder().Attr("a", 2), + "Value for attr 'a' of 2 must be at least minimum 5"); +} + +TEST_F(NodeDefBuilderTest, AttrListMin) { + Op(OpDefBuilder("AttrListMin").Attr("a: list(int) >= 2")); + + ExpectSuccess(Builder().Attr("a", {1, 2}), {}, {}, R"proto( + op: "AttrListMin" + attr { key: "a" value { list { i: [1, 2] } } } )proto"); + + ExpectInvalid(Builder().Attr("a", {17}), + "Length for attr 'a' of 1 must be at least minimum 2"); +} + +TEST_F(NodeDefBuilderTest, AttrEnum) { + Op(OpDefBuilder("AttrEnum").Attr("a: {'apples', 'oranges'}")); + + ExpectSuccess(Builder().Attr("a", "oranges"), {}, {}, R"proto( + op: "AttrEnum" + attr { key: "a" value { s: "oranges" } } )proto"); + + ExpectInvalid( + Builder().Attr("a", "invalid"), + "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " + "\"apples\", \"oranges\""); +} + +TEST_F(NodeDefBuilderTest, AttrEnumList) { + Op(OpDefBuilder("AttrEnumList").Attr("a: list({'apples', 'oranges'})")); + + ExpectSuccess(Builder().Attr("a", {"oranges", "apples"}), {}, {}, R"proto( + op: "AttrEnumList" + attr { key: "a" value { list { s: ["oranges", "apples"] } } } )proto"); + + ExpectInvalid( + Builder().Attr("a", {"apples", "invalid", "oranges"}), + "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " + "\"apples\", \"oranges\""); +} + +TEST_F(NodeDefBuilderTest, AttrShape) { + Op(OpDefBuilder("AttrShape").Attr("a: shape")); + + ExpectSuccess(Builder().Attr("a", TensorShape({5})), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { dim { size: 5 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape({4, 3, 2})), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { + dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape({3, 2})), {}, {}, + R"proto( + op: "AttrShape" + attr { key: "a" value { shape { + dim { size: 3 } dim { size: 2 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape()), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrDefault) { + Op(OpDefBuilder("AttrDefault").Attr("a: string = 'banana'")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrDefault" + attr { key: "a" value { s: "banana" } } )proto"); + + ExpectSuccess(Builder().Attr("a", "kiwi"), {}, {}, R"proto( + op: "AttrDefault" + attr { key: "a" value { s: "kiwi" } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrManyDefault) { + Op(OpDefBuilder("AttrManyDefault") + .Attr("a: string = 'banana'") + .Attr("b: string = 'kiwi'")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrManyDefault" + attr { key: "a" value { s: "banana" } } + attr { key: "b" value { s: "kiwi" } } )proto"); + + Op(OpDefBuilder("AttrManyDefaultWithMandatory") + .Attr("a: string = 'banana'") + .Attr("b: string = 'kiwi'") + .Attr("c: string")); + + ExpectSuccess(Builder().Attr("c", "strawberry"), {}, {}, R"proto( + op: "AttrManyDefaultWithMandatory" + attr { key: "c" value { s: "strawberry" } } + attr { key: "a" value { s: "banana" } } + attr { key: "b" value { s: "kiwi" } } )proto"); + + Op(OpDefBuilder("AttrManyDefaultAndInferred") + .Input("input: T") + .Attr("T: {float, double}") + .Attr("a: string") + .Attr("b: list(string) >= 1") + .Attr("c: bool = true") + .Attr("d: float = 0.3") + .Attr("e: string") + .Attr("f: float = 0.25")); + + ExpectSuccess(Builder() + .Input(FakeInput(DT_FLOAT)) + .Attr("a", "foo") + .Attr("e", "foo") + .Attr("b", std::vector<string>({"bar", "baz"})) + .Attr("f", 1.0f), + {DT_FLOAT}, {}, R"proto( + op: "AttrManyDefaultAndInferred" + input: "a" + attr { key: "T" value { type: DT_FLOAT } } + attr { key: "a" value { s: "foo" } } + attr { key: "e" value { s: "foo" } } + attr { key: "b" value { list { s: "bar" s: "baz" } } } + attr { key: "f" value { f: 1.0 } } + attr { key: "c" value { b: true } } + attr { key: "d" value { f: 0.3 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrListDefault) { + Op(OpDefBuilder("AttrListDefault").Attr("a: list(int) = [5, 15]")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { i: [5, 15] } } } )proto"); + + ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { i: 3 } } } )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrEmptyListDefault) { + Op(OpDefBuilder("AttrEmptyListDefault").Attr("a: list(int) = []")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { } } } )proto"); + + ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { i: 3 } } } )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NIntsIn) { + Op(OpDefBuilder("NIntsIn").Input("a: N*int32").Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2)), {DT_INT32, DT_INT32}, {}, + R"proto( + op: "NIntsIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(5, DT_INT32)), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "NIntsIn" + input: ["a", "a:1", "a:2", "a:3", "a:4"] + attr { key: "N" value { i: 5 } } )proto"); + + ExpectFailures(Builder().Input(FakeInput(2, DT_STRING)), + {"2 errors while building NodeDef", + "Input 'a' passed string expected int32"}); + + ExpectInvalid(Builder().Input(FakeInput(1)), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectFailures( + Builder().Input(FakeInput(DT_INT32)), + {"2 errors while building NodeDef", + "Could not infer length of input 'a': No attr named 'N' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); + + ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), + "Input 'a' passed string expected int32 while"); + + ExpectFailures( + Builder().Input(FakeInput()), + {"2 errors while building NodeDef", + "Could not infer length of input 'a': No attr named 'N' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicIn) { + Op(OpDefBuilder("NPolymorphicIn") + .Input("a: N*T") + .Attr("T: type") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)), {DT_INT32, DT_INT32}, + {}, R"proto( + op: "NPolymorphicIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), + {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( + op: "NPolymorphicIn" + input: ["a", "a:1", "a:2"] + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectFailures( + Builder().Input(FakeInput(2)), + {"2 errors while building NodeDef", + "Could not infer type for input 'a': No attr named 'T' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); + + ExpectFailure(Builder().Input(FakeInput({DT_INT32, DT_STRING})), + "Input 'a' passed string expected int32 while"); + + ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), + "Input 'a' passed string expected int32 while"); + + ExpectInvalid(Builder().Input(FakeInput(1, DT_INT32)), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectFailure(Builder().Input("in", 0, DT_INT32), + "Single tensor passed to 'a', expected list while"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicRestrictIn) { + Op(OpDefBuilder("NPolymorphicRestrictIn") + .Input("a: N*T") + .Attr("T: {string, bool}") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_BOOL)), {DT_BOOL, DT_BOOL}, {}, + R"proto( + op: "NPolymorphicRestrictIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), + {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( + op: "NPolymorphicRestrictIn" + input: ["a", "a:1", "a:2"] + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput(2, DT_INT32)), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, NInTwice) { + Op(OpDefBuilder("NInTwice") + .Input("a: N*int32") + .Input("b: N*string") + .Attr("N: int >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(2)).Input(FakeInput(2)), + {DT_INT32, DT_INT32, DT_STRING, DT_STRING}, {}, R"proto( + op: "NInTwice" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, + R"proto( + op: "NInTwice" attr { key: "N" value { i: 0 } } )proto"); + + ExpectFailure(Builder().Input(FakeInput(3)).Input(FakeInput(1)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); +} + +TEST_F(NodeDefBuilderTest, NInPolymorphicTwice) { + Op(OpDefBuilder("NInPolymorphicTwice") + .Input("a: N*T") + .Input("b: N*T") + .Attr("T: type") + .Attr("N: int >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput()), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "NInPolymorphicTwice" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_INT32)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); + + ExpectFailure(Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, NInTwoTypeVariables) { + Op(OpDefBuilder("NInTwoTypeVariables") + .Input("a: N*S") + .Input("b: N*T") + .Attr("S: type") + .Attr("T: type") + .Attr("N: int >= 0")); + + ExpectSuccess( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_BOOL)), + {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "NInTwoTypeVariables" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "S" value { type: DT_INT32 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_BOOL)), + {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "NInTwoTypeVariables" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "S" value { type: DT_INT32 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_STRING)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); +} + +TEST_F(NodeDefBuilderTest, InPolymorphicTwice) { + Op(OpDefBuilder("InPolymorphicTwice") + .Input("a: N*T") + .Input("b: M*T") + .Attr("T: type") + .Attr("N: int >= 0") + .Attr("M: int >= 0")); + + ExpectSuccess( + Builder().Input(FakeInput(1, DT_INT32)).Input(FakeInput(3, DT_INT32)), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "InPolymorphicTwice" + input: ["a", "b", "b:1", "b:2"] + attr { key: "N" value { i: 1 } } + attr { key: "T" value { type: DT_INT32 } } + attr { key: "M" value { i: 3 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(1, DT_BOOL)).Input(FakeInput(0)), + {DT_BOOL}, {}, R"proto( + op: "InPolymorphicTwice" input: "a" + attr { key: "N" value { i: 1 } } + attr { key: "T" value { type: DT_BOOL } } + attr { key: "M" value { i: 0 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(1, DT_BOOL)), + {DT_BOOL}, {}, R"proto( + op: "InPolymorphicTwice" input: "b" + attr { key: "N" value { i: 0 } } + attr { key: "M" value { i: 1 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, NIntsOut) { + Op(OpDefBuilder("NIntsOut").Output("a: N*int32").Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( + op: "NIntsOut" + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3), {}, {DT_INT32, DT_INT32, DT_INT32}, + R"proto( + op: "NIntsOut" + attr { key: "N" value { i: 3 } } )proto"); + + ExpectInvalid(Builder().Attr("N", 1), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectInvalid(Builder().Attr("N", {3}), + "AttrValue had value with type list(int) when int expected"); + + ExpectInvalid(Builder(), "NodeDef missing attr 'N' from"); +} + +TEST_F(NodeDefBuilderTest, NIntsOutDefault) { + Op(OpDefBuilder("NIntsOutDefault") + .Output("a: N*int32") + .Attr("N: int >= 2 = 3")); + + ExpectSuccess(Builder(), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto( + op: "NIntsOutDefault" + attr { key: "N" value { i: 3 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( + op: "NIntsOutDefault" + attr { key: "N" value { i: 2 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicOut) { + Op(OpDefBuilder("NPolymorphicOut") + .Output("a: N*T") + .Attr("T: type") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("T", DT_INT32).Attr("N", 2), {}, + {DT_INT32, DT_INT32}, R"proto( + op: "NPolymorphicOut" + attr { key: "T" value { type: DT_INT32 } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_STRING), {}, + {DT_STRING, DT_STRING, DT_STRING}, R"proto( + op: "NPolymorphicOut" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectInvalid(Builder().Attr("N", 3).Attr("T", {DT_STRING}), + "AttrValue had value with type list(type) when type expected"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) { + Op(OpDefBuilder("NPolymorphicOutDefault") + .Output("a: N*T") + .Attr("T: type = DT_BOOL") + .Attr("N: int >= 2 = 2")); + + ExpectSuccess(Builder(), {}, {DT_BOOL, DT_BOOL}, R"proto( + op: "NPolymorphicOutDefault" + attr { key: "T" value { type: DT_BOOL } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3), {}, {DT_BOOL, DT_BOOL, DT_BOOL}, + R"proto( + op: "NPolymorphicOutDefault" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32}, + R"proto( + op: "NPolymorphicOutDefault" + attr { key: "T" value { type: DT_INT32 } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_INT32), {}, + {DT_INT32, DT_INT32, DT_INT32}, R"proto( + op: "NPolymorphicOutDefault" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicRestrictOut) { + Op(OpDefBuilder("NPolymorphicRestrictOut") + .Output("a: N*T") + .Attr("T: {string, bool}") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_BOOL), {}, + {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto( + op: "NPolymorphicRestrictOut" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectInvalid(Builder().Attr("N", 3).Attr("T", DT_INT32), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, RefIn) { + Op(OpDefBuilder("RefIn").Input("a: Ref(int32)")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32_REF}, {}, + R"proto( + op: "RefIn" input: "a" )proto"); + + ExpectFailure(Builder().Input(FakeInput(DT_BOOL_REF)), + "Input 'a' passed bool_ref expected int32_ref while"); + + ExpectFailure(Builder().Input(FakeInput(DT_INT32)), + "Input 'a' passed int32 expected int32_ref while"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicRefIn) { + Op(OpDefBuilder("PolymorphicRefIn").Input("a: Ref(T)").Attr("T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_BOOL_REF)), {DT_BOOL_REF}, {}, + R"proto( + op: "PolymorphicRefIn" input: "a" + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)), + "Input 'a' passed bool expected ref type while"); +} + +TEST_F(NodeDefBuilderTest, RefOut) { + Op(OpDefBuilder("RefOut").Output("a: Ref(string)")); + + ExpectSuccess(Builder(), {}, {DT_STRING_REF}, R"proto( + op: "RefOut" )proto"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicRefOut) { + Op(OpDefBuilder("PolymorphicRefOut").Output("a: Ref(t)").Attr("t: type")); + + ExpectSuccess(Builder().Attr("t", DT_BOOL), {}, {DT_BOOL_REF}, R"proto( + op: "PolymorphicRefOut" + attr { key: "t" value { type: DT_BOOL } } )proto"); +} + +TEST_F(NodeDefBuilderTest, SpecifyDevice) { + Op(OpDefBuilder("SpecifyDevice")); + + ExpectSuccess(Builder().Device("ADevice"), {}, {}, R"proto( + op: "SpecifyDevice" device: "ADevice" )proto"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc new file mode 100644 index 0000000000..aefd416187 --- /dev/null +++ b/tensorflow/core/framework/node_def_util.cc @@ -0,0 +1,414 @@ +#include "tensorflow/core/framework/node_def_util.h" + +#include <algorithm> +#include <unordered_map> + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +string SummarizeNodeDef(const NodeDef& node_def) { + string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "["); + + // We sort the attrs so the output is deterministic. + std::vector<string> attr_names; + attr_names.reserve(node_def.attr().size()); + for (const auto& attr : node_def.attr()) { + attr_names.push_back(attr.first); + } + std::sort(attr_names.begin(), attr_names.end()); + bool first = true; + for (const string& attr_name : attr_names) { + if (!first) strings::StrAppend(&ret, ", "); + first = false; + auto iter = node_def.attr().find(attr_name); + strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second)); + } + + // Consider the device to be a final attr with name "_device". + if (!node_def.device().empty()) { + if (!first) strings::StrAppend(&ret, ", "); + first = false; + strings::StrAppend(&ret, "_device=\"", node_def.device(), "\""); + } + strings::StrAppend(&ret, "]("); + + // Output inputs, including control inputs, verbatim. + first = true; + for (const string& input : node_def.input()) { + if (!first) strings::StrAppend(&ret, ", "); + first = false; + strings::StrAppend(&ret, input); + } + strings::StrAppend(&ret, ")"); + return ret; +} + +const AttrValue* AttrSlice::Find(const string& attr_name) const { + auto iter = attrs_->find(attr_name); + if (iter == attrs_->end()) return nullptr; + return &iter->second; +} + +Status AttrSlice::Find(const string& attr_name, + const AttrValue** attr_value) const { + *attr_value = Find(attr_name); + if (*attr_value != nullptr) { + return Status::OK(); + } + Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:"); + if (ndef_) { + s = AttachDef(s, *ndef_); + } + return s; +} + +// The ... is to allow the caller to inject some value validation code. Use +// just ; if no additional validation code is needed. +#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ + Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \ + TYPE* value) { \ + const AttrValue* attr_value; \ + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \ + const auto& v = attr_value->FIELD(); \ + __VA_ARGS__; \ + *value = CAST; \ + return Status::OK(); \ + } \ + Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \ + std::vector<TYPE>* value) { \ + const AttrValue* attr_value; \ + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \ + for (const auto& v : attr_value->list().FIELD()) { \ + __VA_ARGS__; \ + value->APPEND_OP(CAST); \ + } \ + return Status::OK(); \ + } + +DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;) +DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;) +DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast<int32>(v), + if (static_cast<int64>(static_cast<int32>(v)) != v) { + return errors::InvalidArgument("Attr ", attr_name, + " has value ", v, + " out of range for an int32"); + }) +DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;) +// std::vector<bool> specialization does not have emplace_back until +// c++14, so we have to use push_back (see +// http://en.cppreference.com/w/cpp/container/vector/emplace_back) +DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;) +DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v), + ;) +DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;) +DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v), ;) +DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t; + if (!t.FromProto(v)) { + return errors::InvalidArgument( + "Attr ", attr_name, " has value ", v.ShortDebugString(), + " that can't be converted to a Tensor"); + }) + +#undef DEFINE_GET_ATTR + +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + DataTypeVector* value) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)")); + for (const auto& v : attr_value->list().type()) { + value->push_back(static_cast<DataType>(v)); + } + return Status::OK(); +} + +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + const TensorProto** value) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor")); + *value = &attr_value->tensor(); + return Status::OK(); +} + +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + const NameAttrList** value) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); + TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func")); + *value = &attr_value->func(); + return Status::OK(); +} + +namespace { // Helper for InOutTypesForNode(). + +Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def, + DataTypeVector* sig) { + const int original_size = sig->size(); + if (!arg_def.number_attr().empty()) { + // Same type repeated "repeats" times. + int32 repeats = -1; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats)); + if (repeats < 0) { + return errors::InvalidArgument("Value for number_attr() ", repeats, + " < 0"); + } + + if (!arg_def.type_attr().empty()) { + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype)); + for (int i = 0; i < repeats; ++i) { + sig->push_back(dtype); + } + } else if (arg_def.type() != DT_INVALID) { + for (int i = 0; i < repeats; ++i) { + sig->push_back(arg_def.type()); + } + } else { + return errors::InvalidArgument("Missing type or type_attr field in ", + arg_def.ShortDebugString()); + } + } else if (!arg_def.type_attr().empty()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value)); + sig->push_back(attr_value->type()); + } else if (!arg_def.type_list_attr().empty()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value)); + for (int dtype : attr_value->list().type()) { + sig->push_back(static_cast<DataType>(dtype)); + } + } else if (arg_def.type() != DT_INVALID) { + sig->push_back(arg_def.type()); + } else { + return errors::InvalidArgument("No type fields in ", + arg_def.ShortDebugString()); + } + if (arg_def.is_ref()) { + // For all types that were added by this function call, make them refs. + for (size_t i = original_size; i < sig->size(); ++i) { + (*sig)[i] = MakeRefType((*sig)[i]); + } + } + return Status::OK(); +} + +} // namespace + +Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs, DataTypeVector* outputs) { + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); + } + for (const auto& arg : op_def.output_arg()) { + TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs)); + } + return Status::OK(); +} + +Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { + if (node_def.op() != op_def.name()) { + return errors::InvalidArgument("NodeDef op '", node_def.op(), + "' does not match ", SummarizeOpDef(op_def), + "; NodeDef: ", SummarizeNodeDef(node_def)); + } + + bool seen_control = false; + size_t num_inputs = 0; + // TODO(josh11b): Unify the input field validation. + for (const string& input : node_def.input()) { + if (StringPiece(input).starts_with("^")) { + seen_control = true; + if (input.find(':') != string::npos) { + return errors::InvalidArgument("Control input '", input, + "' must not have ':' in NodeDef: ", + SummarizeNodeDef(node_def)); + } + } else if (seen_control) { + return errors::InvalidArgument("Non-control input '", input, + "' after control input in NodeDef: ", + SummarizeNodeDef(node_def)); + } else { + ++num_inputs; + } + } + + std::unordered_map<string, const OpDef::AttrDef*> op_attrs; + for (const auto& attr : op_def.attr()) { + if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) { + return errors::InvalidArgument("OpDef has duplicate attr name '", + attr.name(), "': ", + SummarizeOpDef(op_def)); + } + } + for (const auto& attr : node_def.attr()) { + // Allow internal optional attributes with names starting with "_". + if (StringPiece(attr.first).starts_with("_")) { + continue; + } + auto iter = op_attrs.find(attr.first); + if (iter == op_attrs.end()) { + return errors::InvalidArgument("NodeDef mentions attr '", attr.first, + "' not in ", SummarizeOpDef(op_def), + "; NodeDef: ", SummarizeNodeDef(node_def)); + } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ", + SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def)); + // Keep track of which attr names have (not) been found in the NodeDef. + op_attrs.erase(iter); + } + + // Were all attrs in the OpDef found in the NodeDef? + if (!op_attrs.empty()) { + string attrs; + for (const auto& attr_pair : op_attrs) { + if (!attrs.empty()) strings::StrAppend(&attrs, "', '"); + strings::StrAppend(&attrs, attr_pair.first); + } + return errors::InvalidArgument("NodeDef missing attr", + op_attrs.size() == 1 ? " '" : "s '", attrs, + "' from ", SummarizeOpDef(op_def), + "; NodeDef: ", SummarizeNodeDef(node_def)); + } + + // Validate the number of inputs. + DataTypeVector inputs, outputs; + TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs)); + + if (num_inputs != inputs.size()) { + return errors::InvalidArgument( + "NodeDef expected inputs '", DataTypeVectorString(inputs), + "' do not match ", num_inputs, " inputs specified; ", + SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def)); + } + + return Status::OK(); +} + +namespace { // Helpers for NameRangesForNode() + +Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def, + const OpDef& op_def, int* num) { + if (!arg_def.number_attr().empty()) { + // Same type repeated "num" times. + return GetNodeAttr(node_def, arg_def.number_attr(), num); + } else if (!arg_def.type_list_attr().empty()) { + const AttrValue* attr_value; + TF_RETURN_IF_ERROR( + AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value)); + *num = attr_value->list().type_size(); + } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) { + *num = 1; + } else { + return errors::InvalidArgument("Argument '", arg_def.name(), + "' incorrectly specified in op definition: ", + SummarizeOpDef(op_def)); + } + return Status::OK(); +} + +Status NameRangesHelper(const NodeDef& node_def, + const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, + const OpDef& op_def, NameRangeMap* result) { + int start = 0; + int num; + for (const auto& arg : args) { + TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num)); + (*result)[arg.name()] = std::make_pair(start, start + num); + start += num; + } + return Status::OK(); +} + +} // namespace + +Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs) { + TF_RETURN_IF_ERROR( + NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs)); + return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs); +} + +void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { + for (const auto& attr_def : op_def.attr()) { + AttrSlice attrs(*node_def); + if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) { + AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def); + } + } +} + +namespace { + +static RE2* valid_op_name_pattern = new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); +static RE2* valid_data_input_pattern = + new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*(\\:(0|([1-9][0-9]*)))?"); +static RE2* valid_control_input_pattern = + new RE2("\\^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); + +} // namespace + +Status ValidateOpInput(const string& input_name, bool* is_control_input) { + *is_control_input = false; + if (RE2::FullMatch(input_name, *valid_data_input_pattern)) { + return Status::OK(); + } else if (RE2::FullMatch(input_name, *valid_control_input_pattern)) { + *is_control_input = true; + return Status::OK(); + } else { + return errors::InvalidArgument("Illegal op input name '", input_name, "'"); + } +} + +Status ValidateOpName(const string& op_name) { + if (RE2::FullMatch(op_name, *valid_op_name_pattern)) { + return Status::OK(); + } else { + return errors::InvalidArgument("Illegal op name '", op_name, "'"); + } +} + +Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { + Status s = ValidateOpName(node_def.name()); + if (!s.ok()) { + return AttachDef(s, node_def); + } + bool in_control_inputs = false; + for (const string& input_name : node_def.input()) { + bool is_control_input; + s = ValidateOpInput(input_name, &is_control_input); + if (!s.ok()) { + return AttachDef(s, node_def); + } + + if (in_control_inputs && !is_control_input) { + return AttachDef(errors::InvalidArgument( + "All control inputs must follow all data inputs"), + node_def); + } + in_control_inputs = is_control_input; + } + return Status::OK(); +} + +Status AttachDef(const Status& status, const NodeDef& node_def) { + Status ret = status; + errors::AppendToMessage( + &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]")); + return ret; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h new file mode 100644 index 0000000000..fce6fd2433 --- /dev/null +++ b/tensorflow/core/framework/node_def_util.h @@ -0,0 +1,157 @@ +#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_ + +#include <string> +#include <unordered_map> + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Produce a human-readable version of a NodeDef that is more concise +// than a text-format proto. +string SummarizeNodeDef(const NodeDef& node_def); + +typedef protobuf::Map<string, AttrValue> AttrValueMap; + +// Adds an attr with name <name> and value <value> to *node_def. +// The type of the attr is based on the type of value. +template <class T> +void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) { + AttrValue attr_value; + SetAttrValue(std::forward<T>(value), &attr_value); + node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value)); +} + +// Version to workaround C++'s "perfect" forwarding not being able to +// forward {...} initialization. +template <class T> +void AddNodeAttr(const string& name, std::initializer_list<T> value, + NodeDef* node_def) { + AttrValue attr_value; + SetAttrValue(value, &attr_value); + node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value)); +} + +class AttrSlice { + public: + AttrSlice(const NodeDef& node_def) // NOLINT(runtime/explicit) + : ndef_(&node_def), + attrs_(&ndef_->attr()) {} + + explicit AttrSlice(const AttrValueMap* a) : attrs_(a) {} + + // Returns the attr with attr_name if found. Otherwise, returns + // nullptr. + const AttrValue* Find(const string& attr_name) const; + + // Returns the attr_value for attr_name if found. Otherwise, returns a + // NotFound status. + Status Find(const string& attr_name, const AttrValue** attr_value) const; + + private: + const NodeDef* ndef_ = nullptr; + const AttrValueMap* attrs_; +}; + +// Look up the attr with name attr_name and set *value to its value. If no +// attr with attr_name is found in node_def, or the attr does not have +// a matching type, a non-ok status will be returned. +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + string* value); // type: "string" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + int64* value); // type: "int" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + int32* value); // type: "int" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + float* value); // type: "float" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + bool* value); // type: "bool" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + DataType* value); // type: "type" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + TensorShapeProto* value); // type: "shape" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + TensorShape* value); // type: "shape" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + Tensor* value); // type: "tensor" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector<string>* value); // type "list(string)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector<int64>* value); // type "list(int)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector<int32>* value); // type "list(int)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector<float>* value); // type "list(float)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector<bool>* value); // type "list(bool)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector<DataType>* value); // type "list(type)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + DataTypeVector* value); // type "list(type)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector<TensorShapeProto>* value); // type "list(shape)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector<TensorShape>* value); // type "list(shape)" +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + std::vector<Tensor>* value); // type: "list(tensor)" + +// This version avoids copying the TensorProto. +// REQUIRES: Must not use *value beyond the lifetime of node_def. +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + const TensorProto** value); // type: "tensor" + +// This version avoids copying the NameAttrList. +// REQUIRES: Must not use *value beyond the lifetime of node_def. +Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, + const NameAttrList** value); // type: "func" + +// Computes the input and output types for a specific node, for +// attr-style ops. +// REQUIRES: ValidateOpDef(op_def).ok() +Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs, DataTypeVector* outputs); + +// Validates that the NodeDef: +// * Defines all expected attrs from the OpDef. +// * All attrs satisfies constraints from the OpDef. +// * Has a signature matching SignatureForNode(). +// etc. +Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); + +// Computes the mapping from input/output argument name to the +// corresponding input/output index range. For example, +// input "foo" coresponds to input indices +// [ (*inputs)["foo"].first, (*inputs)["foo"].second ). +typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap; +Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); + +// Adds default values to *node_def for unspecified attrs from op_def. +void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def); + +// Validates the syntax of a NodeDef provided externally. +// +// The following is an EBNF-style syntax for NodeDef objects. Note that +// Node objects are actually specified as tensorflow::NodeDef protocol buffers, +// which contain many other fields that are not (currently) validated. +// +// Node = NodeName, Inputs +// Inputs = ( DataInput * ), ( ControlInput * ) +// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? +// ControlInput = "^", NodeName +// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * +Status ValidateExternalNodeDefSyntax(const NodeDef& node_def); + +// Returns "status" with kernel's NodeDef attached as additional text +// in the error message. +Status AttachDef(const Status& status, const NodeDef& node_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_ diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc new file mode 100644 index 0000000000..71f1760a09 --- /dev/null +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -0,0 +1,442 @@ +#include "tensorflow/core/framework/node_def_util.h" + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +OpDef ToOpDef(const OpDefBuilder& builder) { + OpDef op_def; + EXPECT_OK(builder.Finalize(&op_def)); + return op_def; +} + +NodeDef ToNodeDef(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +NodeDef ToNodeDef(const NodeDefBuilder& builder) { + NodeDef node_def; + EXPECT_OK(builder.Finalize(&node_def)); + return node_def; +} + +void ExpectSuccess(const NodeDef& good, const OpDef& op_def) { + EXPECT_EQ(Status::OK(), ValidateNodeDef(good, op_def)) + << "NodeDef: " << SummarizeNodeDef(good) + << "; OpDef: " << SummarizeOpDef(op_def); +} + +void ExpectFailure(const NodeDef& bad, const OpDef& op_def, + const string& message) { + Status status = ValidateNodeDef(bad, op_def); + + EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad) + << "; OpDef: " << SummarizeOpDef(op_def); + if (status.ok()) return; + + EXPECT_TRUE(errors::IsInvalidArgument(status)) + << status << "; NodeDef: " << SummarizeNodeDef(bad) + << "; OpDef: " << SummarizeOpDef(op_def); + + LOG(INFO) << "Message: " << status.error_message(); + EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + << "NodeDef: " << SummarizeNodeDef(bad) + << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status + << "\nDoes not contain: " << message; +} + +TEST(NodeDefUtilTest, In) { + const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type")); + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'In' input:'a' attr { key:'T' value { type:DT_FLOAT } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def)); + + // Mismatching Op names. + NodeDef bad = node_def; + bad.set_op("Wrong"); + ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op<name=In;"); + + // Missing attr + bad = node_def; + bad.clear_attr(); + ExpectFailure(bad, op, "NodeDef missing attr 'T' from Op<name=In;"); + + // Extra attr + bad = node_def; + AddNodeAttr("EXTRA", 17, &bad); + ExpectFailure(bad, op, "NodeDef mentions attr 'EXTRA' not in Op<name=In;"); + + // Attr has wrong type + bad = node_def; + bad.clear_attr(); + AddNodeAttr("T", 17, &bad); + ExpectFailure( + bad, op, + "AttrValue had value with type int when type expected\n\t for attr " + "'T'\n\t; NodeDef: "); + + // Wrong number of inputs + bad = node_def; + bad.add_input("b"); + ExpectFailure( + bad, op, + "NodeDef expected inputs 'float' do not match 2 inputs specified;"); + + bad = node_def; + bad.clear_input(); + ExpectFailure( + bad, op, + "NodeDef expected inputs 'float' do not match 0 inputs specified;"); + + // Control inputs must appear after data inputs + NodeDef good = node_def; + good.add_input("^b"); + ExpectSuccess(node_def, op); + + bad = node_def; + bad.clear_input(); + bad.add_input("^b"); + bad.add_input("a"); + ExpectFailure(bad, op, + "Invalid argument: Non-control input 'a' after control input " + "in NodeDef:"); + + bad = node_def; + bad.add_input("^b:0"); + ExpectFailure(bad, op, "Control input '^b:0' must not have ':' in NodeDef:"); +} + +TEST(NodeDefUtilTest, Out) { + const OpDef op = + ToOpDef(OpDefBuilder("Out").Output("o: T").Attr("T: numbertype")); + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'Out' attr { key:'T' value { type:DT_INT32 } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = Out[T=DT_INT32]()", SummarizeNodeDef(node_def)); + + // Non-number type. + NodeDef bad = node_def; + bad.clear_attr(); + AddNodeAttr("T", DT_STRING, &bad); + ExpectFailure(bad, op, + "Value for attr 'T' of string is not in the list of allowed " + "values: float, double, int64, int32, uint8, int16, int8, " + "complex64, qint8, quint8, qint32"); +} + +TEST(NodeDefUtilTest, Enum) { + const OpDef op = ToOpDef(OpDefBuilder("Enum").Attr("e: {'apple','orange'}")); + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'Enum' attr { key:'e' value { s:'apple' } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def)); + + NodeDef good = node_def; + good.clear_attr(); + AddNodeAttr("e", "orange", &good); + ExpectSuccess(good, op); + + // Non-allowed value. + NodeDef bad = node_def; + bad.clear_attr(); + AddNodeAttr("e", "foo", &bad); + ExpectFailure(bad, op, + "Value for attr 'e' of \"foo\" is not in the list of allowed " + "values: \"apple\", \"orange\""); +} + +TEST(NodeDefUtilTest, SameIn) { + const OpDef op = ToOpDef(OpDefBuilder("SameIn") + .Input("i: N * T") + .Attr("N: int >= 2") + .Attr("T: {float,double}")); + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'SameIn' input:'a' input:'b' + attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_DOUBLE } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = SameIn[N=2, T=DT_DOUBLE](a, b)", SummarizeNodeDef(node_def)); + + // Illegal type + NodeDef bad = ToNodeDef(R"proto( + name:'n' op:'SameIn' input:'a' input:'b' + attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_STRING } } + )proto"); + ExpectFailure(bad, op, + "Value for attr 'T' of string is not in the list of allowed " + "values: float, double"); + + // Too few inputs + bad = ToNodeDef(R"proto( + name:'n' op:'SameIn' input:'a' input:'b' + attr { key:'N' value { i:1 } } attr { key:'T' value { type:DT_FLOAT } } + )proto"); + ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2"); +} + +TEST(NodeDefUtilTest, AnyIn) { + const OpDef op = + ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1")); + + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectSuccess(node_def, op); + + EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a, b)", + SummarizeNodeDef(node_def)); + + const NodeDef bad = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' attr { key:'T' value { list { } } } + )proto"); + ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1"); + + // With proto3 semantics, an empty value {} is indistinguishable from a value + // with an empty list in it. So we simply expect to get a message complaining + // about empty list for value {}. + const NodeDef bad2 = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' attr { key:'T' value { } } + )proto"); + ExpectFailure(bad2, op, + "Length for attr 'T' of 0 must be at least minimum 1"); +} + +TEST(NodeDefUtilTest, Device) { + const OpDef op_def1 = ToOpDef(OpDefBuilder("None")); + const NodeDef node_def1 = + ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17")); + ExpectSuccess(node_def1, op_def1); + EXPECT_EQ("d = None[_device=\"/cpu:17\"]()", SummarizeNodeDef(node_def1)); + + const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int")); + const NodeDef node_def2 = + ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5")); + ExpectSuccess(node_def2, op_def2); + EXPECT_EQ("d = WithAttr[v=7, _device=\"/cpu:5\"]()", + SummarizeNodeDef(node_def2)); +} + +void ExpectValidSyntax(const NodeDef& good) { + EXPECT_EQ(Status::OK(), ValidateExternalNodeDefSyntax(good)) + << "NodeDef: " << SummarizeNodeDef(good); +} + +void ExpectInvalidSyntax(const NodeDef& bad, const string& message) { + Status status = ValidateExternalNodeDefSyntax(bad); + + ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad); + + EXPECT_TRUE(errors::IsInvalidArgument(status)) + << status << "; NodeDef: " << SummarizeNodeDef(bad); + + EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", " + << message; +} + +TEST(NodeDefUtilTest, ValidSyntax) { + const NodeDef node_def = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectValidSyntax(node_def); + + const NodeDef node_def_explicit_inputs = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a:0' input:'b:123' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectValidSyntax(node_def_explicit_inputs); + + EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)", + SummarizeNodeDef(node_def_explicit_inputs)); + + const NodeDef node_def_control_input = ToNodeDef(R"proto( + name:'n-' op:'AnyIn' input:'a' input:'^b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectValidSyntax(node_def_control_input); + + const NodeDef node_def_invalid_name = ToNodeDef(R"proto( + name:'n:0' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'"); + + const NodeDef node_def_internal_name = ToNodeDef(R"proto( + name:'_n' op:'AnyIn' input:'a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'"); + + const NodeDef node_def_internal_input_name = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'_a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_internal_input_name, + "Illegal op input name '_a'"); + + const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'a' input:'^b:0' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_invalid_control_input_name, + "Illegal op input name '^b:0'"); + + const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto( + name:'n' op:'AnyIn' input:'^a' input:'b' + attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } } + )proto"); + ExpectInvalidSyntax(node_def_data_input_after_control, + "All control inputs must follow all data inputs"); +} + +TEST(NameRangesForNodeTest, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + NameRangeMap inputs, outputs; + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs); + + EXPECT_EQ("simple = Simple[](a, b)", SummarizeNodeDef(node_def)); + + OpDef bad_op_def = op_def; + bad_op_def.mutable_input_arg(0)->clear_type(); + EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok()); +} + +TEST(NameRangesForNodeTest, Polymorphic) { + const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic") + .Input("a: T") + .Input("b: T") + .Output("c: T") + .Attr("T: type")); + NameRangeMap inputs, outputs; + const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("poly", &op_def) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32))); + EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs); + EXPECT_EQ("poly = Polymorphic[T=DT_INT32](a, b)", + SummarizeNodeDef(node_def1)); + + const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(DT_BOOL))); + EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs); + EXPECT_EQ("poly = Polymorphic[T=DT_BOOL](a, b)", SummarizeNodeDef(node_def2)); +} + +TEST(NameRangesForNodeTest, NRepeats) { + const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats") + .Input("a: N * int32") + .Input("b: N * T") + .Output("c: T") + .Output("d: N * string") + .Output("e: M * bool") + .Attr("N: int") + .Attr("M: int") + .Attr("T: type")); + NameRangeMap inputs, outputs; + const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("nr", &op_def) + .Input(FakeInput(4, DT_INT32)) + .Input(FakeInput(4, DT_FLOAT)) + .Attr("M", 3)); + EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}), + outputs); + EXPECT_EQ( + "nr = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, b:2, b:3)", + SummarizeNodeDef(node_def1)); + + const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def) + .Input(FakeInput(2, DT_INT32)) + .Input(FakeInput(2, DT_DOUBLE)) + .Attr("M", 7)); + EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}), + outputs); + EXPECT_EQ("nr = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)", + SummarizeNodeDef(node_def2)); + + NodeDef bad_node_def = node_def2; + bad_node_def.clear_attr(); + EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok()); +} + +TEST(NameRangesForNodeTest, TypeList) { + const OpDef op_def = ToOpDef(OpDefBuilder("TypeList") + .Input("a: T1") + .Input("b: T2") + .Output("c: T2") + .Output("d: T3") + .Output("e: T1") + .Attr("T1: list(type)") + .Attr("T2: list(type)") + .Attr("T3: list(type)")); + NameRangeMap inputs, outputs; + const NodeDef node_def1 = + ToNodeDef(NodeDefBuilder("tl", &op_def) + .Input(FakeInput({DT_BOOL, DT_FLOAT})) + .Input(FakeInput(4, DT_FLOAT)) + .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING})); + EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}), + outputs); + EXPECT_EQ( + "tl = TypeList[T1=[DT_BOOL, DT_FLOAT]," + " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT]," + " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)", + SummarizeNodeDef(node_def1)); + + const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("tl", &op_def) + .Input(FakeInput(7, DT_INT32)) + .Input(FakeInput({DT_DOUBLE})) + .Attr("T3", {DT_DOUBLE, DT_STRING})); + EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); + EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs); + EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}), + outputs); + EXPECT_EQ( + "tl = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32," + " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]" + "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)", + SummarizeNodeDef(node_def2)); + + NodeDef bad_node_def = node_def2; + bad_node_def.clear_attr(); + EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h new file mode 100644 index 0000000000..8413d18f33 --- /dev/null +++ b/tensorflow/core/framework/numeric_op.h @@ -0,0 +1,96 @@ +#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_ +#define TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// One input and one output, both the same type. +template <class T> +class UnaryOp : public OpKernel { + public: + explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum<T>::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt})); + } +}; + +// Two inputs and one output, all the same type. +template <class T> +class BinaryOp : public OpKernel { + public: + explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum<T>::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt})); + } +}; + +// For operations where the input and output are the same shape. +// +// For usage, see ../framework/elementwise_ops.cc. +template <class T, class CHILD> +class UnaryElementWiseOp : public UnaryOp<T> { + public: + using UnaryOp<T>::UnaryOp; + + void Compute(OpKernelContext* context) override { + // Output shape is the same as input shape. + const Tensor& input = context->input(0); + Tensor* output; + OP_REQUIRES_OK(context, + context->allocate_output(0, input.shape(), &output)); + static_cast<CHILD*>(this)->Operate(context, input, output); + } +}; + +// For binary elementwise operations. +template <class T, class CHILD> +class BinaryElementWiseOp : public BinaryOp<T> { + public: + using BinaryOp<T>::BinaryOp; + + void Compute(OpKernelContext* context) override { + const Tensor& a = context->input(0); + const Tensor& b = context->input(1); + + if (!context->ValidateInputsAreSameShape(this)) { + return; + } + + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(), &output)); + + // Dispatch to the descendant's Operate() function. + switch (a.dims()) { +#define NDIM_CASE(NDIMS) \ + case NDIMS: { \ + static_cast<CHILD*>(this)->template Operate<NDIMS>(context, a, b, output); \ + break; \ + } + + NDIM_CASE(1); + NDIM_CASE(2); + NDIM_CASE(3); + NDIM_CASE(4); + NDIM_CASE(5); + NDIM_CASE(6); + NDIM_CASE(7); + NDIM_CASE(8); +#undef NDIM_CASE + + default: + context->SetStatus(errors::OutOfRange( + "We only handle up to Tensor::dims() up to 8, not ", a.dims())); + break; + } + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_ diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h new file mode 100644 index 0000000000..366f00ae03 --- /dev/null +++ b/tensorflow/core/framework/numeric_types.h @@ -0,0 +1,15 @@ +#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ + +#include <complex> + +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Single precision complex. +typedef std::complex<float> complex64; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_ diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc new file mode 100644 index 0000000000..15b7eab4da --- /dev/null +++ b/tensorflow/core/framework/op.cc @@ -0,0 +1,135 @@ +#include "tensorflow/core/framework/op.h" + +#include <algorithm> +#include <memory> +#include <vector> +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// OpRegistry ----------------------------------------------------------------- + +OpRegistryInterface::~OpRegistryInterface() {} + +OpRegistry::OpRegistry() : initialized_(false) {} + +void OpRegistry::Register(std::function<OpDef(void)> func) { + mutex_lock lock(mu_); + if (initialized_) { + OpDef def = func(); + TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: " + << SummarizeOpDef(def); + } else { + deferred_.push_back(func); + } +} + +const OpDef* OpRegistry::LookUp(const string& op_type_name, + Status* status) const { + const OpDef* op_def = nullptr; + bool first_call = false; + { // Scope for lock. + mutex_lock lock(mu_); + first_call = CallDeferred(); + op_def = gtl::FindWithDefault(registry_, op_type_name, nullptr); + // Note: Can't hold mu_ while calling Export() below. + } + if (first_call) { + TF_QCHECK_OK(ValidateKernelRegistrations(this)); + } + if (op_def == nullptr) { + status->Update( + errors::NotFound("Op type not registered '", op_type_name, "'")); + static bool first = true; + if (first) { + OpList op_list; + Export(true, &op_list); + LOG(INFO) << "All registered Ops:"; + for (const auto& op : op_list.op()) { + LOG(INFO) << SummarizeOpDef(op); + } + first = false; + } + } + return op_def; +} + +void OpRegistry::Export(bool include_internal, OpList* ops) const { + mutex_lock lock(mu_); + CallDeferred(); + + std::vector<std::pair<string, const OpDef*>> sorted(registry_.begin(), + registry_.end()); + std::sort(sorted.begin(), sorted.end()); + + auto out = ops->mutable_op(); + out->Clear(); + out->Reserve(sorted.size()); + + for (const auto& item : sorted) { + if (include_internal || !StringPiece(item.first).starts_with("_")) { + *out->Add() = *item.second; + } + } +} + +string OpRegistry::DebugString(bool include_internal) const { + OpList op_list; + Export(include_internal, &op_list); + string ret; + for (const auto& op : op_list.op()) { + strings::StrAppend(&ret, SummarizeOpDef(op), "\n"); + } + return ret; +} + +bool OpRegistry::CallDeferred() const { + if (initialized_) return false; + initialized_ = true; + for (const auto& fn : deferred_) { + OpDef def = fn(); + TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: " + << SummarizeOpDef(def); + } + deferred_.clear(); + return true; +} + +Status OpRegistry::RegisterAlreadyLocked(const OpDef& def) const { + TF_RETURN_IF_ERROR(ValidateOpDef(def)); + + std::unique_ptr<OpDef> copy(new OpDef(def)); + if (gtl::InsertIfNotPresent(®istry_, def.name(), copy.get())) { + copy.release(); // Ownership transferred to op_registry + return Status::OK(); + } else { + return errors::AlreadyExists("Op with name ", def.name()); + } +} + +// static +OpRegistry* OpRegistry::Global() { + static OpRegistry* global_op_registry = new OpRegistry; + return global_op_registry; +} + +namespace register_op { +OpDefBuilder& RegisterOp(StringPiece name) { + VLOG(1) << "RegisterOp: " << name; + OpDefBuilder* b = new OpDefBuilder(name); + OpRegistry::Global()->Register([b]() -> ::tensorflow::OpDef { + OpDef op_def; + TF_QCHECK_OK(b->Finalize(&op_def)); + delete b; + return op_def; + }); + return *b; +} +} // namespace register_op + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h new file mode 100644 index 0000000000..95ad32df35 --- /dev/null +++ b/tensorflow/core/framework/op.h @@ -0,0 +1,122 @@ +#ifndef TENSORFLOW_FRAMEWORK_OP_H_ +#define TENSORFLOW_FRAMEWORK_OP_H_ + +#include <functional> +#include <unordered_map> + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Users that want to look up an OpDef by type name should take an +// OpRegistryInterface. Functions accepting a +// (const) OpRegistryInterface* may call LookUp() from multiple threads. +class OpRegistryInterface { + public: + virtual ~OpRegistryInterface(); + + // Returns nullptr and sets *status if no OpDef is registered under that + // name, otherwise returns the registered OpDef. + // Caller must not delete the returned pointer. + virtual const OpDef* LookUp(const string& op_type_name, + Status* status) const = 0; +}; + +// The standard implementation of OpRegistryInterface, along with a +// global singleton used for registering OpDefs via the REGISTER +// macros below. Thread-safe. +// +// Example registration: +// OpRegistry::Global()->Register([]()->OpDef{ +// OpDef def; +// // Populate def here. +// return def; +// }); +class OpRegistry : public OpRegistryInterface { + public: + OpRegistry(); + ~OpRegistry() override {} + + // Calls func() and registers the returned OpDef. Since Register() + // is normally called during program initialization (before main()), + // we defer calling func() until the first call to LookUp() or + // Export() (if one of those has already been called, func() is + // called immediately). + void Register(std::function<OpDef(void)> func); + + const OpDef* LookUp(const string& op_type_name, + Status* status) const override; + + // Fills *ops with all registered OpDefss (except those with names + // starting with '_' if include_internal == false). + void Export(bool include_internal, OpList* ops) const; + + // Returns ASCII-format OpList for all registered OpDefs (except + // those with names starting with '_' if include_internal == false). + string DebugString(bool include_internal) const; + + // A singleton available at startup. + static OpRegistry* Global(); + + private: + // Ensures that all the functions in deferred_ get called, their OpDef's + // registered, and returns with deferred_ empty. Returns true the first + // time it is called. + bool CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Add 'def' to the registry. On failure, or if there is already an + // OpDef with that name registered, returns a non-okay status. + Status RegisterAlreadyLocked(const OpDef& def) const + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutable mutex mu_; + // Functions in deferred_ may only be called with mu_ held. + mutable std::vector<std::function<OpDef(void)>> deferred_ GUARDED_BY(mu_); + mutable std::unordered_map<string, OpDef*> registry_ GUARDED_BY(mu_); + mutable bool initialized_ GUARDED_BY(mu_); +}; + +// Support for defining the OpDef (specifying the semantics of the Op and how +// it should be created) and registering it in the OpRegistry::Global() +// registry. Usage: +// +// REGISTER_OP("my_op_name") +// .Attr("<name>:<type>") +// .Attr("<name>:<type>=<default>") +// .Input("<name>:<type-expr>") +// .Input("<name>:Ref(<type-expr>)") +// .Output("<name>:<type-expr>") +// .Doc(R"( +// <1-line summary> +// <rest of the description (potentially many lines)> +// <name-of-attr-input-or-output>: <description of name> +// <name-of-attr-input-or-output>: <description of name; +// if long, indent the description on subsequent lines> +// )"); +// +// Note: .Doc() should be last. +// For details, see the OpDefBuilder class in op_def_builder.h. + +namespace register_op { +// To call OpRegistry::Global()->Register(...), used by the +// REGISTER_OP macro below. +OpDefBuilder& RegisterOp(StringPiece name); +} // namespace register_op + +#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name) +#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name) +#define REGISTER_OP_UNIQ(ctr, name) \ + static ::tensorflow::OpDefBuilder& register_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::register_op::RegisterOp(name) + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_H_ diff --git a/tensorflow/core/framework/op_def.proto b/tensorflow/core/framework/op_def.proto new file mode 100644 index 0000000000..4a2e90b1b9 --- /dev/null +++ b/tensorflow/core/framework/op_def.proto @@ -0,0 +1,142 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + // TODO(josh11b): bool is_optional? + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the ..._list, fields of AttrValue). + // If type == "type" or "list(type)" above, then the type_list field + // of allowed_values has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the s_list field has + // the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + // TODO(josh11b): Implement that optimization. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc new file mode 100644 index 0000000000..7d7c07de4c --- /dev/null +++ b/tensorflow/core/framework/op_def_builder.cc @@ -0,0 +1,447 @@ +#include "tensorflow/core/framework/op_def_builder.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +namespace { + +bool RE2Consume(StringPiece* sp, const char* pattern) { + RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); + bool r = RE2::Consume(&base_sp, pattern); + *sp = FromRegexpStringPiece(base_sp); + return r; +} + +bool RE2Consume(StringPiece* sp, const char* pattern, StringPiece* out) { + RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); + RegexpStringPiece base_out; + bool r = RE2::Consume(&base_sp, pattern, &base_out); + *sp = FromRegexpStringPiece(base_sp); + *out = FromRegexpStringPiece(base_out); + return r; +} + +bool RE2Consume(StringPiece* sp, const char* pattern, int64* out) { + RegexpStringPiece base_sp = ToRegexpStringPiece(*sp); + bool r = RE2::Consume(&base_sp, pattern, out); + *sp = FromRegexpStringPiece(base_sp); + return r; +} + +string AttrError(StringPiece orig, const string& op_name) { + return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name); +} + +#define VERIFY(expr, ...) \ + do { \ + if (!(expr)) { \ + errors->push_back( \ + strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \ + return; \ + } \ + } while (false) + +void FinalizeAttr(StringPiece spec, OpDef* op_def, + std::vector<string>* errors) { + OpDef::AttrDef* attr = op_def->add_attr(); + StringPiece orig(spec); + + // Parse "<name>:" at the beginning. + StringPiece tmp_name; + VERIFY(RE2Consume(&spec, "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*", &tmp_name), + "Trouble parsing '<name>:'"); + attr->set_name(tmp_name.data(), tmp_name.size()); + + // Read "<type>" or "list(<type>)". + bool is_list = RE2Consume(&spec, "list\\s*\\(\\s*"); + string type; + if (spec.Consume("string")) { + type = "string"; + } else if (spec.Consume("int")) { + type = "int"; + } else if (spec.Consume("float")) { + type = "float"; + } else if (spec.Consume("bool")) { + type = "bool"; + } else if (spec.Consume("type")) { + type = "type"; + } else if (spec.Consume("shape")) { + type = "shape"; + } else if (spec.Consume("tensor")) { + type = "tensor"; + } else if (spec.Consume("func")) { + type = "func"; + } else if (spec.Consume("numbertype") || spec.Consume("numerictype")) { + type = "type"; + AttrValue* allowed = attr->mutable_allowed_values(); + for (DataType dt : NumberTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (spec.Consume("quantizedtype")) { + type = "type"; + AttrValue* allowed = attr->mutable_allowed_values(); + for (DataType dt : QuantizedTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (spec.Consume("realnumbertype") || + spec.Consume("realnumerictype")) { + type = "type"; + AttrValue* allowed = attr->mutable_allowed_values(); + for (DataType dt : RealNumberTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (spec.Consume("{")) { + // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }" + RE2Consume(&spec, "\\s*"); + AttrValue* allowed = attr->mutable_allowed_values(); + if (spec.starts_with("\"") || spec.starts_with("'")) { + type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }" + while (true) { + StringPiece escaped_string; + VERIFY((RE2Consume(&spec, R"xx("((?:[^"\\]|\\.)*)"\s*)xx", + &escaped_string) || + RE2Consume(&spec, R"xx('((?:[^'\\]|\\.)*)'\s*)xx", + &escaped_string)), + "Trouble parsing allowed string at '", spec, "'"); + string unescaped; + string error; + VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error), + "Trouble unescaping \"", escaped_string, "\", got error: ", + error); + allowed->mutable_list()->add_s(unescaped); + if (spec.Consume(",")) { + RE2Consume(&spec, "\\s*"); + if (spec.Consume("}")) break; // Allow ending with ", }". + } else { + VERIFY(spec.Consume("}"), + "Expected , or } after strings in list, not: '", spec, "'"); + break; + } + } + } else { // "{ int32, float, bool }" + type = "type"; + while (true) { + StringPiece type_string; + VERIFY(RE2Consume(&spec, "([a-z0-9]+)\\s*", &type_string), + "Trouble parsing type string at '", spec, "'"); + DataType dt; + VERIFY(DataTypeFromString(type_string, &dt), + "Unrecognized type string '", type_string, "'"); + allowed->mutable_list()->add_type(dt); + if (spec.Consume(",")) { + RE2Consume(&spec, "\\s*"); + if (spec.Consume("}")) break; // Allow ending with ", }". + } else { + VERIFY(spec.Consume("}"), + "Expected , or } after types in list, not: '", spec, "'"); + break; + } + } + } + } else { + VERIFY(false, "Trouble parsing type string at '", spec, "'"); + } + RE2Consume(&spec, "\\s*"); + + // Write the type into *attr. + if (is_list) { + VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'"); + RE2Consume(&spec, "\\s*"); + attr->set_type(strings::StrCat("list(", type, ")")); + } else { + attr->set_type(type); + } + + // Read optional minimum constraint at the end. + if ((is_list || type == "int") && spec.Consume(">=")) { + int64 min_limit = -999; + VERIFY(RE2Consume(&spec, "\\s*(-?\\d+)\\s*", &min_limit), + "Could not parse integer lower limit after '>=', found '", spec, + "' instead"); + attr->set_has_minimum(true); + attr->set_minimum(min_limit); + } + + // Parse default value, if present. + if (spec.Consume("=")) { + RE2Consume(&spec, "\\s*"); + VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()), + "Could not parse default value '", spec, "'"); + } else { + VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end"); + } +} + +#undef VERIFY + +string InOutError(bool is_output, StringPiece orig, const string& op_name) { + return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig, + "\") for Op ", op_name); +} + +#define VERIFY(expr, ...) \ + do { \ + if (!(expr)) { \ + errors->push_back(strings::StrCat( \ + __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \ + return; \ + } \ + } while (false) + +void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def, + std::vector<string>* errors) { + OpDef::ArgDef* arg = + is_output ? op_def->add_output_arg() : op_def->add_input_arg(); + + StringPiece orig(spec); + + // Parse "<name>:" at the beginning. + StringPiece tmp_name; + VERIFY(RE2Consume(&spec, "([a-z][a-z0-9_]*)\\s*:\\s*", &tmp_name), + "Trouble parsing 'name:'"); + arg->set_name(tmp_name.data(), tmp_name.size()); + + // Detect "Ref(...)". + if (RE2Consume(&spec, "Ref\\s*\\(\\s*")) { + arg->set_is_ref(true); + } + + { // Parse "<name|type>" or "<name>*<name|type>". + StringPiece first, second, type_or_attr; + VERIFY(RE2Consume(&spec, "([a-zA-Z][a-zA-Z0-9_]*)\\s*", &first), + "Trouble parsing either a type or an attr name at '", spec, "'"); + if (RE2Consume(&spec, "[*]\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*", &second)) { + arg->set_number_attr(first.data(), first.size()); + type_or_attr = second; + } else { + type_or_attr = first; + } + DataType dt; + if (DataTypeFromString(type_or_attr, &dt)) { + arg->set_type(dt); + } else { + const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def); + VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'"); + if (attr->type() == "type") { + arg->set_type_attr(type_or_attr.data(), type_or_attr.size()); + } else { + VERIFY(attr->type() == "list(type)", "Reference to attr '", + type_or_attr, "' with type ", attr->type(), + " that isn't type or list(type)"); + arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size()); + } + } + } + + // Closing ) for Ref(. + if (arg->is_ref()) { + VERIFY(RE2Consume(&spec, "\\)\\s*"), + "Did not find closing ')' for 'Ref(', instead found: '", spec, "'"); + } + + // Should not have anything else. + VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end"); + + // Int attrs that are the length of an input or output get a default + // minimum of 1. + if (!arg->number_attr().empty()) { + OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def); + if (attr != nullptr && !attr->has_minimum()) { + attr->set_has_minimum(true); + attr->set_minimum(1); + } + } else if (!arg->type_list_attr().empty()) { + // If an input or output has type specified by a list(type) attr, + // it gets a default minimum of 1 as well. + OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def); + if (attr != nullptr && attr->type() == "list(type)" && + !attr->has_minimum()) { + attr->set_has_minimum(true); + attr->set_minimum(1); + } + } +} + +#undef VERIFY + +int num_leading_spaces(StringPiece s) { + size_t i = 0; + while (i < s.size() && s[i] == ' ') { + ++i; + } + return i; +} + +void FinalizeDoc(const string& text, OpDef* op_def, + std::vector<string>* errors) { + std::vector<string> lines = str_util::Split(text, '\n'); + + // Remove trailing spaces. + for (string& line : lines) { + str_util::StripTrailingWhitespace(&line); + } + + // First non-blank line -> summary. + int l = 0; + while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l; + if (static_cast<size_t>(l) < lines.size()) { + op_def->set_summary(lines[l]); + ++l; + } + while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l; + + // Lines until we see name: -> description. + int start_l = l; + while (static_cast<size_t>(l) < lines.size() && + !RE2::PartialMatch(lines[l], "^[a-zA-Z][a-zA-Z0-9_]*\\s*:")) { + ++l; + } + int end_l = l; + // Trim trailing blank lines from the description. + while (start_l < end_l && lines[end_l - 1].empty()) --end_l; + string desc = str_util::Join( + gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n"); + if (!desc.empty()) op_def->set_description(desc); + + // name: description + // possibly continued on the next line + // if so, we remove the minimum indent + StringPiece name; + std::vector<StringPiece> description; + while (static_cast<size_t>(l) < lines.size()) { + description.clear(); + description.push_back(lines[l]); + RE2Consume(&description.back(), "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*", &name); + ++l; + while (static_cast<size_t>(l) < lines.size() && + !RE2::PartialMatch(lines[l], "^[a-zA-Z][a-zA-Z0-9_]*\\s*:")) { + description.push_back(lines[l]); + ++l; + } + // Remove any trailing blank lines. + while (!description.empty() && description.back().empty()) { + description.pop_back(); + } + // Compute the minimum indent of all lines after the first. + int min_indent = -1; + for (size_t i = 1; i < description.size(); ++i) { + if (!description[i].empty()) { + int indent = num_leading_spaces(description[i]); + if (min_indent < 0 || indent < min_indent) min_indent = indent; + } + } + // Remove min_indent spaces from all lines after the first. + for (size_t i = 1; i < description.size(); ++i) { + if (!description[i].empty()) description[i].remove_prefix(min_indent); + } + // Concatenate lines into a single string. + const string complete(str_util::Join(description, "\n")); + + // Find name. + bool found = false; + for (int i = 0; !found && i < op_def->input_arg_size(); ++i) { + if (op_def->input_arg(i).name() == name) { + op_def->mutable_input_arg(i)->set_description(complete); + found = true; + } + } + for (int i = 0; !found && i < op_def->output_arg_size(); ++i) { + if (op_def->output_arg(i).name() == name) { + op_def->mutable_output_arg(i)->set_description(complete); + found = true; + } + } + for (int i = 0; !found && i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == name) { + op_def->mutable_attr(i)->set_description(complete); + found = true; + } + } + if (!found) { + errors->push_back( + strings::StrCat("No matching input/output/attr for name '", name, + "' from Doc() for Op ", op_def->name())); + return; + } + } +} + +} // namespace + +OpDefBuilder::OpDefBuilder(StringPiece op_name) { + op_def_.set_name(op_name.ToString()); // NOLINT +} + +OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) { + attrs_.emplace_back(spec.data(), spec.size()); + return *this; +} + +OpDefBuilder& OpDefBuilder::Input(StringPiece spec) { + inputs_.emplace_back(spec.data(), spec.size()); + return *this; +} + +OpDefBuilder& OpDefBuilder::Output(StringPiece spec) { + outputs_.emplace_back(spec.data(), spec.size()); + return *this; +} + +OpDefBuilder& OpDefBuilder::Doc(StringPiece text) { + if (!doc_.empty()) { + errors_.push_back( + strings::StrCat("Extra call to Doc() for Op ", op_def_.name())); + } else { + doc_.assign(text.data(), text.size()); + } + return *this; +} + +OpDefBuilder& OpDefBuilder::SetIsCommutative() { + op_def_.set_is_commutative(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::SetIsAggregate() { + op_def_.set_is_aggregate(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::SetIsStateful() { + op_def_.set_is_stateful(true); + return *this; +} + +OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() { + op_def_.set_allows_uninitialized_input(true); + return *this; +} + +Status OpDefBuilder::Finalize(OpDef* op_def) const { + std::vector<string> errors = errors_; + *op_def = op_def_; + + for (StringPiece attr : attrs_) { + FinalizeAttr(attr, op_def, &errors); + } + for (StringPiece input : inputs_) { + FinalizeInputOrOutput(input, false, op_def, &errors); + } + for (StringPiece output : outputs_) { + FinalizeInputOrOutput(output, true, op_def, &errors); + } + FinalizeDoc(doc_, op_def, &errors); + + if (errors.empty()) return Status::OK(); + return errors::InvalidArgument(str_util::Join(errors, "\n")); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h new file mode 100644 index 0000000000..017338c508 --- /dev/null +++ b/tensorflow/core/framework/op_def_builder.h @@ -0,0 +1,109 @@ +// Class and associated machinery for specifying an Op's OpDef for Op +// registration. + +#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ +#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ + +#include <string> +#include <vector> +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Builder class passed to the REGISTER_OP() macro. +class OpDefBuilder { + public: + // Constructs an OpDef with just the name field set. + explicit OpDefBuilder(StringPiece op_name); + + // Adds an attr to this OpDefBuilder (and returns *this). The spec has + // format "<name>:<type>" or "<name>:<type>=<default>" + // where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]* + // (by convention only using capital letters for attrs that can be inferred) + // <type> can be: + // "string", "int", "float", "bool", "type", "shape", or "tensor" + // "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}" + // (meaning "type" with a restriction on valid values) + // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" + // (meaning "string" with a restriction on valid values) + // "list(string)", ..., "list(tensor)", "list(numbertype)", ... + // (meaning lists of the above types) + // "int >= 2" (meaning "int" with a restriction on valid values) + // "list(string) >= 2", "list(int) >= 2" + // (meaning "list(string)" / "list(int)" with length at least 2) + // <default>, if included, should use the Proto text format + // of <type>. For lists use [a, b, c] format. + // + // Note that any attr specifying the length of an input or output will + // get a default minimum of 1 unless the >= # syntax is used. + // + // TODO(josh11b): Perhaps support restrictions and defaults as optional + // extra arguments to Attr() instead of encoding them in the spec string. + // TODO(josh11b): Would like to have better dtype handling for tensor attrs: + // * Ability to say the type of an input/output matches the type of + // the tensor. + // * Ability to restrict the type of the tensor like the existing + // restrictions for type attrs. + // Perhaps by linking the type of the tensor to a type attr? + OpDefBuilder& Attr(StringPiece spec); + + // Adds an input or ouput to this OpDefBuilder (and returns *this). + // The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)" + // where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be: + // * For a single tensor: <type> + // * For a sequence of tensors with the same type: <number>*<type> + // * For a sequence of tensors with different types: <type-list> + // Where: + // <type> is either one of "float", "int32", "string", ... + // or the name of an attr (see above) with type "type". + // <number> is the name of an attr with type "int". + // <type-list> is the name of an attr with type "list(type)". + // TODO(josh11b): Indicate Ref() via an optional argument instead of + // in the spec? + // TODO(josh11b): SparseInput() and SparseOutput() matching the Python + // handling? + OpDefBuilder& Input(StringPiece spec); + OpDefBuilder& Output(StringPiece spec); + + // Turns on the indicated boolean flag in this OpDefBuilder (and + // returns *this). + OpDefBuilder& SetIsCommutative(); + OpDefBuilder& SetIsAggregate(); + OpDefBuilder& SetIsStateful(); + OpDefBuilder& SetAllowsUninitializedInput(); + + // Adds docs to this OpDefBuilder (and returns *this). + // Docs have the format: + // <1-line summary> + // <rest of the description> + // <name>: <description of name> + // <name>: <description of name> + // <if long, indent the description on subsequent lines> + // Where <name> is the name of an attr, input, or output. Please + // wrap docs at 72 columns so that it may be indented in the + // generated output. For tensor inputs or outputs (not attrs), you + // may start the description with an "=" (like name:= <description>) + // to suppress the automatically-generated type documentation in + // generated output. + OpDefBuilder& Doc(StringPiece text); + + // Sets *op_def to the requested OpDef, or returns an error. + // Must be called after all of the above methods. + // Note that OpDefBuilder only reports parsing errors. You should also + // call ValidateOpDef() to detect other problems. + Status Finalize(OpDef* op_def) const; + + private: + OpDef op_def_; + std::vector<string> attrs_; + std::vector<string> inputs_; + std::vector<string> outputs_; + string doc_; + std::vector<string> errors_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_ diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc new file mode 100644 index 0000000000..e53bad7075 --- /dev/null +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -0,0 +1,519 @@ +#include "tensorflow/core/framework/op_def_builder.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +static void CanonicalizeAttrTypeListOrder(OpDef* def) { + for (int i = 0; i < def->attr_size(); i++) { + AttrValue* a = def->mutable_attr(i)->mutable_allowed_values(); + std::sort(a->mutable_list()->mutable_type()->begin(), + a->mutable_list()->mutable_type()->end()); + } +} + +class OpDefBuilderTest : public ::testing::Test { + protected: + OpDefBuilder b() { return OpDefBuilder("Test"); } + + void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto) { + OpDef op_def; + Status status = builder.Finalize(&op_def); + EXPECT_OK(status); + if (status.ok()) { + OpDef expected; + protobuf::TextFormat::ParseFromString( + strings::StrCat("name: 'Test' ", proto), &expected); + // Allow different orderings + CanonicalizeAttrTypeListOrder(&op_def); + CanonicalizeAttrTypeListOrder(&expected); + EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString()); + } + } + + void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) { + OpDef op_def; + Status status = builder.Finalize(&op_def); + EXPECT_OK(status); + if (status.ok()) { + OpDef expected; + protobuf::TextFormat::ParseFromString( + strings::StrCat("name: 'Test' ", proto), &expected); + EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString()); + } + } + + void ExpectFailure(const OpDefBuilder& builder, string error) { + OpDef op_def; + Status status = builder.Finalize(&op_def); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + EXPECT_EQ(status.error_message(), error); + } + } +}; + +TEST_F(OpDefBuilderTest, Attr) { + ExpectSuccess(b().Attr("a:string"), "attr: { name: 'a' type: 'string' }"); + ExpectSuccess(b().Attr("A: int"), "attr: { name: 'A' type: 'int' }"); + ExpectSuccess(b().Attr("a1 :float"), "attr: { name: 'a1' type: 'float' }"); + ExpectSuccess(b().Attr("a_a : bool"), "attr: { name: 'a_a' type: 'bool' }"); + ExpectSuccess(b().Attr("aB : type"), "attr: { name: 'aB' type: 'type' }"); + ExpectSuccess(b().Attr("aB_3\t: shape"), + "attr: { name: 'aB_3' type: 'shape' }"); + ExpectSuccess(b().Attr("t: tensor"), "attr: { name: 't' type: 'tensor' }"); + ExpectSuccess(b().Attr("XYZ\t:\tlist(type)"), + "attr: { name: 'XYZ' type: 'list(type)' }"); + ExpectSuccess(b().Attr("f: func"), "attr { name: 'f' type: 'func'}"); +} + +TEST_F(OpDefBuilderTest, AttrFailure) { + ExpectFailure( + b().Attr("_:string"), + "Trouble parsing '<name>:' from Attr(\"_:string\") for Op Test"); + ExpectFailure( + b().Attr("9:string"), + "Trouble parsing '<name>:' from Attr(\"9:string\") for Op Test"); + ExpectFailure(b().Attr(":string"), + "Trouble parsing '<name>:' from Attr(\":string\") for Op Test"); + ExpectFailure(b().Attr("string"), + "Trouble parsing '<name>:' from Attr(\"string\") for Op Test"); + ExpectFailure(b().Attr("a:invalid"), + "Trouble parsing type string at 'invalid' from " + "Attr(\"a:invalid\") for Op Test"); + ExpectFailure( + b().Attr("b:"), + "Trouble parsing type string at '' from Attr(\"b:\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, AttrWithRestrictions) { + // Types with restrictions. + ExpectSuccess(b().Attr("a:numbertype"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_INT8, DT_COMPLEX64, DT_QINT8, DT_QUINT8, DT_QINT32] } } }"); + ExpectSuccess(b().Attr("a:realnumbertype"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_INT8] } } }"); + ExpectSuccess(b().Attr("a:quantizedtype"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_QINT8, DT_QUINT8, DT_QINT32] } } }"); + ExpectSuccess(b().Attr("a:{string,int32}"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_STRING, DT_INT32] } } }"); + ExpectSuccess(b().Attr("a: { float , complex64 } "), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_COMPLEX64] } } }"); + ExpectSuccess(b().Attr("a: {float, complex64,} "), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_FLOAT, DT_COMPLEX64] } }"); + ExpectSuccess(b().Attr(R"(a: { "X", "yz" })"), + "attr: { name: 'a' type: 'string' allowed_values { list { s: " + "['X', 'yz'] } } }"); + ExpectSuccess(b().Attr(R"(a: { "X", "yz", })"), + "attr: { name: 'a' type: 'string' allowed_values { list { s: " + "['X', 'yz'] } } }"); + ExpectSuccess( + b().Attr("i: int >= -5"), + "attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }"); +} + +TEST_F(OpDefBuilderTest, AttrRestrictionFailure) { + ExpectFailure( + b().Attr("a:{}"), + "Trouble parsing type string at '}' from Attr(\"a:{}\") for Op Test"); + ExpectFailure( + b().Attr("a:{,}"), + "Trouble parsing type string at ',}' from Attr(\"a:{,}\") for Op Test"); + ExpectFailure(b().Attr("a:{invalid}"), + "Unrecognized type string 'invalid' from Attr(\"a:{invalid}\") " + "for Op Test"); + ExpectFailure(b().Attr("a:{\"str\", float}"), + "Trouble parsing allowed string at 'float}' from " + "Attr(\"a:{\"str\", float}\") for Op Test"); + ExpectFailure(b().Attr("a:{ float, \"str\" }"), + "Trouble parsing type string at '\"str\" }' from Attr(\"a:{ " + "float, \"str\" }\") for Op Test"); + ExpectFailure(b().Attr("a:{float,,string}"), + "Trouble parsing type string at ',string}' from " + "Attr(\"a:{float,,string}\") for Op Test"); + ExpectFailure(b().Attr("a:{float,,}"), + "Trouble parsing type string at ',}' from " + "Attr(\"a:{float,,}\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, AttrListOfRestricted) { + ExpectSuccess( + b().Attr("a:list(realnumbertype)"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_INT8] } } }"); + ExpectSuccess( + b().Attr("a:list(quantizedtype)"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_QINT8, DT_QUINT8, DT_QINT32] } } }"); + ExpectSuccess( + b().Attr("a: list({float, string, bool})"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_FLOAT, DT_STRING, DT_BOOL] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ "one fish", "two fish" }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "['one fish', 'two fish'] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ 'red fish', 'blue fish' }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "['red fish', 'blue fish'] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ "single' ", 'double"' }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "[\"single' \", 'double\"'] } } }"); + ExpectSuccess( + b().Attr(R"(a: list({ 'escape\'\n', "from\\\"NY" }))"), + "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: " + "[\"escape'\\n\", 'from\\\\\"NY'] } } }"); +} + +TEST_F(OpDefBuilderTest, AttrListWithMinLength) { + ExpectSuccess( + b().Attr("i: list(bool) >= 4"), + "attr: { name: 'i' type: 'list(bool)' has_minimum: true minimum: 4 }"); +} + +TEST_F(OpDefBuilderTest, AttrWithDefaults) { + ExpectSuccess(b().Attr(R"(a:string="foo")"), + "attr: { name: 'a' type: 'string' default_value { s:'foo' } }"); + ExpectSuccess(b().Attr(R"(a:string='foo')"), + "attr: { name: 'a' type: 'string' default_value { s:'foo' } }"); + ExpectSuccess(b().Attr("a:float = 1.25"), + "attr: { name: 'a' type: 'float' default_value { f: 1.25 } }"); + ExpectSuccess(b().Attr("a:tensor = { dtype: DT_INT32 int_val: 5 }"), + "attr: { name: 'a' type: 'tensor' default_value { tensor {" + " dtype: DT_INT32 int_val: 5 } } }"); + ExpectSuccess(b().Attr("a:shape = { dim { size: 3 } dim { size: 4 } }"), + "attr: { name: 'a' type: 'shape' default_value { shape {" + " dim { size: 3 } dim { size: 4 } } } }"); +} + +TEST_F(OpDefBuilderTest, AttrFailedDefaults) { + ExpectFailure(b().Attr(R"(a:int="foo")"), + "Could not parse default value '\"foo\"' from " + "Attr(\"a:int=\"foo\"\") for Op Test"); + ExpectFailure(b().Attr("a:float = [1.25]"), + "Could not parse default value '[1.25]' from Attr(\"a:float = " + "[1.25]\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, AttrListWithDefaults) { + ExpectSuccess(b().Attr(R"(a:list(string)=["foo", "bar"])"), + "attr: { name: 'a' type: 'list(string)' " + "default_value { list { s: ['foo', 'bar'] } } }"); + ExpectSuccess(b().Attr("a:list(bool)=[true, false, true]"), + "attr: { name: 'a' type: 'list(bool)' " + "default_value { list { b: [true, false, true] } } }"); + ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"), + "attr: { name: 'a' type: 'list(int)' " + "default_value { list { i: [0, -1, 2, -4, 8] } } }"); +} + +TEST_F(OpDefBuilderTest, AttrFailedListDefaults) { + ExpectFailure(b().Attr(R"(a:list(int)=["foo"])"), + "Could not parse default value '[\"foo\"]' from " + "Attr(\"a:list(int)=[\"foo\"]\") for Op Test"); + ExpectFailure(b().Attr(R"(a:list(int)=[7, "foo"])"), + "Could not parse default value '[7, \"foo\"]' from " + "Attr(\"a:list(int)=[7, \"foo\"]\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = [[1.25]]"), + "Could not parse default value '[[1.25]]' from " + "Attr(\"a:list(float) = [[1.25]]\") for Op Test"); + ExpectFailure(b().Attr("a:list(float) = 1.25"), + "Could not parse default value '1.25' from " + "Attr(\"a:list(float) = 1.25\") for Op Test"); + ExpectFailure(b().Attr(R"(a:list(string)='foo')"), + "Could not parse default value ''foo'' from " + "Attr(\"a:list(string)='foo'\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, InputOutput) { + ExpectSuccess(b().Input("a: int32"), + "input_arg: { name: 'a' type: DT_INT32 }"); + ExpectSuccess(b().Output("b: string"), + "output_arg: { name: 'b' type: DT_STRING }"); + ExpectSuccess(b().Input("c: float "), + "input_arg: { name: 'c' type: DT_FLOAT }"); + ExpectSuccess(b().Output("d: Ref(bool)"), + "output_arg: { name: 'd' type: DT_BOOL is_ref: true }"); + ExpectOrdered(b().Input("a: bool") + .Output("c: complex64") + .Input("b: int64") + .Output("d: string"), + "input_arg: { name: 'a' type: DT_BOOL } " + "input_arg: { name: 'b' type: DT_INT64 } " + "output_arg: { name: 'c' type: DT_COMPLEX64 } " + "output_arg: { name: 'd' type: DT_STRING }"); +} + +TEST_F(OpDefBuilderTest, PolymorphicInputOutput) { + ExpectSuccess(b().Input("a: foo").Attr("foo: type"), + "input_arg: { name: 'a' type_attr: 'foo' } " + "attr: { name: 'foo' type: 'type' }"); + ExpectSuccess(b().Output("a: foo").Attr("foo: { bool, int32 }"), + "output_arg: { name: 'a' type_attr: 'foo' } " + "attr: { name: 'foo' type: 'type' " + "allowed_values: { list { type: [DT_BOOL, DT_INT32] } } }"); +} + +TEST_F(OpDefBuilderTest, InputOutputListSameType) { + ExpectSuccess(b().Input("a: n * int32").Attr("n: int"), + "input_arg: { name: 'a' number_attr: 'n' type: DT_INT32 } " + "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 }"); + // Polymorphic case: + ExpectSuccess(b().Output("b: n * foo").Attr("n: int").Attr("foo: type"), + "output_arg: { name: 'b' number_attr: 'n' type_attr: 'foo' } " + "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 } " + "attr: { name: 'foo' type: 'type' }"); +} + +TEST_F(OpDefBuilderTest, InputOutputListAnyType) { + ExpectSuccess( + b().Input("c: foo").Attr("foo: list(type)"), + "input_arg: { name: 'c' type_list_attr: 'foo' } " + "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 }"); + ExpectSuccess( + b().Output("c: foo").Attr("foo: list({string, float})"), + "output_arg: { name: 'c' type_list_attr: 'foo' } " + "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 " + "allowed_values: { list { type: [DT_STRING, DT_FLOAT] } } }"); +} + +TEST_F(OpDefBuilderTest, InputOutputFailure) { + ExpectFailure(b().Input("9: int32"), + "Trouble parsing 'name:' from Input(\"9: int32\") for Op Test"); + ExpectFailure( + b().Output("_: int32"), + "Trouble parsing 'name:' from Output(\"_: int32\") for Op Test"); + ExpectFailure(b().Input(": int32"), + "Trouble parsing 'name:' from Input(\": int32\") for Op Test"); + ExpectFailure(b().Output("int32"), + "Trouble parsing 'name:' from Output(\"int32\") for Op Test"); + ExpectFailure( + b().Input("CAPS: int32"), + "Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test"); + ExpectFailure(b().Input("a: _"), + "Trouble parsing either a type or an attr name at '_' from " + "Input(\"a: _\") for Op Test"); + ExpectFailure(b().Input("a: 9"), + "Trouble parsing either a type or an attr name at '9' from " + "Input(\"a: 9\") for Op Test"); + ExpectFailure(b().Input("a: 9 * int32"), + "Trouble parsing either a type or an attr name at '9 * int32' " + "from Input(\"a: 9 * int32\") for Op Test"); + ExpectFailure( + b().Input("a: x * _").Attr("x: type"), + "Extra '* _' unparsed at the end from Input(\"a: x * _\") for Op Test"); + ExpectFailure(b().Input("a: x * y extra").Attr("x: int").Attr("y: type"), + "Extra 'extra' unparsed at the end from Input(\"a: x * y " + "extra\") for Op Test"); + ExpectFailure(b().Input("a: Ref(int32"), + "Did not find closing ')' for 'Ref(', instead found: '' from " + "Input(\"a: Ref(int32\") for Op Test"); + ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"), + "Did not find closing ')' for 'Ref(', instead found: 'y' from " + "Input(\"a: Ref(x y\") for Op Test"); + ExpectFailure( + b().Input("a: x"), + "Reference to unknown attr 'x' from Input(\"a: x\") for Op Test"); + ExpectFailure( + b().Input("a: x * y").Attr("x: int"), + "Reference to unknown attr 'y' from Input(\"a: x * y\") for Op Test"); + ExpectFailure(b().Input("a: x").Attr("x: int"), + "Reference to attr 'x' with type int that isn't type or " + "list(type) from Input(\"a: x\") for Op Test"); +} + +TEST_F(OpDefBuilderTest, Set) { + ExpectSuccess(b().SetIsStateful(), "is_stateful: true"); + ExpectSuccess(b().SetIsCommutative().SetIsAggregate(), + "is_commutative: true is_aggregate: true"); +} + +TEST_F(OpDefBuilderTest, DocUnpackSparseFeatures) { + ExpectOrdered(b().Input("sf: string") + .Output("indices: int32") + .Output("ids: int64") + .Output("weights: float") + .Doc(R"doc( +Converts a vector of strings with dist_belief::SparseFeatures to tensors. + +Note that indices, ids and weights are vectors of the same size and have +one-to-one correspondence between their elements. ids and weights are each +obtained by sequentially concatenating sf[i].id and sf[i].weight, for i in +1...size(sf). Note that if sf[i].weight is not set, the default value for the +weight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were +extracted from sf[i], then index[j] is set to i. + +sf: vector of string, where each element is the string encoding of + SparseFeatures proto. +indices: vector of indices inside sf +ids: vector of id extracted from the SparseFeatures proto. +weights: vector of weight extracted from the SparseFeatures proto. +)doc"), + R"proto( +input_arg { + name: "sf" + description: "vector of string, where each element is the string encoding of\nSparseFeatures proto." + type: DT_STRING +} +output_arg { + name: "indices" + description: "vector of indices inside sf" + type: DT_INT32 +} +output_arg { + name: "ids" + description: "vector of id extracted from the SparseFeatures proto." + type: DT_INT64 +} +output_arg { + name: "weights" + description: "vector of weight extracted from the SparseFeatures proto." + type: DT_FLOAT +} +summary: "Converts a vector of strings with dist_belief::SparseFeatures to tensors." +description: "Note that indices, ids and weights are vectors of the same size and have\none-to-one correspondence between their elements. ids and weights are each\nobtained by sequentially concatenating sf[i].id and sf[i].weight, for i in\n1...size(sf). Note that if sf[i].weight is not set, the default value for the\nweight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were\nextracted from sf[i], then index[j] is set to i." +)proto"); +} + +TEST_F(OpDefBuilderTest, DocConcat) { + ExpectOrdered(b().Input("concat_dim: int32") + .Input("values: num_values * dtype") + .Output("output: dtype") + .Attr("dtype: type") + .Attr("num_values: int >= 2") + .Doc(R"doc( +Concatenate N Tensors along one dimension. + +concat_dim: The (scalar) dimension along which to concatenate. Must be + in the range [0, rank(values...)). +values: The N Tensors to concatenate. Their ranks and types must match, + and their sizes must match in all dimensions except concat_dim. +output: A Tensor with the concatenation of values stacked along the + concat_dim dimension. This Tensor's shape matches the Tensors in + values, except in concat_dim where it has the sum of the sizes. +)doc"), + R"proto( +input_arg { + name: "concat_dim" + description: "The (scalar) dimension along which to concatenate. Must be\nin the range [0, rank(values...))." + type: DT_INT32 +} +input_arg { + name: "values" + description: "The N Tensors to concatenate. Their ranks and types must match,\nand their sizes must match in all dimensions except concat_dim." + type_attr: "dtype" + number_attr: "num_values" +} +output_arg { + name: "output" + description: "A Tensor with the concatenation of values stacked along the\nconcat_dim dimension. This Tensor\'s shape matches the Tensors in\nvalues, except in concat_dim where it has the sum of the sizes." + type_attr: "dtype" +} +summary: "Concatenate N Tensors along one dimension." +attr { + name: "dtype" + type: "type" +} +attr { + name: "num_values" + type: "int" + has_minimum: true + minimum: 2 +} +)proto"); +} + +TEST_F(OpDefBuilderTest, DocAttr) { + ExpectOrdered(b().Attr("i: int").Doc(R"doc( +Summary + +i: How much to operate. +)doc"), + R"proto( +summary: "Summary" +attr { + name: "i" + type: "int" + description: "How much to operate." +} +)proto"); +} + +TEST_F(OpDefBuilderTest, DocCalledTwiceFailure) { + ExpectFailure(b().Doc("What's").Doc("up, doc?"), + "Extra call to Doc() for Op Test"); +} + +TEST_F(OpDefBuilderTest, DocFailureMissingName) { + ExpectFailure( + b().Input("a: int32").Doc(R"doc( +Summary + +a: Something for a. +b: b is not defined. +)doc"), + "No matching input/output/attr for name 'b' from Doc() for Op Test"); + + ExpectFailure( + b().Input("a: int32").Doc(R"doc( +Summary + +b: b is not defined and by itself. +)doc"), + "No matching input/output/attr for name 'b' from Doc() for Op Test"); +} + +TEST_F(OpDefBuilderTest, DefaultMinimum) { + ExpectSuccess(b().Input("values: num_values * dtype") + .Output("output: anything") + .Attr("anything: list(type)") + .Attr("dtype: type") + .Attr("num_values: int"), + R"proto( +input_arg { + name: "values" + type_attr: "dtype" + number_attr: "num_values" +} +output_arg { + name: "output" + type_list_attr: "anything" +} +attr { + name: "anything" + type: "list(type)" + has_minimum: true + minimum: 1 +} +attr { + name: "dtype" + type: "type" +} +attr { + name: "num_values" + type: "int" + has_minimum: true + minimum: 1 +} +)proto"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc new file mode 100644 index 0000000000..e3aef011de --- /dev/null +++ b/tensorflow/core/framework/op_def_util.cc @@ -0,0 +1,344 @@ +#include "tensorflow/core/framework/op_def_util.h" + +#include <set> +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { +namespace { // ------ Helper functions ------ + +bool HasAttrStyleType(const OpDef::ArgDef& arg) { + return arg.type() != DT_INVALID || !arg.type_attr().empty() || + !arg.type_list_attr().empty(); +} + +Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { + const AttrValue& allowed_values(attr.allowed_values()); + for (auto allowed : allowed_values.list().type()) { + if (dt == allowed) { + return Status::OK(); + } + } + string allowed_str; + for (int i = 0; i < allowed_values.list().type_size(); ++i) { + if (!allowed_str.empty()) { + strings::StrAppend(&allowed_str, ", "); + } + strings::StrAppend(&allowed_str, + DataTypeString(allowed_values.list().type(i))); + } + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of ", DataTypeString(dt), + " is not in the list of allowed values: ", allowed_str); +} + +Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { + const AttrValue& allowed_values(attr.allowed_values()); + for (auto allowed : allowed_values.list().s()) { + if (str == allowed) { + return Status::OK(); + } + } + string allowed_str; + for (const string& allowed : allowed_values.list().s()) { + if (!allowed_str.empty()) { + strings::StrAppend(&allowed_str, ", "); + } + strings::StrAppend(&allowed_str, "\"", allowed, "\""); + } + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of \"", str, + "\" is not in the list of allowed values: ", allowed_str); +} + +} // namespace + +// Requires: attr has already been validated. +Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr) { + // Is it a valid value? + TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()), + " for attr '", attr.name(), "'"); + + // Does the value satisfy the minimum constraint in the AttrDef? + if (attr.has_minimum()) { + if (attr.type() == "int") { + if (attr_value.i() < attr.minimum()) { + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of ", attr_value.i(), + " must be at least minimum ", attr.minimum()); + } + } else { + int length = -1; + if (attr.type() == "list(string)") { + length = attr_value.list().s_size(); + } else if (attr.type() == "list(int)") { + length = attr_value.list().i_size(); + } else if (attr.type() == "list(float)") { + length = attr_value.list().f_size(); + } else if (attr.type() == "list(bool)") { + length = attr_value.list().b_size(); + } else if (attr.type() == "list(type)") { + length = attr_value.list().type_size(); + } else if (attr.type() == "list(shape)") { + length = attr_value.list().shape_size(); + } else if (attr.type() == "list(tensor)") { + length = attr_value.list().tensor_size(); + } + if (length < attr.minimum()) { + return errors::InvalidArgument( + "Length for attr '", attr.name(), "' of ", length, + " must be at least minimum ", attr.minimum()); + } + } + } + + // Does the value satisfy the allowed_value constraint in the AttrDef? + if (attr.has_allowed_values()) { + if (attr.type() == "type") { + TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr)); + } else if (attr.type() == "list(type)") { + for (int dt : attr_value.list().type()) { + TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr)); + } + } else if (attr.type() == "string") { + TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr)); + } else if (attr.type() == "list(string)") { + for (const string& str : attr_value.list().s()) { + TF_RETURN_IF_ERROR(AllowedStringValue(str, attr)); + } + } else { + return errors::Unimplemented( + "Support for allowed_values not implemented for type ", attr.type()); + } + } + return Status::OK(); +} + +const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { + for (int i = 0; i < op_def.attr_size(); ++i) { + if (op_def.attr(i).name() == name) { + return &op_def.attr(i); + } + } + return nullptr; +} + +OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { + for (int i = 0; i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == name) { + return op_def->mutable_attr(i); + } + } + return nullptr; +} + +#define VALIDATE(EXPR, ...) \ + do { \ + if (!(EXPR)) { \ + return errors::InvalidArgument(__VA_ARGS__, "; in OpDef: ", \ + op_def.ShortDebugString()); \ + } \ + } while (false) + +static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, + bool output, std::set<string>* names) { + const string suffix = strings::StrCat( + output ? " for output '" : " for input '", arg.name(), "'"); + VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), "Duplicate name: ", + arg.name()); + VALIDATE(HasAttrStyleType(arg), "Missing type", suffix); + + if (!arg.number_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'", + suffix); + VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length", + suffix, " has type ", attr->type(), " != int"); + VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length", + suffix, " must have minimum"); + VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length", + suffix, " must have minimum >= 0"); + VALIDATE(arg.type_list_attr().empty(), + "Can't have both number_attr and type_list_attr", suffix); + VALIDATE((arg.type() != DT_INVALID ? 1 : 0) + + (!arg.type_attr().empty() ? 1 : 0) == + 1, + "Exactly one of type, type_attr must be set", suffix); + } else { + const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) + + (!arg.type_attr().empty() ? 1 : 0) + + (!arg.type_list_attr().empty() ? 1 : 0); + VALIDATE(num_type_fields == 1, + "Exactly one of type, type_attr, type_list_attr must be set", + suffix); + } + + if (!arg.type_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'", + suffix); + VALIDATE(attr->type() == "type", "Attr '", attr->name(), + "' used as type_attr", suffix, " has type ", attr->type(), + " != type"); + } else if (!arg.type_list_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'", + suffix); + VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(), + "' used as type_list_attr", suffix, " has type ", attr->type(), + " != list(type)"); + } else { + // All argument types should be non-reference types at this point. + // ArgDef.is_ref is set to true for reference arguments. + VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '", + DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix); + } + + return Status::OK(); +} + +Status ValidateOpDef(const OpDef& op_def) { + VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"), + "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); + + std::set<string> names; // for detecting duplicate names + for (const auto& attr : op_def.attr()) { + // Validate name + VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), "Duplicate name: ", + attr.name()); + DataType dt; + VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ", + attr.name(), " that matches a data type"); + + // Validate type + StringPiece type(attr.type()); + bool is_list = type.Consume("list("); + bool found = false; + for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape", + "tensor", "func"}) { + if (type.Consume(valid)) { + found = true; + break; + } + } + VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(), + "'"); + if (is_list) { + VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ", + attr.name(), "'s type ", attr.type()); + } + VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ", + attr.name(), "'s type ", attr.type()); + + // Validate minimum + if (attr.has_minimum()) { + VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(), + "' has minimum for unsupported type ", attr.type()); + if (is_list) { + VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(), + "' with list type must have a non-negative minimum, not ", + attr.minimum()); + } + } else { + VALIDATE(attr.minimum() == 0, "Attr '", attr.name(), + "' with has_minimum = false but minimum ", attr.minimum(), + " not equal to default of 0"); + } + + // Validate allowed_values + if (attr.has_allowed_values()) { + const string list_type = + is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")"); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + AttrValueHasType(attr.allowed_values(), list_type), " for attr '", + attr.name(), "' in Op '", op_def.name(), "'"); + } + + // Validate default_value (after we have validated the rest of the attr, + // so we can use ValidateAttrValue()). + if (attr.has_default_value()) { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ValidateAttrValue(attr.default_value(), attr), " in Op '", + op_def.name(), "'"); + } + } + + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names)); + } + + for (const auto& arg : op_def.output_arg()) { + TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names)); + } + + return Status::OK(); +} + +#undef VALIDATE + +namespace { + +string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { + string ret; + for (const OpDef::ArgDef& arg : args) { + if (!ret.empty()) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, arg.name(), ":"); + if (arg.is_ref()) strings::StrAppend(&ret, "Ref("); + if (!arg.number_attr().empty()) { + strings::StrAppend(&ret, arg.number_attr(), "*"); + } + if (arg.type() != DT_INVALID) { + strings::StrAppend(&ret, DataTypeString(arg.type())); + } else { + strings::StrAppend(&ret, arg.type_attr()); + } + if (arg.is_ref()) strings::StrAppend(&ret, ")"); + } + return ret; +} + +} // namespace + +string SummarizeOpDef(const OpDef& op_def) { + string ret = strings::StrCat("Op<name=", op_def.name()); + strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()), + " -> ", SummarizeArgs(op_def.output_arg())); + for (int i = 0; i < op_def.attr_size(); ++i) { + strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":", + op_def.attr(i).type()); + if (op_def.attr(i).has_default_value()) { + strings::StrAppend(&ret, ",default=", + SummarizeAttrValue(op_def.attr(i).default_value())); + } + if (op_def.attr(i).has_minimum()) { + strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum()); + } + if (op_def.attr(i).has_allowed_values()) { + strings::StrAppend(&ret, ",allowed=", + SummarizeAttrValue(op_def.attr(i).allowed_values())); + } + } + if (op_def.is_commutative()) { + strings::StrAppend(&ret, "; is_commutative=true"); + } + if (op_def.is_aggregate()) { + strings::StrAppend(&ret, "; is_aggregate=true"); + } + if (op_def.is_stateful()) { + strings::StrAppend(&ret, "; is_stateful=true"); + } + if (op_def.allows_uninitialized_input()) { + strings::StrAppend(&ret, "; allows_uninitialized_input=true"); + } + strings::StrAppend(&ret, ">"); + return ret; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h new file mode 100644 index 0000000000..a9fecf3fa0 --- /dev/null +++ b/tensorflow/core/framework/op_def_util.h @@ -0,0 +1,32 @@ +// TODO(josh11b): Probably not needed for OpKernel authors, so doesn't +// need to be as publicly accessible as other files in framework/. + +#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ + +#include <string> +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// Performs a consistency check across the fields of the op_def. +Status ValidateOpDef(const OpDef& op_def); + +// Validates that attr_value satisfies the type and constraints from attr. +// REQUIRES: attr has already been validated. +Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr); + +// The following search through op_def for an attr with the indicated name. +// Returns nullptr if no such attr is found. +const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def); +OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def); + +// Produce a human-readable version of an op_def that is more concise +// than a text-format proto. Excludes descriptions. +string SummarizeOpDef(const OpDef& op_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_ diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc new file mode 100644 index 0000000000..515e8bb288 --- /dev/null +++ b/tensorflow/core/framework/op_def_util_test.cc @@ -0,0 +1,330 @@ +#include "tensorflow/core/framework/op_def_util.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +OpDef FromText(const string& text) { + OpDef op_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &op_def)); + return op_def; +} + +class ValidateOpDefTest : public ::testing::Test { + protected: + Status TestProto(const string& text) { + return ValidateOpDef(FromText(text)); + } + + Status TestBuilder(const OpDefBuilder& builder) { + OpDef op_def; + Status status = builder.Finalize(&op_def); + EXPECT_OK(status); + if (!status.ok()) { + return status; + } else { + return ValidateOpDef(op_def); + } + } + + void ExpectFailure(const Status& status, const string& message) { + EXPECT_FALSE(status.ok()) << "Did not see error with: " << message; + if (!status.ok()) { + LOG(INFO) << "message: " << status; + EXPECT_TRUE(StringPiece(status.ToString()).contains(message)) + << "Actual: " << status << "\nExpected to contain: " << message; + } + } +}; + +TEST_F(ValidateOpDefTest, OpDefValid) { + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Input("a: int32"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Output("a: bool"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("t: type").Input("a: t"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int = 3"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5 = 3"))); + EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: numbertype"))); + EXPECT_OK(TestBuilder(OpDefBuilder("Uppercase"))); +} + +TEST_F(ValidateOpDefTest, InvalidName) { + ExpectFailure(TestBuilder(OpDefBuilder("lower").Attr("a: int")), + "Invalid name"); + ExpectFailure(TestBuilder(OpDefBuilder("BadSuffix 7%")), "Invalid name"); +} + +TEST_F(ValidateOpDefTest, DuplicateName) { + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Input("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder( + OpDefBuilder("DupeName").Input("a: int32").Output("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder( + OpDefBuilder("DupeName").Output("a: int32").Output("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Attr("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Output("a: int32").Attr("a: float")), + "Duplicate name: a"); + ExpectFailure( + TestBuilder(OpDefBuilder("DupeName").Attr("a: int").Attr("a: float")), + "Duplicate name: a"); +} + +TEST_F(ValidateOpDefTest, BadAttrName) { + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("int32: int")), + "Attr can't have name int32 that matches a data type"); + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("float: string")), + "Attr can't have name float that matches a data type"); +} + +TEST_F(ValidateOpDefTest, BadAttrType) { + ExpectFailure( + TestProto("name: 'BadAttrType' attr { name: 'a' type: 'illegal' }"), + "Unrecognized type"); + ExpectFailure( + TestProto("name: 'BadAttrType' attr { name: 'a' type: 'list(illegal)' }"), + "Unrecognized type"); + ExpectFailure( + TestProto("name: 'BadAttrType' attr { name: 'a' type: 'int extra' }"), + "Extra ' extra' at the end"); + ExpectFailure( + TestProto( + "name: 'BadAttrType' attr { name: 'a' type: 'list(int extra)' }"), + "'list(' is missing ')' in attr"); + ExpectFailure( + TestProto( + "name: 'BadAttrType' attr { name: 'a' type: 'list(int) extra' }"), + "Extra ' extra' at the end"); +} + +TEST_F(ValidateOpDefTest, BadAttrDefault) { + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'int' default_value { s: 'x' } }"), + "AttrValue had value with type string when int expected\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'int' default_value { f: 0.5 } }"), + "AttrValue had value with type float when int expected\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'int' " + "default_value { i: 5 list { i: [2] } } }"), + "AttrValue had value with type list(int) when int expected\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'list(int)' default_value { f: 0.5 } }"), + "AttrValue had value with type float when list(int) expected\n\t " + "for attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure( + TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'list(int)' " + "default_value { list { i: [5] f: [0.5] } } }"), + "AttrValue had value with type list(float) when list(int) " + "expected\n\t for attr 'a'\n\t in Op 'BadAttrDef'"); + + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'type' default_value { } }"), + "AttrValue missing value with expected type type\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'shape' default_value { } }"), + "AttrValue missing value with expected type shape\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'tensor' default_value { } }"), + "AttrValue missing value with expected type tensor\n\t for " + "attr 'a'\n\t in Op 'BadAttrDef'"); + + // default_value {} is indistinguishable from default_value{ list{} } (one + // with an empty list) in proto3 semantics. + EXPECT_OK( + TestProto("name: 'GoodAttrDef' attr { name: 'a' " + "type: 'list(int)' default_value { } }")); + + // Empty lists are allowed: + EXPECT_OK( + TestProto("name: 'GoodAttrDef' attr { name: 'a' " + "type: 'list(int)' default_value { list { } } }")); + // Builder should make the same proto: + EXPECT_OK(TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(int) = []"))); + + // Unless there is a minimum length specified: + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' " + "type: 'list(int)' has_minimum: true minimum: 2 " + "default_value { list { } } }"), + "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op " + "'BadAttrDef'"); + ExpectFailure( + TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(bool) >=2 = []")), + "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op " + "'GoodAttrDef'"); + ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' type: " + "'list(string)' has_minimum: true minimum: 2 " + "default_value { list { s: ['foo'] } } }"), + "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op " + "'BadAttrDef'"); + ExpectFailure(TestBuilder(OpDefBuilder("GoodAttrDef") + .Attr("a: list(type) >=2 = [DT_STRING]")), + "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op " + "'GoodAttrDef'"); +} + +TEST_F(ValidateOpDefTest, NoRefTypes) { + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef").Input("i: float_ref")), + "Illegal use of ref type 'float_ref'. " + "Use 'Ref(type)' instead for input 'i'"); + ExpectFailure( + TestBuilder(OpDefBuilder("BadAttrDef").Attr("T: type = DT_INT32_REF")), + "AttrValue must not have reference type value of int32_ref"); + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef") + .Attr("T: list(type) = [DT_STRING_REF]")), + "AttrValue must not have reference type value of string_ref"); +} + +TEST_F(ValidateOpDefTest, BadAttrMin) { + ExpectFailure(TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'string' " + "has_minimum: true minimum: 0 }"), + "minimum for unsupported type string"); + ExpectFailure( + TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'int' default_value " + "{ i: 2 } has_minimum: true minimum: 7 }"), + "Value for attr 'a' of 2 must be at least minimum 7\n\t in Op " + "'BadAttrMin'"); + ExpectFailure( + TestProto("name: 'BadAttrMin' attr { name: 'a' " + "type: 'list(string)' has_minimum: true minimum: -5 }"), + "list type must have a non-negative minimum, not -5"); + EXPECT_OK( + TestProto("name: 'GoodAttrMin' attr { name: 'a' type: 'list(string)' " + "has_minimum: true minimum: 1 }")); + ExpectFailure(TestProto("name: 'NoHasMin' attr { name: 'a' " + "type: 'list(string)' minimum: 3 }"), + "Attr 'a' with has_minimum = false but minimum 3 not equal to " + "default of 0"); +} + +TEST_F(ValidateOpDefTest, BadAttrAllowed) { + // Is in list of allowed types. + EXPECT_OK(TestBuilder( + OpDefBuilder("GoodAttrtude").Attr("x: numbertype = DT_INT32"))); + // Not in list of allowed types. + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: numbertype = DT_STRING")), + "attr 'x' of string is not in the list of allowed values"); + ExpectFailure( + TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: list(realnumbertype) = [DT_COMPLEX64]")), + "attr 'x' of complex64 is not in the list of allowed values"); + // Is in list of allowed strings. + EXPECT_OK(TestBuilder( + OpDefBuilder("GoodAttrtude").Attr("x: {'foo', 'bar'} = 'bar'"))); + // Not in list of allowed strings. + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: {'foo', 'bar'} = 'baz'")), + "attr 'x' of \"baz\" is not in the list of allowed values"); + ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude") + .Attr("x: list({'foo', 'bar'}) = ['baz']")), + "attr 'x' of \"baz\" is not in the list of allowed values"); + ExpectFailure(TestProto( + "name: 'BadAttrtude' attr { name: 'a' " + "type: 'string' allowed_values { s: 'not list' } }"), + "with type string when list(string) expected"); + ExpectFailure( + TestProto("name: 'BadAttrtude' attr { name: 'a' " + "type: 'string' allowed_values { list { i: [6] } } }"), + "with type list(int) when list(string) expected"); +} + +TEST_F(ValidateOpDefTest, BadArgType) { + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type: DT_INT32 } input_arg { name: 'b' }"), + "Missing type for input 'b'"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type: DT_INT32 } output_arg { name: 'b' }"), + "Missing type for output 'b'"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' type: " + "DT_INT32 type_attr: 'x' } attr { name: 'x' type: 'type' }"), + "Exactly one of type, type_attr, type_list_attr must be set for input " + "'a'"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_attr: 'x' } attr { name: 'x' type: 'int' }"), + "Attr 'x' used as type_attr for input 'a' has type int"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_attr: 'x' } attr { name: 'x' type: 'list(type)' }"), + "Attr 'x' used as type_attr for input 'a' has type list(type)"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_list_attr: 'x' } attr { name: 'x' type: 'int' }"), + "Attr 'x' used as type_list_attr for input 'a' has type int"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_list_attr: 'x' } attr { name: 'x' type: 'type' }"), + "Attr 'x' used as type_list_attr for input 'a' has type type"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' " + "type_attr: 'x' }"), + "No attr with name 'x' for input 'a'"); + ExpectFailure( + TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: 'n' " + "type_attr: 'x' } attr { name: 'x' type: 'list(type)' } " + "attr { name: 'n' type: 'int' has_minimum: true minimum: 1 }"), + "Attr 'x' used as type_attr for input 'a' has type list(type)"); + // But list(type) is fine as the type of an arg without a number_attr: + EXPECT_OK(TestProto( + "name: 'Arg' input_arg { name: 'a' type_list_attr: 'x' } " + "attr { name: 'x' type: 'list(type)' } attr { name: 'n' type: 'int' " + "has_minimum: true minimum: 1 }")); + + // number_attr + EXPECT_OK(TestProto( + "name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: 'n' } " + "attr { name: 'n' type: 'int' has_minimum: true minimum: 0 }")); + + ExpectFailure(TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 " + "number_attr: 'n' }"), + "No attr with name 'n'"); + ExpectFailure( + TestProto( + "name: 'Arg' input_arg { name: 'a' type: " + "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'string' }"), + "Attr 'n' used as length for input 'a' has type string"); + ExpectFailure( + TestProto("name: 'Arg' input_arg { name: 'a' type: " + "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'int' }"), + "Attr 'n' used as length for input 'a' must have minimum;"); + ExpectFailure( + TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: " + "'n' } attr { name: 'n' type: 'int' has_minimum: true minimum: " + "-5 }"), + "Attr 'n' used as length for input 'a' must have minimum >= 0;"); + ExpectFailure( + TestProto("name: 'Arg' input_arg { name: 'a' number_attr: 'n' } attr { " + "name: 'n' type: 'int' has_minimum: true minimum: 2 }"), + "Missing type for input 'a'; in OpDef:"); + ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: " + "'n' type_list_attr: 'x' } attr { name: 'n' type: " + "'int' has_minimum: true minimum: 1 } attr { name: " + "'x' type: 'list(type)' }"), + "Can't have both number_attr and type_list_attr for input 'a'"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc new file mode 100644 index 0000000000..04f4b7cacd --- /dev/null +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -0,0 +1,55 @@ +#include "tensorflow/core/framework/op_gen_lib.h" + +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +string WordWrap(StringPiece prefix, StringPiece str, int width) { + const string indent_next_line = "\n" + Spaces(prefix.size()); + width -= prefix.size(); + string result; + strings::StrAppend(&result, prefix); + + while (!str.empty()) { + if (static_cast<int>(str.size()) <= width) { + // Remaining text fits on one line. + strings::StrAppend(&result, str); + break; + } + auto space = str.rfind(' ', width); + if (space == StringPiece::npos) { + // Rather make a too-long line and break at a space. + space = str.find(' '); + if (space == StringPiece::npos) { + strings::StrAppend(&result, str); + break; + } + } + // Breaking at character at position <space>. + StringPiece to_append = str.substr(0, space); + str.remove_prefix(space + 1); + // Remove spaces at break. + while (to_append.ends_with(" ")) { + to_append.remove_suffix(1); + } + while (str.Consume(" ")) { + } + + // Go on to the next line. + strings::StrAppend(&result, to_append); + if (!str.empty()) strings::StrAppend(&result, indent_next_line); + } + + return result; +} + +bool ConsumeEquals(StringPiece* description) { + if (description->Consume("=")) { + while (description->Consume(" ")) { // Also remove spaces after "=". + } + return true; + } + return false; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h new file mode 100644 index 0000000000..9890f1bcec --- /dev/null +++ b/tensorflow/core/framework/op_gen_lib.h @@ -0,0 +1,24 @@ +#ifndef TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ +#define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ + +#include <string> +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +inline string Spaces(int n) { return string(n, ' '); } + +// Wrap prefix + str to be at most width characters, indenting every line +// after the first by prefix.size() spaces. Intended use case is something +// like prefix = " Foo(" and str is a list of arguments (terminated by a ")"). +// TODO(josh11b): Option to wrap on ", " instead of " " when possible. +string WordWrap(StringPiece prefix, StringPiece str, int width); + +// Looks for an "=" at the beginning of *description. If found, strips it off +// (and any following spaces) from *description and return true. Otherwise +// returns false. +bool ConsumeEquals(StringPiece* description); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_ diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc new file mode 100644 index 0000000000..eb83d393f0 --- /dev/null +++ b/tensorflow/core/framework/op_kernel.cc @@ -0,0 +1,749 @@ +#include "tensorflow/core/framework/op_kernel.h" + +#include <unordered_map> + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +namespace { + +Status MatchSignatureHelper(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs, + const DataTypeSlice inputs, + const DataTypeSlice outputs) { + bool signature_mismatch = false; + + if (inputs.size() != expected_inputs.size()) signature_mismatch = true; + for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) { + if (!TypesCompatible(expected_inputs[i], inputs[i])) { + signature_mismatch = true; + } + } + + if (outputs.size() != expected_outputs.size()) signature_mismatch = true; + for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) { + if (!TypesCompatible(expected_outputs[i], outputs[i])) { + signature_mismatch = true; + } + } + + if (signature_mismatch) { + return errors::InvalidArgument("Signature mismatch, have: ", + DataTypeSliceString(inputs), "->", + DataTypeSliceString(outputs), " expected: ", + DataTypeSliceString(expected_inputs), "->", + DataTypeSliceString(expected_outputs)); + } + return Status::OK(); +} + +// Check HostMemory backward compatibility. +bool CheckHostMemoryCompatibility(const DeviceType device_type, + const OpKernel* kernel) { + if (device_type == DEVICE_GPU) { + for (int i = 0; i < kernel->num_inputs(); ++i) { + if (kernel->input_type(i) == DT_INT32 && + kernel->input_memory_types()[i] != HOST_MEMORY) { + return false; + } + } + for (int i = 0; i < kernel->num_outputs(); ++i) { + if (kernel->output_type(i) == DT_INT32 && + kernel->output_memory_types()[i] != HOST_MEMORY) { + return false; + } + } + } + return true; +} + +} // namespace + +// OpKernel ------------------------------------------------------------------ + +OpKernel::OpKernel(OpKernelConstruction* context) + : def_(context->def()), + input_types_(context->input_types().begin(), + context->input_types().end()), + output_types_(context->output_types().begin(), + context->output_types().end()), + input_name_map_(context->num_inputs()), + output_name_map_(context->num_outputs()) { + OP_REQUIRES_OK(context, + NameRangesForNode(def_, context->op_def(), &input_name_map_, + &output_name_map_)); + + // By default, the input and output memory types are always in device memory, + // but can be overridden by individual implementations of OpKernels in their + // constructor. + input_memory_types_ = MemoryTypeVector(input_types_.size(), DEVICE_MEMORY); + output_memory_types_ = MemoryTypeVector(output_types_.size(), DEVICE_MEMORY); + // TODO(yuanbyu): For now we assume the memory types of function + // inputs/outputs to be DEVICE_MEMORY. + auto lib = context->function_library(); + if (lib == nullptr || !lib->IsDefined(def_.op())) { + OP_REQUIRES_OK(context, MemoryTypesForNode( + context->device_type(), def_, context->op_def(), + input_name_map_, output_name_map_, + &input_memory_types_, &output_memory_types_)); + // Log all the uses of int32 on GPU. + // TODO(yunabyu): Remove once everyone transitions to HostMemory. + if (VLOG_IS_ON(2)) { + if (!CheckHostMemoryCompatibility(context->device_type(), this)) { + VLOG(2) << "Using int32 on GPU at node: " << SummarizeNodeDef(def()); + } + } + } +} + +Status OpKernel::InputRange(const string& input_name, int* start, + int* stop) const { + const auto result = input_name_map_.find(input_name); + if (result == input_name_map_.end()) { + return errors::InvalidArgument("Unknown input name: ", input_name); + } else { + *start = result->second.first; + *stop = result->second.second; + return Status::OK(); + } +} + +Status OpKernel::OutputRange(const string& output_name, int* start, + int* stop) const { + const auto result = output_name_map_.find(output_name); + if (result == output_name_map_.end()) { + return errors::InvalidArgument("Unknown output name: ", output_name); + } else { + *start = result->second.first; + *stop = result->second.second; + return Status::OK(); + } +} + +void AsyncOpKernel::Compute(OpKernelContext* context) { + Notification n; + ComputeAsync(context, [&n]() { n.Notify(); }); + n.WaitForNotification(); +} + +// PersistentTensor ---------------------------------------------------------- + +Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) { + // the caller has to have a valid context + CHECK(context); + return &tensor_; +} + +Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { + context->NotifyUseOfPersistentTensor(tensor_); + return &tensor_; +} + +// OpKernelConstruction ------------------------------------------------------ + +Status OpKernelConstruction::MatchSignature( + const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) { + return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_, + output_types_); +} + +Status OpKernelConstruction::allocate_temp(DataType type, + const TensorShape& shape, + Tensor* out_temp) { + Tensor new_temp(allocator_, type, shape); + + if (!new_temp.IsInitialized() && shape.num_elements() > 0) { + return errors::ResourceExhausted( + "OOM when allocating temporary tensor with shape", shape.DebugString()); + } + *out_temp = new_temp; + return Status::OK(); +} + +Status OpKernelConstruction::allocate_persistent( + DataType type, const TensorShape& shape, PersistentTensor* out_persistent, + Tensor** out_tensor) { + // for now just do the same thing as allocate_temp + // TODO(misard) add specific memory tracking for persistent tensors + Tensor persistent; + Status s = allocate_temp(type, shape, &persistent); + if (!s.ok()) { + return s; + } + *out_persistent = PersistentTensor(persistent); + Tensor* allocated = out_persistent->AccessTensor(this); + if (out_tensor) { + *out_tensor = allocated; + } + return s; +} + +// OpKernelContext ----------------------------------------------------------- + +OpKernelContext::OpKernelContext(const Params& params) + : params_(params), + outputs_(params.op_kernel->output_types().size()), + output_allocation_types_(params.op_kernel->output_types().size()) { + Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); + eigen_gpu_device_ = params_.device->MakeGpuDevice(params_.op_device_context, + eigen_gpu_allocator); +} + +OpKernelContext::~OpKernelContext() { + for (TensorValue& value : outputs_) { + if (!value.is_ref()) { + delete value.tensor; + } + } + for (Tensor* t : temp_tensors_) delete t; + delete eigen_gpu_device_; +} + +Status OpKernelContext::input(const string& name, const Tensor** tensor) const { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + if ((*params_.inputs)[start].is_ref()) { + return errors::InvalidArgument("OpKernel used ref input name '", name, + "' when immutable input was expected"); + } + *tensor = (*params_.inputs)[start].tensor; + return Status::OK(); +} + +Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + *out_mutex = input_ref_mutex(start); + return Status::OK(); +} + +Status OpKernelContext::mutable_input(const string& name, Tensor* tensor, + bool lock_held) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + if (!(*params_.inputs)[start].is_ref()) { + return errors::InvalidArgument("OpKernel used immutable input name '", name, + "' when ref input was expected"); + } + // return a copy of the Ref acquired while holding the mutex + if (lock_held) { + *tensor = *(*params_.inputs)[start].tensor; + } else { + mutex_lock l(*input_ref_mutex(start)); + *tensor = *(*params_.inputs)[start].tensor; + } + return Status::OK(); +} + +Status OpKernelContext::replace_ref_input(const string& name, + const Tensor& tensor, + bool lock_held) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was expected"); + } + if (!(*params_.inputs)[start].is_ref()) { + return errors::InvalidArgument("OpKernel used immutable input name '", name, + "' when ref input was expected"); + } + replace_ref_input(start, tensor, lock_held); + return Status::OK(); +} + +Status OpKernelContext::input_list(const string& name, + OpInputList* list) const { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + *list = OpInputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::mutable_input_list(const string& name, + OpMutableInputList* list) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop)); + *list = OpMutableInputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::output_list(const string& name, OpOutputList* list) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + *list = OpOutputList(this, start, stop); + return Status::OK(); +} + +Status OpKernelContext::allocate_output(const string& name, + const TensorShape& shape, + Tensor** tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + return allocate_output(start, shape, tensor); +} + +Status OpKernelContext::allocate_output(const string& name, + const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + return allocate_output(start, shape, tensor, attr); +} + +Status OpKernelContext::set_output(const string& name, const Tensor& tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + set_output(start, tensor); + return Status::OK(); +} + +Status OpKernelContext::set_output_ref(const string& name, mutex* mu, + Tensor* tensor_for_ref) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + set_output_ref(start, mu, tensor_for_ref); + return Status::OK(); +} + +Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + *tensor = mutable_output(start); + return Status::OK(); +} + +Status OpKernelContext::release_output(const string& name, TensorValue* value) { + int start, stop; + TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued output name '", + name, + "' when single-valued output was " + "expected"); + } + *value = release_output(start); + return Status::OK(); +} + +bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { + const auto& inputs = *params_.inputs; + for (size_t i = 1; i < inputs.size(); ++i) { + if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) { + SetStatus(errors::InvalidArgument( + "Inputs to operation ", op->name(), " of type ", op->type_string(), + " must have the same size and shape. Input 0: ", + inputs[0]->shape().DebugString(), " != input ", i, ": ", + inputs[i]->shape().DebugString())); + return false; + } + } + return true; +} + +Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs) { + DataTypeVector inputs; + for (const TensorValue& t : *params_.inputs) { + inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype()); + } + DataTypeVector outputs = params_.op_kernel->output_types(); + return MatchSignatureHelper(expected_inputs, expected_outputs, inputs, + outputs); +} + +// OpKernel registration ------------------------------------------------------ + +struct KernelRegistration { + KernelRegistration(const KernelDef& d, + kernel_factory::OpKernelRegistrar::Factory f) + : def(d), factory(f) {} + const KernelDef def; + const kernel_factory::OpKernelRegistrar::Factory factory; +}; + +// This maps from 'op_type' + DeviceType to the set of KernelDefs and +// factory functions for instantiating the OpKernel that matches the +// KernelDef. +typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry; + +static KernelRegistry* GlobalKernelRegistry() { + static KernelRegistry* global_kernel_registry = new KernelRegistry; + return global_kernel_registry; +} + +static string Key(const string& op_type, DeviceType device_type, + const string& label) { + return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", + label); +} + +namespace kernel_factory { + +OpKernelRegistrar::OpKernelRegistrar(const KernelDef* kernel_def, + Factory factory) { + const string key = + Key(kernel_def->op(), DeviceType(kernel_def->device_type()), + kernel_def->label()); + GlobalKernelRegistry()->insert( + std::make_pair(key, KernelRegistration(*kernel_def, factory))); + delete kernel_def; +} + +} // namespace kernel_factory + +namespace { + +// Helper for AttrsMatch(). +bool InTypeList(DataType dt, const AttrValue& type_list) { + for (int in_list : type_list.list().type()) { + if (dt == in_list) return true; + } + return false; +} + +// Returns whether the attrs in the NodeDef satisfy the constraints in +// the kernel_def. Returns an error if attrs in kernel_def are not +// found, or have a mismatching type. +Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, + bool* match) { + *match = false; + AttrSlice attrs(node_def); + for (const auto& constraint : kernel_def.constraint()) { + if (constraint.allowed_values().list().type_size() == 0) { + return errors::Unimplemented( + "KernelDef '", kernel_def.ShortDebugString(), + " has constraint on attr '", constraint.name(), + "' with unsupported type: ", + SummarizeAttrValue(constraint.allowed_values())); + } + + const AttrValue* found = attrs.Find(constraint.name()); + if (found) { + if (found->type() != DT_INVALID) { + if (!InTypeList(found->type(), constraint.allowed_values())) { + return Status::OK(); + } + } else { + if (!AttrValueHasType(*found, "list(type)").ok()) { + return errors::InvalidArgument( + "KernelDef '", kernel_def.ShortDebugString(), + "' has constraint on attr '", constraint.name(), + "' that has value '", SummarizeAttrValue(*found), + "' that does not have type 'type' or 'list(type)' in NodeDef '", + SummarizeNodeDef(node_def), "'"); + } + + for (int t : found->list().type()) { + if (!InTypeList(static_cast<DataType>(t), + constraint.allowed_values())) { + return Status::OK(); + } + } + } + } else { + return errors::InvalidArgument( + "OpKernel '", kernel_def.op(), "' has constraint on attr '", + constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def), + "', KernelDef: '", kernel_def.ShortDebugString(), "'"); + } + } + *match = true; + return Status::OK(); +} + +Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def, + const KernelRegistration** reg) { + *reg = nullptr; + string label; // Label defaults to empty if not found in NodeDef. + GetNodeAttr(node_def, "_kernel", &label); + const string key = Key(node_def.op(), device_type, label); + auto regs = GlobalKernelRegistry()->equal_range(key); + for (auto iter = regs.first; iter != regs.second; ++iter) { + // If there is a kernel registered for the op and device_type, + // check that the attrs match. + bool match; + TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match)); + if (match) { + if (*reg != nullptr) { + return errors::InvalidArgument( + "Multiple OpKernel registrations match NodeDef '", + SummarizeNodeDef(node_def), "': '", (*reg)->def.ShortDebugString(), + "' and '", iter->second.def.ShortDebugString(), "'"); + } + *reg = &iter->second; + } + } + return Status::OK(); +} + +} // namespace + +Status SupportedDeviceTypesForNode( + const std::vector<DeviceType>& prioritized_types, const NodeDef& def, + DeviceTypeVector* device_types) { + // TODO(zhifengc): Changes the callers (SimplePlacer and + // DynamicPlacer) to consider the possibility that 'def' is call to + // a user-defined function and only calls this + // SupportedDeviceTypesForNode for primitive ops. + Status s; + const OpDef* op_def = OpRegistry::Global()->LookUp(def.op(), &s); + if (op_def) { + for (const DeviceType& device_type : prioritized_types) { + const KernelRegistration* reg = nullptr; + TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, def, ®)); + if (reg != nullptr) device_types->push_back(device_type); + } + } else { + // Assumes that all device types support this node. + for (const DeviceType& device_type : prioritized_types) { + device_types->push_back(device_type); + } + } + return Status::OK(); +} + +std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type, + DeviceBase* device, + Allocator* allocator, + const NodeDef& node_def, + Status* status) { + OpKernel* kernel = nullptr; + *status = CreateOpKernel(device_type, device, allocator, nullptr, node_def, + &kernel); + return std::unique_ptr<OpKernel>(kernel); +} + +Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const NodeDef& node_def, OpKernel** kernel) { + VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def); + + // Look up the Op registered for this op name. + Status s; + const OpDef* op_def = OpRegistry::Global()->LookUp(node_def.op(), &s); + if (op_def == nullptr) return s; + + // Validate node_def against OpDef. + s = ValidateNodeDef(node_def, *op_def); + if (!s.ok()) return s; + + // Look up kernel registration. + const KernelRegistration* registration; + s = FindKernelRegistration(device_type, node_def, ®istration); + if (!s.ok()) { + errors::AppendToMessage(&s, " when instantiating ", node_def.op()); + return s; + } + if (registration == nullptr) { + s.Update(errors::NotFound("No registered '", node_def.op(), + "' OpKernel for ", DeviceTypeString(device_type), + " devices compatible with node ", + SummarizeNodeDef(node_def))); + return s; + } + + // Get signature from the OpDef & NodeDef + DataTypeVector inputs; + DataTypeVector outputs; + s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs)); + if (!s.ok()) { + errors::AppendToMessage(&s, " for node: ", SummarizeNodeDef(node_def)); + return s; + } + + // Everything needed for OpKernel construction. + OpKernelConstruction context(device_type, device, allocator, &node_def, + op_def, flib, inputs, outputs, &s); + *kernel = (*registration->factory)(&context); + if (!s.ok()) { + delete *kernel; + *kernel = nullptr; + } + return s; +} + +namespace { // Helper for MemoryTypesForNode. +// Fills memory_types for either input or output, setting everything +// to DEVICE_MEMORY except those args in host_memory_args. Removes +// elements of host_memory_args that were used. +void MemoryTypesHelper(const NameRangeMap& name_map, + std::vector<string>* host_memory_args, + MemoryTypeVector* memory_types) { + // Set total to the largest endpoint of anything in the name_map. + int total = 0; + for (const auto& item : name_map) { + total = std::max(total, item.second.second); + } + + // Now that we know the size, fill with the default 'DEVICE_MEMORY'. + memory_types->clear(); + memory_types->resize(total, DEVICE_MEMORY); + + // Update args that have been marked as in "HOST_MEMORY". + size_t keep = 0; + for (size_t i = 0; i < host_memory_args->size(); ++i) { + auto iter = name_map.find((*host_memory_args)[i]); + if (iter != name_map.end()) { + for (int j = iter->second.first; j < iter->second.second; ++j) { + (*memory_types)[j] = HOST_MEMORY; + } + } else { + // (*host_memory_args)[i] not found, save it for the next pass. + if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i]; + ++keep; + } + } + host_memory_args->resize(keep); +} +} // namespace + +Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef, + const OpDef& op_def, + const NameRangeMap& input_name_map, + const NameRangeMap& output_name_map, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types) { + Status status; + const KernelRegistration* registration; + TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, ndef, ®istration)); + + if (registration != nullptr) { + const auto& from_proto = registration->def.host_memory_arg(); + std::vector<string> host_memory_args(from_proto.begin(), from_proto.end()); + MemoryTypesHelper(input_name_map, &host_memory_args, input_memory_types); + MemoryTypesHelper(output_name_map, &host_memory_args, output_memory_types); + if (!host_memory_args.empty()) { + return errors::InvalidArgument( + "HostMemory args '", str_util::Join(host_memory_args, "', '"), + "' not found in OpDef: ", SummarizeOpDef(op_def)); + } + } + return status; +} + +Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + DeviceType device_type, const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types) { + // Look up the Op registered for this op name. + Status status; + const OpDef* op_def = op_registry->LookUp(ndef.op(), &status); + if (op_def == nullptr) return status; + + NameRangeMap inputs, outputs; + status = NameRangesForNode(ndef, *op_def, &inputs, &outputs); + if (!status.ok()) return status; + + return MemoryTypesForNode(device_type, ndef, *op_def, inputs, outputs, + input_memory_types, output_memory_types); +} + +namespace { + +bool FindArgInOp(const string& arg_name, + const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { + for (const auto& arg : args) { + if (arg_name == arg.name()) { + return true; + } + } + return false; +} + +} // namespace + +Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry) { + Status unused_status; + for (const auto& key_registration : *GlobalKernelRegistry()) { + const KernelDef& kernel_def(key_registration.second.def); + const OpDef* op_def = op_registry->LookUp(kernel_def.op(), &unused_status); + if (op_def == nullptr) { + // TODO(josh11b): Make this a hard error. + LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString() + << "') for unknown op: " << kernel_def.op(); + continue; + } + for (const auto& host_memory_arg : kernel_def.host_memory_arg()) { + if (!FindArgInOp(host_memory_arg, op_def->input_arg()) && + !FindArgInOp(host_memory_arg, op_def->output_arg())) { + return errors::InvalidArgument("HostMemory arg '", host_memory_arg, + "' not found in OpDef: ", + SummarizeOpDef(*op_def)); + } + } + } + return Status::OK(); +} + +template <> +const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const { + return eigen_cpu_device(); +} + +template <> +const Eigen::GpuDevice& OpKernelContext::eigen_device() const { + return eigen_gpu_device(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h new file mode 100644 index 0000000000..34d588c6c9 --- /dev/null +++ b/tensorflow/core/framework/op_kernel.h @@ -0,0 +1,1250 @@ +#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ +#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ + +#include <functional> + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace Eigen { +class ThreadPoolDevice; +class GpuDevice; +} // end namespace Eigen + +namespace tensorflow { + +namespace checkpoint { +class TensorSliceReaderCacheWrapper; +} // namespace checkpoint + +class AsyncOpKernel; +class OpKernelConstruction; // declared below +class OpKernelContext; // declared below +class ResourceMgr; + +// TODO(josh11b): Make reference-counted if needed. +class OpKernel { + public: + // OpKernel won't be instantiated by the scheduler, so you may perform + // expensive initialization in the descendant's constructor. + explicit OpKernel(OpKernelConstruction* context); + virtual ~OpKernel() {} + + // An OpKernel's computation can be either synchronous or + // asynchronous. + // + // Most OpKernels should compute synchronously. They should + // subclass OpKernel and override the Compute() method and have it + // return after completing the supplied work. + // + // A few special kernels might need to be asynchronous to bound the + // number of threads (e.g., network receive operations). These + // kernels must subclass AsyncOpKernel and override + // AsyncOpKernel::ComputeAsync(). + // + // In both cases, implementations of Compute() and ComputeAsync() + // get inputs and write outputs through the given OpKernelContext + // and returns a status via context->SetStatus(). They must be + // thread-safe. + + // Synchronous compute. + // + // "context" is guaranteed to be alive until Compute() returns. + virtual void Compute(OpKernelContext* context) = 0; + + // Returns nullptr iff this op kernel is synchronous. + virtual AsyncOpKernel* AsAsync() { return nullptr; } + + // Returns true iff this op kernel is considered "expensive". The + // runtime may use this flag to optimize graph execution for example + // to "inline" inexpensive kernels. + virtual bool IsExpensive() { return true; } + + // Accessors. + const NodeDef& def() const { return def_; } + const string& name() const { return def_.name(); } + const string& type_string() const { return def_.op(); } + + int num_inputs() const { return input_types_.size(); } + DataType input_type(int i) const { return input_types_[i]; } + const DataTypeVector& input_types() const { return input_types_; } + const MemoryTypeVector& input_memory_types() const { + return input_memory_types_; + } + + int num_outputs() const { return output_types_.size(); } + DataType output_type(int o) const { return output_types_[o]; } + const DataTypeVector& output_types() const { return output_types_; } + const MemoryTypeVector& output_memory_types() const { + return output_memory_types_; + } + + Status InputRange(const string& input_name, int* start, int* stop) const; + Status OutputRange(const string& output_name, int* start, int* stop) const; + + private: + const NodeDef def_; + const DataTypeVector input_types_; + const DataTypeVector output_types_; + NameRangeMap input_name_map_; + NameRangeMap output_name_map_; + MemoryTypeVector input_memory_types_; + MemoryTypeVector output_memory_types_; + + TF_DISALLOW_COPY_AND_ASSIGN(OpKernel); +}; + +class AsyncOpKernel : public OpKernel { + public: + using OpKernel::OpKernel; // Lift OpKernel constructors. + + // Asynchronous compute. + // + // Implementations of ComputeAsync() must run "done" to signal the + // completion of the computation. "context" is guaranteed to be + // alive until the "done" callback starts. + typedef std::function<void()> DoneCallback; + virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0; + + AsyncOpKernel* AsAsync() final { return this; } + + void Compute(OpKernelContext* context) final; +}; + +// Wraps a tensor that is held by an Op across calls to Compute(). For +// memory safety when using asynchronous devices like GPUs, the system +// must be notified when a Tensor is used inside an Op execution. The +// wrapper ensures that all uses of the Tensor are tracked, because in +// order to retrieve the Tensor the caller must use AccessTensor which +// notifies the context. +class PersistentTensor { + public: + PersistentTensor() {} + explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {} + + // Caller does not own the returned Tensor*. + Tensor* AccessTensor(OpKernelConstruction* context); + // Caller does not own the returned Tensor*. + Tensor* AccessTensor(OpKernelContext* context); + + // The check for initialization does not need to access the + // underlying tensor buffer. + bool IsInitialized() { return tensor_.IsInitialized(); } + + private: + Tensor tensor_; +}; + +class OpKernelConstruction { + public: + // TODO(yuanbyu): Probably reduce the number of arguments. + OpKernelConstruction(DeviceType device_type, DeviceBase* device, + Allocator* allocator, const NodeDef* node_def, + const OpDef* op_def, FunctionLibraryRuntime* flib, + const DataTypeSlice& input_types, + const DataTypeSlice& output_types, Status* status) + : device_type_(device_type), + device_(device), + allocator_(allocator), + def_(node_def), + op_def_(op_def), + flib_(flib), + input_types_(input_types), + output_types_(output_types), + status_(status) {} + + Env* env() const { return device_->env(); } + + // Allocation of tensors during kernel construction: + // + // It is legal to temporarily allocate scratch tensor storage during + // Op kernel construction. Scratch tensors should be allocated using + // allocate_temp below. Some kernels need to keep tensors in between + // invocations. If such a Tensor is allocated during kernel + // construction this must be done using allocate_persistent, and the + // Op may only store the returned PersistentTensor object. When the + // Tensor is needed in a subsequent invocation, it can be retrieved + // from the PersistentTensor using the AccessTensor method. This + // ensures that the system is made aware of any use of the tensor's + // allocated memory, which is needed for correctness on asynchronous + // devices such as GPUs. + + // Allocates a temporary Tensor of the specified type and shape. The + // Tensor must not be used after kernel construction is + // complete. See comment above. + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp); + + // Allocates a Tensor of the specified type and shape which the Op + // plans to maintain as persistent state. out_persistent holds the + // PersistentTensor which is the object the caller should store. For + // convenience, if out_tensor is non-null then it will be filled in + // with a Tensor* pointing to the newly-allocated tensor which the + // caller can use instead of calling + // out_persistent->AccessTensor. The caller does not own out_tensor + // and should not keep a copy of it. See comment above. + Status allocate_persistent(DataType type, const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor); + + // User-supplied configuration of this operation. + const NodeDef& def() const { return *def_; } + + // Op registered for this op type. + const OpDef& op_def() const { return *op_def_; } + + // For inspecting the inputs to this operation. + int num_inputs() const { return input_types_.size(); } + DataType input_type(int i) const { return input_types_[i]; } + const DataTypeSlice& input_types() const { return input_types_; } + + // For inspecting the outputs expected from this operation. + int num_outputs() const { return output_types_.size(); } + DataType output_type(int i) const { return output_types_[i]; } + const DataTypeSlice& output_types() const { return output_types_; } + + // If expected_inputs == inputs() and expected_outputs == output_types(), + // returns OK, else returns INVALID_ARGUMENT with an error message. + // Recommended for Ops with dynamic signatures. + Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); + + // For recording configuration errors during construction. + void SetStatus(const Status& status) { status_->Update(status); } + const Status& status() const { return *status_; } + + // Look up the attr with name attr_name and set *value to its value. If no + // attr with attr_name is found in def(), or the attr does not have + // a matching type, a non-ok status will be returned. + template <class T> + Status GetAttr(const string& attr_name, T* value) const { + return GetNodeAttr(def(), attr_name, value); + } + + // May be used, e.g., to get GPU handles, etc. + // TODO(tucker): Add example usage. + DeviceBase* device() const { return device_; } + + // Return the device type. + const DeviceType& device_type() const { return device_type_; } + + // If not nullptr, the kernel can instantiate functions defined in + // the library. E.g., + // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...). + FunctionLibraryRuntime* function_library() const { return flib_; } + + private: + const DeviceType device_type_; + DeviceBase* const device_; + Allocator* allocator_; + const NodeDef* def_; + const OpDef* op_def_; + FunctionLibraryRuntime* flib_; + DataTypeSlice input_types_; + DataTypeSlice output_types_; + Status* status_; + + TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction); +}; + +// TODO(mrry): Consider converting to a random_access_iterator, and upgrading +// tensorflow::gtl::iterator_range to make the below container classes +// unnecessary. +template <typename ListType, typename ElementType> +class OpArgIterator { + public: + typedef OpArgIterator<ListType, ElementType> ME; + OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} + bool operator==(const ME& rhs) { + DCHECK(list_ == rhs.list_); + return i_ == rhs.i_; + } + bool operator!=(const ME& rhs) { + DCHECK(list_ == rhs.list_); + return i_ != rhs.i_; + } + void operator++() { ++i_; } + ElementType& operator*() { return (*list_)[i_]; } + + private: + const ListType* const list_; + int i_; +}; + +// Utility class for representing a list of immutable input tensors +// that are passed to the op as a single named argument. +class OpInputList { + public: + typedef OpArgIterator<OpInputList, const Tensor&> Iterator; + OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpInputList(const OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpInputList& operator=(const OpInputList& other) = default; + const Tensor& operator[](int i) const; + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + const OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Utility class for representing a list of mutable ("ref") input tensors +// that are passed to the op as a single named argument. +class OpMutableInputList { + public: + typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator; + OpMutableInputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpMutableInputList& operator=(const OpMutableInputList& other) = default; + Tensor at(int i, bool lock_held); + mutex* ref_mutex(int i); + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Utility class for representing a list of output tensors that are +// grouped as a single named output. +class OpOutputList { + public: + typedef OpArgIterator<OpOutputList, const Tensor*> Iterator; + OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpOutputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpOutputList& operator=(const OpOutputList& other) = default; + Tensor* operator[](int i); + bool required(int i) const; + Status allocate(int i, const TensorShape& shape, Tensor** output); + void set(int i, const Tensor& tensor); + void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Holds a tensor or tensor reference. For tensor references, we need +// a mutex to prevent concurrent access to the tensor. +struct TensorValue { + TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {} + TensorValue(Tensor* t) // NOLINT(runtime/explicit) + : mutex_if_ref(nullptr), + tensor(t) {} + TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {} + Tensor* operator->() const { return tensor; } + bool is_ref() const { return mutex_if_ref != nullptr; } + + mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref + Tensor* tensor; +}; + +class OpKernelContext { + public: + // The first element of a WrappedAllocator is a "base" Allocator and + // the second element is that Allocator wrapped by a + // TrackingAllocator + typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator; + + // TODO(zhifengc): Do some cleanup of Params. + struct Params { + // The op kernel being computed. + OpKernel* op_kernel = nullptr; + + // The device on which the kernel is running. + DeviceBase* device = nullptr; + + bool track_allocations = false; + std::function<AllocatorAttributes(int index)> output_alloc_attr = nullptr; + + // Shared resources accessible by this op kernel invocation. + ResourceMgr* resource_manager = nullptr; + + // Per-step resources accessible by this op kernel invocation. + ResourceMgr* step_resource_manager = nullptr; + + // Mechanism used by this op kernel invocation to communicate with + // computations running on other devices. + Rendezvous* rendezvous = nullptr; + + // Mechanism used by this op kernel invocation to register a callback + // for its cancellation. + CancellationManager* cancellation_manager = nullptr; + + // Inputs to this op kernel. + const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr; + bool is_input_dead = false; + + const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs = + nullptr; + + // Device contexts. + const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts = + nullptr; + DeviceContext* op_device_context = nullptr; + + // Control-flow op supports. + FrameAndIter frame_iter; + + // Function call supports. + FunctionCallFrame* call_frame = nullptr; + FunctionLibraryRuntime* function_library = nullptr; + + // TensorSliceReaderCache support. + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; + }; + explicit OpKernelContext(const Params& params); + ~OpKernelContext(); + + Env* env() const { return params_.device->env(); } + + // Input/output signature. + + int num_inputs() const { return params_.inputs->size(); } + DataType input_dtype(int index) const; + int num_outputs() const { return outputs_.size(); } + DataType expected_output_dtype(int index) const; + + // Input + + // Returns an immutable input tensor. May only be used for non-Ref + // inputs. For Ref inputs use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + // TODO(mrry): Convert this to return Status. + const Tensor& input(int index) const; + + // Returns the named immutable input tensor in "tensor", as defined + // in the OpDef. May only be used for non-Ref inputs. For Ref inputs + // use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + // REQUIRES: the named input must not be a list. + Status input(const string& name, const Tensor** tensor) const; + + // Returns the named list-valued immutable input in "list", as + // defined in the OpDef. If the named output is not list-valued, + // returns a one-element list. May only be used for non-Ref + // inputs. For Ref inputs use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + Status input_list(const string& name, OpInputList* list) const; + + // For mutable inputs, use the following together to make sure there + // is no concurrent access to mutable_input(), e.g.: + // { + // Tensor& t = context->mutable_input(index); + // mutex_lock lock(*context->input_ref_mutex(index)); + // // modify the values in t + // } + // REQUIRES: IsRefType(input_dtype(index)) + // TODO(mrry): Convert this to return Status. + mutex* input_ref_mutex(int index); + Status input_ref_mutex(const string& name, mutex** out_mutex); + + // Returns a mutable input tensor. Must be used to access Ref + // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may + // modify the values stored in the Tensor buffer, and modifications + // will be visible to other Ops reading the same ref tensor. If + // !lock_held the input mutex will be acquired before returning the + // Tensor. + // TODO(mrry): + // Convert this to return Status. + Tensor mutable_input(int index, bool lock_held); + + // Returns the named mutable input tensor in "tensor", as defined in + // the OpDef. Must be used to access Ref inputs. The values stored + // in the Tensor buffer may be modified, and modifications will be + // visible to other Ops reading the same ref tensor. If !lock_held + // the input mutex will be acquired before returning the Tensor. + // REQUIRES: the named input must not be a list. + // REQUIRES: the named input must be a ref tensor. + Status mutable_input(const string& name, Tensor* tensor, bool lock_held); + + // Returns the named list-valued mutable input in "list", as defined + // in the OpDef. If the named intput is not list-valued, returns a + // one-element list. Must be used to access Ref inputs. The values + // stored in the Tensor buffer may be modified, and modifications + // will be visible to other Ops reading the same ref tensor. + // REQUIRES: the named input must be a ref tensor. + Status mutable_input_list(const string& name, OpMutableInputList* list); + + // Replace the corresponding Ref Input to use the storage buffer + // used by tensor. If !lock_held the input mutex will be acquired + // before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(index)). + void replace_ref_input(int index, const Tensor& tensor, bool lock_held); + + // Replace the corresponding named Ref Input to use the storage + // buffer used by tensor. If !lock_held the input mutex will be + // acquired before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(index)). + Status replace_ref_input(const string& name, const Tensor& tensor, + bool lock_held); + + // Set the output Ref Tensor at output_index to be an alias of the + // input Ref Tensor at input_index. + // REQUIRES: IsRefType(input_dtype(input_index)). + // REQUIRES: IsRefType(output_dtype(output_index)). + void forward_ref_input_to_ref_output(int input_index, int output_index); + + // Deletes the Tensor object used as the Ref Input at + // input_index. This is not usually necessary and should be used + // with caution. If !lock_held the input mutex will be acquired + // before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(input_index)). + void delete_ref_input(int input_index, bool lock_held); + + // Return true if there is input at the given index. An operator has no + // input at index if its tensor is null. This is primarily used by the + // merge operator. + // TODO(mrry): Convert this to return Status. + bool has_input(int index) const; + + // Returns true if all inputs are the same shape, otherwise sets the + // status to a non-OK value and returns false. + // Usage: if (!context->ValidateInputsAreSameShape(this)) return; + bool ValidateInputsAreSameShape(OpKernel* op); + + // Output + + // Returns the named list-valued output in "list", as defined in the OpDef. + // If the named output is not list-valued, returns a one-element list. + Status output_list(const string& name, OpOutputList* list); + + // If output_required(index) returns true, the OpKernel's Compute() method + // should call allocate_output(index, ...), set_output(index, ...), + // set_output_ref(index, ...), or set the status to a non-ok value. + // If it returns false, it may output, but is not required to do so. + // TODO(mrry): Convert this to return Status, and implement a string + // name version. + bool output_required(int index) const { + return true; // TODO(josh11b): implement + } + + // Allocation of tensors during kernel execution inside the Compute + // method: + // + // There are three methods to allocate Tensors when an Op kernel + // executes. + // + // 1) allocate_persistent. This is only needed for Tensors that will + // be stored by the Op between invocations, and it *must* be used + // for those Tensors. The call returns a PersistentTensor, and that + // is the only object the Op is allowed to hold on to between + // invocations. When the Tensor is needed in a subsequent + // invocation, it can be retrieved from the PersistentTensor using + // the AccessTensor method. This ensures that the system is made + // aware of any use of the tensor's allocated memory, which is + // needed for correctness on asynchronous devices such as GPUs. + // + // 2) allocate_output. This should be used to allocate any tensor + // that is going to be used as an output from the Op at the end of + // the current execution. The caller indicates which output the + // Tensor will be assigned to, and the call returns the + // newly-allocated Tensor. The Tensor can subsequently be assigned + // to during kernel execution, and will be used as the designated + // output when the kernel execution completes. + // + // 3) allocate_temp. This should be used to allocate any scratch + // storage that is needed while the kernel is executing, and will + // not be retained by the Op. + // + // In some cases a Tensor needs to be used as an output even though + // it was previously allocated elsewhere. The Tensor may have been + // passed as an input, or stored in a PersistentTensor during a + // previous kernel execution, or allocated earlier in the kernel + // execution at a time when it was not known which output it would + // be assigned to. In this case the kernel can use set_output or + // set_output_ref to indicate that the tensor should be used as the + // designated output. It is legal to use any previously-allocated + // Tensor as an argument to set_output or set_output_ref, including + // Tensors allocated via allocate_temp. There may be a performance + // penalty to using a Tensor that was not allocated using + // allocate_output. This is because allocate_output uses the + // AllocatorAttributes stored in output_alloc_attr for the + // designated output. In some cases, using the wrong attributes may + // cause an extra copy of the Tensor's buffer. + + // Allocates output for the specified output index with shape. + // OpKernelContext retains ownership of the returned pointer. See + // comment above. + // + // If memory allocation fails, returns an error status. + // + // REQUIRES: !IsRefType(expected_output_dtype(index)) + Status allocate_output(int index, const TensorShape& shape, + Tensor** tensor) TF_MUST_USE_RESULT; + Status allocate_output(const string& name, const TensorShape& shape, + Tensor** tensor) TF_MUST_USE_RESULT; + // The following methods use the supplied attributes instead of + // those in output_alloc_attr. The caller is responsible for + // ensuring that the attributes are "compatible" with the + // output_alloc_attr, e.g. the tensor is allocated on the correct + // device. See comment above. + Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, + AllocatorAttributes attr) TF_MUST_USE_RESULT; + Status allocate_output(const string& name, const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) TF_MUST_USE_RESULT; + + // Allocates a temporary Tensor of the specified type and + // shape. Devices such as GPUs that enqueue Ops for lazy execution + // may retain references to the temporary tensors after the Op's + // Compute method has run. See comment above. + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, AllocatorAttributes attr); + Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp) { + return allocate_temp(type, shape, out_temp, AllocatorAttributes()); + } + + // Allocates a Tensor of the specified type and shape which the Op + // plans to maintain as persistent state. out_persistent holds the + // PersistentTensor which is the object the caller should store. For + // convenience, if out_tensor is non-null then it will be filled in + // with a Tensor* pointing to the newly-allocated tensor which the + // caller can use instead of calling + // out_persistent->AccessTensor. The caller does not own out_tensor + // and should not keep a copy of it. See comment above. + Status allocate_persistent(DataType type, const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor, AllocatorAttributes attr); + Status allocate_persistent(DataType type, const TensorShape& shape, + PersistentTensor* out_persistent, + Tensor** out_tensor) { + return allocate_persistent(type, shape, out_persistent, out_tensor, + AllocatorAttributes()); + } + + // Copies a tensor (allocated by the caller) to the specified output + // index. REQUIRES: !IsRefType(expected_output_dtype(index)) + // REQUIRES: 'tensor' must have the same MemoryType as + // output_memory_types[index]. See comment above. + // TODO(mrry): Convert this to return Status. + void set_output(int index, const Tensor& tensor); + Status set_output(const string& name, const Tensor& tensor); + + // To output a reference. Caller retains ownership of mu and tensor_for_ref, + // and they must outlive all uses within the step. See comment above. + // REQUIRES: IsRefType(expected_output_dtype(index)) + // TODO(mrry): Convert this to return Status. + void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); + Status set_output_ref(const string& name, mutex* mu, Tensor* tensor_for_ref); + + // Returns nullptr if allocate_output() or set_output() have not been called. + // TODO(mrry): Convert this to return Status. + Tensor* mutable_output(int index); + Status mutable_output(const string& name, Tensor** tensor); + + // Transfers ownership of an output tensor to the caller. + // NOTE: For non-reference outputs, the caller takes responsibility + // for deletion. For reference outputs, the caller does NOT take + // responsibility for deletion. + // TODO(mrry): Convert this to return Status. + TensorValue release_output(int index); + Status release_output(const string& name, TensorValue* value); + + // Records device specific state about how the input tensors were + // computed. + // + // If using the templated function, the type must be a subclass + // of DeviceContext. + // + // Get the DeviceContext used for the index input. Returns nullptr + // if no DeviceContext was provided. + template <typename T> + T* input_device_context(int index); + DeviceContext* input_device_context(int index); + + // Return the DeviceContext that should be used for this Op. + // + // If using the templated function, the type must be a subclass + // of DeviceContext. + // + // Returns nullptr if the device did not provide one. + template <typename T> + T* op_device_context(); + DeviceContext* op_device_context() { + DeviceContext* ret = params_.op_device_context; + if (ret == nullptr) { + auto* dev_info = device()->tensorflow_gpu_device_info(); + if (dev_info) ret = dev_info->default_context; + } + return ret; + } + + AllocatorAttributes input_alloc_attr(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.input_alloc_attrs->size()); + return (*params_.input_alloc_attrs)[index]; + } + + AllocatorAttributes output_alloc_attr(int index) const { + return params_.output_alloc_attr(index); + } + + gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const { + mutex_lock lock(mu_); + gtl::InlinedVector<WrappedAllocator, 4> retrieved = wrapped_allocators_; + return retrieved; + } + + // Communication. + // + // An op kernel communicates with outside environment through + // Rendezvous Send() and Recv(). + Rendezvous* rendezvous() const { return params_.rendezvous; } + + // Function call support. + // + // If this kernel invocation is within a function execution, + // call_frame() returns the call frame for the function call. + FunctionCallFrame* call_frame() const { return params_.call_frame; } + + // If not nullptr, the kernel invoke functions defined in the + // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...). + FunctionLibraryRuntime* function_library() const { + return params_.function_library; + } + + // Shared resources accessible to this kernel. + ResourceMgr* resource_manager() const { return params_.resource_manager; } + + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const { + return params_.slice_reader_cache; + } + + // Execution. + // + // OpKernels can use these eigen devices to carry out their + // numerical computation. + const Eigen::ThreadPoolDevice& eigen_cpu_device() const { + return *device()->eigen_cpu_device(); + } + const Eigen::GpuDevice& eigen_gpu_device() const { + return eigen_gpu_device_->device(); + } + template <typename EigenDeviceType> + const EigenDeviceType& eigen_device() const; + + // Error handling. + + // If expected_inputs == inputs() and expected_outputs == output_types(), + // returns OK, else returns INVALID_ARGUMENT with an error message. + // Recommended for Ops with dynamic signatures, where validation can only + // be performed at runtime. + Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); + + // An OpKernel should call SetStatus() if Compute() encounters an + // error. + void SetStatus(const Status& status) { status_.Update(status); } + const Status& status() const { return status_; } + + // Cancellation. + // + // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an + // example of how to use this API. + CancellationManager* cancellation_manager() const { + return params_.cancellation_manager; + } + + // Other accessors. + + // For control flow. + FrameAndIter frame_iter() const { return params_.frame_iter; } + bool is_input_dead() const { return params_.is_input_dead; } + bool* is_output_dead() { return &is_output_dead_; } + + // May be used, e.g., to get GPU handles, etc. + // TODO(tucker): Add example usage. + DeviceBase* device() const { return params_.device; } + + // Access to list of temporary tensors. + int num_temps(); + Tensor* temp(int index); + + // Access to information about whether each output was newly + // allocated or copied from an existing tensor + AllocationType output_allocation_type(int index) const { + return output_allocation_types_[index]; + } + + private: + Allocator* get_allocator(AllocatorAttributes attr) { + Allocator* allocator = params_.device->GetAllocator(attr); + if (params_.track_allocations) { + mutex_lock lock(mu_); + for (const auto& wrapped : wrapped_allocators_) { + if (wrapped.first == allocator) { + return wrapped.second; + } + } + TrackingAllocator* wrapped_allocator = new TrackingAllocator(allocator); + wrapped_allocators_.push_back( + std::make_pair(allocator, wrapped_allocator)); + return wrapped_allocator; + } else { + return allocator; + } + } + + // Per-step resource manager for use by white-listed internal ops. + friend class TemporaryVariableOp; + friend class DestroyTemporaryVariableOp; + ResourceMgr* step_resource_manager() const { + return params_.step_resource_manager; + } + + // Internal common method used when allocating tensor memory + Status allocate_tensor(DataType type, const TensorShape& shape, + Tensor* out_tensor, AllocatorAttributes attr); + + // This is called by PersistentTensor::AccessTensor whenever the + // wrapped tensor is retrieved, to ensure the runtime knows that the + // Tensor is being accessed within an Op. This is necessary for + // memory safety of devices like GPUs that queue Ops for + // asynchronous execution after the Compute() method completes. + friend class PersistentTensor; + void NotifyUseOfPersistentTensor(const Tensor& tensor); + + Status status_; + Params params_; // immutable after construction. + const PerOpGpuDevice* eigen_gpu_device_; // owned, with a per-op + // wrapped allocator + mutable mutex mu_; // mutable so const accessors can acquire the lock + gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_); + gtl::InlinedVector<TensorValue, 4> outputs_; + gtl::InlinedVector<AllocationType, 4> output_allocation_types_; + gtl::InlinedVector<Tensor*, 4> temp_tensors_; + bool is_output_dead_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext); +}; + +// Register your OpKernel by specifying the Op's name, the device the +// kernel runs on, any type attr constraints for this kernel, any +// host-memory args, and the class to instantiate. Examples: +// +// // A kernel that supports all types. +// REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp); +// +// // The following are equivalent ways of specifying that the kernel only +// // works if the "T" type attr is set to DT_FLOAT. +// REGISTER_KERNEL_BUILDER( +// Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"), +// SubOp<float>); +// // (You would then repeat this for every type supported by "Sub".) +// +// // This form allows you to specify a list of types as the constraint. +// REGISTER_KERNEL_BUILDER(Name("Sub") +// .Device(DEVICE_CPU) +// .TypeConstraint("T", {DT_FLOAT}), +// SubOp<float>); +// +// // A kernel that expects one of the input tensors in host memory. +// REGISTER_KERNEL_BUILDER( +// Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp); +// +// See kernel_def_builder for details. + +// Instantiate an OpKernel that has been registered. Returns nullptr +// if no operation for that type of device / input signature combination +// (and a NOT_FOUND *status), or there is an error in construction (and +// an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership +// of the returned pointer. +// EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...); +// REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type, + DeviceBase* device, + Allocator* allocator, + const NodeDef& def, Status* status); +Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const NodeDef& def, OpKernel** kernel); + +// Returns into 'device_types' the subset of prioritized_types that this +// binary has registered for the given NodeDef. +// +// REQUIRES: * 'device_types' is not nullptr. +// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +Status SupportedDeviceTypesForNode( + const std::vector<DeviceType>& prioritized_types, const NodeDef& def, + DeviceTypeVector* device_types); + +// Returns into *{input,output}_memory_types the memory type of each +// {input,output} tensor. +// +// REQUIRES: * '*_memory_types' is not nullptr. +// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef, + const OpDef& op_def, + const NameRangeMap& input_name_map, + const NameRangeMap& output_name_map, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types); + +Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + DeviceType device_type, const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types); + +// Call once after Op registration has completed. +Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry); + +// ----------------------------------------------------------------------------- +// OpKernel registration implementation follows, please ignore. + +// Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax. +namespace register_kernel { +typedef ::tensorflow::KernelDefBuilder Name; +} // namespace register_kernel + +#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ + REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) + +#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ + REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) + +#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ + static ::tensorflow::kernel_factory::OpKernelRegistrar \ + registrar__body__##ctr##__object( \ + ::tensorflow::register_kernel::kernel_builder.Build(), \ + +[](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); }) + +namespace kernel_factory { + +class OpKernelRegistrar { + public: + typedef OpKernel* (*Factory)(OpKernelConstruction*); + OpKernelRegistrar(const KernelDef* kernel_def, Factory factory); +}; + +} // namespace kernel_factory + +// ----------------------------------------------------------------------------- +// Template and inline method implementations, please ignore + +inline DataType OpKernelContext::input_dtype(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + const TensorValue& value((*params_.inputs)[index]); + if (value.is_ref()) { + return MakeRefType(value->dtype()); + } else { + return value->dtype(); + } +} + +inline DataType OpKernelContext::expected_output_dtype(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.op_kernel->output_types().size()); + return params_.op_kernel->output_type(index); +} + +inline const Tensor& OpKernelContext::input(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK(!(*params_.inputs)[index].is_ref()); + return *((*params_.inputs)[index].tensor); +} + +inline Tensor OpKernelContext::mutable_input(int index, bool lock_held) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK((*params_.inputs)[index].is_ref()); + // return a copy of the Ref acquired while holding the mutex + if (lock_held) { + return *((*params_.inputs)[index].tensor); + } else { + mutex_lock l(*input_ref_mutex(index)); + return *((*params_.inputs)[index].tensor); + } +} + +inline void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, + bool lock_held) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK((*params_.inputs)[index].is_ref()); + // should only modify the tensor while holding the mutex + if (lock_held) { + *(*params_.inputs)[index].tensor = tensor; + } else { + mutex_lock l(*input_ref_mutex(index)); + *(*params_.inputs)[index].tensor = tensor; + } +} + +inline void OpKernelContext::forward_ref_input_to_ref_output(int input_index, + int output_index) { + DCHECK_GE(input_index, 0); + DCHECK_LT(input_index, params_.inputs->size()); + DCHECK((*params_.inputs)[input_index].is_ref()); + set_output_ref(output_index, (*params_.inputs)[input_index].mutex_if_ref, + (*params_.inputs)[input_index].tensor); +} + +inline void OpKernelContext::delete_ref_input(int index, bool lock_held) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK((*params_.inputs)[index].is_ref()); + // should only modify the tensor while holding the mutex + if (lock_held) { + delete (*params_.inputs)[index].tensor; + } else { + mutex_lock l(*input_ref_mutex(index)); + delete (*params_.inputs)[index].tensor; + } +} + +// no input if tensor == nullptr. +inline bool OpKernelContext::has_input(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + return (*params_.inputs)[index].tensor != nullptr; +} + +inline mutex* OpKernelContext::input_ref_mutex(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.inputs->size()); + DCHECK((*params_.inputs)[index].is_ref()); + return (*params_.inputs)[index].mutex_if_ref; +} + +inline Status OpKernelContext::allocate_output(int index, + const TensorShape& shape, + Tensor** output) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + DCHECK(params_.output_alloc_attr); + AllocatorAttributes attr = params_.output_alloc_attr(index); + return allocate_output(index, shape, output, attr); +} + +inline Status OpKernelContext::allocate_tensor(DataType type, + const TensorShape& shape, + Tensor* out_tensor, + AllocatorAttributes attr) { + Allocator* a = get_allocator(attr); + Tensor new_tensor(a, type, shape); + + if (!new_tensor.IsInitialized() && shape.num_elements() > 0) { + return errors::ResourceExhausted("OOM when allocating tensor with shape", + shape.DebugString()); + } + *out_tensor = new_tensor; + return Status::OK(); +} + +inline Status OpKernelContext::allocate_output(int index, + const TensorShape& shape, + Tensor** output, + AllocatorAttributes attr) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + // Record the fact that this output tensor was allocated by the Op + DCHECK_LT(index, output_allocation_types_.size()); + output_allocation_types_[index] = AT_ALLOCATED; + const DataType type = params_.op_kernel->output_type(index); + DCHECK(!IsRefType(type)); + DCHECK(mutable_output(index) == nullptr); + Tensor* output_tensor = new Tensor(); + Status s = allocate_tensor(type, shape, output_tensor, attr); + if (s.ok()) { + outputs_[index] = TensorValue(output_tensor); + *output = outputs_[index].tensor; + } + return s; +} + +inline Status OpKernelContext::allocate_temp(DataType type, + const TensorShape& shape, + Tensor* out_temp, + AllocatorAttributes attr) { + Status s = allocate_tensor(type, shape, out_temp, attr); + if (s.ok()) { + if (params_.device->SaveTemporaryTensors()) { + // keep a reference to the underlying memory around + temp_tensors_.push_back(new Tensor(*out_temp)); + } + } + return s; +} + +inline Status OpKernelContext::allocate_persistent( + DataType type, const TensorShape& shape, PersistentTensor* out_persistent, + Tensor** out_tensor, AllocatorAttributes attr) { + // TODO(misard) add specific memory tracking for persistent tensors + Tensor persistent; + Status s = allocate_tensor(type, shape, &persistent, attr); + if (s.ok()) { + *out_persistent = PersistentTensor(persistent); + // This call saves a reference to the newly-allocated tensor if we + // are saving temporary tensors + Tensor* allocated = out_persistent->AccessTensor(this); + if (out_tensor) { + *out_tensor = allocated; + } + } + return s; +} + +inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { + if (t.IsInitialized() && params_.device->SaveTemporaryTensors()) { + // keep a reference to the underlying memory around + temp_tensors_.push_back(new Tensor(t)); + } +} + +inline int OpKernelContext::num_temps() { return temp_tensors_.size(); } + +inline Tensor* OpKernelContext::temp(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, temp_tensors_.size()); + return temp_tensors_[index]; +} + +inline void OpKernelContext::set_output(int index, const Tensor& tensor) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + // Record the fact that this output tensor was set by the Op + DCHECK_LT(index, output_allocation_types_.size()); + output_allocation_types_[index] = AT_EXISTING; + DCHECK(!IsRefType(params_.op_kernel->output_type(index))); + DCHECK_EQ(mutable_output(index), nullptr); + outputs_[index] = TensorValue(new Tensor(tensor)); +} + +inline void OpKernelContext::set_output_ref(int index, mutex* mu, + Tensor* tensor_for_ref) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + // Record the fact that this output tensor was set by reference the Op + DCHECK_LT(index, output_allocation_types_.size()); + output_allocation_types_[index] = AT_REF; + DCHECK(IsRefType(params_.op_kernel->output_type(index))); + outputs_[index] = TensorValue(mu, tensor_for_ref); +} + +inline Tensor* OpKernelContext::mutable_output(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + return outputs_[index].tensor; +} + +inline TensorValue OpKernelContext::release_output(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, outputs_.size()); + TensorValue value = outputs_[index]; + outputs_[index] = TensorValue(); + return value; +} + +template <typename T> +T* OpKernelContext::op_device_context() { + static_assert(std::is_base_of<DeviceContext, T>::value, + "T is not a subclass of DeviceContext"); + return static_cast<T*>(op_device_context()); +} + +template <typename T> +T* OpKernelContext::input_device_context(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.input_device_contexts->size()); + static_assert(std::is_base_of<DeviceContext, T>::value, + "T is not a subclass of DeviceContext"); + return static_cast<T*>((*params_.input_device_contexts)[index]); +} + +inline DeviceContext* OpKernelContext::input_device_context(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_.input_device_contexts->size()); + return (*params_.input_device_contexts)[index]; +} + +inline const Tensor& OpInputList::operator[](int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->input(start_ + i); +} + +inline mutex* OpMutableInputList::ref_mutex(int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->input_ref_mutex(start_ + i); +} + +inline Tensor OpMutableInputList::at(int i, bool lock_held) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->mutable_input(start_ + i, lock_held); +} + +inline Tensor* OpOutputList::operator[](int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->mutable_output(start_ + i); +} + +inline bool OpOutputList::required(int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->output_required(start_ + i); +} + +inline Status OpOutputList::allocate(int i, const TensorShape& shape, + Tensor** output) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->allocate_output(start_ + i, shape, output); +} + +inline void OpOutputList::set(int i, const Tensor& tensor) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output(start_ + i, tensor); +} + +inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output_ref(i, mu, tensor_for_ref); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_ diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc new file mode 100644 index 0000000000..9400ef24f8 --- /dev/null +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -0,0 +1,803 @@ +#include "tensorflow/core/framework/op_kernel.h" + +#include <memory> +#include <vector> +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include <gtest/gtest.h> + +class DummyKernel : public tensorflow::OpKernel { + public: + explicit DummyKernel(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(tensorflow::OpKernelContext* context) override {} +}; + +// Test that registration works outside a namespace. +REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8"); +REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU), + DummyKernel); + +namespace foo { +bool match_signature_ = false; + +// Test that registration works inside a different namespace. +class TestOp2 : public ::tensorflow::OpKernel { + public: + explicit TestOp2(::tensorflow::OpKernelConstruction* context) + : OpKernel(context) { + ::tensorflow::Status status = context->MatchSignature( + {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32}); + match_signature_ = status.ok(); + context->SetStatus(status); + } + void Compute(::tensorflow::OpKernelContext* context) override {} +}; + +REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("Test2") + .Device(::tensorflow::DEVICE_GPU) + .HostMemory("i") + .HostMemory("o"), + TestOp2); +} // namespace foo + +namespace tensorflow { + +// Two operations with the same name but different devices. +REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type"); + +class TestOp3Cpu : public tensorflow::OpKernel { + public: + explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER( + Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu); + +namespace { + +class TestOp3Gpu : public tensorflow::OpKernel { + public: + explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER( + Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu); + +// An Op registered for both +REGISTER_OP("Test4").Input("i: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel); + +static std::vector<DeviceType> DeviceTypes() { + return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}; +} + +class OpKernelTest : public ::testing::Test { + public: + OpKernelTest() : device_(Env::Default()) {} + + protected: + NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) { + NodeDefBuilder builder(op_type + "-op", op_type); + for (DataType dt : inputs) { + builder.Input(FakeInput(dt)); + } + NodeDef node_def; + TF_CHECK_OK(builder.Finalize(&node_def)); + return node_def; + } + + void ExpectEqual(const string& what, const DataTypeVector& expected, + const DataTypeVector& observed) { + EXPECT_EQ(expected.size(), observed.size()) << what; + const int size = std::min(expected.size(), observed.size()); + for (int i = 0; i < size; ++i) { + bool match = TypesCompatible(expected[i], observed[i]); + EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i] + << ", observed: " << observed[i]; + } + } + + void ExpectSuccess(const string& op_type, DeviceType device_type, + const DataTypeVector& inputs, + const DataTypeVector& outputs) { + Status status; + std::unique_ptr<OpKernel> op( + CreateOpKernel(device_type, &device_, cpu_allocator(), + CreateNodeDef(op_type, inputs), &status)); + EXPECT_TRUE(status.ok()) << status; + EXPECT_TRUE(op != nullptr); + if (op != nullptr) { + ExpectEqual("inputs", op->input_types(), inputs); + ExpectEqual("outputs", op->output_types(), outputs); + } + } + + void ExpectFailure(const string& ascii_node_def, DeviceType device_type, + error::Code code) { + NodeDef node_def; + protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def); + Status status; + std::unique_ptr<OpKernel> op(CreateOpKernel( + device_type, &device_, cpu_allocator(), node_def, &status)); + EXPECT_TRUE(op == nullptr); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + LOG(INFO) << "Status message: " << status.error_message(); + EXPECT_EQ(code, status.code()); + } + } + + private: + DeviceBase device_; +}; + +TEST_F(OpKernelTest, SuccessCpu) { + ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8}); + ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8}); +} + +TEST_F(OpKernelTest, SuccessGpu) { + foo::match_signature_ = false; + ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32}); + EXPECT_TRUE(foo::match_signature_); +} + +TEST_F(OpKernelTest, SuccessBothCpuAndGpu) { + ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {}); + ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {}); +} + +TEST_F(OpKernelTest, CpuTypeRegistered) { + NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); +} + +TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { + { + // Try a node def of an op that is registered for a specific type + // only on CPU. + NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); + } + { + // Try a node def of an op that is registered for a specific type + // only on GPU. + NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + } + { + // Try a node def of an op that is only registered for other types. + NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(0, devs.size()); + } + + { + // Try a node def of an op that is registered for both. + NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(2, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]); + } +} + +TEST_F(OpKernelTest, NotFound) { + const auto not_found = error::NOT_FOUND; + // Something with that op type name exists, but only with a + // different DeviceType. + ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(), + DEVICE_GPU, not_found); + ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(), + DEVICE_GPU, not_found); + ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(), + DEVICE_CPU, not_found); + + // No kernel with that signature registered. + ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(), + DEVICE_GPU, not_found); + + // Nothing with that op type name exists. + ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found); + ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found); +} + +TEST_F(OpKernelTest, TooFewInputs) { + const auto invalid = error::INVALID_ARGUMENT; + NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + node_def.clear_input(); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); + node_def.add_input("a"); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); +} + +TEST_F(OpKernelTest, TooManyInputs) { + const auto invalid = error::INVALID_ARGUMENT; + NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + node_def.add_input("c"); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); +} + +TEST_F(OpKernelTest, MatchSignatureFailes) { + const auto invalid = error::INVALID_ARGUMENT; + foo::match_signature_ = true; + ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU, + invalid); + EXPECT_FALSE(foo::match_signature_); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool SaveTemporaryTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST_F(OpKernelTest, SaveTempFalse) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.device = new DummyDevice(env, false); + Status status; + std::unique_ptr<OpKernel> op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + OpKernelContext* ctx = new OpKernelContext(params); + + Tensor t; + EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); + + EXPECT_EQ(0, ctx->num_temps()); + + delete ctx; + delete params.device; +} + +TEST_F(OpKernelTest, SaveTempTrue) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.device = new DummyDevice(env, true); + Status status; + std::unique_ptr<OpKernel> op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + OpKernelContext* ctx = new OpKernelContext(params); + + Tensor t; + EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); + + EXPECT_EQ(1, ctx->num_temps()); + + delete ctx; + delete params.device; +} + +class OpKernelBuilderTest : public ::testing::Test { + protected: + // Each attr is described by a "name|type|value". + NodeDef CreateNodeDef(const string& op_type, + const std::vector<string>& attrs) { + NodeDef node_def; + node_def.set_name(op_type + "-op"); + node_def.set_op(op_type); + for (const string& attr_desc : attrs) { + std::vector<string> parts = str_util::Split(attr_desc, '|'); + CHECK_EQ(parts.size(), 3); + AttrValue attr_value; + CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc; + node_def.mutable_attr()->insert( + AttrValueMap::value_type(parts[0], attr_value)); + } + return node_def; + } + + std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type, + DeviceType device_type, + const std::vector<string>& attrs, + DataTypeSlice input_types = {}) { + Status status; + NodeDef def = CreateNodeDef(op_type, attrs); + for (size_t i = 0; i < input_types.size(); ++i) { + def.add_input("a:0"); + } + + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel() + std::unique_ptr<OpKernel> op( + CreateOpKernel(device_type, &device, cpu_allocator(), def, &status)); + EXPECT_TRUE(status.ok()) << status; + EXPECT_TRUE(op != nullptr); + if (op != nullptr) { + EXPECT_EQ(input_types.size(), op->num_inputs()); + EXPECT_EQ(0, op->num_outputs()); + } + + // Test SupportedDeviceTypesForNode() + DeviceTypeVector devices; + EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + bool found = false; + for (DeviceType dt : devices) { + if (dt == device_type) { + found = true; + } + } + EXPECT_TRUE(found) << "Missing " << device_type << " from " + << devices.size() << " devices."; + + // In case the caller wants to use the OpKernel + return op; + } + + void ExpectFailure(const string& op_type, DeviceType device_type, + const std::vector<string>& attrs, error::Code code) { + Status status; + const NodeDef def = CreateNodeDef(op_type, attrs); + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel(). + std::unique_ptr<OpKernel> op( + CreateOpKernel(device_type, &device, cpu_allocator(), def, &status)); + EXPECT_TRUE(op == nullptr); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + LOG(INFO) << "Status message: " << status.error_message(); + EXPECT_EQ(code, status.code()); + + // Test SupportedDeviceTypesForNode(). + DeviceTypeVector devices; + if (errors::IsNotFound(status)) { + EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + for (DeviceType dt : devices) { + EXPECT_NE(dt, device_type); + } + } else { + Status status2 = + SupportedDeviceTypesForNode(DeviceTypes(), def, &devices); + EXPECT_EQ(status.code(), status2.code()); + } + } + } +}; + +REGISTER_OP("BuildCPU"); +REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderCPU) { + ExpectSuccess("BuildCPU", DEVICE_CPU, {}); + ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND); +} + +REGISTER_OP("BuildGPU"); +REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderGPU) { + ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND); + ExpectSuccess("BuildGPU", DEVICE_GPU, {}); +} + +REGISTER_OP("BuildBoth"); +REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderBoth) { + ExpectSuccess("BuildBoth", DEVICE_CPU, {}); + ExpectSuccess("BuildBoth", DEVICE_GPU, {}); +} + +REGISTER_OP("BuildTypeAttr").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("BuildTypeAttr") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderTypeAttr) { + ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"}); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"}, + error::NOT_FOUND); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"}, + error::INVALID_ARGUMENT); +} + +REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)"); +REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr") + .Device(DEVICE_CPU) + .TypeConstraint<bool>("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) { + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"}); + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"}); + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[DT_BOOL, DT_BOOL]"}); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"}, + error::NOT_FOUND); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"}, + error::INVALID_ARGUMENT); +} + +REGISTER_OP("DuplicateKernel"); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), + DummyKernel); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), + DummyKernel); + +TEST_F(OpKernelBuilderTest, DuplicateKernel) { + const NodeDef ndef = CreateNodeDef("DuplicateKernel", {}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Multiple OpKernel registrations match NodeDef")); + + ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT); +} + +REGISTER_OP("DuplicateKernelForT").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T"), + DummyKernel); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { + const NodeDef ndef = + CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Multiple OpKernel registrations match NodeDef")); + + ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"}, + error::INVALID_ARGUMENT); + ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"}, + error::NOT_FOUND); +} + +REGISTER_OP("BadConstraint").Attr("dtype: type"); +REGISTER_KERNEL_BUILDER(Name("BadConstraint") + .Device(DEVICE_CPU) + // Mistake: "T" should be "dtype". + .TypeConstraint<float>("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BadConstraint) { + const NodeDef ndef = CreateNodeDef("BadConstraint", {}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("OpKernel 'BadConstraint' has constraint on attr " + "'T' not in NodeDef")); + + ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"}, + error::INVALID_ARGUMENT); +} + +class GetAttrKernel : public ::tensorflow::OpKernel { + public: + explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) { + string attr_name; + OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name)); + + status.emplace_back("s", context->GetAttr(attr_name, &s)); + status.emplace_back("s_list", context->GetAttr(attr_name, &s_list)); + status.emplace_back("i", context->GetAttr(attr_name, &i)); + status.emplace_back("i_list", context->GetAttr(attr_name, &i_list)); + status.emplace_back("i32", context->GetAttr(attr_name, &i32)); + status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list)); + status.emplace_back("f", context->GetAttr(attr_name, &f)); + status.emplace_back("f_list", context->GetAttr(attr_name, &f_list)); + status.emplace_back("b", context->GetAttr(attr_name, &b)); + status.emplace_back("b_list", context->GetAttr(attr_name, &b_list)); + status.emplace_back("type", context->GetAttr(attr_name, &type)); + status.emplace_back("type_list", context->GetAttr(attr_name, &type_list)); + status.emplace_back("type_vector", + context->GetAttr(attr_name, &type_vector)); + status.emplace_back("shape_proto", + context->GetAttr(attr_name, &shape_proto)); + status.emplace_back("shape_proto_list", + context->GetAttr(attr_name, &shape_proto_list)); + status.emplace_back("shape", context->GetAttr(attr_name, &shape)); + status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list)); + } + void Compute(::tensorflow::OpKernelContext* context) override {} + + void ExpectOk(std::initializer_list<string> keys) { + for (const auto& key_status : status) { + // Only the status for keys in "keys" should be ok(). + bool in_keys = false; + for (const string& key : keys) { + if (key_status.first == key) { + in_keys = true; + } + } + EXPECT_EQ(in_keys, key_status.second.ok()) + << "key_status: " << key_status.first << ", " << key_status.second; + } + } + + string s; + std::vector<string> s_list; + int64 i; + std::vector<int64> i_list; + int32 i32; + std::vector<int32> i32_list; + float f; + std::vector<float> f_list; + bool b; + std::vector<bool> b_list; + DataType type; + std::vector<DataType> type_list; + DataTypeVector type_vector; + TensorShapeProto shape_proto; + std::vector<TensorShapeProto> shape_proto_list; + TensorShape shape; + std::vector<TensorShape> shape_list; + std::vector<std::pair<string, Status>> status; +}; + +class GetAttrTest : public OpKernelBuilderTest {}; + +REGISTER_OP("GetAttrStringList") + .Attr("attr_name: string") + .Attr("a: list(string)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU), + GetAttrKernel); + +TEST_F(GetAttrTest, StringList) { + std::unique_ptr<OpKernel> op_kernel = + ExpectSuccess("GetAttrStringList", DEVICE_CPU, + {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"s_list"}); + EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list); + + op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU, + {"attr_name|string|'b'", "a|list(string)|['baz']"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({}); + EXPECT_TRUE(get_attr_kernel->s_list.empty()); +} + +REGISTER_OP("GetAttrInt") + .Attr("attr_name: string") + .Attr("a: int") + .Attr("b: list(int)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Int) { + std::unique_ptr<OpKernel> op_kernel = ExpectSuccess( + "GetAttrInt", DEVICE_CPU, + {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"i", "i32"}); + EXPECT_EQ(35, get_attr_kernel->i); + EXPECT_EQ(35, get_attr_kernel->i32); + + op_kernel = ExpectSuccess( + "GetAttrInt", DEVICE_CPU, + {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"i_list", "i32_list"}); + EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list); + EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list); + + // 8589934592 == 2^33, too big to fit in an int32 + op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU, + {"attr_name|string|'a'", "a|int|8589934592", + "b|list(int)|[-8589934592]"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"i"}); // no i32 + EXPECT_EQ(8589934592ll, get_attr_kernel->i); + for (const auto& key_status : get_attr_kernel->status) { + if (key_status.first == "i32") { + EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code()); + EXPECT_EQ("Attr a has value 8589934592 out of range for an int32", + key_status.second.error_message()); + } + } + + op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU, + {"attr_name|string|'b'", "a|int|8589934592", + "b|list(int)|[-8589934592]"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"i_list"}); // no i32_list + EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list); + for (const auto& key_status : get_attr_kernel->status) { + if (key_status.first == "i32_list") { + EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code()); + EXPECT_EQ("Attr b has value -8589934592 out of range for an int32", + key_status.second.error_message()); + } + } +} + +REGISTER_OP("GetAttrShape") + .Attr("attr_name: string") + .Attr("a: shape") + .Attr("b: list(shape)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Shape) { + std::unique_ptr<OpKernel> op_kernel = ExpectSuccess( + "GetAttrShape", DEVICE_CPU, + {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }", + "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"shape", "shape_proto"}); + EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }"); + EXPECT_EQ("[3]", get_attr_kernel->shape.ShortDebugString()); + + op_kernel = ExpectSuccess( + "GetAttrShape", DEVICE_CPU, + {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }", + "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"}); + ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size()); + EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(), + "dim { size: 2 }"); + EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(), + "dim { size: 4 }"); + ASSERT_EQ(2, get_attr_kernel->shape_list.size()); + EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].ShortDebugString()); + EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].ShortDebugString()); +} + +REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type"); +REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Type) { + std::unique_ptr<OpKernel> op_kernel = ExpectSuccess( + "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"type"}); + EXPECT_EQ(DT_FLOAT, get_attr_kernel->type); +} + +REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU), + GetAttrKernel); + +TEST_F(GetAttrTest, TypeList) { + std::unique_ptr<OpKernel> op_kernel = ExpectSuccess( + "GetAttrTypeList", DEVICE_CPU, + {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + + get_attr_kernel->ExpectOk({"type_list", "type_vector"}); + ASSERT_EQ(2, get_attr_kernel->type_list.size()); + EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]); + EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]); + ASSERT_EQ(2, get_attr_kernel->type_vector.size()); + EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]); + EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]); +} + +REGISTER_OP("HostMemoryTest") + .Input("a: float") + .Input("b: T") + .Input("c: N * string") + .Output("o: N * T") + .Attr("T: type") + .Attr("N: int"); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") + .Device(DEVICE_GPU) + .HostMemory("a") + .HostMemory("c") + .HostMemory("o"), + DummyKernel); + +TEST(MemoryTypesForNode, Simple) { + NodeDef node_def; + ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest") + .Input(FakeInput()) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(3)) + .Finalize(&node_def)); + MemoryTypeVector input, output; + + EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def, + &input, &output)); + EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input); + EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output); + + EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def, + &input, &output)); + EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY}), + input); + EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output); +} + +class BaseKernel : public ::tensorflow::OpKernel { + public: + explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(::tensorflow::OpKernelContext* context) override {} + virtual int Which() const = 0; +}; + +template <int WHICH> +class LabeledKernel : public BaseKernel { + public: + using BaseKernel::BaseKernel; + int Which() const override { return WHICH; } +}; + +class LabelTest : public OpKernelBuilderTest {}; + +REGISTER_OP("LabeledKernel"); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU), + LabeledKernel<0>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"), + LabeledKernel<1>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"), + LabeledKernel<2>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"), + LabeledKernel<3>); + +TEST_F(LabelTest, Default) { + std::unique_ptr<OpKernel> op_kernel = + ExpectSuccess("LabeledKernel", DEVICE_CPU, {}); + auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get()); + EXPECT_EQ(0, get_labeled_kernel->Which()); +} + +TEST_F(LabelTest, Specified) { + std::unique_ptr<OpKernel> op_kernel = + ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"}); + auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get()); + EXPECT_EQ(1, get_labeled_kernel->Which()); +} + +TEST_F(LabelTest, Duplicate) { + ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"}, + error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc new file mode 100644 index 0000000000..a39bebd854 --- /dev/null +++ b/tensorflow/core/framework/op_segment.cc @@ -0,0 +1,86 @@ +#include "tensorflow/core/framework/op_segment.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +OpSegment::Item::~Item() { + for (auto kv : name_kernel) delete kv.second; +} + +OpSegment::OpSegment() {} + +OpSegment::~OpSegment() { + for (auto kv : sessions_) delete kv.second; +} + +Status OpSegment::FindOrCreate(const string& session_handle, + const string& node_name, OpKernel** kernel, + CreateKernelFn create_fn) { + { + mutex_lock l(mu_); + auto item = gtl::FindPtrOrNull(sessions_, session_handle); + if (item == nullptr) { + return errors::NotFound("Session ", session_handle, " is not found."); + } + *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name); + if (*kernel != nullptr) { + return Status::OK(); + } + } + Status s = create_fn(kernel); + if (!s.ok()) { + LOG(ERROR) << "Create kernel failed: " << s; + return s; + } + { + mutex_lock l(mu_); + auto item = gtl::FindPtrOrNull(sessions_, session_handle); + if (item == nullptr) { + return errors::NotFound("Session ", session_handle, " is not found."); + } + OpKernel** p_kernel = &(item->name_kernel[node_name]); + if (*p_kernel == nullptr) { + *p_kernel = *kernel; // Inserts 'kernel' in the map. + } else { + delete *kernel; + *kernel = *p_kernel; + } + } + return Status::OK(); +} + +void OpSegment::AddHold(const string& session_handle) { + mutex_lock l(mu_); + Item** item = &sessions_[session_handle]; + if (*item == nullptr) { + *item = new Item; // num_holds == 1 + } else { + ++((*item)->num_holds); + } +} + +void OpSegment::RemoveHold(const string& session_handle) { + Item* item = nullptr; + { + mutex_lock l(mu_); + auto siter = sessions_.find(session_handle); + if (siter == sessions_.end()) { + VLOG(1) << "Session " << session_handle << " is not found."; + return; + } + item = siter->second; + if (--(item->num_holds) > 0) { + return; + } else { + sessions_.erase(siter); + } + } + delete item; +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h new file mode 100644 index 0000000000..55249d2a38 --- /dev/null +++ b/tensorflow/core/framework/op_segment.h @@ -0,0 +1,67 @@ +#ifndef TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ +#define TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ + +#include <string> +#include <unordered_map> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// OpSegment keeps track of OpKernels registered for sessions running +// on a device. +// +// The implementation maintains a two-level map. The 1st level maps +// session handle to the map of registered OpKernels. The 2nd level +// map maps node names to instantiated OpKernel objects. +// +// Each 2-nd level map is reference-counted and the caller can call +// AddHold to obtain a reference on all kernels of a session and +// ensure these kernels are alive until a corresponding RemoveHold is +// called on the same session. +class OpSegment { + public: + OpSegment(); + ~OpSegment(); + + // A hold can be placed on a session, preventing all its kernels + // from being deleted. + void AddHold(const string& session_handle); + void RemoveHold(const string& session_handle); + + // If the kernel for "node_name" has been created in the + // "session_handle", returns the existing op kernel in "*kernel". + // Otherwise, creates the kernel by calling create_fn(), cache it, + // and returns it in "*kernel". If create_fn() fails, returns the + // error. + // + // OpSegment keeps the ownership of the returned "*kernel". + typedef std::function<Status(OpKernel**)> CreateKernelFn; + Status FindOrCreate(const string& session_handle, const string& node_name, + OpKernel** kernel, CreateKernelFn create_fn); + + private: + // op name -> OpKernel + typedef std::unordered_map<string, OpKernel*> KernelMap; + struct Item { + int num_holds = 1; // Num of holds put on the session. + KernelMap name_kernel; // op name -> kernel. + ~Item(); + }; + + // session handle -> item. + // Session handles are produced by strings::FpToString() + typedef std::unordered_map<string, Item*> SessionMap; + + mutable mutex mu_; + SessionMap sessions_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(OpSegment); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_ diff --git a/tensorflow/core/framework/op_segment_test.cc b/tensorflow/core/framework/op_segment_test.cc new file mode 100644 index 0000000000..6297718df8 --- /dev/null +++ b/tensorflow/core/framework/op_segment_test.cc @@ -0,0 +1,142 @@ +#include "tensorflow/core/framework/op_segment.h" + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { + +class OpSegmentTest : public ::testing::Test { + protected: + DeviceBase device_; + std::vector<NodeDef> int32_nodedefs_; + std::vector<NodeDef> float_nodedefs_; + + OpSegmentTest() : device_(Env::Default()) { + RequireDefaultOps(); + for (int i = 0; i < 10; ++i) { + NodeDef def; + TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") + .Input("x", 0, DT_INT32) + .Input("y", 0, DT_INT32) + .Finalize(&def)); + int32_nodedefs_.push_back(def); + TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul") + .Input("x", 0, DT_FLOAT) + .Input("y", 0, DT_FLOAT) + .Finalize(&def)); + float_nodedefs_.push_back(def); + } + } + + void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) { + ASSERT_NE(op, nullptr); + EXPECT_EQ(expected.DebugString(), op->def().DebugString()); + EXPECT_EQ(2, op->num_inputs()); + EXPECT_EQ(dt, op->input_type(0)); + EXPECT_EQ(dt, op->input_type(1)); + EXPECT_EQ(1, op->num_outputs()); + EXPECT_EQ(dt, op->output_type(0)); + } + + OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) { + return [this, ndef](OpKernel** kernel) { + Status s; + auto created = + CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(), *ndef, &s); + if (s.ok()) { + *kernel = created.release(); + } + return s; + }; + } +}; + +TEST_F(OpSegmentTest, Basic) { + OpSegment opseg; + OpKernel* op; + + opseg.AddHold("A"); + opseg.AddHold("B"); + for (int i = 0; i < 10; ++i) { + // Register in session A. + auto* ndef = &float_nodedefs_[i]; + EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef))); + ValidateOpAndTypes(op, *ndef, DT_FLOAT); + + // Register in session B. + ndef = &int32_nodedefs_[i]; + EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef))); + ValidateOpAndTypes(op, *ndef, DT_INT32); + } + + auto reterr = [](OpKernel** kernel) { + return errors::Internal("Should not be called"); + }; + for (int i = 0; i < 10; ++i) { + // Lookup op in session A. + EXPECT_OK(opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr)); + ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT); + + // Lookup op in session B. + EXPECT_OK(opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr)); + ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32); + } + + opseg.RemoveHold("A"); + opseg.RemoveHold("B"); +} + +TEST_F(OpSegmentTest, SessionNotFound) { + OpSegment opseg; + OpKernel* op; + NodeDef def = float_nodedefs_[0]; + Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + EXPECT_TRUE(errors::IsNotFound(s)) << s; +} + +TEST_F(OpSegmentTest, CreateFailure) { + OpSegment opseg; + OpKernel* op; + NodeDef def = float_nodedefs_[0]; + def.set_op("nonexistop"); + opseg.AddHold("A"); + Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + EXPECT_TRUE(errors::IsNotFound(s)) << s; + opseg.RemoveHold("A"); +} + +TEST_F(OpSegmentTest, AddRemoveHolds) { + OpSegment opseg; + OpKernel* op; + const auto& ndef = int32_nodedefs_[0]; + + // No op. + opseg.RemoveHold("null"); + + // Thread1 register the op and wants to ensure it alive. + opseg.AddHold("foo"); + EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef))); + + // Thread2 starts some execution needs "op" to be alive. + opseg.AddHold("foo"); + + // Thread1 clears session "foo". E.g., a master sends CleanupGraph + // before an execution finishes. + opseg.RemoveHold("foo"); + + // Thread2 should still be able to access "op". + ValidateOpAndTypes(op, ndef, DT_INT32); + + // Thread2 then remove its hold on "foo". + opseg.RemoveHold("foo"); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/queue_interface.h b/tensorflow/core/framework/queue_interface.h new file mode 100644 index 0000000000..a765c211cb --- /dev/null +++ b/tensorflow/core/framework/queue_interface.h @@ -0,0 +1,77 @@ +#ifndef TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_ + +#include <string> +#include <vector> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +// All implementations must be thread-safe. +class QueueInterface : public ResourceBase { + public: + typedef std::vector<Tensor> Tuple; + typedef AsyncOpKernel::DoneCallback DoneCallback; + typedef std::function<void(const Tuple&)> CallbackWithTuple; + + virtual Status ValidateTuple(const Tuple& tuple) = 0; + virtual Status ValidateManyTuple(const Tuple& tuple) = 0; + + // Stashes a function object for future execution, that will eventually + // enqueue the tuple of tensors into the queue, and returns immediately. The + // function object is guaranteed to call 'callback'. + virtual void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) = 0; + + // Same as above, but the component tensors are sliced along the 0th dimension + // to make multiple queue-element components. + virtual void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) = 0; + + // Stashes a function object for future execution, that will eventually + // dequeue an element from the queue and call 'callback' with that tuple + // element as argument. + virtual void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) = 0; + + // Same as above, but the stashed function object will attempt to dequeue + // num_elements items. + virtual void TryDequeueMany(int num_elements, OpKernelContext* ctx, + CallbackWithTuple callback) = 0; + + // Signals that no more elements will be enqueued, and optionally + // cancels pending Enqueue(Many) operations. + // + // After calling this function, subsequent calls to Enqueue(Many) + // will fail. If `cancel_pending_enqueues` is true, all pending + // calls to Enqueue(Many) will fail as well. + // + // After calling this function, all current and subsequent calls to + // Dequeue(Many) will fail instead of blocking (though they may + // succeed if they can be satisfied by the elements in the queue at + // the time it was closed). + virtual void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, + DoneCallback callback) = 0; + + // Assuming *this represents a shared queue, verify that it matches + // another instantiation indicated by node_def. + virtual Status MatchesNodeDef(const NodeDef& node_def) = 0; + + // Returns the number of elements in the queue. + virtual int32 size() = 0; + + virtual const DataTypeVector& component_dtypes() const = 0; + + string DebugString() override { return "A queue"; } + + protected: + virtual ~QueueInterface() {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_ diff --git a/tensorflow/core/framework/reader_interface.h b/tensorflow/core/framework/reader_interface.h new file mode 100644 index 0000000000..b307c37f01 --- /dev/null +++ b/tensorflow/core/framework/reader_interface.h @@ -0,0 +1,66 @@ +#ifndef TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_ +#define TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_ + +#include <memory> +#include <string> +#include <vector> +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { + +class QueueInterface; +class ReaderInterface; + +// Readers are the mechanism for reading records from files in +// TensorFlow graphs. Each supported file format has a corresponding +// ReaderInterface descendant and a corresponding Op & OpKernel +// (implemented using ReaderOpKernel from reader_op_kernel.h). +// +// To use a Reader, you first encode "work" (some string, typically a +// filename) in the Reader's "work queue". It then processes the +// "work" (reading records from the file), to produce key/value +// strings. The methods of this class are called by ReaderFoo ops, +// so see ../ops/io_ops.cc for detailed descriptions. +// +// All descendants of this class must be thread-safe. +// +// See the design document here: +// https://docs.google.com/document/d/1UAgZOoeehYr20TdzW2CoZ30V-aqQphU4SwKXsW7eJv4/edit# + +// TODO(josh11b): Switch this to Async. +class ReaderInterface : public ResourceBase { + public: + // Read a single record into *key / *value. May get more work from + // *queue if the current work is complete. Sets the status on + // *context with an OutOfRange Status if the the current work is + // complete and the queue is done (closed and empty). + // This method may block. + virtual void Read(QueueInterface* queue, string* key, string* value, + OpKernelContext* context) = 0; + + // Restore this reader to its newly-constructed state. + virtual Status Reset() = 0; + + // Accessors + virtual int64 NumRecordsProduced() = 0; + virtual int64 NumWorkUnitsCompleted() = 0; + + // -- Serialization/Restoration support -- + // Not all readers will support saving and restoring state. + virtual Status SerializeState(string* state) = 0; + // Note: Must Reset on error. + virtual Status RestoreState(const string& state) = 0; + + string DebugString() override { return "a reader"; } + + protected: + virtual ~ReaderInterface() {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_ diff --git a/tensorflow/core/framework/reader_op_kernel.cc b/tensorflow/core/framework/reader_op_kernel.cc new file mode 100644 index 0000000000..719f27d94b --- /dev/null +++ b/tensorflow/core/framework/reader_op_kernel.cc @@ -0,0 +1,39 @@ +#include "tensorflow/core/framework/reader_op_kernel.h" + +namespace tensorflow { + +ReaderOpKernel::ReaderOpKernel(OpKernelConstruction* context) + : OpKernel(context), have_handle_(false) { + OP_REQUIRES_OK(context, context->allocate_persistent( + tensorflow::DT_STRING, + tensorflow::TensorShape({2}), &handle_, nullptr)); +} + +ReaderOpKernel::~ReaderOpKernel() { + if (have_handle_ && cinfo_.resource_is_private_to_kernel()) { + TF_CHECK_OK(cinfo_.resource_manager()->Delete<ReaderInterface>( + cinfo_.container(), cinfo_.name())); + } +} + +void ReaderOpKernel::Compute(OpKernelContext* ctx) { + mutex_lock l(mu_); + if (!have_handle_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), false)); + ReaderInterface* reader; + OP_REQUIRES_OK(ctx, + cinfo_.resource_manager()->LookupOrCreate<ReaderInterface>( + cinfo_.container(), cinfo_.name(), &reader, + [this](ReaderInterface** ret) { + *ret = factory_(); + return Status::OK(); + })); + auto h = handle_.AccessTensor(ctx)->flat<string>(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + have_handle_ = true; + } + ctx->set_output_ref(0, &mu_, handle_.AccessTensor(ctx)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h new file mode 100644 index 0000000000..8e5cc50c9b --- /dev/null +++ b/tensorflow/core/framework/reader_op_kernel.h @@ -0,0 +1,42 @@ +#ifndef TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_ +#define TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_ + +#include <functional> +#include <string> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/reader_interface.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Implementation for ops providing a Reader. +class ReaderOpKernel : public OpKernel { + public: + explicit ReaderOpKernel(OpKernelConstruction* context); + ~ReaderOpKernel() override; + + void Compute(OpKernelContext* context) override; + + // Must be called by descendants before the first call to Compute() + // (typically called during construction). factory must return a + // ReaderInterface descendant allocated with new that ReaderOpKernel + // will take ownership of. + void SetReaderFactory(std::function<ReaderInterface*()> factory) { + mutex_lock l(mu_); + DCHECK(!have_handle_); + factory_ = factory; + } + + private: + mutex mu_; + bool have_handle_ GUARDED_BY(mu_); + PersistentTensor handle_ GUARDED_BY(mu_); + ContainerInfo cinfo_; + std::function<ReaderInterface*()> factory_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_ diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h new file mode 100644 index 0000000000..18473aea2e --- /dev/null +++ b/tensorflow/core/framework/register_types.h @@ -0,0 +1,90 @@ +#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ +// This file is used by cuda code and must remain compilable by nvcc. + +#include "tensorflow/core/platform/port.h" + +// Macros to apply another macro to lists of supported types. If you change +// the lists of types, please also update the list in types.cc. +// +// See example uses of these macros in core/ops. +// +// +// Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple +// times by passing each invocation a data type supported by TensorFlow. +// +// The different variations pass different subsets of the types. +// TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow. +// The set of types depends on the compilation platform. +//. +// This can be used to register a different template instantiation of +// an OpKernel for different signatures, e.g.: +/* + #define REGISTER_PARTITION(type) \ + REGISTER_TF_OP_KERNEL("partition", DEVICE_CPU, #type ", int32", \ + PartitionOp<type>); + TF_CALL_ALL_TYPES(REGISTER_PARTITION) + #undef REGISTER_PARTITION +*/ + +#ifndef __ANDROID__ + +// Call "m" for all number types that support the comparison operations "<" and +// ">". +#define TF_CALL_REAL_NUMBER_TYPES(m) \ + m(float); \ + m(double); \ + m(int64); \ + m(int32); \ + m(uint8); \ + m(int16); \ + m(int8) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + m(float); \ + m(double); \ + m(int64); \ + m(uint8); \ + m(int16); \ + m(int8) + +// Call "m" for all number types, including complex64. +#define TF_CALL_NUMBER_TYPES(m) \ + TF_CALL_REAL_NUMBER_TYPES(m); \ + m(complex64) + +#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m); \ + m(complex64) + +// Call "m" on all types. +#define TF_CALL_ALL_TYPES(m) \ + TF_CALL_NUMBER_TYPES(m); \ + m(bool); \ + m(string) + +// Call "m" on all types supported on GPU. +#define TF_CALL_GPU_NUMBER_TYPES(m) \ + m(float); \ + m(double) + +#else // __ANDROID__ + +#define TF_CALL_REAL_NUMBER_TYPES(m) \ + m(float); \ + m(int32) + +#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) m(float) + +#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) + +#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) + +// Maybe we could put an empty macro here for Android? +#define TF_CALL_GPU_NUMBER_TYPES(m) m(float) + +#endif // __ANDROID__ + +#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc new file mode 100644 index 0000000000..7f551ea65f --- /dev/null +++ b/tensorflow/core/framework/rendezvous.cc @@ -0,0 +1,263 @@ +#include "tensorflow/core/framework/rendezvous.h" + +#include <unordered_map> +#include <utility> + +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +/* static */ +string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation, + const string& dst_device, const string& name, + const FrameAndIter& frame_iter) { + // NOTE: ';' is not used in the device name's job name. + // + // We include both sender and receiver in the key to facilitate + // debugging. For correctness, we only need to encode the receiver. + // + // "src_incarnation" is used to distinguish a worker when it + // restarts. + return strings::StrCat(src_device, ";", strings::FpToString(src_incarnation), + ";", dst_device, ";", name, ";", frame_iter.frame_id, + ":", frame_iter.iter_id); +} + +/* static */ +Status Rendezvous::ParseKey(const string& key, ParsedKey* out) { + // TODO(zhifengc): This code is not fast enough. + std::vector<string> parts = str_util::Split(key, ';'); + if (parts.size() == 5 && + DeviceNameUtils::ParseFullName(parts[0], &out->src) && + strings::StringToFp(parts[1], &out->src_incarnation) && + DeviceNameUtils::ParseFullName(parts[2], &out->dst) && + !parts[3].empty()) { + out->src_device = parts[0]; + out->dst_device = parts[2]; + out->edge_name = parts[3]; + return Status::OK(); + } + return errors::InvalidArgument("Invalid rendezvous key: ", key); +} + +Rendezvous::~Rendezvous() {} + +Status Rendezvous::Recv(const string& key, const Args& recv_args, Tensor* val, + bool* is_dead) { + Status ret; + Notification n; + RecvAsync(key, recv_args, + [&ret, &n, val, is_dead](const Status& s, const Args& send_args, + const Args& recv_args, const Tensor& v, + const bool dead) { + ret = s; + *val = v; + *is_dead = dead; + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + +class LocalRendezvousImpl : public Rendezvous { + public: + explicit LocalRendezvousImpl(bool tolerate_dup_recv) + : tolerate_dup_recv_(tolerate_dup_recv) {} + + Status Send(const string& key, const Args& send_args, const Tensor& val, + const bool is_dead) override { + VLOG(2) << "Send " << this << " " << key; + DoneCallback waiter = nullptr; + Args recv_args; + { + mutex_lock l(mu_); + if (!status_.ok()) { + return status_; + } + Item* item = nullptr; + Table::iterator iter = table_.find(key); + if (iter == table_.end()) { + // There is no waiter for this message. Insert the message + // into the waiters table. The waiter will pick it up when + // arrives. + item = new Item; + item->waiter = nullptr; + item->value = val; + item->is_dead = is_dead; + if (send_args.device_context) { + send_args.device_context->Ref(); + item->send_dev_context = send_args.device_context; + } + item->recv_dev_context = nullptr; + + // The allocator attributes of item->value. + item->send_alloc_attrs = send_args.alloc_attrs; + + CHECK(table_.insert({key, item}).second); + return Status::OK(); + } else { + item = iter->second; + if (item->waiter == nullptr) { + // There is already a message in the table under the key. + // Should not happen unless it has a waiter. + return errors::Aborted("Duplicated send: ", key); + } + // Mark item as complete. + item->has_been_recvd = true; + waiter = item->waiter; + item->waiter = nullptr; + // The ref on recv_dev_context transfers below. + recv_args.device_context = item->recv_dev_context; + recv_args.alloc_attrs = item->recv_alloc_attrs; + item->recv_dev_context = nullptr; + if (tolerate_dup_recv_) { + item->value = val; + item->is_dead = is_dead; + if (send_args.device_context) { + send_args.device_context->Ref(); + item->send_dev_context = send_args.device_context; + } + item->send_alloc_attrs = send_args.alloc_attrs; + } + } + } // mutex + // Notify the waiter by invoking its done closure, outside scope + // of the table lock. + waiter(Status::OK(), send_args, recv_args, val, is_dead); + if (recv_args.device_context) recv_args.device_context->Unref(); + return Status::OK(); + } + + void RecvAsync(const string& key, const Args& recv_args, + DoneCallback done) override { + VLOG(2) << "Recv " << this << " " << key; + mu_.lock(); + if (!status_.ok()) { + // Rendezvous has been aborted. + Status s = status_; + mu_.unlock(); + done(s, Args(), recv_args, Tensor(), false); + return; + } + Table::iterator iter = table_.find(key); + if (iter != table_.end()) { + Item* item = iter->second; + if (item->has_been_recvd && !tolerate_dup_recv_) { + mu_.unlock(); + done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args, + Tensor(), false); + } else if (item->waiter == nullptr || tolerate_dup_recv_) { + // A message has already arrived and is stored in the table + // under this key. Consumes the message and invokes the done + // closure. + Tensor v = item->value; + if (!tolerate_dup_recv_) { + item->value = Tensor(); + } + item->has_been_recvd = true; + // Before dropping the table lock, capture the item values. + // DeviceContext is only non-null for non-CPU devices. + // If we capture the send_dev_context, we need to hold a ref on + // it. Our caller will have a ref on the recv_dev_context, + // which is not in our table. + DeviceContext* send_dev_context = item->send_dev_context; + if (send_dev_context) send_dev_context->Ref(); + bool is_dead = item->is_dead; + mu_.unlock(); + Args send_args; + send_args.device_context = item->send_dev_context; + send_args.alloc_attrs = item->send_alloc_attrs; + done(Status::OK(), send_args, recv_args, v, is_dead); + if (send_dev_context) send_dev_context->Unref(); + } else { + // Already have a waiter in the waiters table under this key, + // which should not happen. + mu_.unlock(); + done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args, + Tensor(), false); + } + return; + } + // Waiting for a message that has not arrived yet. Insert into the + // waiting table. The done closure will be invoked when the + // message arrives. + Item* item = new Item; + item->waiter = done; + if (recv_args.device_context) { + item->recv_dev_context = recv_args.device_context; + item->recv_alloc_attrs = recv_args.alloc_attrs; + item->recv_dev_context->Ref(); + } + CHECK(table_.insert({key, item}).second); + mu_.unlock(); + return; + } + + void StartAbort(const Status& status) override { + CHECK(!status.ok()); + std::vector<Item*> items; + { + mutex_lock l(mu_); + if (!status_.ok()) return; + status_ = status; + items.reserve(table_.size()); + for (const auto& p : table_) items.push_back(p.second); + table_.clear(); + } + for (Item* item : items) { + if (item->waiter != nullptr) { + item->waiter(status, Args(), Args(), Tensor(), false); + } + delete item; + } + } + + private: + typedef LocalRendezvousImpl ME; + const bool tolerate_dup_recv_; + + struct Item { + DoneCallback waiter = nullptr; + Tensor value; + bool is_dead = false; + bool has_been_recvd = false; + DeviceContext* send_dev_context = nullptr; + DeviceContext* recv_dev_context = nullptr; + AllocatorAttributes send_alloc_attrs; + AllocatorAttributes recv_alloc_attrs; + + ~Item() { + if (send_dev_context) { + send_dev_context->Unref(); + } + if (recv_dev_context) { + recv_dev_context->Unref(); + } + } + }; + typedef std::unordered_map<string, Item*> Table; + + // TODO(zhifengc): shard table_. + mutex mu_; + Table table_ GUARDED_BY(mu_); + Status status_; + + ~LocalRendezvousImpl() override { + for (auto i : table_) { + delete i.second; + } + } + + TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl); +}; + +Rendezvous* NewLocalRendezvous(bool tolerate_dup_recv) { + return new LocalRendezvousImpl(tolerate_dup_recv); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h new file mode 100644 index 0000000000..94fbfb2523 --- /dev/null +++ b/tensorflow/core/framework/rendezvous.h @@ -0,0 +1,102 @@ +#ifndef TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ +#define TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ + +#include <string> + +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// A Rendezvous is an abstraction for passing a Tensor +// from a producer to a consumer, where the consumer may safely +// request the Tensor before or after it has been produced. A +// producer never blocks when using a Rendezvous. A consumer has the +// choice of making a blocking call or providing a callback: in either +// case, the consumer receives the Tensor as soon as it is available. +// +// A Rendezvous key encodes a single <producer, consumer> pair. It is +// an error to call Send() or Recv*() more than once with the same +// key. +class Rendezvous : public core::RefCounted { + public: + struct Args { + DeviceContext* device_context = nullptr; + AllocatorAttributes alloc_attrs; + }; + + // Constructs a rendezvouz key for the tensor of "name" sent from + // "src_device" to "dst_device". The tensor is generated in the frame + // and iteration specified by "frame_iter". + static string CreateKey(const string& src_device, uint64 src_incarnation, + const string& dst_device, const string& name, + const FrameAndIter& frame_iter); + + // Parses the key constructed by CreateKey and parse src/dst device + // names into structures respectively. + struct ParsedKey { + string src_device; + DeviceNameUtils::ParsedName src; + uint64 src_incarnation = 0; + string dst_device; + DeviceNameUtils::ParsedName dst; + string edge_name; + }; + static Status ParseKey(const string& key, ParsedKey* out); + + // The caller is a tensor producer and it sends a message (a tensor + // "val" and a bool "is_dead") under the given "key". + // + // {val, is_dead} is bundled as a message sent and received. + // Typically, is_dead is set by some control flow nodes + // (e.g., a not-take branch). args is passed by Send to the + // Recv function to communicate any information that the Recv + // function might need. This is typically only necessary for + // Send/Recv on the same worker. + // + // Send() never blocks. + virtual Status Send(const string& key, const Args& args, const Tensor& val, + const bool is_dead) = 0; + + // Callback provided by a tensor consumer waiting on the rendezvous. + // It will be invoked when the tensor is available, or when a non-OK + // status arises in the production of that tensor. It also gets + // two Rendezvous::Args, one provided by the sender, the other by the + // receiver, which may be needed when a non-CPU device is in use + // by either side. + typedef std::function<void(const Status&, const Args&, const Args&, + const Tensor&, const bool)> DoneCallback; + + virtual void RecvAsync(const string& key, const Args& args, + DoneCallback done) = 0; + + // Synchronous wrapper for RecvAsync. + Status Recv(const string& key, const Args& args, Tensor* val, bool* is_dead); + + // Aborts all pending and future Send/Recv with the given "status". + // + // StartAbort() does not wait for ongoing calls to finish. + // REQUIRES: !status.ok() + virtual void StartAbort(const Status& status) = 0; + + protected: + ~Rendezvous() override; +}; + +// Returns a Rendezvous instance that is limited to use only by +// producers and consumers in the local process. The caller assumes +// ownership of one Ref() on the returned object. +// +// If "tolerate_dup_recv" is true, then the Rendezvous will retain +// already Recv'd values and make them available to duplicate Recv +// calls. This may be useful if the RPC layer is not reliable, but +// comes at the cost of higher memory consumption. +Rendezvous* NewLocalRendezvous(bool tolerate_dup_recv = false); + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_ diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc new file mode 100644 index 0000000000..32011a468f --- /dev/null +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -0,0 +1,314 @@ +#include "tensorflow/core/framework/rendezvous.h" + +#include <gtest/gtest.h> +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/env.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +TEST(RendezvousTest, Key) { + const string key = Rendezvous::CreateKey( + "/job:mnist/replica:1/task:2/CPU:0", 7890, + "/job:mnist/replica:1/task:2/GPU:0", "var0", FrameAndIter(0, 0)); + EXPECT_EQ(key, + "/job:mnist/replica:1/task:2/CPU:0;" + "0000000000001ed2;" // 7890 = 0x1ed2 + "/job:mnist/replica:1/task:2/GPU:0;" + "var0;" + "0:0"); + Rendezvous::ParsedKey parsed; + EXPECT_OK(Rendezvous::ParseKey(key, &parsed)); + EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0"); + EXPECT_EQ(parsed.src_incarnation, 7890); + EXPECT_EQ(parsed.src.type, "CPU"); + EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/GPU:0"); + EXPECT_EQ(parsed.dst.type, "GPU"); + + EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok()); + EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;" + "/job:mnist/replica:1/task:2/GPU:0;", + &parsed) + .ok()); + EXPECT_FALSE( + Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok()); +} + +class LocalRendezvousTest : public ::testing::Test { + public: + LocalRendezvousTest() + : threads_(new thread::ThreadPool(Env::Default(), "test", 16)) { + rendez_ = NewLocalRendezvous(); + } + + ~LocalRendezvousTest() override { + rendez_->Unref(); + delete threads_; + } + + void SchedClosure(std::function<void()> fn) { threads_->Schedule(fn); } + + Rendezvous* rendez_; + + private: + thread::ThreadPool* threads_; +}; + +// string -> Tensor<string> +Tensor V(const string& content) { + Tensor tensor(DT_STRING, TensorShape({})); + tensor.scalar<string>()() = content; + return tensor; +} + +// Tensor<string> -> string +string V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_STRING); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar<string>()(); +} + +TEST_F(LocalRendezvousTest, SendRecv) { + Rendezvous::Args args; + ASSERT_OK(rendez_->Send("foo", args, V("hello"), false)); + EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, V("hello"), false))); + Tensor val(DT_STRING); + bool is_dead = false; + ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead)); + EXPECT_EQ("hello", V(val)); +} + +TEST_F(LocalRendezvousTest, RecvSend) { + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(10000); + Rendezvous::Args args; + ASSERT_OK(rendez_->Send("foo", args, V("hello"), false)); + }); + Tensor val(DT_STRING); + bool is_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead)); + EXPECT_EQ("hello", V(val)); +} + +TEST_F(LocalRendezvousTest, DuplicateWaiterRecv) { + SchedClosure([this]() { + Tensor t(DT_STRING); + bool is_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead)); + ASSERT_OK(rendez_->Send("bar", args, t, is_dead)); + }); + Env::Default()->SleepForMicroseconds(1000000); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead))); + ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead)); + ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead)); + EXPECT_EQ("secret msg", V(val)); +} + +TEST_F(LocalRendezvousTest, DuplicateSerialRecv) { + SchedClosure([this]() { + Tensor t(DT_STRING); + bool is_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead)); + ASSERT_OK(rendez_->Send("bar", args, t, is_dead)); + }); + Env::Default()->SleepForMicroseconds(1000000); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead)); + ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead)); + EXPECT_EQ("secret msg", V(val)); + EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead))); +} + +// A simple structure that behaves a bit like a blocking counter. The +// user that decrements counter to 0 does done.Notify(), and the main +// thread waits for done to be notified. +struct BlockingState { + mutex lock; + int counter; + Notification done; +}; + +TEST_F(LocalRendezvousTest, RandomSendRecv) { + static const int N = 1000; + BlockingState state; + state.counter = N; + for (int i = 0; i < N; ++i) { + SchedClosure([this, i]() { + random::PhiloxRandom philox(testing::RandomSeed() + i, 17); + random::SimplePhilox rnd(&philox); + Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000)); + Rendezvous::Args args; + ASSERT_OK(rendez_->Send(strings::StrCat(i), args, V(strings::StrCat(i)), + false)); + }); + SchedClosure([this, &state, i]() { + random::PhiloxRandom philox(testing::RandomSeed() + N + i, 17); + random::SimplePhilox rnd(&philox); + Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000)); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + ASSERT_OK(rendez_->Recv(strings::StrCat(i), args, &val, &val_dead)); + EXPECT_EQ(strings::StrCat(i), V(val)); + bool done = false; + { + mutex_lock l(state.lock); + state.counter--; + if (state.counter == 0) { + done = true; + } + } + if (done) { + state.done.Notify(); + } + }); + } + + state.done.WaitForNotification(); +} + +TEST_F(LocalRendezvousTest, RecvAbort) { + rendez_->Ref(); + SchedClosure([this]() { + rendez_->StartAbort(errors::Aborted("")); // abort + rendez_->Unref(); + }); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + Status status = rendez_->Recv("foo", args, &val, &val_dead); + EXPECT_TRUE(errors::IsAborted(status)); +} + +// Similar to RecvAbort. But this test case ensures the main thread +// Recv() call happens after StartAbort(). +TEST_F(LocalRendezvousTest, RecvSleepAbort) { + rendez_->Ref(); + SchedClosure([this]() { + Env::Default()->SleepForMicroseconds(1000000); + rendez_->StartAbort(errors::Aborted("")); // abort + rendez_->Unref(); + }); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + Status status = rendez_->Recv("foo", args, &val, &val_dead); + EXPECT_TRUE(errors::IsAborted(status)); +} + +TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) { + rendez_->StartAbort(errors::Aborted("")); + Tensor val(DT_STRING); + bool val_dead = false; + Rendezvous::Args args; + EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, val, val_dead))); + EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead))); +} + +class DummyDeviceContext : public DeviceContext { + public: + explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {} + ~DummyDeviceContext() override {} + int stream_id() const { return stream_id_; } + + private: + const int stream_id_; +}; + +TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) { + Rendezvous::Args args; + args.device_context = new DummyDeviceContext(123); + + ASSERT_OK(rendez_->Send("foo", args, V("hello"), false)); + + Notification n; + Rendezvous::Args args1; + args1.device_context = new DummyDeviceContext(1); + rendez_->RecvAsync("foo", args1, [&n](const Status& s, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& val, bool is_dead) { + CHECK_EQ(123, + dynamic_cast<const DummyDeviceContext*>(send_args.device_context) + ->stream_id()); + n.Notify(); + }); + + n.WaitForNotification(); + args.device_context->Unref(); + args1.device_context->Unref(); +} + +static void BM_SendRecv(int iters) { + Rendezvous* rendez = NewLocalRendezvous(); + Tensor orig = V("val"); + Tensor val(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + Status s; + if (iters > 0) { + while (iters--) { + s = rendez->Send("foo", args, orig, is_dead); + s = rendez->Recv("foo", args, &val, &is_dead); + } + CHECK_EQ(V(val), V(orig)); + } + rendez->Unref(); +} +BENCHMARK(BM_SendRecv); + +static void BM_RecvSend(int iters) { + thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1); + + // The main thread sends "foo" for iters/2 times and receives "bar" + // for iters/2 times. The other thread sends "bar" for iters/2 + // times and receives "foo" for iters/2 times. + Rendezvous* rendez = NewLocalRendezvous(); + pool->Schedule([rendez, iters]() { + Tensor bar = V("bar"); + Tensor foo(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + Status s; + for (int i = 0; i < iters / 2; ++i) { + s = rendez->Recv("foo", args, &foo, &is_dead); + s = rendez->Send("bar", args, bar, is_dead); + } + CHECK_EQ("foo", V(foo)); + }); + Tensor foo = V("foo"); + Tensor bar(DT_STRING, TensorShape({})); + bool is_dead = false; + Rendezvous::Args args; + Status s; + for (int i = 0; i < iters / 2; ++i) { + s = rendez->Send("foo", args, foo, is_dead); + s = rendez->Recv("bar", args, &bar, &is_dead); + } + CHECK_EQ("bar", V(bar)); + delete pool; +} +BENCHMARK(BM_RecvSend); + +} // namespace tensorflow diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc new file mode 100644 index 0000000000..42326f068e --- /dev/null +++ b/tensorflow/core/framework/resource_mgr.cc @@ -0,0 +1,146 @@ +#include "tensorflow/core/framework/resource_mgr.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { + +ResourceMgr::ResourceMgr() : default_container_("localhost") {} + +ResourceMgr::ResourceMgr(const string& default_container) + : default_container_(default_container) {} + +ResourceMgr::~ResourceMgr() { Clear(); } + +void ResourceMgr::Clear() { + mutex_lock l(mu_); + for (const auto& p : containers_) { + for (const auto& q : *p.second) { + q.second->Unref(); + } + delete p.second; + } + containers_.clear(); +} + +Status ResourceMgr::DoCreate(const string& container, std::type_index type, + const string& name, ResourceBase* resource) { + { + mutex_lock l(mu_); + Container** b = &containers_[container]; + if (*b == nullptr) { + *b = new Container; + } + if ((*b)->insert({{type, name}, resource}).second) { + return Status::OK(); + } + } + resource->Unref(); + return errors::AlreadyExists("Resource ", container, "/", name, "/", + type.name()); +} + +Status ResourceMgr::DoLookup(const string& container, std::type_index type, + const string& name, + ResourceBase** resource) const { + mutex_lock l(mu_); + const Container* b = gtl::FindPtrOrNull(containers_, container); + if (b == nullptr) { + return errors::NotFound("Container ", container, " does not exist."); + } + auto r = gtl::FindPtrOrNull(*b, {type, name}); + if (r == nullptr) { + return errors::NotFound("Resource ", container, "/", name, "/", type.name(), + " does not exist."); + } + *resource = const_cast<ResourceBase*>(r); + (*resource)->Ref(); + return Status::OK(); +} + +Status ResourceMgr::DoDelete(const string& container, std::type_index type, + const string& name) { + ResourceBase* base = nullptr; + { + mutex_lock l(mu_); + Container* b = gtl::FindPtrOrNull(containers_, container); + if (b == nullptr) { + return errors::NotFound("Container ", container, " does not exist."); + } + auto iter = b->find({type, name}); + if (iter == b->end()) { + return errors::NotFound("Resource ", container, "/", name, "/", + type.name(), " does not exist."); + } + base = iter->second; + b->erase(iter); + } + CHECK(base != nullptr); + base->Unref(); + return Status::OK(); +} + +Status ResourceMgr::Cleanup(const string& container) { + Container* b = nullptr; + { + mutex_lock l(mu_); + auto iter = containers_.find(container); + if (iter == containers_.end()) { + return errors::NotFound("Container ", container, " does not exist."); + } + b = iter->second; + containers_.erase(iter); + } + CHECK(b != nullptr); + for (const auto& p : *b) { + p.second->Unref(); + } + delete b; + return Status::OK(); +} + +Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, + bool use_node_name_as_default) { + CHECK(rmgr); + rmgr_ = rmgr; + string attr_container; + TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container)); + static RE2 container_re("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*"); + if (!attr_container.empty() && + !RE2::FullMatch(attr_container, container_re)) { + return errors::InvalidArgument("container contains invalid characters: ", + attr_container); + } + string attr_shared_name; + TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &attr_shared_name)); + if (!attr_shared_name.empty() && (attr_shared_name[0] == '_')) { + return errors::InvalidArgument("shared_name cannot start with '_':", + attr_shared_name); + } + if (!attr_container.empty()) { + container_ = attr_container; + } else { + container_ = rmgr_->default_container(); + } + if (!attr_shared_name.empty()) { + name_ = attr_shared_name; + } else if (use_node_name_as_default) { + name_ = ndef.name(); + } else { + resource_is_private_to_kernel_ = true; + static std::atomic<int64> counter(0); + name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name()); + } + return Status::OK(); +} + +string ContainerInfo::DebugString() const { + return strings::StrCat("[", container(), ",", name(), ",", + resource_is_private_to_kernel() ? "private" : "public", + "]"); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h new file mode 100644 index 0000000000..65e859caf1 --- /dev/null +++ b/tensorflow/core/framework/resource_mgr.h @@ -0,0 +1,280 @@ +#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ +#define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ + +#include <string> +#include <typeindex> +#include <typeinfo> +#include <unordered_map> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// A ResourceMgr instance keeps track of named and typed resources +// grouped into containers. +// +// Each resource must be represented as a sub-class of ResourceBase, +// which is reference counted explicitly. Each named resource is +// registered with ResourceMgr under a named "container" name. At any +// time, there is at most one instance of a resource given the container +// name, the resource type and the resource name. +// +// All resources for a given container can be dropped by one call of +// Cleanup(). +// +// E.g., +// struct MyVar : public ResourceBase { +// mutex mu; +// Tensor val; +// } +// +// ResourceMgr rm; +// +// // Create a var. +// MyVar* my_var = new MyVar; +// my_var.val = Tensor(DT_FLOAT, my_shape); +// my_val.val.flat<float>().setZeros(); // 0 initialized. +// ctx->SetStatus(rm.Create("my_container", "my_name", my_val)); +// +// // += a variable. +// MyVar* my_var = nullptr; +// Status s = rm.Lookup("my_container", "my_name", &my_var); +// if (s.ok()) { +// my_var->val.flat<float>() += grad; +// } +// my_var->Unref(); // Or use ScopedUnref(). +// ctx->SetStatus(s); +class ResourceBase : public core::RefCounted { + public: + // Returns a debug string for *this. + virtual string DebugString() = 0; +}; + +class ResourceMgr { + public: + ResourceMgr(); + explicit ResourceMgr(const string& default_container); + ~ResourceMgr(); + + // Returns the default container name for *this. + const string& default_container() const { return default_container_; } + + // Creates a resource "name" in the "container". The caller transfers + // the ownership of one ref on "resource" to *this + // + // REQUIRES: std::is_base_of<ResourceBase, T> + // REQUIRES: resource != nullptr. + template <typename T> + Status Create(const string& container, const string& name, + T* resource) TF_MUST_USE_RESULT; + + // If "container" has a resource "name", returns it in "*resource" and + // the caller takes the ownership of one ref on "*resource". + // + // REQUIRES: std::is_base_of<ResourceBase, T> + // REQUIRES: resource != nullptr + template <typename T> + Status Lookup(const string& container, const string& name, + T** resource) const TF_MUST_USE_RESULT; + + // If "container" has a resource "name", returns it in + // "*resource". Otherwise, invokes creator() to create the resource. + // The caller takes the ownership of one ref on "*resource". + // + // REQUIRES: std::is_base_of<ResourceBase, T> + // REQUIRES: resource != nullptr + template <typename T> + Status LookupOrCreate(const string& container, const string& name, + T** resource, + std::function<Status(T**)> creator) TF_MUST_USE_RESULT; + + // Deletes the resource "name" from the "container". + // + // REQUIRES: std::is_base_of<ResourceBase, T> + template <typename T> + Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT; + + // Deletes all resources from the "container" and removes the container. + Status Cleanup(const string& container) TF_MUST_USE_RESULT; + + // Deletes all resources in all containers. + void Clear(); + + private: + typedef std::pair<std::type_index, string> Key; + struct KeyHash { + std::size_t operator()(const Key& k) const { + return Hash64(k.second.data(), k.second.size(), k.first.hash_code()); + } + }; + struct KeyEqual { + bool operator()(const Key& x, const Key& y) const { + return (x.second == y.second) && (x.first == y.first); + } + }; + typedef std::unordered_map<Key, ResourceBase*, KeyHash, KeyEqual> Container; + + const string default_container_; + mutable mutex mu_; + std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_); + + Status DoCreate(const string& container, std::type_index type, + const string& name, + ResourceBase* resource) TF_MUST_USE_RESULT; + Status DoLookup(const string& container, std::type_index type, + const string& name, + ResourceBase** resource) const TF_MUST_USE_RESULT; + Status DoDelete(const string& container, std::type_index type, + const string& name) TF_MUST_USE_RESULT; + + TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr); +}; + +// Policy helper to decide which container/shared_name to use for a +// stateful kernel that accesses shared resource. +class ContainerInfo { + public: + // Analyze the node attribute of 'ndef' and decides the container and + // resource name the kernel should use for accessing the shared + // resource. + // + // 'ndef' is expected to have node attribute "container" and + // "shared_name". Returns non-OK if they are not provided or they are + // invalid. + // + // The policy is as following: + // * If the attribute "container" is non-empty, it is used as is. + // Otherwise, uses the resource manager's default container. + // * If the attribute "shared_name" is non-empty, it is used as is. + // Otherwise, if "use_node_name_as_default" is true, the kernel's + // node name is used as the resource name. Otherwise, a string + // unique to this process is used. + Status Init(ResourceMgr* rmgr, const NodeDef& ndef, + bool use_node_name_as_default); + Status Init(ResourceMgr* rmgr, const NodeDef& ndef) { + return Init(rmgr, ndef, false); + } + + // The policy decides that the kernel should access the resource in + // resource_manager(), the resource is in the container() and its + // name is name(). If resource_is_private_to_kernel() is true, the + // kernel should delete the resource when the kernel is deleted. + ResourceMgr* resource_manager() const { return rmgr_; } + const string& container() const { return container_; } + const string& name() const { return name_; } + bool resource_is_private_to_kernel() const { + return resource_is_private_to_kernel_; + } + + // Returns a readable string for *this. + string DebugString() const; + + private: + ResourceMgr* rmgr_ = nullptr; + string container_; + string name_; + bool resource_is_private_to_kernel_ = false; +}; + +// Helper for kernels to obtain 'resource' from the +// ctx->resource_manager(). +// +// "input_name" specifies the kernel's ref input which gives a string +// tensor with two elements, which specifies the container and +// resource name. +// +// Returns OK if the resource is found and transfers one ref of +// *resource to the caller. Otherwise, returns an error. +template <typename T> +Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, + T** resource); + +// Implementation details below. + +template <typename T> +void CheckDeriveFromResourceBase() { + static_assert(std::is_base_of<ResourceBase, T>::value, + "T must derive from ResourceBase"); +} + +template <typename T> +Status ResourceMgr::Create(const string& container, const string& name, + T* resource) { + CheckDeriveFromResourceBase<T>(); + CHECK(resource != nullptr); + return DoCreate(container, std::type_index(typeid(T)), name, resource); +} + +template <typename T> +Status ResourceMgr::Lookup(const string& container, const string& name, + T** resource) const { + CheckDeriveFromResourceBase<T>(); + ResourceBase* found = nullptr; + Status s = DoLookup(container, std::type_index(typeid(T)), name, &found); + if (s.ok()) { + // It's safe to down cast 'found' to T* since + // typeid(T).hash_code() is part of the map key. + *resource = static_cast<T*>(found); + } + return s; +} + +template <typename T> +Status ResourceMgr::LookupOrCreate(const string& container, const string& name, + T** resource, + std::function<Status(T**)> creator) { + Status s; + *resource = nullptr; + while (*resource == nullptr) { + s = Lookup(container, name, resource); + if (s.ok()) break; + s = creator(resource); + if (!s.ok()) break; + s = Create(container, name, *resource); + if (s.ok()) { + (*resource)->Ref(); + break; + } + // Rare event. Concurrent racy creation. Redo the lookup. + *resource = nullptr; + } + return s; +} + +template <typename T> +Status ResourceMgr::Delete(const string& container, const string& name) { + CheckDeriveFromResourceBase<T>(); + return DoDelete(container, std::type_index(typeid(T)), name); +} + +template <typename T> +Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name, + T** resource) { + string container; + string shared_name; + { + mutex* mu; + TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); + mutex_lock l(*mu); + Tensor tensor; + TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); + if (tensor.NumElements() != 2) { + return errors::InvalidArgument( + "Resource handle must have 2 elements, but had shape: ", + tensor.shape().DebugString()); + } + container = tensor.flat<string>()(0); + shared_name = tensor.flat<string>()(1); + } + return ctx->resource_manager()->Lookup(container, shared_name, resource); +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_ diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc new file mode 100644 index 0000000000..9f7ce3dde3 --- /dev/null +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -0,0 +1,173 @@ +#include "tensorflow/core/framework/resource_mgr.h" + +#include <gtest/gtest.h> +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +class Resource : public ResourceBase { + public: + explicit Resource(const string& label) : label_(label) {} + ~Resource() override {} + + string DebugString() { return strings::StrCat("R/", label_); } + + private: + string label_; +}; + +class Other : public ResourceBase { + public: + explicit Other(const string& label) : label_(label) {} + ~Other() override {} + + string DebugString() { return strings::StrCat("O/", label_); } + + private: + string label_; +}; + +template <typename T> +string Find(const ResourceMgr& rm, const string& container, + const string& name) { + T* r; + TF_CHECK_OK(rm.Lookup(container, name, &r)); + const string ret = r->DebugString(); + r->Unref(); + return ret; +} + +template <typename T> +string LookupOrCreate(ResourceMgr* rm, const string& container, + const string& name, const string& label) { + T* r; + TF_CHECK_OK(rm->LookupOrCreate<T>(container, name, &r, [&label](T** ret) { + *ret = new T(label); + return Status::OK(); + })); + const string ret = r->DebugString(); + r->Unref(); + return ret; +} + +static void HasError(const Status& s, const string& substr) { + EXPECT_TRUE(StringPiece(s.ToString()).contains(substr)) + << s << ", expected substring " << substr; +} + +template <typename T> +Status FindErr(const ResourceMgr& rm, const string& container, + const string& name) { + T* r; + Status s = rm.Lookup(container, name, &r); + CHECK(!s.ok()); + return s; +} + +TEST(ResourceMgrTest, Basic) { + ResourceMgr rm; + TF_CHECK_OK(rm.Create("foo", "bar", new Resource("cat"))); + TF_CHECK_OK(rm.Create("foo", "baz", new Resource("dog"))); + TF_CHECK_OK(rm.Create("foo", "bar", new Other("tiger"))); + + // Expected to fail. + HasError(rm.Create("foo", "bar", new Resource("kitty")), + "Already exists: Resource foo/bar"); + + // Expected to be found. + EXPECT_EQ("R/cat", Find<Resource>(rm, "foo", "bar")); + EXPECT_EQ("R/dog", Find<Resource>(rm, "foo", "baz")); + EXPECT_EQ("O/tiger", Find<Other>(rm, "foo", "bar")); + + // Expected to be not found. + HasError(FindErr<Resource>(rm, "bar", "foo"), "Not found: Container bar"); + HasError(FindErr<Resource>(rm, "foo", "xxx"), "Not found: Resource foo/xxx"); + HasError(FindErr<Other>(rm, "foo", "baz"), "Not found: Resource foo/baz"); + + // Delete foo/bar/Resource. + TF_CHECK_OK(rm.Delete<Resource>("foo", "bar")); + HasError(FindErr<Resource>(rm, "foo", "bar"), "Not found: Resource foo/bar"); + + TF_CHECK_OK(rm.Create("foo", "bar", new Resource("kitty"))); + EXPECT_EQ("R/kitty", Find<Resource>(rm, "foo", "bar")); + + // Drop the whole container foo. + TF_CHECK_OK(rm.Cleanup("foo")); + HasError(FindErr<Resource>(rm, "foo", "bar"), "Not found: Container foo"); +} + +TEST(ResourceMgr, CreateOrLookup) { + ResourceMgr rm; + EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "cat")); + EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "dog")); + EXPECT_EQ("R/cat", Find<Resource>(rm, "foo", "bar")); + + EXPECT_EQ("O/tiger", LookupOrCreate<Other>(&rm, "foo", "bar", "tiger")); + EXPECT_EQ("O/tiger", LookupOrCreate<Other>(&rm, "foo", "bar", "lion")); + TF_CHECK_OK(rm.Delete<Other>("foo", "bar")); + HasError(FindErr<Other>(rm, "foo", "bar"), "Not found: Resource foo/bar"); +} + +Status ComputePolicy(const string& attr_container, + const string& attr_shared_name, + bool use_node_name_as_default, string* result) { + ContainerInfo cinfo; + ResourceMgr rmgr; + NodeDef ndef; + ndef.set_name("foo"); + if (attr_container != "none") { + AddNodeAttr("container", attr_container, &ndef); + } + if (attr_shared_name != "none") { + AddNodeAttr("shared_name", attr_shared_name, &ndef); + } + TF_RETURN_IF_ERROR(cinfo.Init(&rmgr, ndef, use_node_name_as_default)); + *result = cinfo.DebugString(); + return Status::OK(); +} + +string Policy(const string& attr_container, const string& attr_shared_name, + bool use_node_name_as_default) { + string ret; + TF_CHECK_OK(ComputePolicy(attr_container, attr_shared_name, + use_node_name_as_default, &ret)); + return ret; +} + +TEST(ContainerInfo, Basic) { + // Correct cases. + EXPECT_EQ(Policy("", "", false), "[localhost,_0_foo,private]"); + EXPECT_EQ(Policy("", "", true), "[localhost,foo,public]"); + EXPECT_EQ(Policy("", "bar", false), "[localhost,bar,public]"); + EXPECT_EQ(Policy("", "bar", true), "[localhost,bar,public]"); + EXPECT_EQ(Policy("cat", "", false), "[cat,_1_foo,private]"); + EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]"); + EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]"); + EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]"); +} + +Status WrongPolicy(const string& attr_container, const string& attr_shared_name, + bool use_node_name_as_default) { + string dbg; + auto s = ComputePolicy(attr_container, attr_shared_name, + use_node_name_as_default, &dbg); + CHECK(!s.ok()); + return s; +} + +TEST(ContainerInfo, Error) { + // Missing attribute. + HasError(WrongPolicy("none", "", false), "No attr"); + HasError(WrongPolicy("", "none", false), "No attr"); + HasError(WrongPolicy("none", "none", false), "No attr"); + + // Invalid container. + HasError(WrongPolicy("12$%", "", false), "container contains invalid char"); + + // Invalid shared name. + HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'"); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/step_stats.proto b/tensorflow/core/framework/step_stats.proto new file mode 100644 index 0000000000..78610350ec --- /dev/null +++ b/tensorflow/core/framework/step_stats.proto @@ -0,0 +1,58 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/tensor_description.proto"; + +// TODO(tucker): The next 4 message defs are very similar to +// the *LogEntry messages in profile.proto. They should be +// unified in one place. + +message AllocatorMemoryUsed { + string allocator_name = 1; + int64 total_bytes = 2; + int64 peak_bytes = 3; +} + +enum AllocationType { + AT_NOTUSED = 0; // tensor was not filled in + AT_ALLOCATED = 1; // tensor was allocated by the Op + AT_EXISTING = 2; // tensor was set to share the value of an existing tensor + AT_REF = 3; // tensor was set to be a reference to an existing tensor +} + +// Output sizes recorded for a single execution of a graph node. +message NodeOutput { + int32 slot = 1; + // Was the tensor allocated by this Op or a previous computation + AllocationType allocation_type = 2; + TensorDescription tensor_description = 3; +}; + +// Time/size stats recorded for a single execution of a graph node. +message NodeExecStats { + // TODO(tucker): Use some more compact form of node identity than + // the full string name. Either all processes should agree on a + // global id (cost_id?) for each node, or we should use a hash of + // the name. + string node_name = 1; + int64 all_start_micros = 2; + int64 op_start_rel_micros = 3; + int64 op_end_rel_micros = 4; + int64 all_end_rel_micros = 5; + repeated AllocatorMemoryUsed memory = 6; + repeated NodeOutput output = 7; + string timeline_label = 8; + int64 scheduled_micros = 9; + uint32 thread_id = 10; +}; + +message DeviceStepStats { + string device = 1; + repeated NodeExecStats node_stats = 2; +} + +message StepStats { + repeated DeviceStepStats dev_stats = 1; +}; diff --git a/tensorflow/core/framework/summary.proto b/tensorflow/core/framework/summary.proto new file mode 100644 index 0000000000..0e6e659f2f --- /dev/null +++ b/tensorflow/core/framework/summary.proto @@ -0,0 +1,67 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +// Serialization format for histogram module in +// core/lib/histogram/histogram.h +message HistogramProto { + double min = 1; + double max = 2; + double num = 3; + double sum = 4; + double sum_squares = 5; + + // Parallel arrays encoding the bucket boundaries and the bucket values. + // bucket(i) is the count for the bucket i. The range for + // a bucket is: + // i == 0: -DBL_MAX .. bucket_limit(0) + // i != 0: bucket_limit(i-1) .. bucket_limit(i) + repeated double bucket_limit = 6 [packed = true]; + repeated double bucket = 7 [packed = true]; +}; + +// A Summary is a set of named values to be displayed by the +// visualizer. +// +// Summaries are produced regularly during training, as controlled by +// the "summary_interval_secs" attribute of the training operation. +// Summaries are also produced at the end of an evaluation. +message Summary { + message Image { + // Dimensions of the image. + int32 height = 1; + int32 width = 2; + // Valid colorspace values are + // 1 - grayscale + // 2 - grayscale + alpha + // 3 - RGB + // 4 - RGBA + // 5 - DIGITAL_YUV + // 6 - BGRA + int32 colorspace = 3; + // Image data in encoded format. All image formats supported by + // image_codec::CoderUtil can be stored here. + bytes encoded_image_string = 4; + } + + message Value { + // Tag name for the data. Will be used as the title of the graph + // in the visualizer. + // + // Tag is usually "op_name:value_name", where "op_name" itself can have + // structure to indicate grouping. + string tag = 1; + + // Value associated with the tag. + oneof value { + float simple_value = 2; + bytes obsolete_old_style_histogram = 3; + Image image = 4; + HistogramProto histo = 5; + } + } + + // Set of values for the summary. + repeated Value value = 1; +} diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc new file mode 100644 index 0000000000..4a1b65db97 --- /dev/null +++ b/tensorflow/core/framework/tensor.cc @@ -0,0 +1,570 @@ +// Implementation notes: +// +// Tensor.cc uses a few templated classes and structs to facilitate +// implementation of the Tensor class. +// +// * Buffer<T>: provides the implementation for a typed array T[n]. +// The array is allocated by the given allocator. It runs T's +// default constructors and destructors when T is not a simple type +// (e.g., string.), and skips them otherwise. +// +// * Helper<T>: provides various routines given type T. The routines +// includes running the constructor and destructor of T[], encoding +// an decoding T[] into/from a Cord, etc. + +#include "tensorflow/core/public/tensor.h" + +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/tensor_coding.h" + +namespace tensorflow { +namespace { + +// Typed ref-counted buffer: T[n]. +template <typename T> +class Buffer : public TensorBuffer { + public: + Buffer(Allocator* a, int64 n); + + void* data() const override { return data_; } + size_t size() const override { return sizeof(T) * elem_; } + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription(AllocationDescription* proto) const override { + int64 rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(alloc_->Name()); + if (alloc_->TracksAllocationSizes()) { + int64 ab = alloc_->AllocatedSize(data_); + proto->set_allocated_bytes(ab); + } + } + + private: + Allocator* alloc_; + T* data_; + int64 elem_; + + ~Buffer() override; + + TF_DISALLOW_COPY_AND_ASSIGN(Buffer); +}; + +// is_simple<T>::value if T[] can be safely constructed and destructed +// without running T() and ~T(). We do not use std::is_trivial<T> +// directly because std::complex<float> is not trival but its array +// can be constructed and destructed without running its default ctor +// and dtor. +template <typename T> +struct is_simple { + static const bool value = std::is_trivial<T>::value || + std::is_same<T, complex64>::value || + is_quantized<T>::value; +}; + +template <> +struct is_simple<bfloat16> { + static const bool value = true; +}; + +// A set of helper functions depending on T. +template <typename T> +struct Helper { + // By default, we assume T is a simple type (float, int32, etc.) + static_assert(is_simple<T>::value, "T is not a simple type."); + typedef protobuf::RepeatedField<T> RepeatedFieldType; + + // No constructor to run. + static void RunCtor(T* p, int n) {} + + // No destructor to run. + static void RunDtor(T* p, int n) {} + + // Encoder of simple type T to a string. We do a copy. + template <typename Destination> + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + DCHECK_EQ(in->size(), sizeof(T) * n); + port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in, + out); + } + + // Decoder of simple type T. Copy the bytes from "in" into the + // tensor buffer. + template <typename Source> + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + if (in.size() != sizeof(T) * n) { + LOG(ERROR) << "Input size was " << in.size() << " and expected " + << sizeof(T) * n; + return nullptr; + } + Buffer<T>* buf = new Buffer<T>(a, n); + port::CopyToArray(in, buf->template base<char>()); + return buf; + } + + // Memory usage. + static int64 TotalBytes(TensorBuffer* in, int64 n) { + DCHECK_EQ(in->size(), sizeof(T) * n); + return in->size(); + } +}; + +// Helper specialization for string (the only non-simple type we +// support). +template <> +struct Helper<string> { + // Proto message uses RepeatedFieldType to hold repeated T. + typedef protobuf::RepeatedPtrField<string> RepeatedFieldType; + + // Runs string's default constructor for p[0], p[1], ..., p[n-1]. + static void RunCtor(string* p, int n) { + for (int i = 0; i < n; ++p, ++i) new (p) string(); + } + + // Runs T's default destructor for p[0], p[1], ..., p[n-1]. + static void RunDtor(string* p, int n) { + for (int i = 0; i < n; ++p, ++i) p->~string(); + } + + // Encodes "n" elements of type string stored in "in" into Cord + // "out", which is usually the TensorProto::tensor_content. + template <typename Destination> + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + port::EncodeStringList(in->base<const string>(), n, out); + } + + // Decodes "n" elements of type string from "in" and constructs a + // buffer out of it. Returns nullptr if the decoding fails. "in" is + // usually the TensorProto::tensor_content. + template <typename Source> + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + Buffer<string>* buf = new Buffer<string>(a, n); + string* strings = buf->template base<string>(); + if (port::DecodeStringList(in, strings, n)) { + return buf; + } else { + buf->Unref(); + return nullptr; + } + } + + // Returns the estimated memory usage of "n" elements of type T + // stored in buffer "in". + static int64 TotalBytes(TensorBuffer* in, int n) { + int64 tot = in->size(); + DCHECK_EQ(tot, sizeof(string) * n); + const string* p = in->base<const string>(); + for (int i = 0; i < n; ++i, ++p) tot += p->size(); + return tot; + } +}; + +template <typename T> +struct ProtoHelper {}; + +// For a C++ type "T" (float, double, int32, etc.), the repeated field +// "N"_val (float_val, int_val, label_val, etc.) of type "F" (float, +// int32, string, etc) in the TensorProto is used for serializing the +// tensor of type "T". +#define PROTO_TRAITS(T, F, N) \ + template <> \ + struct ProtoHelper<T> { \ + typedef Helper<F>::RepeatedFieldType FieldType; \ + static FieldType::const_iterator Begin(const TensorProto& proto) { \ + return proto.N##_val().begin(); \ + } \ + static size_t NumElements(const TensorProto& proto) { \ + return proto.N##_val().size(); \ + } \ + static void Fill(const T* data, size_t n, TensorProto* proto) { \ + typename ProtoHelper<T>::FieldType copy(data, data + n); \ + proto->mutable_##N##_val()->Swap(©); \ + } \ + }; +PROTO_TRAITS(float, float, float); +PROTO_TRAITS(double, double, double); +PROTO_TRAITS(int32, int32, int); +PROTO_TRAITS(uint8, int32, int); +PROTO_TRAITS(int16, int32, int); +PROTO_TRAITS(int8, int32, int); +PROTO_TRAITS(int64, int64, int64); +PROTO_TRAITS(bool, bool, bool); +PROTO_TRAITS(string, string, string); +PROTO_TRAITS(qint8, int32, int); +PROTO_TRAITS(quint8, int32, int); +#undef PROTO_TRAITS + +template <> +struct ProtoHelper<complex64> { + typedef Helper<float>::RepeatedFieldType FieldType; + static const complex64* Begin(const TensorProto& proto) { + return reinterpret_cast<const complex64*>(proto.scomplex_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.scomplex_val().size() / 2; + } + static void Fill(const complex64* data, size_t n, TensorProto* proto) { + const float* p = reinterpret_cast<const float*>(data); + FieldType copy(p, p + n * 2); + proto->mutable_scomplex_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper<qint32> { + typedef Helper<int32>::RepeatedFieldType FieldType; + static const qint32* Begin(const TensorProto& proto) { + return reinterpret_cast<const qint32*>(proto.int_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.int_val().size(); + } + static void Fill(const qint32* data, size_t n, TensorProto* proto) { + const int32* p = reinterpret_cast<const int32*>(data); + FieldType copy(p, p + n); + proto->mutable_int_val()->Swap(©); + } +}; + +template <> +struct ProtoHelper<bfloat16> { + typedef Helper<float>::RepeatedFieldType FieldType; + static const bfloat16* Begin(const TensorProto& proto) { + return reinterpret_cast<const bfloat16*>(proto.int_val().data()); + } + static size_t NumElements(const TensorProto& proto) { + return proto.int_val().size(); + } + static void Fill(const bfloat16* data, size_t n, TensorProto* proto) { + proto->mutable_int_val()->Reserve(n); + for (size_t i = 0; i < n; ++i) { + proto->mutable_int_val()->AddAlreadyReserved(data[i].value); + } + } +}; + +template <typename T> +Buffer<T>::Buffer(Allocator* a, int64 n) + : alloc_(a), data_(a->Allocate<T>(n)), elem_(n) { + if (data_) Helper<T>::RunCtor(data_, elem_); +} + +template <typename T> +Buffer<T>::~Buffer() { + if (data_) { + Helper<T>::RunDtor(data_, elem_); + alloc_->Deallocate<T>(data_); + } +} + +// Allocates a T[n] buffer. Fills in the buffer with repeated values +// in "in". If "in" has less values than "n", fills the rest of T[n] +// with the last value. If "in" has no values, fills T[n] with the +// default value for T. +// +// This routine is using the typed fields (float_val, etc.) in the +// tenor proto as opposed to the untyped binary representation +// (tensor_content). This is used when we expect the TensorProto is +// used by a client program which may not know how to encode a tensor +// in the compact binary representation. +template <typename T> +TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) { + CHECK_GT(n, 0); + Buffer<T>* buf = new Buffer<T>(a, n); + T* data = buf->template base<T>(); + const int64 in_n = ProtoHelper<T>::NumElements(in); + auto begin = ProtoHelper<T>::Begin(in); + if (n <= in_n) { + std::copy_n(begin, n, data); + } else if (in_n > 0) { + std::copy_n(begin, in_n, data); + const T& last = *(data + in_n - 1); + std::fill_n(data + in_n, n - in_n, last); + } else { + std::fill_n(data, n, T()); + } + return buf; +} + +// Copies T[n] stored in the buffer "in" into the repeated field in +// "out" corresponding to type T. +template <typename T> +void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) { + const T* data = in.base<const T>(); + // NOTE: T may not the same as + // ProtoHelper<T>::FieldType::value_type. E.g., T==int16, + // ProtoHelper<T>::FieldType::value_type==int32. If performance is + // critical, we can specialize T=float and do memcpy directly. + ProtoHelper<T>::Fill(data, n, out); +} + +void RefIfNonNull(core::RefCounted* buf) { + if (buf) buf->Ref(); +} + +void UnrefIfNonNull(core::RefCounted* buf) { + if (buf) buf->Unref(); +} + +} // end namespace + +Tensor::Tensor() : Tensor(DT_FLOAT) {} + +Tensor::Tensor(DataType type) : type_(type), shape_({0}), buf_(nullptr) {} + +Tensor::Tensor(const Tensor& other) + : type_(other.dtype()), shape_(other.shape()), buf_(other.buf_) { + RefIfNonNull(buf_); +} + +Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf) + : type_(type), shape_(shape), buf_(buf) { + RefIfNonNull(buf); +} + +bool Tensor::IsInitialized() const { + return buf_ != nullptr && buf_->data() != nullptr; +} + +Tensor::~Tensor() { UnrefIfNonNull(buf_); } + +void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) { + CHECK_EQ(shape.num_elements(), other.NumElements()); + type_ = other.dtype(); + shape_ = shape; + if (buf_ != other.buf_) { + UnrefIfNonNull(buf_); + buf_ = other.buf_; + RefIfNonNull(buf_); + } +} + +// The macro CASES() expands to a switch statement conditioned on +// TYPE_ENUM. Each case expands the STMTS after a typedef for T. +#define SINGLE_ARG(...) __VA_ARGS__ +#define CASE(TYPE, STMTS) \ + case DataTypeToEnum<TYPE>::value: { \ + typedef TYPE T; \ + STMTS; \ + break; \ + } +#define CASES(TYPE_ENUM, STMTS) \ + switch (TYPE_ENUM) { \ + CASE(float, SINGLE_ARG(STMTS)) \ + CASE(double, SINGLE_ARG(STMTS)) \ + CASE(int32, SINGLE_ARG(STMTS)) \ + CASE(uint8, SINGLE_ARG(STMTS)) \ + CASE(int16, SINGLE_ARG(STMTS)) \ + CASE(int8, SINGLE_ARG(STMTS)) \ + CASE(string, SINGLE_ARG(STMTS)) \ + CASE(complex64, SINGLE_ARG(STMTS)) \ + CASE(int64, SINGLE_ARG(STMTS)) \ + CASE(bool, SINGLE_ARG(STMTS)) \ + CASE(qint32, SINGLE_ARG(STMTS)) \ + CASE(quint8, SINGLE_ARG(STMTS)) \ + CASE(qint8, SINGLE_ARG(STMTS)) \ + CASE(bfloat16, SINGLE_ARG(STMTS)) \ + case DT_INVALID: \ + LOG(FATAL) << "Type not set"; \ + break; \ + default: \ + LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \ + break; \ + } + +Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape) + : type_(type), shape_(shape), buf_(nullptr) { + CHECK_NOTNULL(a); + if (shape_.num_elements() > 0) { + CASES(type, buf_ = new Buffer<T>(a, shape.num_elements())); + } +} + +Tensor::Tensor(DataType type, const TensorShape& shape) + : Tensor(cpu_allocator(), type, shape) {} + +template <typename T> +class SubBuffer : public TensorBuffer { + public: + // This buffer is an alias to buf[delta, delta + n). + SubBuffer(TensorBuffer* buf, int64 delta, int64 n) + : root_(buf->root_buffer()), data_(buf->base<T>() + delta), elem_(n) { + // Sanity check. The caller should ensure the sub buffer is valid. + CHECK_LE(root_->base<T>(), this->base<T>()); + T* root_limit = root_->base<T>() + root_->size() / sizeof(T); + CHECK_LE(this->base<T>(), root_limit); + CHECK_LE(this->base<T>() + n, root_limit); + // Hold a ref of the underlying root buffer. + // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer. + root_->Ref(); + } + + void* data() const override { return data_; } + size_t size() const override { return sizeof(T) * elem_; } + TensorBuffer* root_buffer() override { return root_; } + void FillAllocationDescription(AllocationDescription* proto) const override { + root_->FillAllocationDescription(proto); + } + + private: + TensorBuffer* root_; + T* data_; + int64 elem_; + + ~SubBuffer() override { root_->Unref(); } + + TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer); +}; + +Tensor Tensor::Slice(int64 start, int64 limit) const { + CHECK_GE(dims(), 1); + CHECK_LE(0, start); + CHECK_LE(start, limit); + int64 dim0_size = shape_.dim_size(0); + CHECK_LE(limit, dim0_size); + if ((start == 0) && (limit == dim0_size)) { + return *this; + } + Tensor ret; + ret.type_ = type_; + ret.shape_ = shape_; + ret.buf_ = nullptr; + if (dim0_size > 0) { + const int64 elems_per_dim0 = NumElements() / dim0_size; + const int64 delta = start * elems_per_dim0; + dim0_size = limit - start; + ret.shape_.set_dim(0, dim0_size); + const int64 num_elems = dim0_size * elems_per_dim0; + if (buf_) { + CASES(type_, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems)); + } + } + return ret; +} + +bool Tensor::FromProto(const TensorProto& proto) { + return FromProto(cpu_allocator(), proto); +} + +bool Tensor::FromProto(Allocator* a, const TensorProto& proto) { + CHECK_NOTNULL(a); + TensorBuffer* p = nullptr; + if (!TensorShape::IsValid(proto.tensor_shape())) return false; + if (proto.dtype() == DT_INVALID) return false; + TensorShape shape(proto.tensor_shape()); + const int64 N = shape.num_elements(); + if (N > 0 && proto.dtype()) { + if (!proto.tensor_content().empty()) { + const auto& content = proto.tensor_content(); + CASES(proto.dtype(), p = Helper<T>::Decode(a, content, N)); + } else { + CASES(proto.dtype(), p = FromProtoField<T>(a, proto, N)); + } + if (p == nullptr) return false; + } + type_ = proto.dtype(); + shape_ = shape; + UnrefIfNonNull(buf_); + buf_ = p; + return true; +} + +void Tensor::AsProtoField(TensorProto* proto) const { + proto->Clear(); + proto->set_dtype(dtype()); + shape_.AsProto(proto->mutable_tensor_shape()); + if (buf_) { + CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto)); + } +} + +void Tensor::AsProtoTensorContent(TensorProto* proto) const { + proto->Clear(); + proto->set_dtype(type_); + shape_.AsProto(proto->mutable_tensor_shape()); + if (buf_) { + CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(), + proto->mutable_tensor_content())); + } +} + +size_t Tensor::TotalBytes() const { + if (shape_.num_elements() == 0) return 0; + CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements(); + CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements())); + return 0; // Makes compiler happy. +} + +bool Tensor::CanUseDMA() const { + CASES(dtype(), return is_simple<T>::value); + return false; // Makes compiler happy. +} + +#undef CASES +#undef CASE + +string Tensor::SummarizeValue(int64 max_entries) const { + string ret; + for (int64 i = 0; i < std::min(max_entries, NumElements()); ++i) { + if (i > 0) strings::StrAppend(&ret, " "); + switch (dtype()) { + case DT_STRING: + strings::StrAppend(&ret, str_util::CEscape(flat<string>()(i))); + break; + case DT_BOOL: + strings::StrAppend(&ret, flat<bool>()(i) ? "True" : "False"); + break; + +#define CASE(DT_ENUM) \ + case DT_ENUM: \ + strings::StrAppend(&ret, flat<EnumToDataType<DT_ENUM>::Type>()(i)); \ + break + + CASE(DT_FLOAT); + CASE(DT_DOUBLE); + CASE(DT_INT32); + CASE(DT_UINT8); + CASE(DT_INT16); + CASE(DT_INT8); + CASE(DT_INT64); + +#undef CASE + default: + // TODO(zhifengc, josh11b): Pretty-print other types (bool, + // complex64, quantized, bfloat16). + strings::StrAppend(&ret, " ?"); + } + } + if (max_entries < NumElements()) strings::StrAppend(&ret, "..."); + + return ret; +} + +StringPiece Tensor::tensor_data() const { + if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors + return StringPiece(static_cast<char*>(buf_->data()), TotalBytes()); +} + +string Tensor::DebugString() const { + return strings::StrCat("Tensor<type: ", DataTypeString(dtype()), " shape: ", + shape().ShortDebugString(), " values: ", + SummarizeValue(3), ">"); +} + +void Tensor::FillDescription(TensorDescription* description) const { + description->set_dtype(dtype()); + shape().AsProto(description->mutable_shape()); + buf_->FillAllocationDescription( + description->mutable_allocation_description()); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor.proto b/tensorflow/core/framework/tensor.proto new file mode 100644 index 0000000000..b42694afde --- /dev/null +++ b/tensorflow/core/framework/tensor.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. TODO(mdevin): sort out the 0-rank issues. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized content from TensorBase::Serialize() This representation can be + // used for all tensor types. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; +}; diff --git a/tensorflow/core/framework/tensor_description.proto b/tensorflow/core/framework/tensor_description.proto new file mode 100644 index 0000000000..1fff3ee155 --- /dev/null +++ b/tensorflow/core/framework/tensor_description.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/allocation_description.proto"; + +message TensorDescription { + // Data type of tensor elements + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto shape = 2; + + // Information about the size and allocator used for the data + AllocationDescription allocation_description = 4; +}; diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc new file mode 100644 index 0000000000..3db2ffaaca --- /dev/null +++ b/tensorflow/core/framework/tensor_shape.cc @@ -0,0 +1,138 @@ +#include "tensorflow/core/public/tensor_shape.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// An upper limit of the total number of elements in a tensor. +static const int64 kMaxElements = (1LL << 40); + +bool TensorShape::IsValid(const TensorShapeProto& proto) { + int64 num_elements = 1; + for (const auto& d : proto.dim()) { + if (d.size() < 0) return false; + num_elements *= d.size(); + if (num_elements > kMaxElements) return false; + } + return true; +} + +TensorShape::TensorShape(const TensorShapeProto& proto) { + dim_sizes_.reserve(proto.dim_size()); + num_elements_ = 1; + for (const auto& d : proto.dim()) { + AddDim(d.size()); + } +} + +TensorShape::TensorShape(gtl::ArraySlice<int64> dim_sizes) { + dim_sizes_.reserve(dim_sizes.size()); + num_elements_ = 1; + for (auto s : dim_sizes) { + AddDim(s); + } +} + +TensorShape::TensorShape() : num_elements_(1) {} + +void TensorShape::Clear() { + dim_sizes_.clear(); + num_elements_ = 1; +} + +void TensorShape::AddDim(int64 size) { + CHECK_GE(size, 0); + dim_sizes_.push_back(size); + num_elements_ *= size; + CHECK_LE(0, num_elements_); + CHECK_LE(num_elements_, kMaxElements); +} + +void TensorShape::AppendShape(const TensorShape& shape) { + for (auto d : shape) AddDim(d.size); +} + +void TensorShape::InsertDim(int d, int64 size) { + CHECK_GE(d, 0); + CHECK_LE(d, dims()); + CHECK_GE(size, 0); + dim_sizes_.insert(dim_sizes_.begin() + d, size); + num_elements_ *= size; + CHECK_LE(0, num_elements_); + CHECK_LE(num_elements_, kMaxElements); +} + +void TensorShape::set_dim(int d, int64 size) { + CHECK_GE(d, 0); + CHECK_LT(d, dims()); + CHECK_GE(size, 0); + + // Update the number of elements. num_elements_ is int64. + dim_sizes_[d] = size; + recompute_dims(); +} + +void TensorShape::RemoveDim(int d) { + CHECK_GE(d, 0); + CHECK_LT(d, dims()); + + // Update the number of elements and remove the dimension from the + // sizes. + dim_sizes_.erase(dim_sizes_.begin() + d); + recompute_dims(); +} + +void TensorShape::recompute_dims() { + num_elements_ = 1; + for (auto s : dim_sizes_) { + num_elements_ *= s; + CHECK_LE(0, num_elements_); + CHECK_LE(num_elements_, kMaxElements); + } +} + +bool TensorShape::IsSameSize(const TensorShape& b) const { + if (b.dims() != dims()) return false; + for (int d = 0; d < dims(); d++) { + if (dim_size(d) != b.dim_size(d)) return false; + } + return true; +} + +void TensorShape::AsProto(TensorShapeProto* proto) const { + proto->Clear(); + for (size_t d = 0; d < dim_sizes_.size(); ++d) { + auto* dim = proto->add_dim(); + dim->set_size(dim_sizes_[d]); + } +} + +TensorShapeIter TensorShape::begin() const { return TensorShapeIter(this, 0); } + +TensorShapeIter TensorShape::end() const { + return TensorShapeIter(this, dims()); +} + +string TensorShape::DebugString() const { + TensorShapeProto proto; + AsProto(&proto); + return proto.ShortDebugString(); +} + +string TensorShape::ShortDebugString() const { + return strings::StrCat( + "[", str_util::Join(gtl::ArraySlice<int64>(dim_sizes_), ","), "]"); +} + +bool TensorShapeUtils::StartsWith(const TensorShape& shape, + const TensorShape& prefix) { + if (shape.dims() < prefix.dims()) return false; + for (int i = 0; i < prefix.dims(); i++) { + if (shape.dim_size(i) != prefix.dim_size(i)) return false; + } + return true; +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_shape.proto b/tensorflow/core/framework/tensor_shape.proto new file mode 100644 index 0000000000..8fe7cce13d --- /dev/null +++ b/tensorflow/core/framework/tensor_shape.proto @@ -0,0 +1,29 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +// option cc_enable_arenas = true; + +package tensorflow; + +// Dimensions of a tensor and the type of data it contains. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} for a 30 x + // 40 2D tensor. The names are optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + repeated Dim dim = 2; +}; diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc new file mode 100644 index 0000000000..adac1a4787 --- /dev/null +++ b/tensorflow/core/framework/tensor_shape_test.cc @@ -0,0 +1,75 @@ +#include "tensorflow/core/public/tensor_shape.h" + +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +TEST(TensorShapeTest, Default) { + // The default TensorShape constructor constructs a shape of 0-dim + // and 1-element. + TensorShape s; + EXPECT_EQ(s.dims(), 0); + EXPECT_EQ(s.num_elements(), 1); +} + +TEST(TensorShapeTest, set_dim) { + TensorShape s({10, 5}); + + s.set_dim(0, 20); + ASSERT_EQ(2, s.dims()); + EXPECT_EQ(20, s.dim_size(0)); + EXPECT_EQ(100, s.num_elements()); + + s.set_dim(1, 2); + ASSERT_EQ(2, s.dims()); + EXPECT_EQ(2, s.dim_size(1)); + EXPECT_EQ(40, s.num_elements()); +} + +TEST(TensorShapeTest, RemoveDim) { + TensorShape s({10, 5}); + s.RemoveDim(0); + EXPECT_EQ(5, s.num_elements()); + ASSERT_EQ(1, s.dims()); +} + +TEST(TensorShapeTest, RemoveAndAddDim) { + TensorShape s({10, 5, 20}); + s.RemoveDim(1); + s.AddDim(100); + + EXPECT_EQ(20000, s.num_elements()); + ASSERT_EQ(3, s.dims()); +} + +TEST(TensorShapeTest, InvalidShapeProto) { + TensorShapeProto proto; + EXPECT_TRUE(TensorShape::IsValid(proto)); + + proto.add_dim()->set_size(357); + proto.add_dim()->set_size(982); + EXPECT_TRUE(TensorShape::IsValid(proto)); + + proto.Clear(); + proto.add_dim()->set_size(-357); + proto.add_dim()->set_size(-982); + EXPECT_FALSE(TensorShape::IsValid(proto)); + + proto.Clear(); + proto.add_dim()->set_size(1LL << 20); + proto.add_dim()->set_size((1LL << 20) + 1); + EXPECT_FALSE(TensorShape::IsValid(proto)); +} + +TEST(TensorShapeTest, SetDimForEmptyTensor) { + TensorShape s({10, 5, 20}); + EXPECT_EQ(1000, s.num_elements()); + s.set_dim(1, 0); + EXPECT_EQ(0, s.num_elements()); + s.set_dim(1, 7); + EXPECT_EQ(1400, s.num_elements()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_slice.cc b/tensorflow/core/framework/tensor_slice.cc new file mode 100644 index 0000000000..473d9463ee --- /dev/null +++ b/tensorflow/core/framework/tensor_slice.cc @@ -0,0 +1,226 @@ +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); } + +TensorSlice::TensorSlice(const TensorSliceProto& proto) { + starts_.reserve(proto.extent_size()); + lengths_.reserve(proto.extent_size()); + for (const auto& e : proto.extent()) { + starts_.push_back(e.start()); + lengths_.push_back(GetExtentLength(e)); + } +} + +TensorSlice::TensorSlice(std::initializer_list<std::pair<int, int>> extents) { + starts_.reserve(extents.size()); + lengths_.reserve(extents.size()); + for (const auto& e : extents) { + starts_.push_back(e.first); + lengths_.push_back(e.second); + } +} + +Status TensorSlice::Parse(const string& str, TensorSlice* slice) { + std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty()); + slice->starts_.reserve(items.size()); + slice->lengths_.reserve(items.size()); + for (const string& x : items) { + int s, l; + if (x == "-") { + // "everything" + s = 0; + l = kFullExtent; + } else { + char junk; + if (sscanf(x.c_str(), "%d,%d%c", &s, &l, &junk) != 2) { + return errors::InvalidArgument( + "Expected a pair of numbers or '-' " + "but got '", + x, "': string = ", str); + } + if (s < 0 || l <= 0) { + return errors::InvalidArgument( + "Expected non-negative start and " + "positive length but got start = ", + s, ", length = ", l, ": string = ", str); + } + } + slice->starts_.push_back(s); + slice->lengths_.push_back(l); + } + + return Status::OK(); +} + +void TensorSlice::Clear() { + starts_.clear(); + lengths_.clear(); +} + +void TensorSlice::SetFullSlice(int dim) { + Clear(); + starts_.reserve(dim); + lengths_.reserve(dim); + for (int d = 0; d < dim; ++d) { + starts_.push_back(0); + lengths_.push_back(kFullExtent); + } +} + +void TensorSlice::Extend(int dim) { + int old_dim = dims(); + DCHECK_LE(old_dim, dim); + starts_.resize(dim); + lengths_.resize(dim); + for (int d = old_dim; d < dim; ++d) { + starts_[d] = 0; + lengths_[d] = kFullExtent; + } +} + +void TensorSlice::AsProto(TensorSliceProto* proto) const { + for (int d = 0; d < dims(); ++d) { + TensorSliceProto::Extent* e = proto->add_extent(); + // We only need to record the explicit slice for non-full slices + if (!IsFullAt(d)) { + e->set_start(starts_[d]); + e->set_length(lengths_[d]); + } + } +} + +string TensorSlice::DebugString() const { + string buffer; + bool first = true; + for (int d = 0; d < dims(); ++d) { + if (!first) { + buffer.append(":"); + } + string s; + if (IsFullAt(d)) { + buffer.append("-"); + } else { + strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]); + } + first = false; + } + return buffer; +} + +bool TensorSlice::Intersect(const TensorSlice& other, + TensorSlice* result) const { + // First, if two slices have different ranks, they obviously don't overlap + // -- in fact they are not compatible. + if (dims() != other.dims()) { + return false; + } + + // Setting the result to the right dimension + if (result) { + result->SetFullSlice(dims()); + } + // The two slices overlap if they overlap in all dimensions. + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + if (result) { + result->set_start(d, other.start(d)); + result->set_length(d, other.length(d)); + } + } else if (other.IsFullAt(d)) { + if (result) { + result->set_start(d, start(d)); + result->set_length(d, length(d)); + } + } else { + // If we have an intersection here, it should have a start that is the + // max of the two starts and an end that is the min of the two ends. + int s = std::max(start(d), other.start(d)); + int l = std::min(end(d), other.end(d)) - s; + if (l > 0) { + // We have a real intersection + if (result) { + result->set_start(d, s); + result->set_length(d, l); + } + } else { + // We don't have an intersection for this dimension -- thus we don't + // have any intersection at all. + if (result) { + result->Clear(); + } + return false; + } + } + } + // If we are here, we know there is overlap in every dimension. + return true; +} + +void TensorSlice::ComputeRelative(const TensorSlice& sub, + TensorSlice* relative) const { + DCHECK_EQ(dims(), sub.dims()); + relative->SetFullSlice(dims()); + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + relative->set_start(d, sub.start(d)); + relative->set_length(d, sub.length(d)); + } else { + // Otherwise the relative start is the difference between the start of + // sub and the start of base + relative->set_start(d, sub.start(d) - start(d)); + relative->set_length(d, sub.length(d)); + } + } +} + +// static +bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) { + return extent.has_length_case() == TensorSliceProto::Extent::kLength; +} + +// static +int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) { + if (!HasExtentLength(extent)) return -1; + return extent.length(); +} + +Status TensorSlice::SliceTensorShape(const TensorShape& shape, + TensorShape* result_shape) const { + result_shape->Clear(); + // Mismatching ranks: we can't apply the slice at all. + if (shape.dims() != dims()) { + return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(), + ", slice = ", DebugString()); + } + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + result_shape->AddDim(shape.dim_size(d)); + } else { + // Check if the extent applies to the dimension + if (end(d) <= shape.dim_size(d)) { + // Yes: the end is within the range of the dim -- we adjust the result + // shape so that its size along this dimension is the length of the + // slice. + result_shape->AddDim(length(d)); + } else { + // The extent doesn't apply to the dimension + result_shape->Clear(); + return errors::Internal("Extent in dimension ", d, + " out of bounds: shape = ", shape.DebugString(), + ", slice = ", DebugString()); + } + } + } + // If we are here, we have successfully applied the shape. + return Status::OK(); +} + +const int TensorSlice::kFullExtent = -1; + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h new file mode 100644 index 0000000000..8e2f108c3f --- /dev/null +++ b/tensorflow/core/framework/tensor_slice.h @@ -0,0 +1,189 @@ +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_ + +#include <string> +#include "tensorflow/core/framework/tensor_slice.pb.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor_shape.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +// A tensor slice represents a slice of a given tensor. It is represented by a +// list of (start, length) pairs, where the size of the list is the rank of the +// tensor. + +class TensorSlice { + public: + // Construct a tensor slice: you have a number of ways: + // -- creating an empty slice + // -- from just a dimension (in this case it will create a full slice) + // -- from an array of pairs of integers. + // -- from a TensorSliceProto protocol buffer + // -- from a string format of "start,lenth:start,length..." where each + // "start,length" pair represents the slice on one dimension. We allow a + // special "-" that means "everything for this dimension". One such example + // is: 0,10:-:14,1:-:- + TensorSlice() {} + explicit TensorSlice(int dim); + explicit TensorSlice(const TensorSliceProto& proto); + explicit TensorSlice(std::initializer_list<std::pair<int, int>> extents); + + static Status Parse(const string& str, TensorSlice* output); + static TensorSlice ParseOrDie(const string& str) { + TensorSlice ret; + Status s = Parse(str, &ret); + if (!s.ok()) { + LOG(FATAL) << "Could not parse TensorSlice"; + } + return ret; + } + + void Clear(); + + // Accessors + int dims() const { return starts_.size(); } + + int start(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return starts_[d]; + } + + int length(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return lengths_[d]; + } + + int end(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return start(d) + length(d); + } + + void set_start(int d, int x) { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + DCHECK_GE(x, 0); + starts_[d] = x; + } + + void set_length(int d, int x) { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + lengths_[d] = x; + } + + // If we have a full slice along dimension "d". + bool IsFullAt(int d) const { return lengths_[d] < 0; } + + // Set the slice to be a full slice of "dim" dimensions + void SetFullSlice(int dim); + + // Extend a slice to "dim" dimensions: all the added dimensions are full. + // Requires: dim >= dims(). + void Extend(int dim); + + // Conversion of a TensorSlice to other formats + void AsProto(TensorSliceProto* proto) const; + string DebugString() const; + + // Fill *indices and *sizes from *this (so that we can use the slice() + // function in eigen tensor). We need a tensor shape in case some of the + // slices are full slices. + // We allow NDIMS to be greater than dims(), in which case we will pad the + // higher dimensions with trivial dimensions. + template <int NDIMS> + void FillIndicesAndSizes(const TensorShape& shape, + Eigen::DSizes<ptrdiff_t, NDIMS>* indices, + Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const; + + // Interaction with other TensorSlices. + + // Compute the intersection with another slice and if "result" is not + // nullptr, store the results in *result; returns true is there is any real + // intersection. + bool Intersect(const TensorSlice& other, TensorSlice* result) const; + // A short hand. + bool Overlaps(const TensorSlice& other) const { + return Intersect(other, nullptr); + } + + // Interaction with TensorShape. + + // Slices a shape and stores the result into *result_shape. + // Requires that the shape and *this have the same rank. + // For example, given a tensor shape of {3, 4, 5}, and a slice of + // 1,2:-:0,2, the result shape is {2, 4, 2}. + Status SliceTensorShape(const TensorShape& shape, + TensorShape* result_shape) const; + + // Given slice "sub" where "sub" is fully contained in *this, + // (meaning that the intersection of "sub" and *this equals "sub"), computes + // the "relative" slice of "sub" with respect to *this. + // + // In other words, if we use A>S to denote slicing a shape S with a slice A, + // then the function is computing a slice X such that: + // X > (this > S) = sub > S + // for any shape S. + // + // In general, along every dimension, the start of the relative slice is the + // start of the "sub" slice minus the start of *this; the length of the + // relative slice is the length of the "sub" slice. + // + // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and + // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2. + // + // The caller needs to make sure that "sub" is indeed a sub-slice of *this; + // otherwise the result is undefined. + void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const; + + // Returns true if the length field was specified in an Extent. + static bool HasExtentLength(const TensorSliceProto::Extent& extent); + + // Returns the value of the length field in an Extent, or -1 if it + // is not present. + static int64 GetExtentLength(const TensorSliceProto::Extent& extent); + + private: + // a length value of kFullExtent (-1) means we have a full slice at this + // dimension. It's defined in tensor_slice.cc. + static const int kFullExtent; + + // TODO(yangke): switch to Eigen once it supports variable size arrays. + // A value of + gtl::InlinedVector<int, 4> starts_; + gtl::InlinedVector<int, 4> lengths_; +}; + +template <int NDIMS> +void TensorSlice::FillIndicesAndSizes( + const TensorShape& shape, Eigen::DSizes<ptrdiff_t, NDIMS>* indices, + Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const { + CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape " + << "slices: shape = " << shape.DebugString() + << ", slice = " << DebugString(); + CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from " + << "a slice of dimension " << dims(); + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + (*indices)[d] = 0; + (*sizes)[d] = shape.dim_size(d); + } else { + (*indices)[d] = starts_[d]; + (*sizes)[d] = lengths_[d]; + } + } + for (int d = dims(); d < NDIMS; ++d) { + (*indices)[d] = 0; + (*sizes)[d] = 1; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_ diff --git a/tensorflow/core/framework/tensor_slice.proto b/tensorflow/core/framework/tensor_slice.proto new file mode 100644 index 0000000000..ca676bc766 --- /dev/null +++ b/tensorflow/core/framework/tensor_slice.proto @@ -0,0 +1,34 @@ +// Protocol buffer representing slices of a tensor + +syntax = "proto3"; +// option cc_enable_arenas = true; + +package tensorflow; + +// Can only be interpreted if you know the corresponding TensorShape. +message TensorSliceProto { + // Extent of the slice in one dimension. + message Extent { + // Either both or no attributes must be set. When no attribute is set + // means: All data in that dimension. + + // Start index of the slice, starting at 0. + int64 start = 1; + + // Length of the slice: if the length is missing or -1 we will + // interpret this as "everything in this dimension". We use + // "oneof" to preserve information about whether the length is + // present without changing the serialization format from the + // prior proto2 version of this proto. + oneof has_length { + int64 length = 2; + } + }; + + // Extent of the slice in all tensor dimensions. + // + // Must have one entry for each of the dimension of the tensor that this + // slice belongs to. The order of sizes is the same as the order of + // dimensions in the TensorShape. + repeated Extent extent = 1; +}; diff --git a/tensorflow/core/framework/tensor_slice_test.cc b/tensorflow/core/framework/tensor_slice_test.cc new file mode 100644 index 0000000000..5f718a56b6 --- /dev/null +++ b/tensorflow/core/framework/tensor_slice_test.cc @@ -0,0 +1,246 @@ +#include "tensorflow/core/framework/tensor_slice.h" + +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +// Basic tests +TEST(TensorSliceTest, Basic) { + { + // Repeatedly setting FullSlice should work. + TensorSlice s(3); + EXPECT_EQ("-:-:-", s.DebugString()); + + s.SetFullSlice(4); + EXPECT_EQ("-:-:-:-", s.DebugString()); + } +} + +// Testing for serialization and parsing for the string format of slices. +TEST(TensorSliceTest, Serialization) { + // Serialization + { + TensorSlice s({{0, -1}, {0, 10}, {14, 1}, {0, -1}}); + EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); + } + + { + TensorSliceProto proto; + // Define ptxt outside ASSERT_TRUE call to work around bug in some + // versions of gcc that breaks when you use raw string literals + // inside macro expansions. + // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 + const char* ptxt = R"PROTO( + extent { } + extent { start: 0 length: 10 } + extent { start: 14 length: 1 } + extent { } + )PROTO"; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); + TensorSlice s(proto); + EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); + } + + // Parsing + { + TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5"); + TensorSliceProto proto; + s.AsProto(&proto); + EXPECT_EQ( + "extent { } " + "extent { } " + "extent { start: 1 length: 3 } " + "extent { start: 4 length: 5 }", + proto.ShortDebugString()); + } + + // Failed parsing + { + TensorSlice slice; + Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice); + EXPECT_EQ( + "Invalid argument: " + "Expected a pair of numbers or '-' but got '4': " + "string = -:-:1,3:4:5", + s.ToString()); + } + { + TensorSlice slice; + Status s = TensorSlice::Parse("-:-1,3", &slice); + EXPECT_EQ( + "Invalid argument: " + "Expected non-negative start and positive length but got " + "start = -1, length = 3: string = -:-1,3", + s.ToString()); + } +} + +// Testing the slice intersection +TEST(TensorSliceTest, Intersection) { + // "EVERYTHING" intersects with everything + { + TensorSlice a = TensorSlice::ParseOrDie("-:-"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("1,2:3,4", c.DebugString()); + } + + { + TensorSlice a = TensorSlice::ParseOrDie("-:-"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); + TensorSlice c; + EXPECT_TRUE(b.Intersect(a, &c)); + EXPECT_EQ("1,2:3,4", c.DebugString()); + } + + // Overlap at all dimensions + { + TensorSlice a = TensorSlice::ParseOrDie("1,5:2,6:3,7:5,10"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4:9,10:12,1"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("1,2:3,4:9,1:12,1", c.DebugString()); + } + + // A mixture of everything and non-trivial slices + { + TensorSlice a = TensorSlice::ParseOrDie("-:1,1"); + TensorSlice b = TensorSlice::ParseOrDie("-:0,2"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("-:1,1", c.DebugString()); + } + + // No overlap on dimension 3: "3,1" and "4,5" don't intersect + { + TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:5,6"); + TensorSlice b = TensorSlice::ParseOrDie("1,3:4,5:1,6"); + TensorSlice c; + EXPECT_FALSE(a.Intersect(b, &c)); + EXPECT_EQ("", c.DebugString()); + } + // No intersection when there are different dimensions + { + TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:-"); + TensorSlice b = TensorSlice::ParseOrDie("-:-"); + TensorSlice c; + EXPECT_FALSE(a.Intersect(b, &c)); + EXPECT_EQ("", c.DebugString()); + } +} + +// Testing applying a slice to a tensor shape +TEST(TensorSliceTest, SliceTensorShape) { + // A proper application + { + TensorSlice a = TensorSlice::ParseOrDie("1,1:-:4,1:2,6"); + TensorShape x({2, 4, 5, 8}); + TensorShape y; + EXPECT_OK(a.SliceTensorShape(x, &y)); + EXPECT_EQ( + "dim { size: 1 } " + "dim { size: 4 } " + "dim { size: 1 } " + "dim { size: 6 }", + y.DebugString()); + } + + // An invalid application -- dimension 2 is out of bound + { + TensorSlice a = TensorSlice::ParseOrDie("1,1:1,4:-:-"); + TensorShape x({2, 4, 5, 8}); + TensorShape y; + EXPECT_EQ( + "Internal: " + "Extent in dimension 1 out of bounds: " + "shape = dim { size: 2 } " + "dim { size: 4 } " + "dim { size: 5 } " + "dim { size: 8 }, " + "slice = 1,1:1,4:-:-", + a.SliceTensorShape(x, &y).ToString()); + EXPECT_EQ("", y.DebugString()); + } +} + +// Testing the computation of relative slices. +TEST(TensorSliceTest, ComputeRelative) { + // Easy case: base is "everything" + { + TensorSlice base = TensorSlice::ParseOrDie("-:-:-:-"); + TensorSlice sub = TensorSlice::ParseOrDie("-:1,2:-:3,4"); + TensorSlice relative; + base.ComputeRelative(sub, &relative); + EXPECT_EQ("-:1,2:-:3,4", relative.DebugString()); + } + + // A slightly more complicated case + { + TensorSlice base = TensorSlice::ParseOrDie("1,2:3,4:-:5,1"); + TensorSlice sub = TensorSlice::ParseOrDie("1,1:4,2:3,3:5,1"); + TensorSlice relative; + base.ComputeRelative(sub, &relative); + EXPECT_EQ("0,1:1,2:3,3:0,1", relative.DebugString()); + } +} + +TEST(TensorSliceTest, ExtentLength) { + TensorSliceProto proto; + // Define ptxt outside ASSERT_TRUE call to work around bug in some + // versions of gcc that breaks when you use raw string literals + // inside macro expansions. + // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 + const char* ptxt = R"PROTO( + extent { } + extent { start: 0 length: 10 } + extent { start: 14 length: 1 } + extent { } + )PROTO"; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); + EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(0))); + EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(1))); + EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(2))); + EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(3))); + EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(0))); + EXPECT_EQ(10, TensorSlice::GetExtentLength(proto.extent(1))); + EXPECT_EQ(1, TensorSlice::GetExtentLength(proto.extent(2))); + EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(3))); +} + +TEST(TensorSliceTest, Deserialization) { + // Serialization of + // extent { length: 5 } + // extent { start: 0 length: 10 } + // extent { start: 14 length: 1 } + // extent { start: 1 } + // extent { } + // in proto2 and proto3: + const char pb2[] = + "\x0A\x02\x10\x05\x0A\x04\x08\x00" + "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; + const char pb3[] = + "\x0A\x02\x10\x05\x0A\x02" + "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; + // (The difference is that in the proto3 version, "start: 0" isn't included + // since 0 is start's default value.) + + TensorSliceProto proto2; + ASSERT_TRUE(proto2.ParseFromArray(pb2, sizeof(pb2) - 1)); + TensorSlice ts2(proto2); + + TensorSliceProto proto3; + ASSERT_TRUE(proto3.ParseFromArray(pb3, sizeof(pb3) - 1)); + TensorSlice ts3(proto3); + + // Both serializations should be interpreted the same. + EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString()); + EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc new file mode 100644 index 0000000000..4963c2c219 --- /dev/null +++ b/tensorflow/core/framework/tensor_test.cc @@ -0,0 +1,551 @@ +#include "tensorflow/core/public/tensor.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +TEST(TensorTest, Default) { + Tensor t; + EXPECT_EQ(t.dtype(), DT_FLOAT); + EXPECT_EQ(t.dims(), 1); + EXPECT_EQ(t.NumElements(), 0); +} + +TEST(TensorTest, DataType_Traits) { + EXPECT_TRUE(std::is_trivial<float>::value); + EXPECT_TRUE(std::is_trivial<double>::value); + EXPECT_TRUE(std::is_trivial<int32>::value); + EXPECT_TRUE(std::is_trivial<uint8>::value); + EXPECT_TRUE(std::is_trivial<int16>::value); + EXPECT_TRUE(std::is_trivial<int8>::value); + EXPECT_TRUE(std::is_trivial<int64>::value); + EXPECT_TRUE(std::is_trivial<bool>::value); + EXPECT_FALSE(std::is_trivial<string>::value); + + EXPECT_EQ(sizeof(bool), 1); + + // Unfortunately. std::complex::complex() initializes (0, 0). + EXPECT_FALSE(std::is_trivial<complex64>::value); + EXPECT_FALSE(std::is_trivial<std::complex<double>>::value); + EXPECT_TRUE(std::is_trivial<float[2]>::value); + struct MyComplex { + float re, im; + }; + EXPECT_TRUE(std::is_trivial<MyComplex>::value); +} + +template <typename T> +void TestCopies(const Tensor& t) { + { + LOG(INFO) << "CopyFrom()"; + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.CopyFrom(t, t.shape())); + test::ExpectTensorEqual<T>(t, t2); + } + { + LOG(INFO) << "operator=()"; + Tensor t2(t.dtype()); + t2 = t; + test::ExpectTensorEqual<T>(t, t2); + } + { + LOG(INFO) << "deep copy"; + Tensor t2(t.dtype(), t.shape()); + t2.flat<T>() = t.flat<T>(); + test::ExpectTensorEqual<T>(t, t2); + } + { + LOG(INFO) << "AsProtoField()"; + TensorProto proto; + t.AsProtoField(&proto); + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.FromProto(proto)); + test::ExpectTensorEqual<T>(t, t2); + } + { + LOG(INFO) << "AsProtoTensorContent()"; + TensorProto proto; + t.AsProtoTensorContent(&proto); + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.FromProto(proto)); + test::ExpectTensorEqual<T>(t, t2); + // Make another copy via tensor_content field. + *proto.mutable_tensor_content() = proto.tensor_content(); + Tensor t3(t.dtype()); + EXPECT_TRUE(t3.FromProto(proto)); + test::ExpectTensorEqual<T>(t, t2); + } + { + LOG(INFO) << "AsTensor"; + gtl::ArraySlice<T> values(t.flat<T>().data(), t.NumElements()); + Tensor t2 = test::AsTensor(values, t.shape()); + test::ExpectTensorEqual<T>(t, t2); + } +} + +TEST(Tensor_Float, Simple) { + Tensor t(DT_FLOAT, TensorShape({10, 20})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 20}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix<float>()(a, b) = static_cast<float>(a * b); + } + } + TestCopies<float>(t); +} + +TEST(Tensor_QInt8, Simple) { + Tensor t(DT_QINT8, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix<qint8>()(a, b) = qint8(a * b); + } + } + TestCopies<qint8>(t); +} + +TEST(Tensor_QUInt8, Simple) { + Tensor t(DT_QUINT8, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix<Eigen::QUInt8>()(a, b) = Eigen::QUInt8(a * b); + } + } + TestCopies<Eigen::QUInt8>(t); +} + +TEST(Tensor_QInt32, Simple) { + Tensor t(DT_QINT32, TensorShape({2, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); + for (int64 a = 0; a < t.shape().dim_size(0); a++) { + for (int64 b = 0; b < t.shape().dim_size(1); b++) { + t.matrix<qint32>()(a, b) = qint32(static_cast<int32>(a * b)); + } + } + TestCopies<qint32>(t); +} + +TEST(Tensor_Float, Reshape) { + Tensor t(DT_FLOAT, TensorShape({2, 3, 4, 5})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 3, 4, 5}))); + + { + auto tensor = t.tensor<float, 4>(); + EXPECT_EQ(2, tensor.dimension(0)); + EXPECT_EQ(3, tensor.dimension(1)); + EXPECT_EQ(4, tensor.dimension(2)); + EXPECT_EQ(5, tensor.dimension(3)); + + // Set first and last elements. + tensor(0, 0, 0, 0) = 0.01f; + tensor(1, 2, 3, 4) = 0.02f; + } + { + auto shaped = t.shaped<float, 1>({120}); + EXPECT_EQ(120, shaped.dimension(0)); + EXPECT_EQ(shaped(0), 0.01f); + EXPECT_EQ(shaped(119), 0.02f); + } + { + auto shaped = t.shaped<float, 2>({6, 20}); + EXPECT_EQ(6, shaped.dimension(0)); + EXPECT_EQ(20, shaped.dimension(1)); + EXPECT_EQ(shaped(0, 0), 0.01f); + EXPECT_EQ(shaped(5, 19), 0.02f); + } + { + auto shaped = t.shaped<float, 3>({6, 4, 5}); + EXPECT_EQ(6, shaped.dimension(0)); + EXPECT_EQ(4, shaped.dimension(1)); + EXPECT_EQ(5, shaped.dimension(2)); + EXPECT_EQ(shaped(0, 0, 0), 0.01f); + EXPECT_EQ(shaped(5, 3, 4), 0.02f); + } + { + auto shaped = t.shaped<float, 4>({2, 3, 4, 5}); + EXPECT_EQ(2, shaped.dimension(0)); + EXPECT_EQ(3, shaped.dimension(1)); + EXPECT_EQ(4, shaped.dimension(2)); + EXPECT_EQ(5, shaped.dimension(3)); + + EXPECT_EQ(shaped(0, 0, 0, 0), 0.01f); + EXPECT_EQ(shaped(1, 2, 3, 4), 0.02f); + } + { + auto flat = t.flat<float>(); + EXPECT_EQ(flat(0), 0.01f); + EXPECT_EQ(120, flat.dimension(0)); + EXPECT_EQ(flat(0), 0.01f); + EXPECT_EQ(flat(119), 0.02f); + } + { + auto flat_inner_dims = t.flat_inner_dims<float>(); + EXPECT_EQ(24, flat_inner_dims.dimension(0)); + EXPECT_EQ(5, flat_inner_dims.dimension(1)); + EXPECT_EQ(flat_inner_dims(0, 0), 0.01f); + EXPECT_EQ(flat_inner_dims(23, 4), 0.02f); + } +} + +TEST(Tensor_Scalar, Basics) { + { + Tensor t(DT_FLOAT, TensorShape({})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar<float>(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.scalar<float>()() = 123.45f; + EXPECT_FLOAT_EQ(123.45f, Tt()); + } + { + Tensor t(DT_FLOAT, TensorShape({1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.vec<float>(); + EXPECT_EQ(1, Tt.size()); + t.vec<float>()(0) = 123.45f; + EXPECT_FLOAT_EQ(123.45f, Tt(0)); + } + { + Tensor t(DT_FLOAT, TensorShape({1, 1, 1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar<float>(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.flat<float>()(0) = 123.45f; + EXPECT_FLOAT_EQ(123.45f, Tt()); + } + { + Tensor t(DT_STRING, TensorShape({})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar<string>(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.scalar<string>()() = "foo"; + EXPECT_EQ("foo", Tt()); + } + { + Tensor t(DT_STRING, TensorShape({1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.vec<string>(); + EXPECT_EQ(1, Tt.size()); + t.flat<string>()(0) = "foo"; + EXPECT_EQ("foo", Tt(0)); + } + { + Tensor t(DT_STRING, TensorShape({1, 1, 1})); + EXPECT_EQ(1, t.NumElements()); + auto Tt = t.scalar<string>(); + EXPECT_EQ(1, Tt.size()); + EXPECT_EQ(0, Tt.rank()); + t.flat<string>()(0) = "bar"; + EXPECT_EQ("bar", Tt()); + } + { + Tensor t(DT_FLOAT, TensorShape({0, 1})); + EXPECT_EQ(0, t.NumElements()); + auto Tt = t.flat<float>(); + EXPECT_EQ(0, Tt.size()); + auto Tm = t.matrix<float>(); + EXPECT_EQ(0, Tm.size()); + EXPECT_EQ(0, Tm.dimensions()[0]); + EXPECT_EQ(1, Tm.dimensions()[1]); + } +} + +TEST(Tensor_Float, Reshape_And_Slice_Assignment) { + // A test to experiment with a way to assign to a subset of a tensor + Tensor t(DT_FLOAT, TensorShape({10, 4, 3, 2})); + EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 4, 3, 2}))); + + // Get the N dimensional tensor (N==4 here) + auto e_t = t.tensor<float, 4>(); + // Reshape to view it as a two-dimensional tensor + auto e_2d = t.shaped<float, 2>({10, 4 * 3 * 2}); + for (int i = 0; i < 10; i++) { + // Assign a 1 x 4*3*2 matrix (really vector) to a slice of size + // 1 x 4*3*2 in e_t. + Eigen::Tensor<float, 2, Eigen::RowMajor> m(1, 4 * 3 * 2); + m.setConstant(i * 2.0); + + Eigen::DSizes<Eigen::DenseIndex, 2> indices(i, 0); + Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, 4 * 3 * 2); + e_2d.slice(indices, sizes) = m; + } + for (int i = 0; i < 10; i++) { + for (int j = 0; j < 4; j++) { + for (int k = 0; k < 3; k++) { + for (int l = 0; l < 2; l++) { + EXPECT_EQ(e_t(i, j, k, l), i * 2.0f); + LOG(INFO) << i << "," << j << "," << k << "," << l + << " &e_t(i, j, k, l): " << &e_t(i, j, k, l) << " = " + << e_t(i, j, k, l); + } + } + } + } +} + +TEST(Tensor_String, Simple) { + Tensor t = test::AsTensor<string>( + {"hello", "world", "machine", "learning", "new", "york"}, + TensorShape({3, 2})); + auto s = t.shape(); + ASSERT_EQ(s.dims(), 2); + ASSERT_EQ(s.dim_size(0), 3); + ASSERT_EQ(s.dim_size(1), 2); + auto m = t.matrix<string>(); + EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(string) + 5 + 5 + 7 + 8 + 3 + 4); + + EXPECT_EQ(m(0, 0), "hello"); + EXPECT_EQ(m(0, 1), "world"); + EXPECT_EQ(m(1, 0), "machine"); + EXPECT_EQ(m(1, 1), "learning"); + EXPECT_EQ(m(2, 0), "new"); + EXPECT_EQ(m(2, 1), "york"); + + TestCopies<string>(t); +} + +TEST(Tensor_Float, SimpleWithHelper) { + Tensor t1 = test::AsTensor<float>({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat<float>() = t1.flat<float>() * 2.0f; + Tensor t3 = test::AsTensor<float>({0, 2, 4, 6, 8, 10}, t1.shape()); + test::ExpectTensorEqual<float>(t2, t3); +} + +TEST(Tensor_Int32, SimpleWithHelper) { + Tensor t1 = test::AsTensor<int32>({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat<int32>() = t1.flat<int32>() * 2; + Tensor t3 = test::AsTensor<int32>({0, 2, 4, 6, 8, 10}, t1.shape()); + test::ExpectTensorEqual<int32>(t2, t3); +} + +TEST(Tensor_QInt8, SimpleWithHelper) { + Tensor t1 = test::AsTensor<qint8>({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat<qint8>() = t1.flat<qint8>() + qint8(-2); + Tensor t3 = test::AsTensor<qint8>({-2, -1, 0, 1, 2, 3}, {2, 3}); + test::ExpectTensorEqual<qint8>(t2, t3); +} + +TEST(Tensor_QUInt8, SimpleWithHelper) { + Tensor t1 = test::AsTensor<quint8>({0, 1, 2, 3, 4, 5}, {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat<quint8>() = t1.flat<quint8>() + quint8(2); + Tensor t3 = test::AsTensor<quint8>({2, 3, 4, 5, 6, 7}, {2, 3}); + test::ExpectTensorEqual<quint8>(t2, t3); +} + +TEST(Tensor_Int64, SimpleWithHelper) { + Tensor t1 = test::AsTensor<int64>( + {0LL << 48, 1LL << 48, 2LL << 48, 3LL << 48, 4LL << 48, 5LL << 48}, + {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat<int64>() = t1.flat<int64>() * static_cast<int64>(2); + Tensor t3 = test::AsTensor<int64>( + {0LL << 48, 2LL << 48, 4LL << 48, 6LL << 48, 8LL << 48, 10LL << 48}, + {2, 3}); + test::ExpectTensorEqual<int64>(t2, t3); +} + +TEST(Tensor_String, SimpleWithHelper) { + Tensor t1 = test::AsTensor<string>({"0", "1", "2", "3", "4", "5"}, {2, 3}); + Tensor t2(DT_STRING, {2, 3}); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + t2.matrix<string>()(i, j) = strings::StrCat(i * 3 + j); + } + } + + // Test with helper. + test::ExpectTensorEqual<string>(t1, t2); +} + +TEST(Tensor_Bool, SimpleWithHelper) { + Tensor t1 = + test::AsTensor<bool>({false, true, false, true, false, true}, {2, 3}); + + Tensor t2(DT_BOOL, {2, 3}); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + t2.matrix<bool>()(i, j) = (((i + j) % 2) != 0); + } + } + + // Test with helper. + test::ExpectTensorEqual<bool>(t1, t2); +} + +TEST(Tensor_Complex, Simple) { + Tensor t(DT_COMPLEX64, {4, 5, 3, 7}); + t.flat<complex64>().setRandom(); + TestCopies<complex64>(t); +} + +TEST(Tensor_Complex, SimpleWithHelper) { + { + Tensor t1 = test::AsTensor<complex64>({0, + {1, 1}, + complex64(2), + complex64(3, 3), + complex64(0, 4), + complex64(2, 5)}, + {2, 3}); + Tensor t2(t1.dtype(), t1.shape()); + t2.flat<complex64>() = t1.flat<complex64>() * complex64(0, 2); + Tensor t3 = test::AsTensor<complex64>( + {0, {-2, 2}, {0, 4}, {-6, 6}, {-8, 0}, {-10, 4}}, + // shape + {2, 3}); + test::ExpectTensorEqual<complex64>(t2, t3); + } + + // Does some numeric operations for complex numbers. + { + const float PI = std::acos(-1); + const complex64 rotate_45 = std::polar(1.0f, PI / 4); + + // x contains all the 8-th root of unity. + Tensor x(DT_COMPLEX64, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + x.vec<complex64>()(i) = std::pow(rotate_45, i); + } + + // Shift the roots by 45 degree. + Tensor y(DT_COMPLEX64, TensorShape({8})); + y.vec<complex64>() = x.vec<complex64>() * rotate_45; + Tensor y_expected(DT_COMPLEX64, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + y_expected.vec<complex64>()(i) = std::pow(rotate_45, i + 1); + } + test::ExpectTensorNear<complex64>(y, y_expected, 1e-5); + + // Raise roots to the power of 8. + Tensor z(DT_COMPLEX64, TensorShape({8})); + z.vec<complex64>() = x.vec<complex64>().pow(8); + Tensor z_expected(DT_COMPLEX64, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + z_expected.vec<complex64>()(i) = 1; + } + test::ExpectTensorNear<complex64>(z, z_expected, 1e-5); + } +} + +// On the alignment. +// +// As of 2015/8, tensorflow::Tensor allocates its buffer with 32-byte +// alignment. Tensor::tensor/flat/vec/matrix methods requires the the +// buffer satisfies Eigen::Aligned (e.g., 16-bytes aligned usually, +// and 32-bytes for AVX). Tensor::Slice requires the caller to ensure +// its result is aligned if the caller intends to use those methods. +// In this test case, we simply make sure each slice is 32-byte +// aligned: sizeof(float) * 4 * 2 = 32. +TEST(Tensor, Slice_Basic) { + Tensor saved; + { // General + Tensor x(DT_FLOAT, TensorShape({10, 4, 34})); + // Fills in known values. + for (int i = 0; i < 10; ++i) { + x.Slice(i, i + 1).flat<float>().setConstant(i * 1.f); + } + // A simple slice along dim0. + Tensor y = x.Slice(4, 8); + EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 4, 34}))); + auto tx = x.tensor<float, 3>(); + auto ty = y.tensor<float, 3>(); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 34; ++k) { + EXPECT_EQ(ty(i, j, k), 4.0 + i); + EXPECT_EQ(&tx(4 + i, j, k), &ty(i, j, k)); + } + } + } + // A simple slice equivalent to identity. + TestCopies<float>(y); + y = x.Slice(0, 10); + test::ExpectTensorEqual<float>(x, y); + EXPECT_EQ(x.flat<float>().data(), y.flat<float>().data()); + + // A slice of a slice. + auto z = x.Slice(4, 8).Slice(2, 3); + auto tz = z.tensor<float, 3>(); + EXPECT_EQ(1, z.dim_size(0)); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 34; ++k) { + EXPECT_EQ(tz(0, j, k), 6.0); + } + } + + // x and y will be out of scope. But 'saved' should be alive. + saved = z; + } + { + EXPECT_EQ(1, saved.dim_size(0)); + auto tsaved = saved.tensor<float, 3>(); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 34; ++k) { + EXPECT_EQ(tsaved(0, j, k), 6.0); + } + } + } + { // Empty + Tensor x(DT_FLOAT, TensorShape({10, 0, 34})); + x.flat<float>().setRandom(); + Tensor y = x.Slice(4, 8); + EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 0, 34}))); + } + + { + // Test unaligned access via a Slice. + Tensor x(DT_FLOAT, TensorShape({30})); + x.flat<float>().setConstant(0.0); + + // Take an unaligned slice. + Tensor y = x.Slice(1, 13); + y.unaligned_flat<float>().setConstant(1.0); + for (int64 i = 0; i < y.NumElements(); ++i) { + EXPECT_EQ(1.0, y.unaligned_flat<float>()(i)); + } + } +} + +static void BM_CreateAndDestroy(int iters) { + TensorShape shape({10, 20}); + while (--iters) { + Tensor t(DT_FLOAT, shape); + } +} +BENCHMARK(BM_CreateAndDestroy); + +static void BM_Assign(int iters) { + Tensor a(DT_FLOAT, TensorShape({10, 20})); + Tensor b(DT_FLOAT, TensorShape({10, 20})); + bool a_to_b = true; + while (--iters) { + if (a_to_b) { + b = a; + } else { + a = b; + } + a_to_b = !a_to_b; + } +} +BENCHMARK(BM_Assign); + +// Ensure tensor_data() works on empty tensors +TEST(Tensor, EmptyTensorData) { + Tensor empty; + EXPECT_EQ(empty.tensor_data().size(), 0); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc new file mode 100644 index 0000000000..b6cd12a864 --- /dev/null +++ b/tensorflow/core/framework/tensor_testutil.cc @@ -0,0 +1,43 @@ +#include <cmath> +#include "tensorflow/core/framework/tensor_testutil.h" + +namespace tensorflow { +namespace test { + +template <typename T> +bool IsClose(const T& x, const T& y, double atol, double rtol) { + return fabs(x - y) < atol + rtol * fabs(x); +} + +template <typename T> +void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { + auto Tx = x.flat<T>(); + auto Ty = y.flat<T>(); + for (int i = 0; i < Tx.size(); ++i) { + if (!IsClose(Tx(i), Ty(i), atol, rtol)) { + LOG(ERROR) << "x = " << x.DebugString(); + LOG(ERROR) << "y = " << y.DebugString(); + LOG(ERROR) << "atol = " << atol << " rtol = " << rtol + << " tol = " << atol + rtol * std::fabs(Tx(i)); + EXPECT_TRUE(false) << i << "-th element is not close " << Tx(i) << " vs. " + << Ty(i); + } + } +} + +void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { + internal::AssertSameTypeDims(x, y); + switch (x.dtype()) { + case DT_FLOAT: + ExpectClose<float>(x, y, atol, rtol); + break; + case DT_DOUBLE: + ExpectClose<double>(x, y, atol, rtol); + break; + default: + LOG(FATAL) << "Unexpected type : " << DataTypeString(x.dtype()); + } +} + +} // end namespace test +} // end namespace tensorflow diff --git a/tensorflow/core/framework/tensor_testutil.h b/tensorflow/core/framework/tensor_testutil.h new file mode 100644 index 0000000000..53d6da0fb2 --- /dev/null +++ b/tensorflow/core/framework/tensor_testutil.h @@ -0,0 +1,189 @@ +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/tensor.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace test { + +// Constructs a scalar tensor with 'val'. +template <typename T> +Tensor AsScalar(const T& val) { + Tensor ret(DataTypeToEnum<T>::value, {}); + ret.scalar<T>()() = val; + return ret; +} + +// Constructs a flat tensor with 'vals'. +template <typename T> +Tensor AsTensor(gtl::ArraySlice<T> vals) { + Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())}); + std::copy_n(vals.data(), vals.size(), ret.flat<T>().data()); + return ret; +} + +// Constructs a tensor of "shape" with values "vals". +template <typename T> +Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) { + Tensor ret; + CHECK(ret.CopyFrom(AsTensor(vals), shape)); + return ret; +} + +// Fills in '*tensor' with 'vals'. E.g., +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillValues<float>(&x, {11, 21, 21, 22}); +template <typename T> +void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) { + auto flat = tensor->flat<T>(); + CHECK_EQ(flat.size(), vals.size()); + if (flat.size() > 0) { + std::copy_n(vals.data(), vals.size(), flat.data()); + } +} + +// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillIota<float>(&x, 1.0); +template <typename T> +void FillIota(Tensor* tensor, const T& val) { + auto flat = tensor->flat<T>(); + std::iota(flat.data(), flat.data() + flat.size(), val); +} + +// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillFn<float>(&x, [](int i)->float { return i*i; }); +template <typename T> +void FillFn(Tensor* tensor, std::function<T(int)> fn) { + auto flat = tensor->flat<T>(); + for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i); +} + +// Expects "x" and "y" are tensors of the same type, same shape, and +// identical values. +template <typename T> +void ExpectTensorEqual(const Tensor& x, const Tensor& y); + +// Expects "x" and "y" are tensors of the same type, same shape, and +// approxmiate equal values, each within "abs_err". +template <typename T> +void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err); + +// Expects "x" and "y" are tensors of the same type (float or double), +// same shape and element-wise difference between x and y is no more +// than atol + rtol * abs(x). +void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6, + double rtol = 1e-6); + +// Implementation details. + +namespace internal { + +template <typename T> +struct is_floating_point_type { + static const bool value = std::is_same<T, float>::value || + std::is_same<T, double>::value || + std::is_same<T, std::complex<float> >::value || + std::is_same<T, std::complex<double> >::value; +}; + +template <typename T> +static void ExpectEqual(const T& a, const T& b) { + EXPECT_EQ(a, b); +} + +template <> +void ExpectEqual<float>(const float& a, const float& b) { + EXPECT_FLOAT_EQ(a, b); +} + +template <> +void ExpectEqual<double>(const double& a, const double& b) { + EXPECT_DOUBLE_EQ(a, b); +} + +template <> +void ExpectEqual<complex64>(const complex64& a, const complex64& b) { + EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b; + EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b; +} + +inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), y.dtype()); + ASSERT_TRUE(x.IsSameSize(y)) + << "x.shape [" << x.shape().DebugString() << "] vs " + << "y.shape [ " << y.shape().DebugString() << "]"; +} + +template <typename T, bool is_fp = is_floating_point_type<T>::value> +struct Expector; + +template <typename T> +struct Expector<T, false> { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v()); + AssertSameTypeDims(x, y); + auto a = x.flat<T>(); + auto b = y.flat<T>(); + for (int i = 0; i < a.size(); ++i) { + ExpectEqual(a(i), b(i)); + } + } +}; + +// Partial specialization for float and double. +template <typename T> +struct Expector<T, true> { + static void Equal(const T& a, const T& b) { ExpectEqual(a, b); } + + static void Equal(const Tensor& x, const Tensor& y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v()); + AssertSameTypeDims(x, y); + auto a = x.flat<T>(); + auto b = y.flat<T>(); + for (int i = 0; i < a.size(); ++i) { + ExpectEqual(a(i), b(i)); + } + } + + static void Near(const T& a, const T& b, const double abs_err) { + if (a != b) { // Takes care of inf. + EXPECT_LE(std::abs(a - b), abs_err) << "a = " << a << " b = " << b; + } + } + + static void Near(const Tensor& x, const Tensor& y, const double abs_err) { + ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v()); + AssertSameTypeDims(x, y); + auto a = x.flat<T>(); + auto b = y.flat<T>(); + for (int i = 0; i < a.size(); ++i) { + Near(a(i), b(i), abs_err); + } + } +}; + +} // namespace internal + +template <typename T> +void ExpectTensorEqual(const Tensor& x, const Tensor& y) { + internal::Expector<T>::Equal(x, y); +} + +template <typename T> +void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) { + static_assert(internal::is_floating_point_type<T>::value, + "T is not a floating point types."); + internal::Expector<T>::Near(x, y, abs_err); +} + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_ diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h new file mode 100644 index 0000000000..077d86d442 --- /dev/null +++ b/tensorflow/core/framework/tensor_types.h @@ -0,0 +1,92 @@ +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +// Helper to define Tensor types given that the scalar is of type T. +template <typename T, int NDIMS = 1> +struct TTypes { + // Rank-<NDIMS> tensor of scalar type T. + typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor>, + Eigen::Aligned> Tensor; + typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor>, + Eigen::Aligned> ConstTensor; + + // Unaligned Rank-<NDIMS> tensor of scalar type T. + typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor> > + UnalignedTensor; + typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor> > + UnalignedConstTensor; + + typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>, + Eigen::Aligned> Tensor32Bit; + + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + typedef Eigen::TensorMap< + Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor>, + Eigen::Aligned> Scalar; + typedef Eigen::TensorMap< + Eigen::TensorFixedSize<const T, Eigen::Sizes<>, Eigen::RowMajor>, + Eigen::Aligned> ConstScalar; + + // Unaligned Scalar tensor of scalar type T. + typedef Eigen::TensorMap<Eigen::TensorFixedSize< + T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedScalar; + typedef Eigen::TensorMap<Eigen::TensorFixedSize< + const T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedConstScalar; + + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned> + Flat; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>, + Eigen::Aligned> ConstFlat; + typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned> + Vec; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>, + Eigen::Aligned> ConstVec; + + // Unaligned Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedFlat; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> > + UnalignedConstFlat; + typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedVec; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> > + UnalignedConstVec; + + // Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned> + Matrix; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Aligned> ConstMatrix; + + // Unaligned Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor> > + UnalignedMatrix; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor> > + UnalignedConstMatrix; +}; + +typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32; + +template <typename DSizes> +Eigen::DSizes<Index32, DSizes::count> To32BitDims(const DSizes& in) { + Eigen::DSizes<Index32, DSizes::count> out; + for (int i = 0; i < DSizes::count; ++i) { + out[i] = in[i]; + } + return out; +} + +template <typename TensorType> +typename TTypes<typename TensorType::Scalar, + TensorType::NumIndices>::Tensor32Bit +To32Bit(TensorType in) { + typedef typename TTypes<typename TensorType::Scalar, + TensorType::NumIndices>::Tensor32Bit RetType; + return RetType(in.data(), To32BitDims(in.dimensions())); +} + +} // namespace tensorflow +#endif // TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_ diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc new file mode 100644 index 0000000000..7353191c74 --- /dev/null +++ b/tensorflow/core/framework/tensor_util.cc @@ -0,0 +1,28 @@ +#include "tensorflow/core/framework/tensor_util.h" + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace tensor { + +Tensor DeepCopy(const Tensor& other) { + Tensor tmp = Tensor(other.dtype(), other.shape()); + if (DataTypeCanUseMemcpy(other.dtype())) { + StringPiece other_data = other.tensor_data(); + + // We use StringPiece as a convenient map over the tensor buffer, + // but we cast the type to get to the underlying buffer to do the + // copy. + StringPiece tmp_data = tmp.tensor_data(); + memcpy(const_cast<char*>(tmp_data.data()), other_data.data(), + other_data.size()); + } else { + CHECK_EQ(DT_STRING, other.dtype()); + tmp.flat<string>() = other.flat<string>(); + } + return tmp; +} + +} // namespace tensor +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h new file mode 100644 index 0000000000..a8dde1d0ca --- /dev/null +++ b/tensorflow/core/framework/tensor_util.h @@ -0,0 +1,21 @@ +#ifndef TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ +#define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ + +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace tensor { + +// DeepCopy returns a tensor whose contents are a deep copy of the +// contents of 'other'. This function is intended only for +// convenience, not speed. +// +// REQUIRES: 'other' must point to data stored in CPU memory. +// REQUIRES: 'other' must be a Tensor of a copy-able type if +// 'other' is not appropriately memory-aligned. +Tensor DeepCopy(const Tensor& other); + +} // namespace tensor +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_ diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc new file mode 100644 index 0000000000..fef7468151 --- /dev/null +++ b/tensorflow/core/framework/tensor_util_test.cc @@ -0,0 +1,124 @@ +#include "tensorflow/core/framework/tensor_util.h" + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/public/tensor.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +TEST(TensorUtil, DeepCopy0d) { + Tensor x(DT_FLOAT, TensorShape({})); + x.scalar<float>()() = 10.0; + + // Make y a deep copy of x and then change it. + Tensor y = tensor::DeepCopy(x); + y.scalar<float>()() = 20.0; + + // x doesn't change + EXPECT_EQ(10.0, x.scalar<float>()()); + + // Change x. + x.scalar<float>()() = 30.0; + + // Y doesn't change. + EXPECT_EQ(20.0, y.scalar<float>()()); + + Tensor z = tensor::DeepCopy(y); + + // Change y. + y.scalar<float>()() = 40.0; + + // The final states should all be different. + EXPECT_EQ(20.0, z.scalar<float>()()); + EXPECT_EQ(30.0, x.scalar<float>()()); + EXPECT_EQ(40.0, y.scalar<float>()()); + + // Should have the same shape and type. + EXPECT_EQ(TensorShape({}), x.shape()); + EXPECT_EQ(TensorShape({}), y.shape()); + EXPECT_EQ(TensorShape({}), z.shape()); + + EXPECT_EQ(DT_FLOAT, x.dtype()); + EXPECT_EQ(DT_FLOAT, y.dtype()); + EXPECT_EQ(DT_FLOAT, z.dtype()); +} + +TEST(TensorUtil, DeepCopy) { + Tensor x(DT_FLOAT, TensorShape({1})); + x.flat<float>()(0) = 10.0; + + // Make y a deep copy of x and then change it. + Tensor y = tensor::DeepCopy(x); + y.flat<float>()(0) = 20.0; + + // x doesn't change + EXPECT_EQ(10.0, x.flat<float>()(0)); + + // Change x. + x.flat<float>()(0) = 30.0; + + // Y doesn't change. + EXPECT_EQ(20.0, y.flat<float>()(0)); + + Tensor z = tensor::DeepCopy(y); + + // Change y. + y.flat<float>()(0) = 40.0; + + // The final states should all be different. + EXPECT_EQ(20.0, z.flat<float>()(0)); + EXPECT_EQ(30.0, x.flat<float>()(0)); + EXPECT_EQ(40.0, y.flat<float>()(0)); + + // Should have the same shape and type. + EXPECT_EQ(TensorShape({1}), x.shape()); + EXPECT_EQ(TensorShape({1}), y.shape()); + EXPECT_EQ(TensorShape({1}), z.shape()); + + EXPECT_EQ(DT_FLOAT, x.dtype()); + EXPECT_EQ(DT_FLOAT, y.dtype()); + EXPECT_EQ(DT_FLOAT, z.dtype()); + + // Test string deep copy + Tensor str1(DT_STRING, TensorShape({2})); + str1.flat<string>()(0) = "foo1"; + str1.flat<string>()(1) = "foo2"; + Tensor str2 = tensor::DeepCopy(str1); + str2.flat<string>()(0) = "bar1"; + str2.flat<string>()(1) = "bar2"; + EXPECT_NE(str2.flat<string>()(0), str1.flat<string>()(0)); +} + +TEST(TensorUtil, DeepCopySlice) { + Tensor x(DT_INT32, TensorShape({10})); + x.flat<int32>().setConstant(1); + + // Slice 'x' -- y still refers to the same buffer. + Tensor y = x.Slice(2, 6); + + // Do a deep copy of y, which is a slice. + Tensor z = tensor::DeepCopy(y); + + // Set x to be different. + x.flat<int32>().setConstant(2); + + EXPECT_EQ(TensorShape({10}), x.shape()); + EXPECT_EQ(TensorShape({4}), y.shape()); + EXPECT_EQ(TensorShape({4}), z.shape()); + EXPECT_EQ(DT_INT32, x.dtype()); + EXPECT_EQ(DT_INT32, y.dtype()); + EXPECT_EQ(DT_INT32, z.dtype()); + + // x and y should now all be '2', but z should be '1'. + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(2, x.flat<int32>()(i)); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(2, y.unaligned_flat<int32>()(i)); + EXPECT_EQ(1, z.flat<int32>()(i)); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/tracking_allocator.cc b/tensorflow/core/framework/tracking_allocator.cc new file mode 100644 index 0000000000..78311ded19 --- /dev/null +++ b/tensorflow/core/framework/tracking_allocator.cc @@ -0,0 +1,100 @@ +#include "tensorflow/core/framework/tracking_allocator.h" + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +TrackingAllocator::TrackingAllocator(Allocator* allocator) + : allocator_(allocator), + ref_(1), + allocated_(0), + high_watermark_(0), + total_bytes_(0) {} + +void* TrackingAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { + void* ptr = allocator_->AllocateRaw(alignment, num_bytes); + // If memory is exhausted AllocateRaw returns nullptr, and we should + // pass this through to the caller + if (nullptr == ptr) { + return ptr; + } + if (allocator_->TracksAllocationSizes()) { + size_t allocated_bytes = allocator_->AllocatedSize(ptr); + { + mutex_lock lock(mu_); + allocated_ += allocated_bytes; + high_watermark_ = std::max(high_watermark_, allocated_); + total_bytes_ += allocated_bytes; + ++ref_; + } + } else { + mutex_lock lock(mu_); + total_bytes_ += num_bytes; + ++ref_; + } + return ptr; +} + +void TrackingAllocator::DeallocateRaw(void* ptr) { + // freeing a null ptr is a no-op + if (nullptr == ptr) { + return; + } + bool should_delete; + // fetch the following outside the lock in case the call to + // AllocatedSize is slow + bool tracks_allocation_sizes = allocator_->TracksAllocationSizes(); + size_t allocated_bytes = 0; + if (tracks_allocation_sizes) { + allocated_bytes = allocator_->AllocatedSize(ptr); + } + Allocator* allocator = allocator_; + { + mutex_lock lock(mu_); + if (tracks_allocation_sizes) { + CHECK_GE(allocated_, allocated_bytes); + allocated_ -= allocated_bytes; + } + should_delete = UnRef(); + } + allocator->DeallocateRaw(ptr); + if (should_delete) { + delete this; + } +} + +bool TrackingAllocator::TracksAllocationSizes() { + return allocator_->TracksAllocationSizes(); +} + +size_t TrackingAllocator::RequestedSize(void* ptr) { + return allocator_->RequestedSize(ptr); +} + +size_t TrackingAllocator::AllocatedSize(void* ptr) { + return allocator_->AllocatedSize(ptr); +} + +std::pair<size_t, size_t> TrackingAllocator::GetSizesAndUnRef() { + size_t high_watermark; + size_t total_bytes; + bool should_delete; + { + mutex_lock lock(mu_); + high_watermark = high_watermark_; + total_bytes = total_bytes_; + should_delete = UnRef(); + } + if (should_delete) { + delete this; + } + return std::make_pair(total_bytes, high_watermark); +} + +bool TrackingAllocator::UnRef() { + CHECK_GE(ref_, 1); + --ref_; + return (ref_ == 0); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h new file mode 100644 index 0000000000..f809e3822c --- /dev/null +++ b/tensorflow/core/framework/tracking_allocator.h @@ -0,0 +1,80 @@ +#ifndef TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_ +#define TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// TrackingAllocator is a wrapper for an Allocator. It keeps a running +// count of the number of bytes allocated through the wrapper. It is +// used by the Executor to "charge" allocations to particular Op +// executions. Each Op gets a separate TrackingAllocator wrapper +// around the underlying allocator. +// +// The implementation assumes the invariant that all calls to +// AllocateRaw by an Op (or work items spawned by the Op) will occur +// before the Op's Compute method returns. Thus the high watermark is +// established once Compute returns. +// +// DeallocateRaw can be called long after the Op has finished, +// e.g. when an output tensor is deallocated, and the wrapper cannot +// be deleted until the last of these calls has occurred. The +// TrackingAllocator keeps track of outstanding calls using a +// reference count, and deletes itself once the last call has been +// received and the high watermark has been retrieved. +class TrackingAllocator : public Allocator { + public: + explicit TrackingAllocator(Allocator* allocator); + string Name() override { return allocator_->Name(); } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + bool TracksAllocationSizes() override; + size_t RequestedSize(void* ptr) override; + size_t AllocatedSize(void* ptr) override; + + // If the underlying allocator tracks allocation sizes, this returns + // a pair where the first value is the total number of bytes + // allocated through this wrapper, and the second value is the high + // watermark of bytes allocated through this wrapper. If the + // underlying allocator does not track allocation sizes the first + // value is the total number of bytes requested through this wrapper + // and the second is 0. + // + // After GetSizesAndUnref is called, the only further calls allowed + // on this wrapper are calls to DeallocateRaw with pointers that + // were allocated by this wrapper and have not yet been + // deallocated. After this call completes and all allocated pointers + // have been deallocated the wrapper will delete itself. + std::pair<size_t, size_t> GetSizesAndUnRef(); + + private: + ~TrackingAllocator() override {} + bool UnRef() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + Allocator* allocator_; // not owned. + mutex mu_; + // the number of calls to AllocateRaw that have not yet been matched + // by a corresponding call to DeAllocateRaw, plus 1 if the Executor + // has not yet read out the high watermark. + int ref_ GUARDED_BY(mu_); + // the current number of outstanding bytes that have been allocated + // by this wrapper, or 0 if the underlying allocator does not track + // allocation sizes. + size_t allocated_ GUARDED_BY(mu_); + // the maximum number of outstanding bytes that have been allocated + // by this wrapper, or 0 if the underlying allocator does not track + // allocation sizes. + size_t high_watermark_ GUARDED_BY(mu_); + // the total number of bytes that have been allocated by this + // wrapper if the underlying allocator tracks allocation sizes, + // otherwise the total number of bytes that have been requested by + // this allocator. + size_t total_bytes_ GUARDED_BY(mu_); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_ diff --git a/tensorflow/core/framework/tracking_allocator_test.cc b/tensorflow/core/framework/tracking_allocator_test.cc new file mode 100644 index 0000000000..90ce851775 --- /dev/null +++ b/tensorflow/core/framework/tracking_allocator_test.cc @@ -0,0 +1,115 @@ +#include "tensorflow/core/framework/tracking_allocator.h" + +#include <unordered_map> + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> + +namespace tensorflow { + +class TestableSizeTrackingAllocator : public Allocator { + public: + string Name() override { return "test"; } + void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override { + void* ptr = malloc(num_bytes); + size_map_[ptr] = num_bytes; + return ptr; + } + void DeallocateRaw(void* ptr) override { + const auto& iter = size_map_.find(ptr); + EXPECT_NE(size_map_.end(), iter); + size_map_.erase(iter); + free(ptr); + } + bool TracksAllocationSizes() override { return true; } + size_t RequestedSize(void* ptr) override { + const auto& iter = size_map_.find(ptr); + EXPECT_NE(size_map_.end(), iter); + return iter->second; + } + + private: + std::unordered_map<void*, size_t> size_map_; +}; + +class NoMemoryAllocator : public Allocator { + public: + string Name() override { return "test"; } + void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override { + return nullptr; + } + void DeallocateRaw(void* ptr) override {} + bool TracksAllocationSizes() override { return true; } +}; + +TEST(TrackingAllocatorTest, SimpleNoTracking) { + Allocator* a = cpu_allocator(); + + EXPECT_FALSE(a->TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(a); + + void* p1 = ta->AllocateRaw(4, 4); + ta->Deallocate(p1); + void* p2 = ta->AllocateRaw(4, 12); + + std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef(); + + EXPECT_EQ(16, sizes.first); + EXPECT_EQ(0, sizes.second); + + ta->Deallocate(p2); +} + +TEST(TrackingAllocatorTest, SimpleTracking) { + TestableSizeTrackingAllocator a = TestableSizeTrackingAllocator(); + + EXPECT_TRUE(a.TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(&a); + + void* p1 = ta->AllocateRaw(4, 12); + ta->Deallocate(p1); + void* p2 = ta->AllocateRaw(4, 4); + + std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef(); + + EXPECT_EQ(16, sizes.first); + EXPECT_EQ(12, sizes.second); + + ta->Deallocate(p2); +} + +TEST(TrackingAllocatorTest, OutOfMemory) { + NoMemoryAllocator a; + + EXPECT_TRUE(a.TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(&a); + + void* p1 = ta->AllocateRaw(4, 12); + EXPECT_EQ(nullptr, p1); + + std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef(); + + EXPECT_EQ(0, sizes.first); + EXPECT_EQ(0, sizes.second); +} + +TEST(TrackingAllocatorTest, FreeNullPtr) { + NoMemoryAllocator a; + + EXPECT_TRUE(a.TracksAllocationSizes()); + + TrackingAllocator* ta = new TrackingAllocator(&a); + + ta->DeallocateRaw(nullptr); + + std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef(); + + EXPECT_EQ(0, sizes.first); + EXPECT_EQ(0, sizes.second); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/type_traits.h b/tensorflow/core/framework/type_traits.h new file mode 100644 index 0000000000..d87b6ff49b --- /dev/null +++ b/tensorflow/core/framework/type_traits.h @@ -0,0 +1,69 @@ +#ifndef TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ +#define TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ + +#include <limits> +#include <utility> + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/port.h" + +namespace tensorflow { + +// Functions to define quantization attribute of types. +struct true_type { + static const bool value = true; +}; +struct false_type { + static const bool value = false; +}; + +// Default is_quantized is false. +template <typename T> +struct is_quantized : false_type {}; + +// Specialize the quantized types. +template <> +struct is_quantized<qint8> : true_type {}; +template <> +struct is_quantized<quint8> : true_type {}; +template <> +struct is_quantized<qint32> : true_type {}; + +// All types not specialized are marked invalid. +template <class T> +struct IsValidDataType { + static constexpr bool value = false; +}; + +// Extra validity checking; not part of public API. +struct TestIsValidDataType { + static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64"); + static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32"); +}; + +} // namespace tensorflow + +// Define numeric limits for our quantized as subclasses of the +// standard types. +namespace std { +template <> +class numeric_limits<tensorflow::qint8> + : public numeric_limits<tensorflow::int8> {}; +template <> +class numeric_limits<tensorflow::quint8> + : public numeric_limits<tensorflow::uint8> {}; +template <> +class numeric_limits<tensorflow::qint32> + : public numeric_limits<tensorflow::int32> {}; + +// Specialize is_signed for quantized types. +template <> +struct is_signed<tensorflow::qint8> : public is_signed<tensorflow::int8> {}; +template <> +struct is_signed<tensorflow::quint8> : public is_signed<tensorflow::uint8> {}; +template <> +struct is_signed<tensorflow::qint32> : public is_signed<tensorflow::int32> {}; + +} // namespace std + +#endif // TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_ diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc new file mode 100644 index 0000000000..01b9fca3b6 --- /dev/null +++ b/tensorflow/core/framework/types.cc @@ -0,0 +1,210 @@ +#include "tensorflow/core/framework/types.h" + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +bool DeviceType::operator<(const DeviceType& other) const { + return type_ < other.type_; +} + +bool DeviceType::operator==(const DeviceType& other) const { + return type_ == other.type_; +} + +std::ostream& operator<<(std::ostream& os, const DeviceType& d) { + os << d.type(); + return os; +} + +const char* const DEVICE_CPU = "CPU"; +const char* const DEVICE_GPU = "GPU"; + +string DataTypeString(DataType dtype) { + if (IsRefType(dtype)) { + DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset); + return strings::StrCat(DataTypeString(non_ref), "_ref"); + } + switch (dtype) { + case DT_INVALID: + return "INVALID"; + case DT_FLOAT: + return "float"; + case DT_DOUBLE: + return "double"; + case DT_INT32: + return "int32"; + case DT_UINT8: + return "uint8"; + case DT_INT16: + return "int16"; + case DT_INT8: + return "int8"; + case DT_STRING: + return "string"; + case DT_COMPLEX64: + return "complex64"; + case DT_INT64: + return "int64"; + case DT_BOOL: + return "bool"; + case DT_QINT8: + return "qint8"; + case DT_QUINT8: + return "quint8"; + case DT_QINT32: + return "qint32"; + case DT_BFLOAT16: + return "bfloat16"; + default: + LOG(FATAL) << "Unrecognized DataType enum value " << dtype; + return ""; + } +} + +bool DataTypeFromString(StringPiece sp, DataType* dt) { + if (sp.ends_with("_ref")) { + sp.remove_suffix(4); + DataType non_ref; + if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) { + *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset); + return true; + } else { + return false; + } + } + + if (sp == "float" || sp == "float32") { + *dt = DT_FLOAT; + return true; + } else if (sp == "double" || sp == "float64") { + *dt = DT_DOUBLE; + return true; + } else if (sp == "int32") { + *dt = DT_INT32; + return true; + } else if (sp == "uint8") { + *dt = DT_UINT8; + return true; + } else if (sp == "int16") { + *dt = DT_INT16; + return true; + } else if (sp == "int8") { + *dt = DT_INT8; + return true; + } else if (sp == "string") { + *dt = DT_STRING; + return true; + } else if (sp == "complex64") { + *dt = DT_COMPLEX64; + return true; + } else if (sp == "int64") { + *dt = DT_INT64; + return true; + } else if (sp == "bool") { + *dt = DT_BOOL; + return true; + } else if (sp == "qint8") { + *dt = DT_QINT8; + return true; + } else if (sp == "quint8") { + *dt = DT_QUINT8; + return true; + } else if (sp == "qint32") { + *dt = DT_QINT32; + return true; + } else if (sp == "bfloat16") { + *dt = DT_BFLOAT16; + return true; + } + return false; +} + +string DeviceTypeString(DeviceType device_type) { return device_type.type(); } + +string DataTypeSliceString(const DataTypeSlice types) { + string out; + for (auto it = types.begin(); it != types.end(); ++it) { + strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "), + DataTypeString(*it)); + } + return out; +} + +DataTypeVector AllTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, + DT_INT8, DT_STRING, DT_COMPLEX64, DT_INT64, DT_BOOL, + DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +#ifndef __ANDROID__ + +DataTypeVector RealNumberTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8}; +} + +DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; } + +DataTypeVector RealAndQuantizedTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, + DT_INT16, DT_INT8, DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +DataTypeVector NumberTypes() { + return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, + DT_INT8, DT_COMPLEX64, DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +#else // __ANDROID__ + +DataTypeVector RealNumberTypes() { return {DT_FLOAT, DT_INT32}; } + +DataTypeVector NumberTypes() { + return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; } + +DataTypeVector RealAndQuantizedTypes() { + return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32}; +} + +#endif // __ANDROID__ + +// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying +// is_simple<T> in tensor.cc (and possible choose a more general name?) +bool DataTypeCanUseMemcpy(DataType dt) { + switch (dt) { + case DT_FLOAT: + case DT_DOUBLE: + case DT_INT32: + case DT_UINT8: + case DT_INT16: + case DT_INT8: + case DT_COMPLEX64: + case DT_INT64: + case DT_BOOL: + case DT_QINT8: + case DT_QUINT8: + case DT_QINT32: + case DT_BFLOAT16: + return true; + default: + return false; + } +} + +bool DataTypeIsQuantized(DataType dt) { + switch (dt) { + case DT_QINT8: + case DT_QUINT8: + case DT_QINT32: + return true; + default: + return false; + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h new file mode 100644 index 0000000000..2d417cf076 --- /dev/null +++ b/tensorflow/core/framework/types.h @@ -0,0 +1,168 @@ +#ifndef TENSORFLOW_FRAMEWORK_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_TYPES_H_ + +#include <map> +#include <set> +#include <string> + +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint" + +namespace tensorflow { + +// MemoryType is used to describe whether input or output Tensors of +// an OpKernel should reside in "Host memory" (e.g., CPU memory) or +// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU +// devices). +enum MemoryType { + DEVICE_MEMORY = 0, + HOST_MEMORY = 1, +}; + +// A DeviceType is just a string, but we wrap it up in a class to give +// some type checking as we're passing these around +class DeviceType { + public: + DeviceType(const char* type) // NOLINT(runtime/explicit) + : type_(type) {} + + explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {} + + const char* type() const { return type_.c_str(); } + + bool operator<(const DeviceType& other) const; + bool operator==(const DeviceType& other) const; + bool operator!=(const DeviceType& other) const { return !(*this == other); } + + private: + string type_; +}; +std::ostream& operator<<(std::ostream& os, const DeviceType& d); + +// Convenient constants that can be passed to a DeviceType constructor +extern const char* const DEVICE_CPU; // "CPU" +extern const char* const DEVICE_GPU; // "GPU" + +typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector; + +typedef gtl::InlinedVector<DataType, 4> DataTypeVector; +typedef gtl::ArraySlice<DataType> DataTypeSlice; + +typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector; + +// Convert the enums to strings for errors: +string DataTypeString(DataType dtype); +string DeviceTypeString(DeviceType device_type); +string DataTypeSliceString(const DataTypeSlice dtypes); +inline string DataTypeVectorString(const DataTypeVector& dtypes) { + return DataTypeSliceString(dtypes); +} + +// If "sp" names a valid type, store it in "*dt" and return true. Otherwise, +// return false. +bool DataTypeFromString(StringPiece sp, DataType* dt); + +// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc. +enum { kDataTypeRefOffset = 100 }; +inline bool IsRefType(DataType dtype) { + return dtype > static_cast<DataType>(kDataTypeRefOffset); +} +inline DataType MakeRefType(DataType dtype) { + DCHECK(!IsRefType(dtype)); + return static_cast<DataType>(dtype + kDataTypeRefOffset); +} +inline DataType RemoveRefType(DataType dtype) { + DCHECK(IsRefType(dtype)); + return static_cast<DataType>(dtype - kDataTypeRefOffset); +} +inline DataType BaseType(DataType dtype) { + return IsRefType(dtype) ? RemoveRefType(dtype) : dtype; +} + +// Returns true if the actual type is the same as or ref of the expected type. +inline bool TypesCompatible(DataType expected, DataType actual) { + return expected == actual || expected == BaseType(actual); +} + +// Does not include _ref types. +DataTypeVector AllTypes(); + +// Return the list of all numeric types. +// NOTE: On Android, we only include the float and int32 types for now. +DataTypeVector RealNumberTypes(); // Types that support '<' and '>'. +DataTypeVector NumberTypes(); // Includes complex and quantized types. + +DataTypeVector QuantizedTypes(); +DataTypeVector RealAndQuantizedTypes(); // Types that support '<' and + // '>', including quantized + // types + +// Validates type T for whether it is a supported DataType. +template <class T> +struct IsValidDataType; + +// DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType +// constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT. +template <class T> +struct DataTypeToEnum { + static_assert(IsValidDataType<T>::value, "Specified Data Type not supported"); +}; // Specializations below + +// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g. +// EnumToDataType<DT_FLOAT>::Type is float. +template <DataType VALUE> +struct EnumToDataType {}; // Specializations below + +// Template specialization for both DataTypeToEnum and EnumToDataType. +#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ + template <> \ + struct DataTypeToEnum<TYPE> { \ + static DataType v() { return ENUM; } \ + static DataType ref() { return MakeRefType(ENUM); } \ + static constexpr DataType value = ENUM; \ + }; \ + template <> \ + struct IsValidDataType<TYPE> { \ + static constexpr bool value = true; \ + }; \ + template <> \ + struct EnumToDataType<ENUM> { \ + typedef TYPE Type; \ + } + +// We use Eigen's QInt implementations for our quantized int types. +typedef Eigen::QInt8 qint8; +typedef Eigen::QUInt8 quint8; +typedef Eigen::QInt32 qint32; + +MATCH_TYPE_AND_ENUM(float, DT_FLOAT); +MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); +MATCH_TYPE_AND_ENUM(int32, DT_INT32); +MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); +MATCH_TYPE_AND_ENUM(int16, DT_INT16); +MATCH_TYPE_AND_ENUM(int8, DT_INT8); +MATCH_TYPE_AND_ENUM(string, DT_STRING); +MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64); +MATCH_TYPE_AND_ENUM(int64, DT_INT64); +MATCH_TYPE_AND_ENUM(bool, DT_BOOL); +MATCH_TYPE_AND_ENUM(qint8, DT_QINT8); +MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8); +MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); +MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); + +#undef MATCH_TYPE_AND_ENUM + +bool DataTypeCanUseMemcpy(DataType dt); + +bool DataTypeIsQuantized(DataType dt); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_TYPES_H_ diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto new file mode 100644 index 0000000000..e5dc9c45a0 --- /dev/null +++ b/tensorflow/core/framework/types.proto @@ -0,0 +1,48 @@ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + + // TODO(josh11b): DT_GENERIC_PROTO = ??; + // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? DT_UINT16? + // TODO(zhifengc): DT_COMPLEX128 (double-precision complex)? + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; +} diff --git a/tensorflow/core/framework/types_test.cc b/tensorflow/core/framework/types_test.cc new file mode 100644 index 0000000000..eb92600397 --- /dev/null +++ b/tensorflow/core/framework/types_test.cc @@ -0,0 +1,117 @@ +#include "tensorflow/core/framework/types.h" + +#include <gtest/gtest.h> +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace { + +TEST(TypesTest, DeviceTypeName) { + EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU))); + EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU))); +} + +TEST(TypesTest, kDataTypeRefOffset) { + // Basic sanity check + EXPECT_EQ(DT_FLOAT + kDataTypeRefOffset, DT_FLOAT_REF); + + // Use the meta-data provided by proto2 to iterate through the basic + // types and validate that adding kDataTypeRefOffset gives the + // corresponding reference type. + const auto* enum_descriptor = DataType_descriptor(); + int e = DataType_MIN; + if (e == DT_INVALID) ++e; + int e_ref = e + kDataTypeRefOffset; + EXPECT_FALSE(DataType_IsValid(e_ref - 1)) + << "Reference enum " + << enum_descriptor->FindValueByNumber(e_ref - 1)->name() + << " without corresponding base enum with value " << e - 1; + for (; + DataType_IsValid(e) && DataType_IsValid(e_ref) && e_ref <= DataType_MAX; + ++e, ++e_ref) { + string enum_name = enum_descriptor->FindValueByNumber(e)->name(); + string enum_ref_name = enum_descriptor->FindValueByNumber(e_ref)->name(); + EXPECT_EQ(enum_name + "_REF", enum_ref_name) + << enum_name << "_REF should have value " << e_ref << " not " + << enum_ref_name; + // Validate DataTypeString() as well. + DataType dt_e = static_cast<DataType>(e); + DataType dt_e_ref = static_cast<DataType>(e_ref); + EXPECT_EQ(DataTypeString(dt_e) + "_ref", DataTypeString(dt_e_ref)); + + // Test DataTypeFromString reverse conversion + DataType dt_e2, dt_e2_ref; + EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e), &dt_e2)); + EXPECT_EQ(dt_e, dt_e2); + EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e_ref), &dt_e2_ref)); + EXPECT_EQ(dt_e_ref, dt_e2_ref); + } + ASSERT_FALSE(DataType_IsValid(e)) + << "Should define " << enum_descriptor->FindValueByNumber(e)->name() + << "_REF to be " << e_ref; + ASSERT_FALSE(DataType_IsValid(e_ref)) + << "Extra reference enum " + << enum_descriptor->FindValueByNumber(e_ref)->name() + << " without corresponding base enum with value " << e; + ASSERT_LT(DataType_MAX, e_ref) << "Gap in reference types, missing value for " + << e_ref; + + // Make sure there are no enums defined after the last regular type before + // the first reference type. + for (; e < DataType_MIN + kDataTypeRefOffset; ++e) { + EXPECT_FALSE(DataType_IsValid(e)) + << "Discontinuous enum value " + << enum_descriptor->FindValueByNumber(e)->name() << " = " << e; + } +} + +TEST(TypesTest, DataTypeFromString) { + DataType dt; + ASSERT_TRUE(DataTypeFromString("int32", &dt)); + EXPECT_EQ(DT_INT32, dt); + ASSERT_TRUE(DataTypeFromString("int32_ref", &dt)); + EXPECT_EQ(DT_INT32_REF, dt); + EXPECT_FALSE(DataTypeFromString("int32_ref_ref", &dt)); + EXPECT_FALSE(DataTypeFromString("foo", &dt)); + EXPECT_FALSE(DataTypeFromString("foo_ref", &dt)); + ASSERT_TRUE(DataTypeFromString("int64", &dt)); + EXPECT_EQ(DT_INT64, dt); + ASSERT_TRUE(DataTypeFromString("int64_ref", &dt)); + EXPECT_EQ(DT_INT64_REF, dt); + ASSERT_TRUE(DataTypeFromString("quint8_ref", &dt)); + EXPECT_EQ(DT_QUINT8_REF, dt); + ASSERT_TRUE(DataTypeFromString("bfloat16", &dt)); + EXPECT_EQ(DT_BFLOAT16, dt); +} + +template <typename T> +static bool GetQuantized() { + return is_quantized<T>::value; +} + +TEST(TypesTest, QuantizedTypes) { + // NOTE: GUnit cannot parse is::quantized<TYPE>::value() within the + // EXPECT_TRUE() clause, so we delegate through a template function. + EXPECT_TRUE(GetQuantized<qint8>()); + EXPECT_TRUE(GetQuantized<quint8>()); + EXPECT_TRUE(GetQuantized<qint32>()); + + EXPECT_FALSE(GetQuantized<int8>()); + EXPECT_FALSE(GetQuantized<uint8>()); + EXPECT_FALSE(GetQuantized<int16>()); + EXPECT_FALSE(GetQuantized<int32>()); + + EXPECT_TRUE(DataTypeIsQuantized(DT_QINT8)); + EXPECT_TRUE(DataTypeIsQuantized(DT_QUINT8)); + EXPECT_TRUE(DataTypeIsQuantized(DT_QINT32)); + + EXPECT_FALSE(DataTypeIsQuantized(DT_INT8)); + EXPECT_FALSE(DataTypeIsQuantized(DT_UINT8)); + EXPECT_FALSE(DataTypeIsQuantized(DT_INT16)); + EXPECT_FALSE(DataTypeIsQuantized(DT_INT32)); + EXPECT_FALSE(DataTypeIsQuantized(DT_BFLOAT16)); +} + +} // namespace +} // namespace tensorflow |