-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathyaecl.hpp
More file actions
388 lines (383 loc) · 13.1 KB
/
yaecl.hpp
File metadata and controls
388 lines (383 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
#ifndef INCLUDE_YAECL_HPP_
#define INCLUDE_YAECL_HPP_
#include <cassert>
#include <bitset>
#include <fstream>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
#include <pybind11/pybind11.h>
namespace yaecl {
class BitStream {
public:
BitStream(){ _pos = 0; _fpos=0; }
~BitStream(){}
void push_back(bool bit){
if(_pos % 8==0) _data.push_back(0);
_data[_pos / 8] |= (bit << (7 - (_pos % 8)));
_pos++;
}
void push_back_byte(uint8_t byte){
/* alined push for fast ANS
*/
assert(_pos % 8 == 0);
_data.push_back(byte);
_pos+=8;
}
bool get(int pos){
assert(pos < _pos);
return _data[pos / 8] & (1 << (7 - (pos % 8)));
}
bool pop_front(){
/* queue style pop, use only with ac
*/
if (_fpos >= _pos) return 0;
bool tmp = get(_fpos);
_fpos++;
return tmp;
}
bool pop_back(){
/* queue style pop, use only with ans
*/
if (_pos <= 0) return 0;
_pos--;
return _data[_pos / 8] & (1 << (7 - (_pos % 8)));
}
uint8_t pop_back_byte(){
/* alined pop for fast ANS
*/
assert(_pos % 8 == 0);
_pos-=8;
return _data[_pos / 8];
}
int size(){ return _pos; }
void save(const std::string &fpath){
std::ofstream of(fpath, std::ios::out | std::ios::binary);
of.write(reinterpret_cast<const char*>(_data.data()), _pos / 8 + int(_pos % 8 != 0));
}
void load(const std::string &fpath){
std::ifstream rf(fpath, std::ios::in | std::ios::binary);
rf.seekg(0, rf.end);
int flen = rf.tellg();
rf.seekg(0, rf.beg);
while(flen){
uint8_t tmp;
rf.read(reinterpret_cast<char*>(&tmp), 1);
_data.push_back(tmp);
_pos+=8;
flen--;
}
}
pybind11::bytes getData() {
return pybind11::bytes(reinterpret_cast<const char*>(_data.data()), _pos / 8 + int(_pos % 8 != 0));
}
void setData(const pybind11::bytes &data) {
std::string s = data;
_data = std::vector<uint8_t>(s.begin(), s.end());
_pos = _data.size() * 8;
_fpos = 0;
}
private:
std::vector<uint8_t> _data;
int _pos;
int _fpos;
};
template <typename T_in, typename T_out>
/* T_in:
* * internal type doing computation.
* * default: uint64_t
* T_out:
* * interface type for io.
* * default: uint32_t for cxx, int32 for python
*/
class ArithmeticCodingEncoder {
/* according to paper: ARITHMETIC CODING FOR DATA COMPRESSION
*/
public:
BitStream bit_stream;
ArithmeticCodingEncoder(const int &precision){
/* precision:
* * aka Code_value_bits
* * internal precision of arithmetic coding
* * following paper ARITHMETIC CODING FOR DATA COMPRESSION
* * requires:
* * f \le c - 2 && f + c \le p
* * default: 32
*/
assert(precision >= 2 && precision < std::numeric_limits<decltype(_full_range)>::digits);
_precision = precision;
_full_range = (static_cast<decltype(_full_range)>(1) << _precision) - 1;
_quarter_range = (_full_range >> 2) + 1;
_half_range = _quarter_range * 2;
_three_forth_range = _quarter_range * 3;
int max_total_bits = std::min(_precision - 2, std::numeric_limits<decltype(_full_range)>::digits - _precision);
_max_total = (static_cast<decltype(_full_range)>(1) << max_total_bits) - 1;
_low = 0;
_high = _full_range;
_pending_bits = 0;
}
~ArithmeticCodingEncoder(){}
void encode(const T_out &sym, const T_out *cdf, const int &cdf_bits){
/* sym:
* * symbol to encode, start from 0, should statisfy 0 <= sym < sym_cnt (alphabet size)
* cdf:
* * cdf for symbol, should have sym_cnt + 1 bins, and pmf(sym) = cdf[sym + 1] - cdf[sym]
* * the last element must be power of 2, which is 2 ** cdf_bits
* cdf_bits:
* * 2 ** cdf_bits == last element of cdf, always <= _frequency_bits
*/
assert(_low < _high);
assert((_low & _full_range) == _low);
assert((_high & _full_range) == _high);
T_in range = _high - _low + 1;
T_in c_total = static_cast<decltype(c_total)>(1) << cdf_bits;
assert(c_total <= _max_total);
T_in c_low = cdf[sym];
T_in c_high = cdf[sym + 1];
assert(c_low != c_high);
_high = _low + ((c_high * range) >> cdf_bits) - 1;
_low = _low + ((c_low * range) >> cdf_bits);
_renormalize();
}
void flush(){
/* call before the end of encoding */
_pending_bits++;
bool bit = static_cast<bool>(_low >= _quarter_range);
bit_stream.push_back(bit);
for (; _pending_bits > 0; _pending_bits--)
bit_stream.push_back(!bit);
}
private:
void _renormalize(){
while(1){
if(_high < _half_range){
bit_stream.push_back(0);
for (; _pending_bits > 0; _pending_bits--)
bit_stream.push_back(1);
} else if (_low >= _half_range){
bit_stream.push_back(1);
for (; _pending_bits > 0; _pending_bits--)
bit_stream.push_back(0);
_low -= _half_range;
_high -= _half_range;
}else if(_low>=_quarter_range&&_high<_three_forth_range){
assert(_pending_bits < std::numeric_limits<decltype(_pending_bits)>::max());
_pending_bits++;
_low-=_quarter_range;
_high-=_quarter_range;
}else{
break;
}
_high = (_high << 1) + 1;
_low <<=1;
}
}
T_in _precision;
T_in _full_range;
T_in _half_range;
T_in _quarter_range;
T_in _three_forth_range;
T_in _max_total;
T_in _low;
T_in _high;
T_in _pending_bits;
};
template <typename T_in, typename T_out>
/* template args: see ArithmeticCodingEncoder */
class ArithmeticCodingDecoder{
public:
BitStream bit_stream;
ArithmeticCodingDecoder(const int &precision, const BitStream &encode_bit_stream){
/* precision:
* * See ArithmeticCodingEncoder
* encode_bit_stream:
* * the BitStream ro decode from encoder / read from file
*/
bit_stream = encode_bit_stream;
assert(precision >= 2 && precision < std::numeric_limits<decltype(_full_range)>::digits);
_precision = precision;
_full_range = (static_cast<decltype(_full_range)>(1) << _precision) - 1;
_quarter_range = (_full_range >> 2) + 1;
_half_range = _quarter_range * 2;
_three_forth_range = _quarter_range * 3;
_max_total = (static_cast<decltype(_full_range)>(1) << (_precision - 2)) - 1;
_max_total = std::min(_max_total, (static_cast<decltype(_full_range)>(1) << (std::numeric_limits<decltype(_full_range)>::digits - _precision - 1)) - 1);
_low = 0;
_high = _full_range;
_pending_bits = 0;
_code = 0;
for(int i=0;i<_precision;i++){
_code <<= 1;
_code += static_cast<int>(bit_stream.pop_front());
}
}
~ArithmeticCodingDecoder(){}
T_out decode(const int &sym_cnt, const T_out *cdf, const int &cdf_bits){
/* sym_cnt:
* * sym_cnt is the alphabet size
* cdf:
* * cdf for symbol, should have sym_cnt + 1 bins, and pmf(sym) = cdf[sym + 1] - cdf[sym]
* * the last element must be power of 2, which is 2 ** cdf_bits
* cdf_bits:
* * 2 ** cdf_bits == last element of cdf, always <= _frequency_bits
*/
T_in c_total = static_cast<decltype(c_total)>(1) << cdf_bits;
assert(c_total <= _max_total);
T_in range = _high - _low + 1;
T_in scaled_range = _code - _low;
T_in scaled_value = (((scaled_range + 1) << cdf_bits) - 1) / range;
assert(scaled_value < c_total);
T_in start = 0;
T_in end = sym_cnt;
while (end - start > 1) {
T_in middle = (start + end) >> 1;
if (cdf[middle] > scaled_value)
end = middle;
else
start = middle;
}
assert(start + 1 == end);
T_in sym = start;
T_in c_low = cdf[sym];
T_in c_high = cdf[sym + 1];
assert(c_low != c_high);
_high = _low + ((c_high * range) >> cdf_bits) - 1;
_low = _low + ((c_low * range) >> cdf_bits);
_renormalize();
return static_cast<T_out>(sym);
}
private:
void _renormalize(){
while(1){
if(_high < _half_range){
// pass
}else if(_low >= _half_range){
_code -= _half_range;
_low -= _half_range;
_high -= _half_range;
}else if(_low >= _quarter_range && _high < _three_forth_range){
_code -= _quarter_range;
_low -= _quarter_range;
_high -= _quarter_range;
}else{
break;
}
_high = (_high << 1) + 1;
_low <<=1;
_code = (_code << 1) + static_cast<int>(bit_stream.pop_front());
}
}
T_in _precision;
T_in _full_range;
T_in _half_range;
T_in _quarter_range;
T_in _three_forth_range;
T_in _max_total;
T_in _low;
T_in _high;
T_in _pending_bits;
T_in _code;
};
template <typename T_in, typename T_out>
/* template args: see ArithmeticCodingEncoder */
class RANSCodec {
public:
BitStream bit_stream;
RANSCodec(const int &h_precision, const int &t_precision){
/* h_precision:
* * precision for head part, divided by 8
* * t_precision < h_precision <= t_precision * 2
* * the larger it is, the more precise the pdf you can have
* * but also leads to more overhead in flush()
* * default: 64
* t_precision:
* * precision for tail part, divided by 8
* * the larger it is, the faster rans is
* * but also leads to more overhead in flush()
* * default: 32
*/
_h_precision = h_precision;
_t_precision = t_precision;
assert(_h_precision % 8 == 0);
assert(_t_precision % 8 == 0);
assert(_t_precision < _h_precision);
assert(_h_precision <= _t_precision * 2);
_h_min = static_cast<decltype(_h_min)>(1) << (_h_precision - _t_precision);
_state = _h_min; // max state
}
RANSCodec(const int &h_precision, const int &t_precision, const BitStream &encode_bit_stream){
_h_precision = h_precision;
_t_precision = t_precision;
_h_min = static_cast<decltype(_h_min)>(1) << (_h_precision - _t_precision);
bit_stream = encode_bit_stream;
_state = 0;
for(int i = 1; i <= _h_precision / 8; i++){
_state <<= 8;
uint8_t byte = bit_stream.pop_back_byte();
_state |= byte;
}
}
~RANSCodec(){}
void encode(const T_out &sym, const T_out *cdf, const int &cdf_bits){
/* args: See ArithmeticCodingEncoder */
T_in c_low = cdf[sym];
T_in c_range = cdf[sym + 1] - c_low;
T_in c_total = static_cast<decltype(c_total)>(1) << cdf_bits;
T_in state = _state;
T_in state_max = c_range << (_h_precision - cdf_bits);
if(state >= state_max){
T_in mask = 0xff;
for(int i = 1; i <= _t_precision / 8; i++){
bit_stream.push_back_byte(static_cast<uint8_t>(state & mask));
state >>= 8;
}
assert(state < state_max);
}
_state = ((state / c_range) << cdf_bits) + (state % c_range) + c_low;
}
void flush(){
T_in mask = 0xff;
for(int i = 1; i <= _h_precision / 8; i++){
bit_stream.push_back_byte(static_cast<uint8_t>(_state & mask));
_state >>= 8;
}
}
T_out decode(const int &sym_cnt, const T_out *cdf, const int &cdf_bits){
/* args: See ArithmeticCodingDecoder */
T_in scaled_value = _state & ((static_cast<decltype(_state)>(1) << cdf_bits) - 1);
T_in start = 0;
T_in end = sym_cnt;
while (end - start > 1) {
T_in middle = (start + end) >> 1;
if (cdf[middle] > scaled_value)
end = middle;
else
start = middle;
}
assert(start + 1 == end);
T_in sym = start;
T_in c_low = cdf[sym];
T_in c_range = cdf[sym + 1] - c_low;
T_in state = _state;
state = c_range * (state >> cdf_bits) + scaled_value - c_low;
if (state < _h_min){
for(int i = 1; i <= _t_precision / 8; i++){
state <<= 8;
uint8_t byte = bit_stream.pop_back_byte();
state |= byte;
}
assert (state >= _h_min);
}
_state = state;
return static_cast<T_out>(sym);
}
private:
T_in _state;
T_in _h_precision;
T_in _t_precision;
T_in _t_mask;
T_in _h_min;
};
}
#endif