问题:graphsage 使用dataloader运行到最后一个batch卡死(batchsize=128),我看数据最后一个batch剩下1个node;另外dataloader 射者work=1也是卡死,无法结束?
(1)数据部分:
num_nodes = data.graph.num_nodes
predict_nodes = range(0, data.graph.num_nodes)
labels = [0] * num_nodes
test_ds = ShardedDataset(predict_nodes, labels)
collate_fn = partial(batch_fn, graph=graph, samples=args.samples)
test_loader = Dataloader(
test_ds,
batch_size=args.batch_size,#128
shuffle=False,
num_workers=args.sample_workers,#5后来改成1都卡死
collate_fn=collate_fn)
=================
调用
def batch_fn(batch_ex, graph, samples):
""" batch_fn """
batch_train_samples = []
batch_train_labels = []
for i, l in batch_ex:
batch_train_samples.append(i)
batch_train_labels.append(l)
log.info("=====graphsage_sample============")
subgraphs = graphsage_sample(graph, batch_train_samples, samples)
log.info("=====graphsage_sample end============")
subgraph, sample_index, node_index = subgraphs[0]
node_label = np.array(batch_train_labels, dtype="int64").reshape([-1, 1])
return [subgraph, sample_index, node_index, node_label]
=================
(2)模型配置:
model = GraphSage(
input_size=data.feature.shape[-1],
num_class=data.num_classes,
hidden_size=args.hidden_size,
num_layers=len(args.samples))
graphsage_bug.md
问题:graphsage 使用dataloader运行到最后一个batch卡死(batchsize=128),我看数据最后一个batch剩下1个node;另外dataloader 射者work=1也是卡死,无法结束?
(1)数据部分:
num_nodes = data.graph.num_nodes
predict_nodes = range(0, data.graph.num_nodes)
labels = [0] * num_nodes
test_ds = ShardedDataset(predict_nodes, labels)
=================
调用
def batch_fn(batch_ex, graph, samples):
""" batch_fn """
batch_train_samples = []
batch_train_labels = []
for i, l in batch_ex:
batch_train_samples.append(i)
batch_train_labels.append(l)
=================
(2)模型配置:
model = GraphSage(
input_size=data.feature.shape[-1],
num_class=data.num_classes,
hidden_size=args.hidden_size,
num_layers=len(args.samples))
graphsage_bug.md