Skip to content

问题:graphsage 使用dataloader运行到最后一个batch卡死(batchsize=128),我看数据最后一个batch剩下1个node;另外dataloader 射者work=1也是卡死,无法结束? #561

@niushixiong

Description

@niushixiong

问题:graphsage 使用dataloader运行到最后一个batch卡死(batchsize=128),我看数据最后一个batch剩下1个node;另外dataloader 射者work=1也是卡死,无法结束?

(1)数据部分:
num_nodes = data.graph.num_nodes
predict_nodes = range(0, data.graph.num_nodes)
labels = [0] * num_nodes
test_ds = ShardedDataset(predict_nodes, labels)

collate_fn = partial(batch_fn, graph=graph, samples=args.samples)

test_loader = Dataloader(
    test_ds,
    batch_size=args.batch_size,#128
    shuffle=False,
    num_workers=args.sample_workers,#5后来改成1都卡死
    collate_fn=collate_fn)

=================
调用
def batch_fn(batch_ex, graph, samples):
""" batch_fn """
batch_train_samples = []
batch_train_labels = []
for i, l in batch_ex:
batch_train_samples.append(i)
batch_train_labels.append(l)

log.info("=====graphsage_sample============")
subgraphs = graphsage_sample(graph, batch_train_samples, samples)
log.info("=====graphsage_sample end============")
subgraph, sample_index, node_index = subgraphs[0]

node_label = np.array(batch_train_labels, dtype="int64").reshape([-1, 1])

return [subgraph, sample_index, node_index, node_label]

=================

(2)模型配置:
model = GraphSage(
input_size=data.feature.shape[-1],
num_class=data.num_classes,
hidden_size=args.hidden_size,
num_layers=len(args.samples))
graphsage_bug.md

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions