Skip to content

Commit 087939b

Browse files
committed
fix: Use validity in rolling fit
1 parent 0a629c7 commit 087939b

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

src/fpm_risk_model/rolling_factor_risk_model.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import json
2+
from datetime import datetime
23
from multiprocessing import cpu_count
34
from multiprocessing.pool import ThreadPool
45
from os import makedirs
56
from os.path import join
6-
from typing import Optional
7+
from typing import Dict, Optional
78

89
from pandas import DataFrame, Timestamp
910

1011
from .factor_risk_model import FactorRiskModel
12+
from .risk_model import RiskModel
1113
from .rolling_risk_model import RollingRiskModel
1214

1315

@@ -20,11 +22,37 @@ class RollingFactorRiskModel(RollingRiskModel):
2022
returns.
2123
"""
2224

23-
def __init__(self, **kwargs):
25+
def __init__(
26+
self,
27+
model: Optional[RiskModel] = None,
28+
window: Optional[int] = None,
29+
show_progress: Optional[bool] = False,
30+
values: Optional[Dict[datetime, RiskModel]] = None,
31+
):
2432
"""
2533
Constructor.
34+
35+
Parameters
36+
----------
37+
model: Optional[RiskModel]
38+
Risk model object to fit in rolling basis.
39+
40+
window: Optional[int]
41+
Number of rolling windows to use from the returns.
42+
Must be provided in fitting the model.
43+
44+
show_progress: Optional[bool]
45+
Indicate to show progress bar in running.
46+
47+
values: Optional[Dict[datetime, RiskModel]]
48+
Rolling risk models values.
2649
"""
27-
super().__init__(**kwargs)
50+
super().__init__(
51+
model=model,
52+
window=window,
53+
show_progress=show_progress,
54+
values=values,
55+
)
2856

2957
def transform(
3058
self,

src/fpm_risk_model/rolling_risk_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,19 @@ def fit(
129129
values = {}
130130

131131
T = X.shape[0]
132+
start_index = 0
133+
if validity is not None:
134+
try:
135+
start_index = X.index.get_loc(validity.index[0])
136+
except: # noqa: E722
137+
raise ValueError(
138+
f"Validity index (e.g. {validity.index[0]}) should "
139+
"exist in X index"
140+
)
141+
142+
if self._config.window is None:
143+
raise ValueError("The window must be provided in the config.")
144+
132145
iterator = range(0, T)
133146
if self._config.show_progress:
134147
from tqdm import tqdm

0 commit comments

Comments
 (0)