-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabase.py
More file actions
215 lines (173 loc) · 7.16 KB
/
database.py
File metadata and controls
215 lines (173 loc) · 7.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Face database module with K-Means cluster indexing.
Maps cluster IDs to lists of encrypted enrollment entries.
"""
import pickle
import os
from typing import List, Dict, Tuple, Optional, Union
from collections import defaultdict
import numpy as np
class FaceDatabase:
"""
Encrypted face database with K-Means cluster routing.
Structure:
{
cluster_id (int): [
{
'person_id': str,
'enc_template': <GGSW ciphertext>,
'feature': np.ndarray, # retained for evaluation only
},
...
],
...
}
"""
def __init__(self, num_tables: int = 10, use_kmeans: bool = True):
self.data: Dict[Union[str, int], List[Dict]] = defaultdict(list)
self.person_to_id: Dict[str, Union[str, int]] = {}
self.total_entries = 0
self.use_kmeans = use_kmeans
self.num_tables = num_tables
self.table_data: List[Dict[str, List[Dict]]] = [defaultdict(list) for _ in range(num_tables)]
def add_entry(self,
index_id: Union[str, int],
person_id: str,
enc_template,
feature: Optional[np.ndarray] = None,
table_hashes: Optional[List[str]] = None):
"""
Add an encrypted entry to the database.
Args:
index_id: Cluster ID (int) or LSH bucket ID (str)
person_id: Identity label
enc_template: GGSW ciphertext
feature: Plaintext feature (optional; for evaluation only)
table_hashes: Per-table hash values for multi-table LSH indexing
"""
entry = {
'person_id': person_id,
'enc_template': enc_template,
'feature': feature,
'entry_id': self.total_entries
}
self.data[index_id].append(entry)
self.person_to_id[person_id] = index_id
self.total_entries += 1
if table_hashes is not None and len(table_hashes) == self.num_tables:
for table_idx, table_hash in enumerate(table_hashes):
self.table_data[table_idx][table_hash].append(entry)
def get_candidates(self, index_id: Union[str, int]) -> List[Dict]:
"""Return all entries in a given cluster/bucket."""
return self.data.get(index_id, [])
def get_candidates_top_k(self, cluster_ids: List[int]) -> Tuple[List[Dict], Dict]:
"""
Retrieve all candidates across a list of cluster IDs.
Args:
cluster_ids: Top-K cluster indices
Returns:
candidates: Merged list of entries
stats: Counts per cluster and total
"""
candidates = []
per_cluster_count = {}
for cluster_id in cluster_ids:
cluster_candidates = self.data.get(cluster_id, [])
candidates.extend(cluster_candidates)
per_cluster_count[cluster_id] = len(cluster_candidates)
stats = {
'total': len(candidates),
'per_cluster': per_cluster_count,
'num_clusters': len(cluster_ids)
}
return candidates, stats
def get_candidates_multi_table(self, table_hashes: List[str]) -> List[Dict]:
"""
Multi-table OR-query: return entries matching any table hash (deduplicated).
"""
seen_entry_ids = set()
candidates = []
for table_idx, table_hash in enumerate(table_hashes):
if table_idx >= len(self.table_data):
continue
table_candidates = self.table_data[table_idx].get(table_hash, [])
for entry in table_candidates:
entry_id = entry['entry_id']
if entry_id not in seen_entry_ids:
seen_entry_ids.add(entry_id)
candidates.append(entry)
return candidates
def get_by_person_id(self, person_id: str) -> Optional[Dict]:
"""Look up an entry by person ID (for evaluation)."""
index_id = self.person_to_id.get(person_id)
if index_id is None:
return None
candidates = self.data[index_id]
for entry in candidates:
if entry['person_id'] == person_id:
return entry
return None
def get_statistics(self) -> Dict:
"""Return summary statistics about the database."""
bucket_sizes = [len(entries) for entries in self.data.values()]
stats = {
'total_entries': self.total_entries,
'num_buckets': len(self.data),
'num_unique_persons': len(self.person_to_id),
'avg_bucket_size': np.mean(bucket_sizes) if bucket_sizes else 0,
'max_bucket_size': max(bucket_sizes) if bucket_sizes else 0,
'min_bucket_size': min(bucket_sizes) if bucket_sizes else 0,
'use_kmeans': self.use_kmeans,
}
return stats
def print_statistics(self):
"""Print database statistics."""
stats = self.get_statistics()
print("\n" + "="*60)
print("Database Statistics")
print("="*60)
print(f"Total entries: {stats['total_entries']:,}")
print(f"Unique identities: {stats['num_unique_persons']:,}")
print(f"Active buckets: {stats['num_buckets']:,}")
print(f"Avg bucket size: {stats['avg_bucket_size']:.2f}")
print(f"Max bucket size: {stats['max_bucket_size']}")
print(f"Min bucket size: {stats['min_bucket_size']}")
print("="*60 + "\n")
def save(self, filepath: str):
"""Serialize database to disk."""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
table_data_serializable = [dict(table) for table in self.table_data]
save_data = {
'data': dict(self.data),
'person_to_id': self.person_to_id,
'total_entries': self.total_entries,
'num_tables': self.num_tables,
'table_data': table_data_serializable,
'use_kmeans': self.use_kmeans
}
with open(filepath, 'wb') as f:
pickle.dump(save_data, f)
print(f"Database saved: {filepath} ({self.total_entries:,} entries)")
@classmethod
def load(cls, filepath: str):
"""Load a serialized database."""
with open(filepath, 'rb') as f:
save_data = pickle.load(f)
num_tables = save_data.get('num_tables', 10)
use_kmeans = save_data.get('use_kmeans', False)
db = cls(num_tables=num_tables, use_kmeans=use_kmeans)
db.data = defaultdict(list, save_data['data'])
if 'person_to_id' in save_data:
db.person_to_id = save_data['person_to_id']
elif 'person_to_lsh' in save_data:
db.person_to_id = save_data['person_to_lsh']
else:
db.person_to_id = {}
db.total_entries = save_data['total_entries']
if 'table_data' in save_data:
db.table_data = [defaultdict(list, table) for table in save_data['table_data']]
else:
db.table_data = [defaultdict(list) for _ in range(num_tables)]
mode = "K-Means" if use_kmeans else "LSH"
print(f"Database loaded: {filepath} (mode={mode}, entries={db.total_entries:,})")
return db