-
Notifications
You must be signed in to change notification settings - Fork 31
/
numpy_tests.ml
190 lines (181 loc) · 6.87 KB
/
numpy_tests.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
let () =
Pyml_tests_common.add_test ~title:"of_bigarray"
(fun () ->
if Py.Import.try_import_module "numpy" = None then
Pyml_tests_common.Disabled "numpy is not available"
else
begin
let array = [| 1.; 2. |] in
let array1 =
Bigarray.Array1.of_array (Bigarray.float64) (Bigarray.c_layout) array in
let bigarray = Bigarray.genarray_of_array1 array1 in
let a = Numpy.of_bigarray bigarray in
let m = Py.Import.add_module "test" in
Py.Module.set m "array" a;
assert (Py.Run.simple_string "
from test import array
assert len(array) == 2
assert array[0] == 1.
assert array[1] == 2.
array[0] = 42.
array[1] = 43.
");
assert (Bigarray.Array1.get array1 0 = 42.);
assert (Bigarray.Array1.get array1 1 = 43.);
Pyml_tests_common.Passed
end)
let () =
Pyml_tests_common.add_test ~title:"of_bigarray2"
(fun () ->
if Py.Import.try_import_module "numpy" = None then
Pyml_tests_common.Disabled "numpy is not available"
else
begin
let array = [| [| 1.; 2.; 3. |]; [| -1.23; Stdcompat.Float.nan; 2.72 |] |] in
let array2 =
Bigarray.Array2.of_array (Bigarray.float64) (Bigarray.c_layout) array in
let bigarray = Bigarray.genarray_of_array2 array2 in
let a = Numpy.of_bigarray bigarray in
let m = Py.Import.add_module "test" in
Py.Module.set m "array" a;
assert (Py.Run.simple_string "
from test import array
import numpy
assert list(array.shape) == [2, 3]
numpy.testing.assert_almost_equal(array[0], [1, 2, 3])
assert(numpy.isnan(array[1, 1]))
array[0, 0] = 42.
array[0, 1] = 43.
array[1, 1] = 1.
");
assert (Bigarray.Array2.get array2 0 0 = 42.);
assert (Bigarray.Array2.get array2 0 1 = 43.);
assert (Bigarray.Array2.get array2 1 1 = 1.);
Pyml_tests_common.Passed
end)
let () =
Pyml_tests_common.add_test ~title:"to_bigarray"
(fun () ->
if Py.Import.try_import_module "numpy" = None then
Pyml_tests_common.Disabled "numpy is not available"
else
begin
let m = Py.Import.add_module "test" in
let callback arg =
let bigarray =
Numpy.to_bigarray Bigarray.nativeint Bigarray.c_layout arg.(0) in
assert (Bigarray.Genarray.dims bigarray = [| 4 |]);
let array1 = Bigarray.array1_of_genarray bigarray in
assert (Bigarray.Array1.get array1 0 = 0n);
assert (Bigarray.Array1.get array1 1 = 1n);
assert (Bigarray.Array1.get array1 2 = 2n);
assert (Bigarray.Array1.get array1 3 = 3n);
Py.none in
Py.Module.set m "callback" (Py.Callable.of_function callback);
assert (Py.Run.simple_string "
from test import callback
import numpy
callback(numpy.array([0,1,2,3]))
");
Pyml_tests_common.Passed
end)
let assert_almost_eq ?(eps = 1e-7) f1 f2 =
if Stdcompat.Float.abs (f1 -. f2) > eps then
failwith (Printf.sprintf "%f <> %f" f1 f2)
let () =
Pyml_tests_common.add_test ~title:"to_bigarray2"
(fun () ->
if Py.Import.try_import_module "numpy" = None then
Pyml_tests_common.Disabled "numpy is not available"
else
begin
let m = Py.Import.add_module "test" in
let callback arg =
let bigarray =
Numpy.to_bigarray Bigarray.float32 Bigarray.c_layout arg.(0) in
assert (Bigarray.Genarray.dims bigarray = [| 2; 4 |]);
let array2 = Bigarray.array2_of_genarray bigarray in
let assert_almost_eq i j v =
assert_almost_eq (Bigarray.Array2.get array2 i j) v in
let assert_is_nan i j =
let v = Bigarray.Array2.get array2 i j in
assert (Stdcompat.Float.is_nan v) in
assert_almost_eq 0 0 0.12;
assert_almost_eq 0 1 1.23;
assert_almost_eq 0 2 2.34;
assert_almost_eq 0 3 3.45;
assert_almost_eq 1 0 (-1.);
assert_is_nan 1 1;
assert_almost_eq 1 2 1.;
assert_almost_eq 1 3 0.;
Py.none in
Py.Module.set m "callback" (Py.Callable.of_function callback);
assert (Py.Run.simple_string "
from test import callback
import numpy
callback(numpy.array([[0.12,1.23,2.34,3.45],[-1.,numpy.nan,1.,0.]], dtype=numpy.float32))
");
Pyml_tests_common.Passed
end)
let assert_invalid_argument f =
try
let () = f () in
assert false
with Invalid_argument _ ->
()
let () =
Pyml_tests_common.add_test ~title:"to_bigarray invalid type"
(fun () ->
if Py.Import.try_import_module "numpy" = None then
Pyml_tests_common.Disabled "numpy is not available"
else
begin
assert_invalid_argument (fun () ->
ignore (Numpy.to_bigarray Float64 C_layout Py.none));
assert_invalid_argument (fun () ->
ignore (Numpy.to_bigarray Float64 C_layout (Py.Int.of_int 0)));
let array =
Numpy.of_bigarray (Bigarray.genarray_of_array1 (
Bigarray.Array1.of_array (Bigarray.float64) (Bigarray.c_layout)
[| 1.; 2. |])) in
ignore (Numpy.to_bigarray Float64 C_layout array);
assert_invalid_argument (fun () ->
ignore (Numpy.to_bigarray Float32 C_layout array));
assert_invalid_argument (fun () ->
ignore (Numpy.to_bigarray Float64 Fortran_layout array));
Pyml_tests_common.Passed
end)
let () =
Pyml_tests_common.add_test ~title:"to_bigarray_k"
(fun () ->
if Py.Import.try_import_module "numpy" = None then
Pyml_tests_common.Disabled "numpy is not available"
else
begin
let m = Py.Import.add_module "test" in
let callback arg =
let k { Numpy.kind; layout; array } =
assert (Numpy.compare_kind kind Bigarray.nativeint = 0);
assert (Numpy.compare_layout layout Bigarray.c_layout = 0);
let bigarray =
Stdcompat.Option.get (Numpy.check_kind_and_layout
Bigarray.nativeint Bigarray.c_layout array) in
assert (Bigarray.Genarray.dims bigarray = [| 4 |]);
let array1 = Bigarray.array1_of_genarray bigarray in
assert (Bigarray.Array1.get array1 0 = 0n);
assert (Bigarray.Array1.get array1 1 = 1n);
assert (Bigarray.Array1.get array1 2 = 2n);
assert (Bigarray.Array1.get array1 3 = 3n) in
Numpy.to_bigarray_k { Numpy.f = k } arg.(0);
Py.none in
Py.Module.set m "callback" (Py.Callable.of_function callback);
assert (Py.Run.simple_string "
from test import callback
import numpy
callback(numpy.array([0,1,2,3]))
");
Pyml_tests_common.Passed
end)
let () =
if not !Sys.interactive then
Pyml_tests_common.main ()