Skip to content

Commit d14c5dc

Browse files
refactor(models): optimize to_yaml method
1 parent 6430e9c commit d14c5dc

1 file changed

Lines changed: 6 additions & 13 deletions

File tree

models/strategy/travserse_strategy.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass, field
1+
from dataclasses import dataclass, field, fields
22

33
from models.strategy.base_strategy import BaseStrategy
44

@@ -23,17 +23,10 @@ class TraverseStrategy(BaseStrategy):
2323
isolated_node_strategy: str = "add" # "add" or "ignore"
2424
# 难度顺序 ["easy", "medium", "hard"], ["hard", "medium", "easy"], ["medium", "medium", "medium"]
2525
difficulty_order: list = field(default_factory=lambda: ["medium", "medium", "medium"])
26+
loss_strategy: str = "only_edge" # only_edge, both
2627

2728
def to_yaml(self):
28-
return {
29-
"traverse_strategy": {
30-
"expand_method": self.expand_method,
31-
"bidirectional": self.bidirectional,
32-
"max_extra_edges": self.max_extra_edges,
33-
"max_tokens": self.max_tokens,
34-
"max_depth": self.max_depth,
35-
"edge_sampling": self.edge_sampling,
36-
"isolated_node_strategy": self.isolated_node_strategy,
37-
"difficulty_order": self.difficulty_order
38-
}
39-
}
29+
strategy_dict = {}
30+
for f in fields(self):
31+
strategy_dict[f.name] = getattr(self, f.name)
32+
return {"traverse_strategy": strategy_dict}

0 commit comments

Comments
 (0)