Skip to content

Commit 5581b79

Browse files
feat: add augment_from_xarray for EOProduct (#113)
Co-authored-by: Sylvain Brunato <sylvain.brunato@c-s.fr>
1 parent 4e890c4 commit 5581b79

File tree

5 files changed

+714
-1
lines changed

5 files changed

+714
-1
lines changed

eodag_cube/api/product/_product.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from eodag_cube.api.product._assets import AssetsDict
4747
from eodag_cube.types import XarrayDict
4848
from eodag_cube.utils.exceptions import DatasetCreationError
49+
from eodag_cube.utils.metadata import build_bands, build_cube_metadata, merge_bands
4950
from eodag_cube.utils.xarray import try_open_dataset
5051

5152
logger = logging.getLogger("eodag-cube.api.product")
@@ -348,3 +349,61 @@ def to_xarray(
348349
xd.sort()
349350

350351
return xd
352+
353+
def augment_from_xarray(
354+
self,
355+
roles: Iterable[str] = {"data", "data-mask"},
356+
) -> EOProduct:
357+
"""
358+
Annotate the product properties and assets with STAC metadata got by fetching its xarray representation.
359+
360+
:param roles: (optional) roles of assets that must be fetched
361+
:returns: updated EOProduct
362+
"""
363+
if not self.assets:
364+
try:
365+
xd = self.to_xarray(roles=roles)
366+
except Exception:
367+
return self
368+
369+
dimensions, variables, proj_info = build_cube_metadata(xd)
370+
self.properties["cube:dimensions"] = dimensions
371+
self.properties["cube:variables"] = variables
372+
self.properties["bands"] = build_bands(xd)
373+
for key, value in proj_info.items():
374+
self.properties[key] = value
375+
376+
else:
377+
# have roles been set in assets ?
378+
roles_exist = any("roles" in a for a in self.assets.values())
379+
for asset_key, asset in self.assets.items():
380+
try:
381+
asset_roles = asset.get("roles", [])
382+
if (
383+
roles
384+
and asset_roles
385+
and not any(r in asset_roles for r in roles)
386+
or not roles
387+
or not roles_exist
388+
):
389+
continue
390+
xd = self.to_xarray(asset_key=asset_key, roles=roles)
391+
except Exception:
392+
continue
393+
394+
dimensions, variables, proj_info = build_cube_metadata(xd)
395+
asset["cube:dimensions"] = dimensions
396+
asset["cube:variables"] = variables
397+
for key, value in proj_info.items():
398+
asset[key] = value
399+
400+
has_band_data = any("band_data" in ds.data_vars for ds in xd.values())
401+
402+
if has_band_data:
403+
generated_bands = build_bands(xd)
404+
if "bands" in asset:
405+
asset["bands"] = merge_bands(asset["bands"], generated_bands)
406+
else:
407+
asset["bands"] = generated_bands
408+
409+
return self

eodag_cube/utils/metadata.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2026, CS GROUP - France, http://www.c-s.fr
3+
#
4+
# This file is part of EODAG project
5+
# https://www.github.com/CS-SI/EODAG
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
"""Metadata-related utilities for eodag-cube."""
19+
20+
from math import isnan
21+
from typing import Any, Union
22+
23+
import numpy as np
24+
from xarray import DataArray, Dataset
25+
26+
from eodag_cube.types import XarrayDict
27+
28+
29+
def extract_projection_info(ds: Dataset) -> dict[str, Any]:
30+
"""
31+
Extract projection information from a :class:`xarray.Dataset`.
32+
33+
:param ds: :class:`xarray.Dataset` to extract projection information from
34+
:return: dictionary with projection information
35+
"""
36+
proj_info: dict[str, Any] = {}
37+
38+
epsg_code = 4326
39+
proj_bbox = None
40+
41+
if hasattr(ds, "rio") and ds.rio.crs is not None:
42+
epsg_code = ds.rio.crs.to_epsg() or 4326
43+
try:
44+
proj_bbox = list(ds.rio.bounds())
45+
except Exception:
46+
proj_bbox = None
47+
48+
proj_info["proj:code"] = f"EPSG:{epsg_code}"
49+
if proj_bbox is not None:
50+
proj_info["proj:bbox"] = proj_bbox
51+
proj_info["proj:shape"] = list(ds.sizes.values())
52+
return proj_info
53+
54+
55+
def _get_nodata_value(var: DataArray) -> Union[float, str, None]:
56+
"""
57+
Get nodata value from a variable's attributes or return a default value.
58+
59+
:param var: variable to get nodata value from
60+
:return: nodata value
61+
"""
62+
if "nodata" in var.attrs:
63+
value = var.attrs["nodata"]
64+
elif "_FillValue" in var.encoding:
65+
value = var.encoding["_FillValue"]
66+
elif "missing_value" in var.encoding:
67+
value = var.encoding["missing_value"]
68+
elif hasattr(var, "rio"):
69+
value = getattr(var.rio, "encoded_nodata", None)
70+
if value is None:
71+
value = getattr(var.rio, "nodata", None)
72+
else:
73+
return None
74+
75+
if value is None:
76+
return None
77+
78+
# handle NaN
79+
value = float(value)
80+
if isnan(value):
81+
return str(value)
82+
83+
return value
84+
85+
86+
def set_variables(ds: Dataset) -> dict[str, Any]:
87+
"""
88+
Set variables metadata from a :class:`xarray.Dataset`.
89+
90+
:param ds: :class:`xarray.Dataset` to extract variables metadata from
91+
:return: dictionary with variables metadata
92+
"""
93+
variables: dict[str, dict] = {}
94+
auxiliary_geo_vars: dict[str, str] = {
95+
"latitude": "Latitude",
96+
"longitude": "Longitude",
97+
}
98+
for var_name, var in ds.data_vars.items():
99+
variables[str(var_name)] = {
100+
"dimensions": list(var.dims),
101+
"type": "data",
102+
"data_type": str(var.dtype),
103+
}
104+
if desc := var.attrs.get("description"):
105+
variables[str(var_name)]["description"] = desc
106+
variables[str(var_name)]["nodata"] = _get_nodata_value(var)
107+
108+
for aux_name, desc in auxiliary_geo_vars.items():
109+
if aux_name in ds:
110+
var = ds[aux_name]
111+
112+
if aux_name in variables:
113+
continue
114+
if aux_name in ds.dims:
115+
continue
116+
117+
variables[aux_name] = {
118+
"dimensions": list(var.dims),
119+
"type": "auxiliary",
120+
"description": desc,
121+
"data_type": str(var.dtype),
122+
}
123+
variables[aux_name]["nodata"] = _get_nodata_value(var)
124+
125+
return variables
126+
127+
128+
def build_cube_metadata(ds_dict: XarrayDict) -> tuple[dict, dict, dict]:
129+
"""
130+
Build datacube and projection metadata from a dict of :class:`xarray.Dataset`.
131+
132+
:param ds_dict: input xarray dict
133+
:return: tuple of 3 dicts for cube dimensions, cube variables and projection info
134+
"""
135+
dimensions: dict[str, dict] = {}
136+
variables: dict[str, dict] = {}
137+
138+
for ds in ds_dict.values():
139+
proj_info: dict[str, Any] = extract_projection_info(ds)
140+
141+
# Dimensions
142+
for dim_name in ds.sizes.keys():
143+
dim_name_str = str(dim_name)
144+
145+
# Type
146+
dim_type = (
147+
"spatial"
148+
if dim_name_str in ("x", "y", "lon", "lat")
149+
else "temporal"
150+
if dim_name_str == "time"
151+
else "other"
152+
)
153+
154+
dim_entry: dict[str, Any] = {"type": dim_type}
155+
156+
if dim_type == "spatial":
157+
# Axis
158+
if dim_name_str in ("x", "lon"):
159+
dim_entry["axis"] = "x"
160+
elif dim_name_str in ("y", "lat"):
161+
dim_entry["axis"] = "y"
162+
elif dim_name_str == "z":
163+
dim_entry["axis"] = "z"
164+
165+
proj_code = proj_info.get("proj:code", "EPSG:4326")
166+
try:
167+
dim_entry["reference_system"] = int(proj_code.split(":")[-1])
168+
except ValueError:
169+
pass
170+
171+
if dim_name_str in ds.coords:
172+
values = ds[dim_name_str].values
173+
if values.ndim == 1:
174+
if values.size <= 10:
175+
dim_entry["values"] = values.tolist()
176+
else:
177+
dim_entry["extent"] = (
178+
[float(values.min()), float(values.max())]
179+
if np.issubdtype(values.dtype, np.number)
180+
else [str(values.min()), str(values.max())]
181+
)
182+
diffs = np.diff(values)
183+
if np.allclose(diffs, diffs[0]):
184+
dim_entry["step"] = (
185+
float(diffs[0]) if np.issubdtype(values.dtype, np.number) else str(diffs[0])
186+
)
187+
else:
188+
dim_entry["extent"] = [float(np.nanmin(values)), float(np.nanmax(values))]
189+
190+
dimensions[dim_name_str] = dim_entry
191+
192+
# Variables
193+
var_ds = set_variables(ds)
194+
variables.update(var_ds)
195+
196+
return dimensions, variables, proj_info
197+
198+
199+
def build_bands(xd: XarrayDict) -> list[dict]:
200+
"""
201+
Build STAC bands metadata from xarray datasets.
202+
203+
If names are not available, use generic band names.
204+
205+
:param xd: input xarray dict
206+
:return: list of bands metadata
207+
"""
208+
band_count = 0
209+
210+
for ds in xd.values():
211+
for var in ds.data_vars.values():
212+
for dim in var.dims:
213+
if str(dim).lower() in ("band", "bands"):
214+
band_count = ds.sizes[dim]
215+
break
216+
if band_count:
217+
break
218+
219+
if band_count:
220+
break
221+
222+
if band_count == 0:
223+
band_count = len(next(iter(xd.values())).data_vars)
224+
225+
return [{"name": f"band{i + 1}"} for i in range(band_count)]
226+
227+
228+
def merge_bands(existing_bands: list[dict], new_bands: list[dict]) -> list[dict]:
229+
"""
230+
Merge existing bands metadata with newly generated ones from xarray.
231+
232+
Existing bands metadata take precedence over generated ones.
233+
234+
:param existing_bands: existing bands metadata
235+
:param new_bands: newly generated bands metadata
236+
:return: merged bands metadata
237+
"""
238+
merged = []
239+
240+
for i, band in enumerate(existing_bands):
241+
band = dict(band)
242+
band.setdefault("name", f"band{i + 1}")
243+
merged.append(band)
244+
245+
for i in range(len(existing_bands), len(new_bands)):
246+
merged.append(new_bands[i])
247+
248+
return merged

0 commit comments

Comments
 (0)