aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/function_test.cc
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2017-05-16 17:01:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-16 17:05:15 -0700
commit43db5c623f748b6f9704e9e9be5a5a11fa2a4c1a (patch)
tree985844ec8f6653f36e38592f9700dcaba66d94f2 /tensorflow/core/common_runtime/function_test.cc
parent7ab0c2eff12ea79648f6717dae8558d6669e5c27 (diff)
Automated g4 rollback of changelist 156244933
PiperOrigin-RevId: 156251356
Diffstat (limited to 'tensorflow/core/common_runtime/function_test.cc')
-rw-r--r--tensorflow/core/common_runtime/function_test.cc40
1 files changed, 12 insertions, 28 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index e27fc3898d..dfa1ed8a7e 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -40,7 +40,6 @@ limitations under the License.
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
-namespace {
typedef FunctionDefHelper FDH;
@@ -59,29 +58,13 @@ void HasError(const Status& s, const string& substr) {
<< s << ", expected substring " << substr;
}
-// A helper class to make AttrSlice from initializer lists
-class Attrs {
- public:
- Attrs(const std::initializer_list< // NOLINT(runtime/explicit)
- std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) {
- for (const auto& aval : attrs) {
- map_.insert({aval.first, aval.second.proto});
- }
- }
-
- operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit)
-
- private:
- AttrValueMap map_;
-};
-
class FunctionTest : public ::testing::Test {
protected:
FunctionTest()
: device_(DeviceFactory::NewDevice("CPU", {},
"/job:localhost/replica:0/task:0")) {}
- void Create(const FunctionDef& fdef, Attrs attrs) {
+ void Create(const FunctionDef& fdef, InstantiateAttrValueSlice attrs) {
exec_ = nullptr;
InstantiationResult result;
TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result));
@@ -168,8 +151,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
fdef_lib_ = lib_def_->ToProto();
}
- Status Run(const string& name, Attrs attrs, const std::vector<Tensor>& args,
- std::vector<Tensor*> rets) {
+ Status Run(const string& name, InstantiateAttrValueSlice attrs,
+ const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
@@ -205,7 +188,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return Status::OK();
}
- std::unique_ptr<Graph> GetFuncBody(const string& name, Attrs attrs) {
+ std::unique_ptr<Graph> GetFuncBody(const string& name,
+ InstantiateAttrValueSlice attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
@@ -219,7 +203,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return ret;
}
- std::unique_ptr<Graph> GetGradBody(const string& func, Attrs attrs) {
+ std::unique_ptr<Graph> GetGradBody(const string& func,
+ InstantiateAttrValueSlice attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = lib_->Instantiate(func, attrs, &handle);
if (!status.ok()) {
@@ -630,14 +615,13 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
// Instantiating "XTimesTwo" should fail.
FunctionLibraryRuntime::Handle handle;
- HasError(lib_->Instantiate("XTimesTwo", Attrs({{"T", DT_FLOAT}}), &handle),
+ HasError(lib_->Instantiate("XTimesTwo", {{"T", DT_FLOAT}}, &handle),
"Not found: type attr not found");
// But XTimesFour and XTimes16 instantiation should succeed. Only
// when they run, they fail because XTimesTwo is bad.
- TF_CHECK_OK(
- lib_->Instantiate("XTimesFour", Attrs({{"T", DT_FLOAT}}), &handle));
- TF_CHECK_OK(lib_->Instantiate("XTimes16", Attrs({{"T", DT_FLOAT}}), &handle));
+ TF_CHECK_OK(lib_->Instantiate("XTimesFour", {{"T", DT_FLOAT}}, &handle));
+ TF_CHECK_OK(lib_->Instantiate("XTimes16", {{"T", DT_FLOAT}}, &handle));
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
@@ -944,7 +928,8 @@ bool DoNothing(Graph* g) { return false; }
GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
const FunctionDef& fdef) {
InstantiationResult result;
- TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
+ InstantiateAttrValueMap empty;
+ TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result));
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
@@ -1263,5 +1248,4 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
}
-} // end namespace
} // end namespace tensorflow