aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/js/ops/ts_op_gen_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/js/ops/ts_op_gen_test.cc')
-rw-r--r--tensorflow/js/ops/ts_op_gen_test.cc138
1 files changed, 86 insertions, 52 deletions
diff --git a/tensorflow/js/ops/ts_op_gen_test.cc b/tensorflow/js/ops/ts_op_gen_test.cc
index 9a85c021b0..03241689b5 100644
--- a/tensorflow/js/ops/ts_op_gen_test.cc
+++ b/tensorflow/js/ops/ts_op_gen_test.cc
@@ -36,7 +36,6 @@ void ExpectDoesNotContainStr(StringPiece s, StringPiece expected) {
<< "'" << s << "' does not contain '" << expected << "'";
}
-// TODO(kreeger): Add multiple outputs here?
constexpr char kBaseOpDef[] = R"(
op {
name: "Foo"
@@ -79,50 +78,15 @@ op {
summary: "Summary for op Foo."
description: "Description for op Foo."
}
-op {
- name: "DeprecatedFoo"
- input_arg {
- name: "input"
- description: "Description for input."
- type: DT_FLOAT
- }
- output_arg {
- name: "output"
- description: "Description for output."
- type: DT_FLOAT
- }
- deprecation {
- explanation: "Deprecated."
- }
-}
-op {
- name: "MultiOutputFoo"
- input_arg {
- name: "input"
- description: "Description for input."
- type: DT_FLOAT
- }
- output_arg {
- name: "output1"
- description: "Description for output 1."
- type: DT_FLOAT
- }
- output_arg {
- name: "output2"
- description: "Description for output 2."
- type: DT_FLOAT
- }
- summary: "Summary for op MultiOutputFoo."
- description: "Description for op MultiOutputFoo."
-}
)";
// Generate TypeScript code
-// @param api_def_str TODO doc me.
-void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
+void GenerateTsOpFileText(const string& op_def_str, const string& api_def_str,
+ string* ts_file_text) {
Env* env = Env::Default();
OpList op_defs;
- protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs);
+ protobuf::TextFormat::ParseFromString(
+ op_def_str.empty() ? kBaseOpDef : op_def_str, &op_defs);
ApiDefMap api_def_map(op_defs);
if (!api_def_str.empty()) {
@@ -138,11 +102,11 @@ void GenerateTsOpFileText(const string& api_def_str, string* ts_file_text) {
TEST(TsOpGenTest, TestImports) {
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText("", "", &ts_file_text);
const string expected = R"(
import * as tfc from '@tensorflow/tfjs-core';
-import {createTypeOpAttr, getTFDTypeForInputs, nodeBackend} from './op_utils';
+import {createTensorsTypeOpAttr, nodeBackend} from './op_utils';
)";
ExpectContainsStr(ts_file_text, expected);
}
@@ -160,12 +124,10 @@ op {
)";
string ts_file_text;
- GenerateTsOpFileText(api_def, &ts_file_text);
+ GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
- return null;
-}
)";
ExpectContainsStr(ts_file_text, expected);
}
@@ -179,34 +141,106 @@ op {
)";
string ts_file_text;
- GenerateTsOpFileText(api_def, &ts_file_text);
+ GenerateTsOpFileText("", api_def, &ts_file_text);
const string expected = R"(
export function Foo(images: tfc.Tensor[], dim: tfc.Tensor): tfc.Tensor {
- return null;
-}
)";
ExpectDoesNotContainStr(ts_file_text, expected);
}
TEST(TsOpGenTest, SkipDeprecated) {
+ const string op_def = R"(
+op {
+ name: "DeprecatedFoo"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ description: "Description for input."
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ deprecation {
+ explanation: "Deprecated."
+ }
+}
+)";
+
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
ExpectDoesNotContainStr(ts_file_text, "DeprecatedFoo");
}
TEST(TsOpGenTest, MultiOutput) {
+ const string op_def = R"(
+op {
+ name: "MultiOutputFoo"
+ input_arg {
+ name: "input"
+ description: "Description for input."
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output1"
+ description: "Description for output 1."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output2"
+ description: "Description for output 2."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for input"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ summary: "Summary for op MultiOutputFoo."
+ description: "Description for op MultiOutputFoo."
+}
+)";
+
string ts_file_text;
- GenerateTsOpFileText("", &ts_file_text);
+ GenerateTsOpFileText(op_def, "", &ts_file_text);
const string expected = R"(
export function MultiOutputFoo(input: tfc.Tensor): tfc.Tensor[] {
- return null;
-}
)";
ExpectContainsStr(ts_file_text, expected);
}
+TEST(TsOpGenTest, OpAttrs) {
+ string ts_file_text;
+ GenerateTsOpFileText("", "", &ts_file_text);
+
+ const string expectedFooAttrs = R"(
+ const opAttrs = [
+ createTensorsTypeOpAttr('T', images),
+ {name: 'N', type: nodeBackend().binding.TF_ATTR_INT, value: images.length}
+ ];
+)";
+
+ ExpectContainsStr(ts_file_text, expectedFooAttrs);
+}
+
} // namespace
} // namespace tensorflow