aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/js/ops
diff options
context:
space:
mode:
authorGravatar Nick Kreeger <kreeger@google.com>2018-08-20 19:39:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 19:46:03 -0700
commit49115abfd39d30506679d9fdc572ccd2f7c22dbe (patch)
tree62c2a07584a7f861db39b875718c910eb79bff15 /tensorflow/js/ops
parentdebd8b6b4e6b7bc16eb58b07135861702d75fb15 (diff)
Introduce basic CC library for generating TypeScript files for TensorFlow.js from registered Ops.
This initial change provides the very basics to start generating TypeScript. Non-deprecated and visible Ops are exported as a typescript function using internal functionality that is used the @tensorflow/tfjs-node repo (https://github.com/tensorflow/tfjs-node). Future changes will introduce more code generation + tests. This initial change will help set the foundation for those upcoming changes. PiperOrigin-RevId: 209528126
Diffstat (limited to 'tensorflow/js/ops')
-rw-r--r--tensorflow/js/ops/ts_op_gen.cc199
-rw-r--r--tensorflow/js/ops/ts_op_gen.h31
-rw-r--r--tensorflow/js/ops/ts_op_gen_test.cc212
3 files changed, 442 insertions, 0 deletions
diff --git a/tensorflow/js/ops/ts_op_gen.cc b/tensorflow/js/ops/ts_op_gen.cc
new file mode 100644
index 0000000000..babf55cd5f
--- /dev/null
+++ b/tensorflow/js/ops/ts_op_gen.cc
@@ -0,0 +1,199 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/js/ops/ts_op_gen.h"
+#include <unordered_map>
+
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+namespace {
+
+static bool IsListAttr(const OpDef_ArgDef& arg) {
+ return !arg.type_list_attr().empty() || !arg.number_attr().empty();
+}
+
+// Struct to hold a combo OpDef and ArgDef for a given Op argument:
+struct ArgDefs {
+ ArgDefs(const OpDef::ArgDef& op_def_arg, const ApiDef::Arg& api_def_arg)
+ : op_def_arg(op_def_arg), api_def_arg(api_def_arg) {}
+
+ const OpDef::ArgDef& op_def_arg;
+ const ApiDef::Arg& api_def_arg;
+};
+
+// Helper class to generate TypeScript code for a given OpDef:
+class GenTypeScriptOp {
+ public:
+ GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def);
+ ~GenTypeScriptOp();
+
+ // Returns the generated code as a string:
+ string Code();
+
+ private:
+ void ProcessArgs();
+
+ void AddMethodSignature();
+ void AddMethodReturnAndClose();
+
+ const OpDef& op_def_;
+ const ApiDef& api_def_;
+
+ // Placeholder string for all generated code:
+ string result_;
+
+ // Holds in-order vector of Op inputs:
+ std::vector<ArgDefs> input_op_args_;
+
+ // Holds number of outputs:
+ int num_outputs_;
+};
+
+GenTypeScriptOp::GenTypeScriptOp(const OpDef& op_def, const ApiDef& api_def)
+ : op_def_(op_def), api_def_(api_def), num_outputs_(0) {}
+
+GenTypeScriptOp::~GenTypeScriptOp() {}
+
+string GenTypeScriptOp::Code() {
+ ProcessArgs();
+
+ // Generate exported function for Op:
+ AddMethodSignature();
+ AddMethodReturnAndClose();
+
+ strings::StrAppend(&result_, "\n");
+ return result_;
+}
+
+void GenTypeScriptOp::ProcessArgs() {
+ for (int i = 0; i < api_def_.arg_order_size(); i++) {
+ auto op_def_arg = FindInputArg(api_def_.arg_order(i), op_def_);
+ if (op_def_arg == nullptr) {
+ LOG(WARNING) << "Could not find OpDef::ArgDef for "
+ << api_def_.arg_order(i);
+ continue;
+ }
+ auto api_def_arg = FindInputArg(api_def_.arg_order(i), api_def_);
+ if (api_def_arg == nullptr) {
+ LOG(WARNING) << "Could not find ApiDef::Arg for "
+ << api_def_.arg_order(i);
+ continue;
+ }
+ input_op_args_.push_back(ArgDefs(*op_def_arg, *api_def_arg));
+ }
+
+ num_outputs_ = api_def_.out_arg_size();
+}
+
+void GenTypeScriptOp::AddMethodSignature() {
+ strings::StrAppend(&result_, "export function ", api_def_.endpoint(0).name(),
+ "(");
+
+ bool is_first = true;
+ for (auto& in_arg : input_op_args_) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ strings::StrAppend(&result_, ", ");
+ }
+
+ auto op_def_arg = in_arg.op_def_arg;
+
+ strings::StrAppend(&result_, op_def_arg.name(), ": ");
+ if (IsListAttr(op_def_arg)) {
+ strings::StrAppend(&result_, "tfc.Tensor[]");
+ } else {
+ strings::StrAppend(&result_, "tfc.Tensor");
+ }
+ }
+
+ if (num_outputs_ == 1) {
+ strings::StrAppend(&result_, "): tfc.Tensor {\n");
+ } else {
+ strings::StrAppend(&result_, "): tfc.Tensor[] {\n");
+ }
+}
+
+void GenTypeScriptOp::AddMethodReturnAndClose() {
+ strings::StrAppend(&result_, " return null;\n}\n");
+}
+
+void WriteTSOp(const OpDef& op_def, const ApiDef& api_def, WritableFile* ts) {
+ GenTypeScriptOp ts_op(op_def, api_def);
+ TF_CHECK_OK(ts->Append(GenTypeScriptOp(op_def, api_def).Code()));
+}
+
+void StartFile(WritableFile* ts_file) {
+ const string header =
+ R"header(/**
+ * @license
+ * Copyright 2018 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+// This file is MACHINE GENERATED! Do not edit
+
+import * as tfc from '@tensorflow/tfjs-core';
+import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+
+)header";
+
+ TF_CHECK_OK(ts_file->Append(header));
+}
+
+} // namespace
+
+void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
+ const string& ts_filename) {
+ Env* env = Env::Default();
+
+ std::unique_ptr<WritableFile> ts_file = nullptr;
+ TF_CHECK_OK(env->NewWritableFile(ts_filename, &ts_file));
+
+ StartFile(ts_file.get());
+
+ for (const auto& op_def : ops.op()) {
+ // Skip deprecated ops
+ if (op_def.has_deprecation() &&
+ op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) {
+ continue;
+ }
+
+ const auto* api_def = api_def_map.GetApiDef(op_def.name());
+ if (api_def->visibility() == ApiDef::VISIBLE) {
+ WriteTSOp(op_def, *api_def, ts_file.get());
+ }
+ }
+
+ TF_CHECK_OK(ts_file->Close());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/js/ops/ts_op_gen.h b/tensorflow/js/ops/ts_op_gen.h
new file mode 100644
index 0000000000..fcd46a17a7
--- /dev/null
+++ b/tensorflow/js/ops/ts_op_gen.h
@@ -0,0 +1,31 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JS_OPS_TS_OP_GEN_H_
+#define TENSORFLOW_JS_OPS_TS_OP_GEN_H_
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Generated code is written to the file ts_filename:
+void WriteTSOps(const OpList& ops, const ApiDefMap& api_def_map,
+ const string& ts_filename);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_JS_OPS_TS_OP_GEN_H_
diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc
new file mode 100644
index 0000000000..9a85c021b0
--- /dev/null
+++ b/tensorflow/js/ops/ts_op_gen_test.cc
@@ -0,0 +1,212 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/js/ops/ts_op_gen.h"
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+void ExpectContainsStr(StringPiece s, StringPiece expected) {
+ EXPECT_TRUE(str_util::StrContains(s, expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+}
+
+void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
+ EXPECT_FALSE(str_util::StrContains(s, expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+}
+
+// TODO(kreeger): Add multiple outputs here?
+constexpr char kBaseOpDef[] = R"(
+op {
+ name: "Foo"
+ input_arg {
+ name: "images"
+ type_attr: "T"
+ number_attr: "N"
+ description: "Images to process."
+ }
+ input_arg {
+ name: "dim"
+ description: "Description for dim."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for images"
+ allowed_values {
+ list {
+ type: DT_UINT8
+ type: DT_INT8
+ }
+ }
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Summary for op Foo."
+ description: "Description for op Foo."
+}
+op {
+ name: "DeprecatedFoo"
+ input_arg {
+ name: "input"
+ description: "Description for input."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ deprecation {
+ explanation: "Deprecated."
+ }
+}
+op {
+ name: "MultiOutputFoo"
+ input_arg {
+ name: "input"
+ description: "Description for input."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output1"
+ description: "Description for output 1."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output2"
+ description: "Description for output 2."
+ type: DT_FLOAT
+ }
+ summary: "Summary for op MultiOutputFoo."
+ description: "Description for op MultiOutputFoo."
+}
+)";
+
+// Generate TypeScript code
+// @param api_def_str TODO doc me.
+void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
+ Env* env = Env::Default();
+ OpList op_defs;
+ protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
+ ApiDefMap api_def_map(op_defs);
+
+ if (!api_def_str.empty()) {
+ TF_ASSERT_OK(api_def_map.LoadApiDef(api_def_str));
+ }
+
+ const string& tmpdir = testing::TmpDir();
+ const auto ts_file_path = io::JoinPath(tmpdir, "test.ts");
+
+ WriteTSOps(op_defs, api_def_map, ts_file_path);
+ TF_ASSERT_OK(ReadFileToString(env, ts_file_path, ts_file_text));
+}
+
+TEST(TsOpGenTest, TestImports) {
+ string ts_file_text;
+ GenerateTsOpFileText("", &ts_file_text);
+
+ const string expected = R"(
+import * as tfc from '@tensorflow/tfjs-core';
+import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+)";
+ ExpectContainsStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, InputSingleAndList) {
+ const string api_def = R"(
+op {
+ name: "Foo"
+ input_arg {
+ name: "images"
+ type_attr: "T"
+ number_attr: "N"
+ }
+}
+)";
+
+ string ts_file_text;
+ GenerateTsOpFileText(api_def, &ts_file_text);
+
+ const string expected = R"(
+export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
+ return null;
+}
+)";
+ ExpectContainsStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, TestVisibility) {
+ const string api_def = R"(
+op {
+ graph_op_name: "Foo"
+ visibility: HIDDEN
+}
+)";
+
+ string ts_file_text;
+ GenerateTsOpFileText(api_def, &ts_file_text);
+
+ const string expected = R"(
+export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
+ return null;
+}
+)";
+ ExpectDoesNotContainStr(ts_file_text, expected);
+}
+
+TEST(TsOpGenTest, SkipDeprecated) {
+ string ts_file_text;
+ GenerateTsOpFileText("", &ts_file_text);
+
+ ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
+}
+
+TEST(TsOpGenTest, MultiOutput) {
+ string ts_file_text;
+ GenerateTsOpFileText("", &ts_file_text);
+
+ const string expected = R"(
+export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
+ return null;
+}
+)";
+ ExpectContainsStr(ts_file_text, expected);
+}
+
+} // namespace
+} // namespace tensorflow