Skip to content

Commit 2ef58f2

Browse files
committed
clib.conversion: Add type hints and improve docstrings for dataarray_to_matrix
1 parent 054d148 commit 2ef58f2

1 file changed

Lines changed: 32 additions & 32 deletions

File tree

pygmt/clib/conversion.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,40 @@
77
from collections.abc import Sequence
88

99
import numpy as np
10+
import xarray as xr
1011
from pygmt.exceptions import GMTInvalidInput
1112

1213

13-
def dataarray_to_matrix(grid):
14+
def dataarray_to_matrix(
15+
grid: xr.DataArray,
16+
) -> tuple[np.ndarray, list[float], list[float]]:
1417
"""
15-
Transform an xarray.DataArray into a data 2-D array and metadata.
18+
Transform an xarray.DataArray into a 2-D numpy array and metadata.
1619
17-
Use this to extract the underlying numpy array of data and the region and
18-
increment for the grid.
20+
Use this to extract the underlying numpy array of data and the region and increment
21+
for the grid.
1922
20-
Only allows grids with two dimensions and constant grid spacing (GMT
21-
doesn't allow variable grid spacing). If the latitude and/or longitude
22-
increments of the input grid are negative, the output matrix will be
23-
sorted by the DataArray coordinates to yield positive increments.
23+
Only allows grids with two dimensions and constant grid spacings (GMT doesn't allow
24+
variable grid spacings). If the latitude and/or longitude increments of the input
25+
grid are negative, the output matrix will be sorted by the DataArray coordinates to
26+
yield positive increments.
2427
25-
If the underlying data array is not C contiguous, for example if it's a
26-
slice of a larger grid, a copy will need to be generated.
28+
If the underlying data array is not C contiguous, for example if it's a slice of a
29+
larger grid, a copy will need to be generated.
2730
2831
Parameters
2932
----------
30-
grid : xarray.DataArray
31-
The input grid as a DataArray instance. Information is retrieved from
32-
the coordinate arrays, not from headers.
33+
grid
34+
The input grid as a DataArray instance. Information is retrieved from the
35+
coordinate arrays, not from headers.
3336
3437
Returns
3538
-------
36-
matrix : 2-D array
39+
matrix
3740
The 2-D array of data from the grid.
38-
region : list
41+
region
3942
The West, East, South, North boundaries of the grid.
40-
inc : list
43+
inc
4144
The grid spacing in East-West and North-South, respectively.
4245
4346
Raises
@@ -62,8 +65,8 @@ def dataarray_to_matrix(grid):
6265
(180, 360)
6366
>>> matrix.flags.c_contiguous
6467
True
65-
>>> # Using a slice of the grid, the matrix will be copied to guarantee
66-
>>> # that it's C-contiguous in memory. The increment should be unchanged.
68+
>>> # Using a slice of the grid, the matrix will be copied to guarantee that it's
69+
>>> # C-contiguous in memory. The increment should be unchanged.
6770
>>> matrix, region, inc = dataarray_to_matrix(grid[10:41, 30:101])
6871
>>> matrix.flags.c_contiguous
6972
True
@@ -73,7 +76,7 @@ def dataarray_to_matrix(grid):
7376
[-150.0, -79.0, -80.0, -49.0]
7477
>>> print(inc)
7578
[1.0, 1.0]
76-
>>> # but not if only taking every other grid point.
79+
>>> # The increment should change acoordingly if taking every other grid point.
7780
>>> matrix, region, inc = dataarray_to_matrix(grid[10:41:2, 30:101:2])
7881
>>> matrix.flags.c_contiguous
7982
True
@@ -85,21 +88,19 @@ def dataarray_to_matrix(grid):
8588
[2.0, 2.0]
8689
"""
8790
if len(grid.dims) != 2:
88-
raise GMTInvalidInput(
89-
f"Invalid number of grid dimensions '{len(grid.dims)}'. Must be 2."
90-
)
91+
msg = f"Invalid number of grid dimensions 'len({grid.dims})'. Must be 2."
92+
raise GMTInvalidInput(msg)
93+
9194
# Extract region and inc from the grid
92-
region = []
93-
inc = []
94-
# Reverse the dims because it is rows, columns ordered. In geographic
95-
# grids, this would be North-South, East-West. GMT's region and inc are
96-
# East-West, North-South.
95+
region, inc = [], []
96+
# Reverse the dims because it is rows, columns ordered. In geographic grids, this
97+
# would be North-South, East-West. GMT's region and inc are East-West, North-South.
9798
for dim in grid.dims[::-1]:
9899
coord = grid.coords[dim].to_numpy()
99-
coord_incs = coord[1:] - coord[0:-1]
100+
coord_incs = coord[1:] - coord[:-1]
100101
coord_inc = coord_incs[0]
101102
if not np.allclose(coord_incs, coord_inc):
102-
# calculate the increment if irregular spacing is found
103+
# Calculate the increment if irregular spacing is found.
103104
coord_inc = (coord[-1] - coord[0]) / (coord.size - 1)
104105
msg = (
105106
f"Grid may have irregular spacing in the '{dim}' dimension, "
@@ -108,9 +109,8 @@ def dataarray_to_matrix(grid):
108109
)
109110
warnings.warn(msg, category=RuntimeWarning, stacklevel=2)
110111
if coord_inc == 0:
111-
raise GMTInvalidInput(
112-
f"Grid has a zero increment in the '{dim}' dimension."
113-
)
112+
msg = f"Grid has a zero increment in the '{dim}' dimension."
113+
raise GMTInvalidInput(msg)
114114
region.extend(
115115
[
116116
coord.min() - coord_inc / 2 * grid.gmt.registration,

0 commit comments

Comments
 (0)