Skip to content

Partition

This module provides partitioning strategies for raster data.

It includes functions for block-based, tile-based, and window-based iteration, as well as a TileStitcher class for reassembling raster tiles into a single raster file on disk.

TileStitcher

Reassembles Raster tiles into a single file on disk.

Acts as a context manager to ensure the output file is closed properly. Checks bounds to ensure tiles are written exactly where they are intended.

Source code in src/phytospatial/raster/partition.py
class TileStitcher:
    """
    Reassembles Raster tiles into a single file on disk.

    Acts as a context manager to ensure the output file is closed properly.
    Checks bounds to ensure tiles are written exactly where they are intended.
    """

    def __init__(
        self,
        output_path: Union[str, Path],
        profile: Dict[str, Any],
        **profile_overrides: Any
        ):
        """
        Open a new raster file for writing.

        Args:
            output_path (Union[str, Path]): Destination path.
            profile (Dict[str, Any]): Rasterio profile (metadata).
            **profile_overrides (Any): Changes to the profile (dtype, compression, etc.).

        Raises:
            IOError: If the file cannot be created/opened.
        """
        self.output_path = Path(output_path)
        self.profile = profile.copy()
        self.profile.update(profile_overrides)

        self.output_path.parent.mkdir(parents=True, exist_ok=True)

        self._dst = None
        self._tiles_written = 0

        try:
            self._dst = rasterio.open(self.output_path, 'w', **self.profile)
        except Exception as e:
            raise IOError(f"Failed to initialize stitcher at {self.output_path}: {e}") from e

    def add_tile(
        self,
        window: Window,
        tile: Raster,
        indexes: Optional[List[int]] = None
        ):
        """
        Write a tile to the output file.

        Args:
            window (Window): The specific window in the output file where this data goes.
            tile (Raster): The Raster object containing the data.
            indexes (Optional[List[int]]): Specific bands to write to.

        Raises:
            ValueError: If the tile shape does not match the window shape.
            RuntimeError: If the stitcher is closed.
        """
        if self._dst is None:
            raise RuntimeError("Attempted to write to a closed TileStitcher.")

        if (tile.height != window.height) or (tile.width != window.width):
            raise ValueError(
                f"Dimension Mismatch: Window is {window.width}x{window.height}, "
                f"but Tile is {tile.width}x{tile.height}."
            )

        try:
            self._dst.write(tile.data, window=window, indexes=indexes)
            self._tiles_written += 1
        except Exception as e:
            raise IOError(f"Failed to write tile to {window}: {e}") from e

    def finalize(self):
        """
        Flush and close the file.
        """
        if self._dst:
            self._dst.close()
            self._dst = None
            log.debug(f"Stitcher closed. Total tiles written: {self._tiles_written}")

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.finalize()

__init__(output_path, profile, **profile_overrides)

Open a new raster file for writing.

Parameters:

Name Type Description Default
output_path Union[str, Path]

Destination path.

required
profile Dict[str, Any]

Rasterio profile (metadata).

required
**profile_overrides Any

Changes to the profile (dtype, compression, etc.).

{}

Raises:

Type Description
IOError

If the file cannot be created/opened.

Source code in src/phytospatial/raster/partition.py
def __init__(
    self,
    output_path: Union[str, Path],
    profile: Dict[str, Any],
    **profile_overrides: Any
    ):
    """
    Open a new raster file for writing.

    Args:
        output_path (Union[str, Path]): Destination path.
        profile (Dict[str, Any]): Rasterio profile (metadata).
        **profile_overrides (Any): Changes to the profile (dtype, compression, etc.).

    Raises:
        IOError: If the file cannot be created/opened.
    """
    self.output_path = Path(output_path)
    self.profile = profile.copy()
    self.profile.update(profile_overrides)

    self.output_path.parent.mkdir(parents=True, exist_ok=True)

    self._dst = None
    self._tiles_written = 0

    try:
        self._dst = rasterio.open(self.output_path, 'w', **self.profile)
    except Exception as e:
        raise IOError(f"Failed to initialize stitcher at {self.output_path}: {e}") from e

add_tile(window, tile, indexes=None)

Write a tile to the output file.

Parameters:

Name Type Description Default
window Window

The specific window in the output file where this data goes.

required
tile Raster

The Raster object containing the data.

required
indexes Optional[List[int]]

Specific bands to write to.

None

Raises:

Type Description
ValueError

If the tile shape does not match the window shape.

RuntimeError

If the stitcher is closed.

Source code in src/phytospatial/raster/partition.py
def add_tile(
    self,
    window: Window,
    tile: Raster,
    indexes: Optional[List[int]] = None
    ):
    """
    Write a tile to the output file.

    Args:
        window (Window): The specific window in the output file where this data goes.
        tile (Raster): The Raster object containing the data.
        indexes (Optional[List[int]]): Specific bands to write to.

    Raises:
        ValueError: If the tile shape does not match the window shape.
        RuntimeError: If the stitcher is closed.
    """
    if self._dst is None:
        raise RuntimeError("Attempted to write to a closed TileStitcher.")

    if (tile.height != window.height) or (tile.width != window.width):
        raise ValueError(
            f"Dimension Mismatch: Window is {window.width}x{window.height}, "
            f"but Tile is {tile.width}x{tile.height}."
        )

    try:
        self._dst.write(tile.data, window=window, indexes=indexes)
        self._tiles_written += 1
    except Exception as e:
        raise IOError(f"Failed to write tile to {window}: {e}") from e

finalize()

Flush and close the file.

Source code in src/phytospatial/raster/partition.py
def finalize(self):
    """
    Flush and close the file.
    """
    if self._dst:
        self._dst.close()
        self._dst = None
        log.debug(f"Stitcher closed. Total tiles written: {self._tiles_written}")

iter_blocks(source, bands=None)

Stream data using the file's native internal block structure.

Parameters:

Name Type Description Default
source Union[str, Path, DatasetReader]

An open rasterio.DatasetReader, or a path to the raster file.

required
bands Optional[Union[int, List[int]]]

Specific band(s) to load (None=all, int=single, list=subset).

None

Yields:

Type Description
Tuple[Window, Raster]

Tuple[Window, Raster]: A window and corresponding Raster object.

Source code in src/phytospatial/raster/partition.py
def iter_blocks(
    source: Union[str, Path, rasterio.DatasetReader],
    bands: Optional[Union[int, List[int]]] = None
    ) -> Iterator[Tuple[Window, Raster]]:
    """
    Stream data using the file's native internal block structure.

    Args:
        source (Union[str, Path, rasterio.DatasetReader]): An open rasterio.DatasetReader, or a path to the raster file.
        bands (Optional[Union[int, List[int]]]): Specific band(s) to load (None=all, int=single, list=subset).

    Yields:
        Tuple[Window, Raster]: A window and corresponding Raster object.
    """
    def _generator(src: rasterio.DatasetReader) -> Iterator[Tuple[Window, Raster]]:
        indices = extract_band_indices(src, bands)
        for _, window in src.block_windows(1):
            data = src.read(indexes=indices, window=window)
            tile_transform = src.window_transform(window)
            band_names = extract_band_names(src, indices)

            yield window, Raster(
                data=data,
                transform=tile_transform,
                crs=src.crs,
                nodata=src.nodata,
                band_names=band_names
            )

    if isinstance(source, (str, Path)):
        path = resolve_envi_path(Path(source))
        if not path.exists():
            raise FileNotFoundError(f"Source file not found: {path}")

        try:
            with rasterio.open(path) as src:
                yield from _generator(src)
        except rasterio.RasterioIOError as e:
            raise IOError(f"Block iteration failed for {path}: {e}") from e
    else:
        # If an open dataset is passed directly, bypass the context manager
        yield from _generator(source)

iter_core_halo(source, tile_mode='auto', tile_size=1024, overlap=64)

Streams spatial data using a Core-Halo architecture. Provides overlapping read buffers to prevent boundary truncation, while supplying a strict core bounding box for geometry deduplication.

Source code in src/phytospatial/raster/partition.py
def iter_core_halo(
    source: Union[str, Path, Raster],
    tile_mode: str = "auto",
    tile_size: int = 1024,
    overlap: int = 64
    ) -> Iterator[Tuple[np.ndarray, rasterio.Affine, Optional[box], Optional[box]]]:
    """
    Streams spatial data using a Core-Halo architecture.
    Provides overlapping read buffers to prevent boundary truncation, while supplying 
    a strict core bounding box for geometry deduplication.
    """
    if isinstance(source, Raster):
        mode = ProcessingMode.IN_MEMORY
    else:
        path = ensure_tiled_raster(source, block_size=tile_size)
        report = determine_strategy(path, user_mode=tile_mode)
        mode = report.mode

    if mode == ProcessingMode.IN_MEMORY:
        if isinstance(source, Raster):
            yield source.data[0], source.transform, None, None
        else:
            full_raster = load(path)
            yield full_raster.data[0], full_raster.transform, None, None
    else:
        with rasterio.open(path) as src:
            for row_off in range(0, src.height, tile_size):
                for col_off in range(0, src.width, tile_size):
                    core_window = Window(col_off, row_off, min(tile_size, src.width - col_off), min(tile_size, src.height - row_off))

                    read_col_off = max(0, col_off - overlap)
                    read_row_off = max(0, row_off - overlap)
                    read_width = min(src.width - read_col_off, core_window.width + overlap + (col_off - read_col_off))
                    read_height = min(src.height - read_row_off, core_window.height + overlap + (row_off - read_row_off))

                    read_window = Window(read_col_off, read_row_off, read_width, read_height)

                    data = src.read(1, window=read_window)
                    transform = src.window_transform(read_window)

                    core_min_x, core_max_y = src.window_transform(core_window) * (0, 0)
                    core_max_x, core_min_y = src.window_transform(core_window) * (core_window.width, core_window.height)
                    core_box = box(min(core_min_x, core_max_x), min(core_min_y, core_max_y), max(core_min_x, core_max_x), max(core_min_y, core_max_y))
                    read_box = box(*src.window_bounds(read_window))

                    yield data, transform, core_box, read_box

iter_tiles(source, tile_size=512, overlap=0, bands=None)

Stream data using a virtual grid of fixed-size tiles.

Parameters:

Name Type Description Default
source Union[str, Path, DatasetReader]

An open rasterio.DatasetReader, or a path to the raster file.

required
tile_size Union[int, Tuple[int, int]]

Dimensions (width, height) or single int for square tiles.

512
overlap int

Pixels of overlap between tiles.

0
bands Optional[Union[int, List[int]]]

Specific band(s) to load (None=all, int=single, list=subset).

None

Yields:

Type Description
Tuple[Window, Raster]

Tuple[Window, Raster]: A window and corresponding Raster object.

Source code in src/phytospatial/raster/partition.py
def iter_tiles(
    source: Union[str, Path, rasterio.DatasetReader],
    tile_size: Union[int, Tuple[int, int]] = 512,
    overlap: int = 0,
    bands: Optional[Union[int, List[int]]] = None
    ) -> Iterator[Tuple[Window, Raster]]:
    """
    Stream data using a virtual grid of fixed-size tiles.

    Args:
        source (Union[str, Path, rasterio.DatasetReader]): An open rasterio.DatasetReader, or a path to the raster file.
        tile_size (Union[int, Tuple[int, int]]): Dimensions (width, height) or single int for square tiles.
        overlap (int): Pixels of overlap between tiles.
        bands (Optional[Union[int, List[int]]]): Specific band(s) to load (None=all, int=single, list=subset).

    Yields:
        Tuple[Window, Raster]: A window and corresponding Raster object.
    """
    if isinstance(tile_size, int):
        t_width, t_height = tile_size, tile_size
    else:
        t_width, t_height = tile_size

    if overlap >= min(t_width, t_height):
        raise ValueError(f"Overlap ({overlap}) must be smaller than tile dimensions")

    step_w = t_width - overlap
    step_h = t_height - overlap

    def _generator(src: rasterio.DatasetReader) -> Iterator[Tuple[Window, Raster]]:
        indices = extract_band_indices(src, bands)

        for row_off in range(0, src.height, step_h):
            for col_off in range(0, src.width, step_w):
                width = min(t_width, src.width - col_off)
                height = min(t_height, src.height - row_off)

                window = Window(col_off, row_off, width, height)

                data = src.read(indexes=indices, window=window)
                tile_transform = src.window_transform(window)
                band_names = extract_band_names(src, indices)

                yield window, Raster(
                    data=data,
                    transform=tile_transform,
                    crs=src.crs,
                    nodata=src.nodata,
                    band_names=band_names
                )

    if isinstance(source, (str, Path)):
        path = resolve_envi_path(Path(source))
        if not path.exists():
            raise FileNotFoundError(f"Source file not found: {path}")

        try:
            with rasterio.open(path) as src:
                yield from _generator(src)
        except rasterio.RasterioIOError as e:
            raise IOError(f"Tile iteration failed for {path}: {e}") from e
    else:
        # If an open dataset is passed directly, bypass the context manager
        yield from _generator(source)

iter_windows(raster, tile_size=512, overlap=0)

Partition an in-memory Raster object into smaller Raster tiles.

Useful for batch processing a loaded raster, notably for neural networks.

Parameters:

Name Type Description Default
raster Raster

The source Raster object (already in memory).

required
tile_size Union[int, Tuple[int, int]]

Dimensions (width, height) or single int for square tiles.

512
overlap int

Pixels of overlap.

0

Yields:

Type Description
Tuple[Window, Raster]

Tuple[Window, Raster]: A deep copy of the sliced data as a new Raster.

Source code in src/phytospatial/raster/partition.py
def iter_windows(
    raster: Raster,
    tile_size: Union[int, Tuple[int, int]] = 512,
    overlap: int = 0
    ) -> Iterator[Tuple[Window, Raster]]:
    """
    Partition an in-memory Raster object into smaller Raster tiles.

    Useful for batch processing a loaded raster, notably for neural networks.

    Args:
        raster (Raster): The source Raster object (already in memory).
        tile_size (Union[int, Tuple[int, int]]): Dimensions (width, height) or single int for square tiles.
        overlap (int): Pixels of overlap.

    Yields:
        Tuple[Window, Raster]: A deep copy of the sliced data as a new Raster.
    """
    if isinstance(tile_size, int):
        t_width, t_height = tile_size, tile_size
    else:
        t_width, t_height = tile_size

    if overlap >= min(t_width, t_height):
        raise ValueError(f"Overlap ({overlap}) must be smaller than tile dimensions")

    step_w = t_width - overlap
    step_h = t_height - overlap

    for row_off in range(0, raster.height, step_h):
        for col_off in range(0, raster.width, step_w):

            width = min(t_width, raster.width - col_off)
            height = min(t_height, raster.height - row_off)

            window = Window(
                col_off=col_off,
                row_off=row_off,
                width=width,
                height=height
            )

            tile_data = raster.data[
                :, 
                row_off : row_off + height, 
                col_off : col_off + width
            ].copy() 

            tile_transform = compute_window_transform(window, raster.transform)

            tile_raster = Raster(
                data=tile_data,
                transform=tile_transform,
                crs=raster.crs,
                nodata=raster.nodata,
                band_names=raster.band_names.copy()
            )

            yield window, tile_raster