-
Notifications
You must be signed in to change notification settings - Fork 0
/
drift_main.cpp
95 lines (83 loc) · 2.81 KB
/
drift_main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include "vqtree.cpp"
#define PRINT() if (tree.maxSize <= 100) { tree.printTree(); }
void randVec(double *data, int dim) {
for (int i = 0; i < dim; i++) {
data[i] = (2.0*(rand()/(double)RAND_MAX-0.5));
}
}
// https://nl.mathworks.com/help/matlab/ref/peaks.html
double f(double *d, bool partA) {
if (partA) {
//return -10;
return 0.1*d[0] + 0.2*d[1] + 0.3*d[2] + 0.4*d[3];
} else {
//return 10;
return 0.4*d[0] + 0.3*d[1] + 0.2*d[2] + 0.1*d[3] - 10;
}
}
double f_noisy(double *d, bool partA) {
//return f(d, partA);
return f(d, partA) + 0.005*(rand()/(double)RAND_MAX-0.5);
}
void test(KTree& tree) {
int dim = tree.dim, size = tree.maxSize;
double* d = new double[dim];
srand(10);
long removed = 0;
for (int i = 0; i < size; i++) {
randVec(d, dim);
double target = f_noisy(d, true);
tree.add(d, target);
if (tree.driftCount > 0) {
removed += tree.driftCount;
printf("Block1 Drift %d: %ld (%ld)\n", i, tree.driftCount, tree.size()-removed);
}
}
puts("Block1 complete");
PRINT();
double MSE = 0, MAE = 0;
for (int i = 0; i < size; i++) {
randVec(d, dim);
double target = f_noisy(d, false);
double query = tree.query(d);
double diff = query-target;
printf("query:%.4f diff:%.4f\n", query, diff);
MSE += diff*diff;
MAE += std::abs(diff);
tree.add(d, target);
if (tree.driftCount > 0) {
removed += tree.driftCount;
printf("Block2 Drift %d: %ld (%ld)\n", i, tree.driftCount, tree.size()-removed);
//PRINT();
}
}
puts("Block2 complete");
PRINT();
delete[] d;
printf("MAE: %f\n", MAE/tree.size());
printf("MSE: %f\n", MSE/tree.size());
}
void test(size_t dim, size_t maxSize, size_t maxLeafSize=64, size_t branchFactor=16, size_t minLeaves=100, size_t minN=100, int searchType=6, double spill=-1., bool removeDups=true, size_t driftHistLen=100, size_t driftThreshold=100) {
KTree tree(dim, maxSize, maxLeafSize, branchFactor, minLeaves, minN, searchType, spill, removeDups, driftHistLen, driftThreshold);
//VQTree(size_t dim, size_t maxSize, size_t maxLeafSize=64, size_t branchFactor=16, size_t minLeaves=100, size_t minN=100, int searchType=6, double spill=-1., bool removeDups=true, size_t driftHistLen=5, size_t driftThreshold=4) :
test(tree);
}
int main(int argc, char *argv[]) {
if (argc < 2) {
fprintf(stderr, "usage: %s size [dims=4] [leafSize=64] [branchFactor=16]\n", argv[0]);
return -1;
}
int size = std::stoi(argv[1]);
int dims = 4, leafSize = 64, branchFactor = 16;
if (argc >= 3) {
dims = std::stoi(argv[2]);
}
if (argc >= 4) {
leafSize = std::stoi(argv[3]);
}
if (argc >= 5) {
branchFactor = std::stoi(argv[4]);
}
printf("size:%d dims:%d leafSize:%d branchFactor:%d\n", size, dims, leafSize, branchFactor);
test(dims, size, leafSize, branchFactor);
}