-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathBiTrie.cpp
More file actions
176 lines (174 loc) · 4.16 KB
/
BiTrie.cpp
File metadata and controls
176 lines (174 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
// Memory optimized version
// src: AI generated based on the original source code
// USAGE: Trie trie(n * Trie::B + 5);
struct Trie {
static const int B = 31;
int sz = 1;
struct node {
int32_t nxt[2];
int sz;
node() {
nxt[0] = nxt[1] = 0;
sz = 0;
}
};
vector<node> data;
int new_node() {
return ++sz;
}
Trie(int mxSz) {
data.assign(mxSz, node());
sz = 1;
}
void insert(int val) {
int curr = 1;
data[curr].sz++;
for (int i = B - 1; i >= 0; i--) {
int b = val >> i & 1;
if (data[curr].nxt[b] == 0) data[curr].nxt[b] = new_node();
curr = data[curr].nxt[b];
data[curr].sz++;
}
}
void erase(int val) {
if (get_min(val) != 0) return; // Value not present in Trie
int curr = 1;
data[curr].sz--; // Decrement root size
for (int i = B - 1; i >= 0; i--) {
int b = val >> i & 1;
int nxt_node = data[curr].nxt[b];
data[nxt_node].sz--;
if (data[nxt_node].sz == 0) {
data[curr].nxt[b] = 0; // Disconnect the branch
return;
}
curr = nxt_node;
}
}
int query(int x, int k) { // number of values s.t. val ^ x < k
int cur = 1;
int ans = 0;
for (int i = B - 1; i >= 0; i--) {
if (cur == 0) break;
int b1 = x >> i & 1, b2 = k >> i & 1;
if (b2 == 1) {
if (data[cur].nxt[b1]) ans += data[data[cur].nxt[b1]].sz;
cur = data[cur].nxt[!b1];
} else cur = data[cur].nxt[b1];
}
return ans;
}
// Returns maximum of val ^ x
int get_max(int x) {
int curr = 1;
int ans = 0;
for (int i = B - 1; i >= 0; i--) {
int k = x >> i & 1;
if (data[curr].nxt[!k]) {
curr = data[curr].nxt[!k];
ans <<= 1;
ans++;
} else {
curr = data[curr].nxt[k];
ans <<= 1;
}
}
return ans;
}
// Returns minimum of val ^ x
int get_min(int x) {
int curr = 1;
int ans = 0;
for (int i = B - 1; i >= 0; i--) {
int k = x >> i & 1;
if (data[curr].nxt[k]) {
curr = data[curr].nxt[k];
ans <<= 1;
} else {
curr = data[curr].nxt[!k];
ans <<= 1;
ans++;
}
}
return ans;
}
};
// src: https://github.com/ShahjalalShohag/code-library/blob/main/Data%20Structures/Trie.cpp
struct Trie {
static const int B = 31;
struct node {
node* nxt[2];
int sz;
node() {
nxt[0] = nxt[1] = NULL;
sz = 0;
}
} *root;
Trie() {
root = new node();
}
void insert(int val) {
node* cur = root;
cur -> sz++;
for (int i = B - 1; i >= 0; i--) {
int b = val >> i & 1;
if (cur -> nxt[b] == NULL) cur -> nxt[b] = new node();
cur = cur -> nxt[b];
cur -> sz++;
}
}
void erase(int val) {
if (get_min(val)) return ; // value not present
node *cur = root;
cur -> sz--;
for (int i = B - 1; i >= 0; i--) {
int b = val >> i & 1;
node* nxtNode = cur -> nxt[b];
nxtNode -> sz--;
if (nxtNode -> sz == 0) {
cur -> nxt[b] = NULL;
del(nxtNode);
return;
}
cur = nxtNode;
}
}
int query(int x, int k) { // number of values s.t. val ^ x < k
node* cur = root;
int ans = 0;
for (int i = B - 1; i >= 0; i--) {
if (cur == NULL) break;
int b1 = x >> i & 1, b2 = k >> i & 1;
if (b2 == 1) {
if (cur -> nxt[b1]) ans += cur -> nxt[b1] -> sz;
cur = cur -> nxt[!b1];
} else cur = cur -> nxt[b1];
}
return ans;
}
int get_max(int x) { // returns maximum of val ^ x
node* cur = root;
int ans = 0;
for (int i = B - 1; i >= 0; i--) {
int k = x >> i & 1;
if (cur -> nxt[!k]) cur = cur -> nxt[!k], ans <<= 1, ans++;
else cur = cur -> nxt[k], ans <<= 1;
}
return ans;
}
int get_min(int x) { // returns minimum of val ^ x
node* cur = root;
int ans = 0;
for (int i = B - 1; i >= 0; i--) {
int k = x >> i & 1;
if (cur -> nxt[k]) cur = cur -> nxt[k], ans <<= 1;
else cur = cur -> nxt[!k], ans <<= 1, ans++;
}
return ans;
}
void del(node* cur) {
for (int i = 0; i < 2; i++) if (cur -> nxt[i]) del(cur -> nxt[i]);
delete(cur);
}
~Trie() { del(root); }
};