diff --git a/mapshader/multifile.py b/mapshader/multifile.py index f88c0bc..986b6be 100644 --- a/mapshader/multifile.py +++ b/mapshader/multifile.py @@ -2,9 +2,7 @@ import geopandas as gpd from glob import glob import itertools -import numpy as np import os -from rasterio.enums import Resampling import rioxarray # noqa: F401 from rioxarray.merge import merge_arrays from shapely.geometry import Polygon @@ -12,6 +10,7 @@ import xarray as xr from .mercator import MercatorTileDefinition +from .overview import create_single_band_overview from .transforms import get_transform_by_name @@ -160,7 +159,6 @@ def _create_overviews(self, raster_overviews, transforms, force_recreate_overvie levels_and_resolutions = raster_overviews["args"]["levels"] # dict[int, int] tuple_keys = itertools.product(levels_and_resolutions.keys(), self._bands) self._overviews = dict.fromkeys(tuple_keys, None) - band_limits = dict.fromkeys(self._bands, [None, None]) for level, resolution in levels_and_resolutions.items(): if not force_recreate_overviews: @@ -188,65 +186,9 @@ def _create_overviews(self, raster_overviews, transforms, force_recreate_overvie print(f"Overview already exists {overview_filename}", flush=True) continue - self._create_single_band_overview( - overview_shape, overview_transform, overview_crs, band, overview_filename, - transforms, band_limits[band]) - - def _create_single_band_overview(self, overview_shape, overview_transform, overview_crs, band, - overview_filename, transforms, band_limits): - # Open a block of files at a time for writing to overview DataArray. - # Block size of one file initially. - # Each file needs transforms applied before it can be resampled/reprojected. - calc_limits = band_limits[0] is None or band_limits[1] is None - overview = None - for filename in self._grid.filename: - with xr.open_dataset(filename, chunks=dict(y=512, x=512)) as ds: - da = ds[band] - crs = self._get_crs(ds) - da.rio.set_crs(crs, inplace=True) - - da = self._apply_transforms(da, transforms) - - if calc_limits: - min_ = da.min().item() - max_ = da.max().item() - # Update limits in place. - band_limits[0] = min_ if band_limits[0] is None else min(band_limits[0], min_) - band_limits[1] = max_ if band_limits[1] is None else max(band_limits[1], max_) - - # Reproject to same grid as overview. - da = da.rio.reproject( - dst_crs=overview_crs, - shape=overview_shape, - transform=overview_transform, - # resampling=Resampling.average, # Prefer this, but gives missing pixels. - resampling=Resampling.bilinear, - nodata=np.nan) - - if overview is None: - overview = da - else: - # Elementwise maximum taking into account nans. - overview = xr.where( - np.logical_and(np.isfinite(overview), ~(overview > da)), - overview, - da) - - # Remove attrs that can cause problem serializing xarrays. - for key in ["grid_mapping"]: - if key in overview.attrs: - del overview.attrs[key] - - overview.attrs["limits"] = band_limits - - # Save overview as geotiff. - print(f"Writing overview {overview_filename}", flush=True) - try: - overview.rio.to_raster(overview_filename, tags=dict(hello="Ian")) - except: # noqa: E722 - if os.path.isfile(overview_filename): - os.remove(overview_filename) - raise + create_single_band_overview( + self._grid.filename, overview_shape, overview_transform, overview_crs, band, + overview_filename, transforms) def _get_crs(self, ds): crs = ds.rio.crs diff --git a/mapshader/overview.py b/mapshader/overview.py new file mode 100644 index 0000000..1914198 --- /dev/null +++ b/mapshader/overview.py @@ -0,0 +1,83 @@ +import dask.bag as db +import numpy as np +import os +from rasterio.enums import Resampling +import xarray as xr + +from .transforms import get_transform_by_name + + +# There is some code duplication here with MultiFileRaster which should be refactored. + +def _apply_transforms(da, transforms): + # This may be called with either a single xr.DataArray that is a single band of a single + # NetCDF file, or with the merged output from a number of files called from load_bounds(). + for trans in transforms: + transform_name = trans['name'] + func = get_transform_by_name(transform_name) + args = trans.get('args', {}) + + if 'overviews' in transform_name: + pass + else: + da = func(da, **args) + + return da + +def _get_crs(ds): + crs = ds.rio.crs + if not crs: + # Fallback for reading spatial_ref written in strange way. + crs = ds.spatial_ref.spatial_ref + return crs + +def _overview_combine(da1, da2): + # Elementwise maximum taking into account nans. + return xr.where(np.logical_and(np.isfinite(da1), ~(da1 > da2)), da1, da2) + +def _overview_map(filename, band, overview_crs, overview_shape, overview_transform, transforms): + with xr.open_dataset(filename, chunks=dict(y=512, x=512)) as ds: + da = ds[band] + da = da.squeeze() + crs = _get_crs(ds) + da.rio.set_crs(crs, inplace=True) + + da = _apply_transforms(da, transforms) + + # Reproject to same grid as overview. + da = da.rio.reproject( + dst_crs=overview_crs, + shape=overview_shape, + transform=overview_transform, + # resampling=Resampling.average, # Prefer this, but gives missing pixels. + resampling=Resampling.bilinear, + nodata=np.nan) + + return da + +def create_single_band_overview(filenames, overview_shape, overview_transform, overview_crs, band, + overview_filename, transforms): + bag = db.from_sequence(filenames) + + # Map from filename to reprojected xr.DataArray. + bag = bag.map(lambda filename: _overview_map( + filename, band, overview_crs, overview_shape, overview_transform, transforms)) + + # Combine xr.DataArrays using elementwise maximum taking into account nans. + bag = bag.fold(_overview_combine) + + overview = bag.compute() + + # Remove attrs that can cause problem serializing xarrays. + for key in ["grid_mapping"]: + if key in overview.attrs: + del overview.attrs[key] + + # Save overview as geotiff. + print(f"Writing overview {overview_filename}", flush=True) + try: + overview.rio.to_raster(overview_filename) + except: # noqa: E722 + if os.path.isfile(overview_filename): + os.remove(overview_filename) + raise