-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbaseline.py
More file actions
executable file
·70 lines (59 loc) · 2.47 KB
/
baseline.py
File metadata and controls
executable file
·70 lines (59 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from models.BertClassifier import load_bert_model, build_bert_model
from models.BigClone import load_BigClone_model, build_BigClone_model
from dataloaders.bigclone import bigClone_get_dataloader
from dataloaders.imdb import imdb_get_dataloader
from dataloaders.snli import snli_get_dataloader
class ModelBase():
def __init__(self, model_type, number_classes, device='cuda'):
self.model_type = model_type
self.number_classes = number_classes
self.device = device
def build_model(self):
if self.model_type == 'bert':
self.model = build_bert_model(self.number_classes, self.device)
elif self.model_type == 'BigCloneModel':
self.model = build_BigClone_model(2, self.device)
def load_model(self, path_pretrain):
if self.model_type == 'bert':
self.model = load_bert_model(path_pretrain, self.number_classes, self.device)
elif self.model_type == 'BigCloneModel':
self.model = load_BigClone_model(path_pretrain, 2, self.device)
def inference(self, data):
if self.model_type == 'bert':
ids = data['ids'].to(self.device)
attention_mask = data['attention_mask'].to(self.device)
token_type_ids = data['token_type_ids'].to(self.device)
labels = data['label'].to(self.device)
predictions = self.model(ids, attention_mask, token_type_ids)
elif self.model_type == 'BigCloneModel':
sample, labels = data
sample = sample.to(self.device)
labels = labels.to(self.device)
predictions = self.model(sample)
return predictions, labels
class DataBase():
def __init__(self, type_data):
self.type_data = type_data
def get_dataloader(self, df, batch_size, mode, num_workers=0):
if self.type_data == 'bigclone':
return bigClone_get_dataloader(
df=df,
batch_size=batch_size,
mode=mode,
num_workers=num_workers
)
elif self.type_data == 'imdb':
return imdb_get_dataloader(
df=df,
batch_size=batch_size,
mode=mode,
num_workers=num_workers
)
elif self.type_data == 'snli':
return snli_get_dataloader(
df=df,
batch_size=batch_size,
mode=mode,
num_workers=num_workers
)