Source code for nexgen.tools.vds_tools

"""
Tools to write Virtual DataSets
"""

from __future__ import annotations

import logging
import operator
from dataclasses import dataclass
from functools import reduce
from pathlib import Path
from typing import List, Tuple

import h5py
import numpy as np
from numpy.typing import DTypeLike

from ..utils import MAX_FRAMES_PER_DATASET
from .constants import jungfrau_fill_value, jungfrau_gap_size, jungfrau_mod_size

vds_logger = logging.getLogger("nexgen.VDSWriter")


[docs] @dataclass class Dataset: name: str # The full shape of the source, regardless of start index source_shape: Tuple[int] # The start index that we should start copying from start_index: int = 0 # The point where we should stop copying. Defaults to maximum for an Eiger stop_index: int = MAX_FRAMES_PER_DATASET # The shape of the destination, including the start_index dest_shape: Tuple[int] = None def __post_init__(self): self.dest_shape = ( self.source_shape[0] - self.start_index, *self.source_shape[1:], ) if ( self.stop_index < MAX_FRAMES_PER_DATASET or self.stop_index < self.source_shape[0] ): self.dest_shape = ( self.dest_shape[0] - (self.source_shape[0] - self.stop_index), *self.source_shape[1:], ) def __add__(self, x): """Returns a dataset that has the same start index and shape as if the two were appended to each other.""" return Dataset( "", source_shape=( self.source_shape[0] + x.source_shape[0], *self.source_shape[1:], ), start_index=self.start_index + x.start_index, stop_index=self.stop_index + x.stop_index, )
[docs] def find_datasets_in_file(nxdata: h5py.Group) -> List: """ Look for the source datasets in the NeXus file. Assumes that the source datasets are always h5py.ExternalLink. Args: nxdata (h5py.Group): Group where the data should be linked. Raises: KeyError: If no ExternalLinks to data are found in the group. Returns: dsets (List): The source datasets. """ # FIXME for now this assumes that the source datasets are always links dsets = [] for k in nxdata.keys(): if isinstance(nxdata.get(k, getlink=True), h5py.ExternalLink): dsets.append(k) if not dsets: raise KeyError( f"No External Link datasets found in NeXus file under {nxdata.name}" ) return dsets
[docs] def split_datasets( dsets, data_shape: Tuple[int, int, int], start_idx: int = 0, vds_shape: Tuple[int, int, int] = None, ) -> List[Dataset]: """ Splits the full data shape and start index up into values per dataset, given that each dataset has a maximum size. Args: dsets (Dataset): The input datasets. data_shape (Tuple[int, int, int]): Shape of the data, usually defined as (num_frames, *image_size). start_idx (int, optional): The start point for the source data. Defaults to 0. vds_shape(Tuple, optional): Desired shape of the VDS, usually defined as (num_frames, *image_size). \ The number of frames must be smaller or equal to the one in full_data_shape. Defaults to None. Raises: ValueError: If the passed start index value is higher than the dataset lenght. ValueError: It the passed start index value is negative. Returns: List[Dataset]: A list of datasets. """ if start_idx > data_shape[0]: raise ValueError( f"Start index {start_idx} must be less than full dataset length {data_shape[0]}" ) if start_idx < 0: raise ValueError("Start index must be positive") if not isinstance(data_shape[0], int): vds_logger.warning("Datashape not passed as int, will attempt to cast") if not isinstance(start_idx, int): vds_logger.warning("VDS start index not passed as int, will attempt to cast") if vds_shape and not isinstance(vds_shape[0], int): vds_logger.warning("VDS start index not passed as int, will attempt to cast") if vds_shape is None: vds_logger.debug( "VDS shape not chosen, it will be calculated from the full data shape and the chosen start index." ) vds_shape = (data_shape[0] - start_idx, *data_shape[1:]) full_frames = int(data_shape[0]) end_cut_frames = int(full_frames - vds_shape[0]) - int(start_idx) result = [] for dset_name in dsets: dset = Dataset( name=dset_name, source_shape=(min(MAX_FRAMES_PER_DATASET, full_frames), *data_shape[1:]), start_index=min(MAX_FRAMES_PER_DATASET, max(int(start_idx), 0)), stop_index=min(MAX_FRAMES_PER_DATASET, (full_frames - end_cut_frames)), ) # if start index == 1000 then that source dataset is not used and we should # not pass it on to use as a source for the VDS # Same goes if all datasets from last files are not used if dset.start_index != MAX_FRAMES_PER_DATASET: if dset.stop_index > 0 and dset.stop_index <= MAX_FRAMES_PER_DATASET: result.append(dset) start_idx -= MAX_FRAMES_PER_DATASET full_frames -= MAX_FRAMES_PER_DATASET return result
[docs] def create_virtual_layout(datasets: List[Dataset], data_type: DTypeLike): """ Create a virtual layout and populate it based on the provided data. Args: datasets (List[Dataset]): A list of datasets that are to be merged. data_type (DTypeLike): The type of the input data. Returns: layout (h5py.VirtualLayout): Virtual layout. """ full_dataset: Dataset = reduce(operator.add, datasets) layout = h5py.VirtualLayout(shape=full_dataset.dest_shape, dtype=data_type) dest_start = 0 for dataset in datasets: if dataset.stop_index == dataset.source_shape[0]: dest_end = dest_start + dataset.source_shape[0] - dataset.start_index else: dest_end = dest_start + dataset.dest_shape[0] vsource = h5py.VirtualSource( ".", "/entry/data/" + dataset.name, shape=dataset.source_shape ) layout[dest_start:dest_end, :, :] = vsource[ dataset.start_index : dataset.stop_index, :, : ] dest_start = dest_end return layout
[docs] def image_vds_writer( nxsfile: h5py.File, full_data_shape: Tuple | List, start_index: int = 0, vds_shape: Tuple | List | None = None, data_type: DTypeLike = np.uint16, entry_key: str = "data", ): """ Virtual DataSet writer function for image data. Args: nxsfile (h5py.File): Handle to NeXus file being written. full_data_shape (Tuple | List): Shape of the full dataset, usually defined as (num_frames, *image_size). start_index(int): The start point for the source data. Defaults to 0. vds_shape(Tuple, optional): Desired shape of the VDS, usually defined as (num_frames, *image_size). \ The number of frames must be smaller or equal to the one in full_data_shape. Defaults to None. data_type (DTypeLike, optional): The type of the input data. Defaults to np.uint16. entry_key (str, optional): Entry key for the Virtual DataSet name. Defaults to data. """ vds_logger.debug("Start creating VDS ...") # Where the vds will go nxdata = nxsfile["/entry/data"] dset_names = find_datasets_in_file(nxdata) vds_shape = ( tuple(vds_shape) if vds_shape is not None else (full_data_shape[0] - start_index, *full_data_shape[1:]) ) # Hack for datasets with no maximum number of frames (eg. Singla) if len(dset_names) == 1 and full_data_shape[0] > MAX_FRAMES_PER_DATASET: # nxdata[dset_names[0]].shape[0] > MAX_FRAMES_PER_DATASET # .maxshape[0] is None datasets = [ Dataset( name=dset_names[0], source_shape=full_data_shape, start_index=start_index, stop_index=full_data_shape[0], ) ] else: datasets = split_datasets(dset_names, full_data_shape, start_index, vds_shape) layout = create_virtual_layout(datasets, data_type) # Writea Virtual Dataset in NeXus file nxdata.create_virtual_dataset(entry_key, layout, fillvalue=-1) vds_logger.debug("VDS correctly written to NeXus file.")
[docs] def jungfrau_vds_writer( nxsfile: h5py.File, vds_shape: Tuple | List, data_type: DTypeLike = np.uint16, source_dsets: List[str] | None = None, ): """Write VDS for Jungfrau 1M use case, with a tiled layout.""" external_dsets = True entry_key = "data" frames = vds_shape[0] nxdata = nxsfile["/entry/data"] if not source_dsets: source_dsets = find_datasets_in_file(nxdata) external_dsets = False sources = [] for dset in source_dsets: source_path = dset if external_dsets is True else "." source_name = entry_key if external_dsets is True else f"/entry/data/{dset}" source = h5py.VirtualSource( source_path, source_name, shape=(frames, *jungfrau_mod_size) ) sources.append(source) layout = h5py.VirtualLayout(shape=vds_shape, dtype=data_type) # The first one is the upper one s0 = jungfrau_mod_size[0] + jungfrau_gap_size[0] layout[:, : jungfrau_mod_size[0], :] = sources[1][:, :, :] layout[:, s0:, :] = sources[0][:, :, :] nxdata.create_virtual_dataset(entry_key, layout, fillvalue=jungfrau_fill_value)
[docs] def vds_file_writer( nxsfile: h5py.File, datafiles: List[Path], data_shape: Tuple | List, data_type: DTypeLike = np.uint16, entry_key: str = "data", ): """ Write a Virtual DataSet _vds.h5 file for image data. Args: nxsfile (h5py.File): NeXus file being written. datafiles (List[Path]): List of paths to source files. data_shape (Tuple | List): Shape of the dataset, usually defined as (num_frames, *image_size). data_type (DTypeLike, optional): Dtype. Defaults to np.uint16. entry_key (str): Entry key for the Virtual DataSet name. Defaults to data. """ vds_logger.debug("Start creating VDS file ...") # Where the vds will go nxdata = nxsfile["/entry/data"] # entry_key = "data" # For every source dataset define its shape and number of frames # Once again, it is assumed that the maximum number of frames per dataset is 1000 frames = (data_shape[0] // 1000) * [1000] + [data_shape[0] % 1000] sshape = [(f, *data_shape[1:]) for f in frames] # Create virtual layout layout = h5py.VirtualLayout(shape=data_shape, dtype=data_type) start = 0 for n, filename in enumerate(datafiles): end = start + frames[n] vsource = h5py.VirtualSource( filename.name, entry_key, shape=sshape[n] ) # Source definition layout[start:end:1, :, :] = vsource start = end # Create a _vds.h5 file and add link to nexus file s = Path(nxsfile.filename).expanduser().resolve() vds_filename = s.parent / f"{s.stem}_vds.h5" del s with h5py.File(vds_filename, "w") as vds: vds.create_virtual_dataset("data", layout, fillvalue=-1) nxdata["data"] = h5py.ExternalLink(vds_filename.name, "data") vds_logger.debug(f"{vds_filename} written and link added to NeXus file.")