1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
|
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/platform/protobuf.h"
#include <gtest/gtest.h>
namespace tensorflow {
namespace {
TEST(KernelDefBuilderTest, Basic) {
const KernelDef* def = KernelDefBuilder("A").Device(DEVICE_CPU).Build();
KernelDef expected;
protobuf::TextFormat::ParseFromString("op: 'A' device_type: 'CPU'",
&expected);
EXPECT_EQ(def->DebugString(), expected.DebugString());
delete def;
}
TEST(KernelDefBuilderTest, TypeConstraint) {
const KernelDef* def = KernelDefBuilder("B")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T")
.Build();
KernelDef expected;
protobuf::TextFormat::ParseFromString(R"proto(
op: 'B' device_type: 'GPU'
constraint { name: 'T' allowed_values { list { type: DT_FLOAT } } } )proto",
&expected);
EXPECT_EQ(def->DebugString(), expected.DebugString());
delete def;
def = KernelDefBuilder("C")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("U")
.TypeConstraint<bool>("V")
.Build();
protobuf::TextFormat::ParseFromString(R"proto(
op: 'C' device_type: 'GPU'
constraint { name: 'U' allowed_values { list { type: DT_INT32 } } }
constraint { name: 'V' allowed_values { list { type: DT_BOOL } } } )proto",
&expected);
EXPECT_EQ(def->DebugString(), expected.DebugString());
delete def;
def = KernelDefBuilder("D")
.Device(DEVICE_CPU)
.TypeConstraint("W", {DT_DOUBLE, DT_STRING})
.Build();
protobuf::TextFormat::ParseFromString(R"proto(
op: 'D' device_type: 'CPU'
constraint { name: 'W'
allowed_values { list { type: [DT_DOUBLE, DT_STRING] } } } )proto",
&expected);
EXPECT_EQ(def->DebugString(), expected.DebugString());
delete def;
}
TEST(KernelDefBuilderTest, HostMemory) {
const KernelDef* def = KernelDefBuilder("E")
.Device(DEVICE_GPU)
.HostMemory("in")
.HostMemory("out")
.Build();
KernelDef expected;
protobuf::TextFormat::ParseFromString(
"op: 'E' device_type: 'GPU' "
"host_memory_arg: ['in', 'out']",
&expected);
EXPECT_EQ(def->DebugString(), expected.DebugString());
delete def;
}
} // namespace
} // namespace tensorflow
|