aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2017-05-17 09:23:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-17 09:27:36 -0700
commit73882f257ffb1bc9e1a828571c085d080b1d9266 (patch)
tree8adcefa226f95d6c6ce067ee45528d76794e55fb /tensorflow/core/common_runtime/function_test.cc
parent9a47c258c9c2286ae2c14a0da6458055f3b691d3 (diff)
Automated g4 rollback of changelist 156251356
PiperOrigin-RevId: 156315860
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc40
1 files changed, 28 insertions, 12 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index dfa1ed8a7e..e27fc3898d 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
+namespace {
typedef FunctionDefHelper FDH;
@@ -58,13 +59,29 @@ void HasError(const Status& s, const string& substr) {
<< s << ", expected substring " << substr;
}
+// A helper class to make AttrSlice from initializer lists
+class Attrs {
+ public:
+ Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
+ std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) {
+ for (const auto& aval : attrs) {
+ map_.insert({aval.first, aval.second.proto});
+ }
+ }
+
+ operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
+
+ private:
+ AttrValueMap map_;
+};
+
class FunctionTest : public ::testing::Test {
protected:
FunctionTest()
: device_(DeviceFactory::NewDevice("CPU", {},
"/job:localhost/replica:0/task:0")) {}
- void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) {
+ void Create(const FunctionDef& fdef, Attrs attrs) {
exec_ = nullptr;
InstantiationResult result;
TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result));
@@ -151,8 +168,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
fdef_lib_ = lib_def_->ToProto();
}
- Status Run(const string& name, InstantiateAttrValueSlice attrs,
- const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+ Status Run(const string& name, Attrs attrs, const std::vector<Tensor>& args,
+ std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
@@ -188,8 +205,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return Status::OK();
}
- std::unique_ptr<Graph> GetFuncBody(const string& name,
- InstantiateAttrValueSlice attrs) {
+ std::unique_ptr<Graph> GetFuncBody(const string& name, Attrs attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
@@ -203,8 +219,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return ret;
}
- std::unique_ptr<Graph> GetGradBody(const string& func,
- InstantiateAttrValueSlice attrs) {
+ std::unique_ptr<Graph> GetGradBody(const string& func, Attrs attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(func, attrs, &handle);
if (!status.ok()) {
@@ -615,13 +630,14 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
// Instantiating "XTimesTwo" should fail.
FunctionLibraryRuntime::Handle handle;
- HasError(lib_->Instantiate("XTimesTwo", {{"T", DT_FLOAT}}, &handle),
+ HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle),
"Not found: type attr not found");
// But XTimesFour and XTimes16 instantiation should succeed. Only
// when they run, they fail because XTimesTwo is bad.
- TF_CHECK_OK(lib_->Instantiate("XTimesFour", {{"T", DT_FLOAT}}, &handle));
- TF_CHECK_OK(lib_->Instantiate("XTimes16", {{"T", DT_FLOAT}}, &handle));
+ TF_CHECK_OK(
+ lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle));
+ TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle));
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
@@ -928,8 +944,7 @@ bool DoNothing(Graph* g) { return false; }
GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
const FunctionDef& fdef) {
InstantiationResult result;
- InstantiateAttrValueMap empty;
- TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result));
+ TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
@@ -1248,4 +1263,5 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
}
+} // end namespace
} // end namespace tensorflow