-
Notifications
You must be signed in to change notification settings - Fork 558
Expand file tree
/
Copy pathpubmed.py
More file actions
124 lines (102 loc) · 5.02 KB
/
pubmed.py
File metadata and controls
124 lines (102 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import tqdm
import numpy as np
import urllib
import tarfile
from tf_euler.python.dataset.pubmed_utils import *
from tf_euler.python.dataset.base_dataset import DataSet
current_dir = os.path.dirname(os.path.realpath(__file__))
pubmed_dir = os.path.join(current_dir, 'Pubmed-Diabetes')
class pubmed(DataSet):
def __init__(self, data_dir=pubmed_dir, data_type='all'):
super(pubmed, self).__init__(data_dir, data_type)
self.source_url = 'https://linqs-data.soe.ucsc.edu/public/Pubmed-Diabetes.tgz'
self.origin_files = ['data/Pubmed-Diabetes.NODE.paper.tab',
'data/Pubmed-Diabetes.DIRECTED.cites.tab']
self.max_node_id = 19717
self.train_node_type = ['train']
self.train_edge_type = ['train']
self.total_size = 19717
self.all_node_type = -1
self.all_edge_type = ['train', 'train_removed']
self.id_file = os.path.join(data_dir, 'pubmed_test.id')
self.feature_idx = 'feature'
self.feature_dim = 500
self.label_idx = 'label'
self.label_dim = 3
self.num_classes = 3
self.test_start_num = 18717
def download_data(self, source_url, out_dir):
pubmed_tgz_dir = os.path.join(out_dir, 'Pubmed-Diabetes.tgz')
out_dir = os.path.join(out_dir, "..")
DataSet.download_file(source_url, pubmed_tgz_dir)
with tarfile.open(pubmed_tgz_dir) as pubmed_file:
print('unzip data..')
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)
safe_extract(pubmed_file, out_dir)
def convert2json(self, convert_dir, out_dir):
def add_node(id, type, weight, label, feature):
node_buf = {}
node_buf["id"] = id
node_buf["type"] = type
node_buf["weight"] = weight
node_buf["features"] = [{}, {}]
node_buf['features'][0]['name'] = 'label'
node_buf['features'][0]['type'] = 'dense'
node_buf['features'][0]['value'] = label.astype(
float).tolist()
node_buf['features'][1]['name'] = 'feature'
node_buf['features'][1]['type'] = 'dense'
feature = feature.astype(float)
feature /= np.sum(feature) + 1e-7
node_buf['features'][1]['value'] = feature.tolist()
return node_buf
def add_edge(src, dst, type, weight):
edge_buf = {}
edge_buf["src"] = src
edge_buf["dst"] = dst
edge_buf["type"] = type
edge_buf["weight"] = weight
edge_buf["features"] = []
return edge_buf
node_ids, node_type, node_label, node_feature, edge_src, edge_dst, edge_type = \
parse_graph_file(convert_dir, self.num_classes, 'Pubmed-Diabetes', self.feature_dim + 2, self.test_start_num)
with open(out_dir, 'w') as out, open(self.id_file, 'w') as out_test:
buf = {}
buf["nodes"] = []
buf["edges"] = []
for one_node, one_type, one_label, one_feature in zip(node_ids, node_type, node_label, node_feature):
buf["nodes"].append(add_node(one_node, one_type, 1, one_label, one_feature))
if one_type == "test":
out_test.write(str(one_node) + "\n")
for one_src, one_dst, one_type in zip(edge_src, edge_dst, edge_type):
buf["edges"].append(add_edge(one_src, one_dst, one_type, 1))
out.write(json.dumps(buf))