Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions launch_rn_babi.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,21 @@ def init_weights(m):

if args.epochs > 0:
print("Start training")
avg_train_losses, avg_train_accuracies, val_losses, val_accuracies = train(train_stories, validation_stories, args.epochs, lstm, rn, criterion, optimizer, args.no_save, device, result_folder, args.batch_size)
avg_train_losses, avg_train_accuracies, val_losses, val_accuracies = train(train_stories, validation_stories, args.epochs, lstm, rn, criterion, optimizer, args.no_save, device, result_folder, args.batch_size, dict_size)
print("End training!")

if not args.test_on_test:
test_stories = validation_stories

if args.test_jointly:
print("Testing jointly...")
avg_test_loss, avg_test_accuracy = test(test_stories, lstm, rn, criterion, device, args.batch_size)
avg_test_loss, avg_test_accuracy = test(test_stories, lstm, rn, criterion, device, args.batch_size, dict_size)

print("Test accuracy: ", avg_test_accuracy)
print("Test loss: ", avg_test_loss)
else:
print("Testing separately...")
avg_test_accuracy = test_separately(test_stories, lstm, rn, device, args.batch_size)
avg_test_accuracy = test_separately(test_stories, lstm, rn, device, args.batch_size, dict_size)
avg_test_loss = None
print("Test accuracy: ", avg_test_accuracy)

Expand Down
6 changes: 3 additions & 3 deletions launch_rrn_babi.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,21 @@ def init_weights(m):

if args.epochs > 0:
print("Start training")
avg_train_losses, avg_train_accuracies, val_losses, val_accuracies = train(train_stories, validation_stories, args.epochs, lstm, rrn, criterion, optimizer, args.batch_size, args.no_save, device, result_folder)
avg_train_losses, avg_train_accuracies, val_losses, val_accuracies = train(train_stories, validation_stories, args.epochs, lstm, rrn, criterion, optimizer, args.batch_size, args.no_save, device, result_folder, dict_size)
print("End training!")

if not args.test_on_test:
test_stories = validation_stories

if args.test_jointly:
print("Testing jointly...")
avg_test_loss, avg_test_accuracy = test(test_stories, lstm, rrn, criterion, device, args.batch_size)
avg_test_loss, avg_test_accuracy = test(test_stories, lstm, rrn, criterion, device, args.batch_size, dict_size)

print("Test accuracy: ", avg_test_accuracy)
print("Test loss: ", avg_test_loss)
else:
print("Testing separately...")
avg_test_accuracy = test_separately(test_stories, lstm, rrn, device, args.batch_size)
avg_test_accuracy = test_separately(test_stories, lstm, rrn, device, args.batch_size, dict_size)
avg_test_loss = None
print("Test accuracy: ", avg_test_accuracy)

Expand Down
6 changes: 3 additions & 3 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __getitem__(self, idx):

return (question, answer, facts, label, ordering)

def batchify(data_batch):
def batchify(data_batch, dict_size):
'''
Custom collate_fn for dataset

Expand Down Expand Up @@ -73,12 +73,12 @@ def batchify(data_batch):
#lengths_q = torch.tensor([el.size(0) for el in q_s]).long() # number of words

rows, columns = max([ el.size(0) for el in f_s] ), max([ el.size(1) for el in f_s] )
ff = torch.ones(len(f_s), rows, columns, dtype=torch.long)*157 # len(dictionary) as pad value
ff = torch.ones(len(f_s), rows, columns, dtype=torch.long)*dict_size # len(dictionary) as pad value
for i, t in enumerate(f_s):
r, c = t.size(0), t.size(1)
ff[i, :r, :c] = t

return ( pad_sequence(q_s, batch_first=True, padding_value=157), torch.stack(a_s, dim=0), ff.view(-1, ff.size(2)), torch.stack(l_s, dim=0), o_s)
return ( pad_sequence(q_s, batch_first=True, padding_value=dict_size), torch.stack(a_s, dim=0), ff.view(-1, ff.size(2)), torch.stack(l_s, dim=0), o_s)

def save_stories(stories, valid, name):
if valid:
Expand Down
14 changes: 7 additions & 7 deletions task/babi_task/rn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict


def train(train_stories, validation_stories, epochs, lstm, rn, criterion, optimizer, no_save, device, result_folder, batch_size):
def train(train_stories, validation_stories, epochs, lstm, rn, criterion, optimizer, no_save, device, result_folder, batch_size, dict_size):

train_babi_dataset = BabiDataset(train_stories)
best_acc = 0.
Expand All @@ -20,7 +20,7 @@ def train(train_stories, validation_stories, epochs, lstm, rn, criterion, optimi
train_accuracies = []
train_losses = []

train_dataset = DataLoader(train_babi_dataset, batch_size=batch_size, shuffle=True, collate_fn=batchify, drop_last=True)
train_dataset = DataLoader(train_babi_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda b: batchify(data_batch=b, dict_size=dict_size), drop_last=True)

rn.train()
lstm.train()
Expand Down Expand Up @@ -61,7 +61,7 @@ def train(train_stories, validation_stories, epochs, lstm, rn, criterion, optimi
avg_train_losses.append(sum(train_losses)/len(train_losses))
avg_train_accuracies.append(sum(train_accuracies)/len(train_accuracies))

val_loss, val_accuracy = test(validation_stories,lstm,rn,criterion, device, batch_size)
val_loss, val_accuracy = test(validation_stories,lstm,rn,criterion, device, batch_size, dict_size)
val_accuracies.append(val_accuracy)
val_losses.append(val_loss)

Expand All @@ -87,7 +87,7 @@ def train(train_stories, validation_stories, epochs, lstm, rn, criterion, optimi
return avg_train_losses, avg_train_accuracies, val_losses, val_accuracies


def test(stories, lstm, rn, criterion, device, batch_size):
def test(stories, lstm, rn, criterion, device, batch_size, dict_size):

with torch.no_grad():

Expand All @@ -98,7 +98,7 @@ def test(stories, lstm, rn, criterion, device, batch_size):
lstm.eval()

test_babi_dataset = BabiDataset(stories)
test_dataset = DataLoader(test_babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=batchify, drop_last=True)
test_dataset = DataLoader(test_babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda b: batchify(data_batch=b, dict_size=dict_size), drop_last=True)


for batch_id, (question_batch,answer_batch,facts_batch,_,_) in enumerate(test_dataset):
Expand Down Expand Up @@ -130,7 +130,7 @@ def test(stories, lstm, rn, criterion, device, batch_size):
return test_loss / float(len(test_dataset)), test_accuracy / float(len(test_dataset))


def test_separately(stories, lstm, rn, device, batch_size):
def test_separately(stories, lstm, rn, device, batch_size, dict_size):


with torch.no_grad():
Expand All @@ -141,7 +141,7 @@ def test_separately(stories, lstm, rn, device, batch_size):
lstm.eval()

test_babi_dataset = BabiDataset(stories)
test_dataset = DataLoader(test_babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=batchify, drop_last=True)
test_dataset = DataLoader(test_babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda b: batchify(data_batch=b, dict_size=dict_size), drop_last=True)


for batch_id, (question_batch,answer_batch,facts_batch,task_label,_) in enumerate(test_dataset):
Expand Down
14 changes: 7 additions & 7 deletions task/babi_task/rrn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


REASONING_STEPS = 3
def train(train_stories, validation_stories, epochs, lstm, rrn, criterion, optimizer, batch_size, no_save, device, result_folder):
def train(train_stories, validation_stories, epochs, lstm, rrn, criterion, optimizer, batch_size, no_save, device, result_folder, dict_size):

train_babi_dataset = BabiDataset(train_stories)
best_acc = 0.
Expand All @@ -22,7 +22,7 @@ def train(train_stories, validation_stories, epochs, lstm, rrn, criterion, optim
train_accuracies = []
train_losses = []

train_dataset = DataLoader(train_babi_dataset, batch_size=batch_size, shuffle=True, collate_fn=batchify, drop_last=True)
train_dataset = DataLoader(train_babi_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda b: batchify(data_batch=b, dict_size=dict_size), drop_last=True)

rrn.train()
lstm.train()
Expand Down Expand Up @@ -77,7 +77,7 @@ def train(train_stories, validation_stories, epochs, lstm, rrn, criterion, optim
avg_train_accuracies.append(sum(train_accuracies)/len(train_accuracies))


val_loss, val_accuracy = test(validation_stories,lstm,rrn,criterion, device, batch_size)
val_loss, val_accuracy = test(validation_stories,lstm,rrn,criterion, device, batch_size, dict_size)
val_accuracies.append(val_accuracy)
val_losses.append(val_loss)

Expand All @@ -101,7 +101,7 @@ def train(train_stories, validation_stories, epochs, lstm, rrn, criterion, optim
return avg_train_losses, avg_train_accuracies, val_losses, val_accuracies


def test(stories, lstm, rrn, criterion, device, batch_size):
def test(stories, lstm, rrn, criterion, device, batch_size, dict_size):

lstm.eval()
rrn.eval()
Expand All @@ -112,7 +112,7 @@ def test(stories, lstm, rrn, criterion, device, batch_size):
test_accuracy = 0.
test_loss = 0.

test_dataset = DataLoader(test_babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=batchify, drop_last=True)
test_dataset = DataLoader(test_babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda b: batchify(data_batch=b, dict_size=dict_size), drop_last=True)


for batch_id, (question_batch,answer_batch,facts_batch,_,_) in enumerate(test_dataset):
Expand Down Expand Up @@ -151,15 +151,15 @@ def test(stories, lstm, rrn, criterion, device, batch_size):
return test_loss / float(len(test_dataset)), test_accuracy / float(len(test_dataset))


def test_separately(stories, lstm, rrn, device, batch_size):
def test_separately(stories, lstm, rrn, device, batch_size, dict_size):
lstm.eval()
rrn.eval()

with torch.no_grad():
accuracies = defaultdict(list)

test_babi_dataset = BabiDataset(stories)
test_dataset = DataLoader(test_babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=batchify, drop_last=True)
test_dataset = DataLoader(test_babi_dataset, batch_size=batch_size, shuffle=False, collate_fn=lambda b: batchify(data_batch=b, dict_size=dict_size), drop_last=True)


for batch_id, (question_batch,answer_batch,facts_batch,task_label,_) in enumerate(test_dataset):
Expand Down