-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathexample_ranking_din.py
More file actions
276 lines (238 loc) · 9.25 KB
/
example_ranking_din.py
File metadata and controls
276 lines (238 loc) · 9.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""
深度兴趣网络(DIN)模型示例
文件说明:
本示例演示如何使用 NextRec 框架训练 DIN (Deep Interest Network) 排序模型。
DIN 模型通过注意力机制对用户历史行为序列进行建模,能够捕捉用户对不同商品的兴趣强度,
在点击率预估等排序任务中表现优异。
主要功能:
- 数据加载与预处理
- 特征定义(稠密特征、稀疏特征、序列特征)
- 行为序列与候选物品的注意力建模
- DIN 模型构建与训练
- 模型评估与预测
使用方法:
直接运行此脚本:
python tutorials/example_ranking_din.py
测试数据格式:
- user_id: 用户ID
- item_id: 候选物品ID
- label: 标签 (1表示点击, 0表示未点击)
- dense_*: 稠密特征
- sparse_*: 稀疏特征
- sequence_*: 序列特征(用户历史行为序列,字符串格式的列表)
模型架构:
使用 DIN 模型:
- 注意力机制: 计算历史行为序列中每个物品与候选物品的相关性
- 加权池化: 根据注意力权重对历史行为进行加权求和
- MLP 预测: 将特征和加权后的行为表示输入 MLP 进行点击率预估
输出:
- 训练好的模型
- 预测结果
- 评估指标(AUC、GAUC、LogLoss)
作者: NextRec Team
创建日期: 2026
最后更新: 2026-01-28
"""
import pandas as pd
from sklearn.model_selection import train_test_split
from nextrec.models.ranking.din import DIN
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
from nextrec.data.preprocessor import DataProcessor
# ==============================================================================
# 1. 数据加载和预处理
# ==============================================================================
# 加载排序任务数据集
df = pd.read_csv("dataset/ranking_task.csv")
# 将序列特征从字符串格式转换为列表格式
# 数据集中序列特征以字符串形式存储,需要使用 eval 转换为 Python 列表
for col in df.columns:
if "sequence" in col:
df[col] = df[col].apply(lambda x: eval(x) if isinstance(x, str) else x)
# 同一特征做多种预处理变换
# 变换后会新增列名: {column}_{preprocess_method}
processor = DataProcessor()
# dense_0 同时做 minmax 和 robust
processor.add_numeric_feature("dense_0", scaler="minmax")
processor.add_numeric_feature("dense_0", scaler="robust")
# dense_1 同时做 standard 和 log
processor.add_numeric_feature("dense_1", scaler="standard")
processor.add_numeric_feature("dense_1", scaler="log")
# sparse_0 同时做 label 和 hash
processor.add_sparse_feature("sparse_0", encode_method="label")
processor.add_sparse_feature("sparse_0", encode_method="hash", hash_size=5000)
# sparse_1 同时做 label 和 hash
processor.add_sparse_feature("sparse_1", encode_method="label")
processor.add_sparse_feature("sparse_1", encode_method="hash", hash_size=5000)
processor.fit(df)
df = processor.transform(df, return_dict=False).to_pandas()
vocab_sizes = processor.get_vocab_sizes()
# ==============================================================================
# 2. 数据集划分
# ==============================================================================
# 划分训练集和验证集(80% 训练, 20% 验证)
train_df, valid_df = train_test_split(df, test_size=0.2, random_state=2024)
# ==============================================================================
# 3. 特征定义
# ==============================================================================
# 定义稠密特征:
# 原始 dense_0~dense_7 + 多变换新增列
dense_cols = [f"dense_{i}" for i in range(8)] + [
"dense_0_minmax",
"dense_0_robust",
"dense_1_standard",
"dense_1_log",
]
dense_features = [DenseFeature(name=col, input_dim=1) for col in dense_cols]
# 定义稀疏特征
# user_id 和 item_id 使用较大的 embedding 维度(32)
sparse_features = [
SparseFeature(
name="user_id",
embedding_name="user_emb", # embedding 名称,用于权重共享
vocab_size=int(df["user_id"].max() + 1),
embedding_dim=32,
),
SparseFeature(
name="item_id",
embedding_name="item_emb", # 与序列特征共享 embedding
vocab_size=int(df["item_id"].max() + 1),
embedding_dim=32,
),
]
# 添加其他稀疏特征(10个)
# 注意: sparse_0、sparse_1 已做多重预处理并使用新增列,
# 这里仅保留未做处理的原始稀疏列.
sparse_features.extend(
[
SparseFeature(
name=f"sparse_{i}",
embedding_name=f"sparse_{i}_emb",
vocab_size=int(df[f"sparse_{i}"].max() + 1),
embedding_dim=32,
)
for i in range(2, 10)
]
)
# 多变换后的稀疏特征
sparse_features.extend(
[
SparseFeature(
name="sparse_0_label",
embedding_name="sparse_0_label_emb",
vocab_size=vocab_sizes["sparse_0_label"],
embedding_dim=32,
),
SparseFeature(
name="sparse_0_hash",
embedding_name="sparse_0_hash_emb",
vocab_size=vocab_sizes["sparse_0_hash"],
embedding_dim=32,
),
SparseFeature(
name="sparse_1_label",
embedding_name="sparse_1_label_emb",
vocab_size=vocab_sizes["sparse_1_label"],
embedding_dim=32,
),
SparseFeature(
name="sparse_1_hash",
embedding_name="sparse_1_hash_emb",
vocab_size=vocab_sizes["sparse_1_hash"],
embedding_dim=32,
),
]
)
# 定义序列特征
# sequence_0: 用户历史浏览物品序列,与 item_id 共享 embedding (item_emb)
sequence_features = [
SequenceFeature(
name="sequence_0",
vocab_size=int(df["sequence_0"].apply(lambda x: max(x)).max() + 1),
embedding_dim=32,
padding_idx=0, # 填充索引
embedding_name="item_emb", # 与 item_id 共享 embedding
),
SequenceFeature(
name="sequence_1",
vocab_size=int(df["sequence_1"].apply(lambda x: max(x)).max() + 1),
embedding_dim=16,
padding_idx=0,
embedding_name="sequence_1_emb",
),
]
# ==============================================================================
# 4. 模型构建
# ==============================================================================
# 定义 MLP 参数(深度神经网络部分)
mlp_params = {
"hidden_dims": [256, 128, 64], # 隐藏层维度
"activation": "relu", # 激活函数
"dropout": 0.3, # Dropout 比例
}
# 创建 DIN 模型
model = DIN(
dense_features=dense_features,
sparse_features=sparse_features,
sequence_features=sequence_features,
behavior_feature_name="sequence_0", # 行为序列特征名称(用户历史浏览物品)
candidate_feature_name="item_id", # 候选物品特征名称(当前物品)
mlp_params=mlp_params, # MLP 参数
attention_mlp_params={ # 注意力网络参数
"hidden_dims": [80, 40], # 注意力 MLP 的隐藏层维度
"activation": "dice", # 使用 DICE 激活函数
"dropout": 0.2,
},
attention_use_softmax=True, # 是否使用 softmax 归一化注意力权重
target=["label"], # 目标列名
device="cpu",
session_id="din_tutorial",
)
# ==============================================================================
# 5. 模型编译
# ==============================================================================
# 编译模型:配置优化器、学习率调度器和损失函数
model.compile(
optimizer="adam",
optimizer_params={"lr": 1e-3, "weight_decay": 1e-5}, # Adam 优化器参数
scheduler="step", # 使用 StepLR 学习率调度器
scheduler_params={"step_size": 3, "gamma": 0.5}, # 每3轮学习率衰减为原来的0.5倍
loss="focal", # 使用 Focal Loss 缓解类别不平衡
loss_params={"gamma": 2.0, "alpha": 0.25}, # Focal Loss 参数
)
# ==============================================================================
# 6. 模型训练
# ==============================================================================
model.fit(
train_data=train_df,
valid_data=valid_df,
metrics=["auc", "gauc", "logloss"], # 评估指标: AUC、GAUC(分组AUC)、对数损失
epochs=1, # 训练轮数
batch_size=512, # 批次大小
shuffle=True, # 是否打乱训练数据
group_id="user_id", # 指定分组列,用于计算 GAUC
)
print("Training Complete!")
# ==============================================================================
# 7. 模型预测
# ==============================================================================
print("Prediction")
# 对验证集进行预测
predictions = model.predict(valid_df, batch_size=512, return_dataframe=True)
print(f"Prediction shape: {predictions.shape}")
print(f"Prediction sample: {predictions[:10]}")
# ==============================================================================
# 8. 模型评估
# ==============================================================================
# 对验证集进行评估
metrics = model.evaluate(
valid_df,
metrics=["auc", "gauc", "logloss"],
batch_size=512,
group_id="user_id",
)
# 打印评估指标
for name, value in metrics.items():
print(f"{name}: {value:.6f}")
print("")
print("DIN Example Complete!")
print("")