This repository was archived by the owner on Sep 17, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathnodejs_kernel_backend_test.ts
More file actions
65 lines (59 loc) · 2.38 KB
/
nodejs_kernel_backend_test.ts
File metadata and controls
65 lines (59 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as tf from '@tensorflow/tfjs-core';
import {Tensor5D} from '@tensorflow/tfjs-core';
import {test_util} from '@tensorflow/tfjs-core';
import {NodeJSKernelBackend} from './nodejs_kernel_backend';
describe('delayed upload', () => {
it('should handle data before op execution', async () => {
const t = tf.tensor1d([1, 2, 3]);
test_util.expectArraysClose(await t.data(), [1, 2, 3]);
const r = t.add(tf.tensor1d([4, 5, 6]));
test_util.expectArraysClose(await r.data(), [5, 7, 9]);
});
it('Should not cache tensors in the tensor map for device support. ', () => {
const logits = tf.tensor1d([1, 2, 3]);
const softmaxLogits = tf.softmax(logits);
const data = softmaxLogits.dataSync();
expect(softmaxLogits.dataSync()[0]).toEqual(data[0]);
expect(softmaxLogits.dataSync()[1]).toEqual(data[1]);
expect(softmaxLogits.dataSync()[2]).toEqual(data[2]);
});
});
describe('type casting', () => {
it('exp support int32', () => {
tf.exp(tf.scalar(2, 'int32'));
});
});
describe('conv3d dilations', () => {
it('CPU should throw error on dilations >1', () => {
const input = tf.ones([1, 2, 2, 2, 1]) as Tensor5D;
const filter = tf.ones([1, 1, 1, 1, 1]) as Tensor5D;
expect(() => {
tf.conv3d(input, filter, 1, 'same', 'NDHWC', [2, 2, 2]);
}).toThrowError();
});
it('GPU should handle dilations >1', () => {
// This test can only run locally with CUDA bindings and GPU package
// installed.
if ((tf.backend() as NodeJSKernelBackend).isGPUPackage) {
const input = tf.ones([1, 2, 2, 2, 1]) as Tensor5D;
const filter = tf.ones([1, 1, 1, 1, 1]) as Tensor5D;
tf.conv3d(input, filter, 1, 'same', 'NDHWC', [2, 2, 2]);
}
});
});