from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union
import metatensor.torch as mts
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import ModelCapabilities, System
[docs]
class Writer(ABC):
    def __init__(
        self,
        filename: Union[str, Path],
        capabilities: Optional[ModelCapabilities] = None,
        append: Optional[bool] = None,
    ):
        self.filename = filename
        self.capabilities = capabilities
        self.append = append
[docs]
    @abstractmethod
    def write(self, systems: List[System], predictions: Dict[str, TensorMap]) -> None:
        """
        Write a single system and its predictions.
        :param systems: List of systems to write.
        :param predictions: Dictionary of TensorMaps with predictions for the systems.
        """
        ... 
[docs]
    @abstractmethod
    def finish(self) -> None:
        """Called after all writes. Optional to override."""
        ... 
 
def _split_tensormaps(
    systems: List[System],
    batch_predictions: Dict[str, TensorMap],
    istart_system: Optional[int] = 0,
) -> List[Dict[str, TensorMap]]:
    """
    Split a TensorMap into multiple TensorMaps, one for each key.
    :param systems: List of systems in the batch.
    :param batch_predictions: Dictionary of TensorMaps with batch predictions.
    :param istart_system: Starting index for the systems.
    :return: List of dictionaries, each containing TensorMaps for a single system.
    """
    device = next(iter(batch_predictions.values()))[0].values.device
    split_selection = [
        Labels("system", torch.tensor([[i]], device=device))
        for i in range(len(systems))
    ]
    batch_predictions_split = {
        key: mts.split(tensormap, "samples", split_selection)
        for key, tensormap in batch_predictions.items()
    }
    out_tensormaps: List[Dict[str, TensorMap]] = []
    for i in range(len(systems)):
        # build a per-sample dict
        tensormaps: Dict[str, TensorMap] = {}
        for k in batch_predictions_split.keys():
            new_blocks: List[TensorBlock] = []
            for block in batch_predictions_split[k][i]:
                new_block = TensorBlock(
                    samples=Labels(
                        block.samples.names,
                        block.samples.values
                        + istart_system
                        * torch.eye(
                            block.samples.values.size(-1),
                            device=block.samples.values.device,
                            dtype=block.samples.values.dtype,
                        )[0],
                        assume_unique=True,
                    ),
                    components=block.components,
                    properties=block.properties,
                    values=block.values,
                )
                for gradient_name, gradient_block in block.gradients():
                    new_block.add_gradient(
                        gradient_name,
                        TensorBlock(
                            samples=Labels(
                                gradient_block.samples.names,
                                gradient_block.samples.values
                                + istart_system
                                * torch.eye(
                                    gradient_block.samples.values.size(-1),
                                    device=gradient_block.samples.values.device,
                                    dtype=gradient_block.samples.values.dtype,
                                )[0],
                                assume_unique=True,
                            ),
                            components=gradient_block.components,
                            properties=gradient_block.properties,
                            values=gradient_block.values,
                        ),
                    )
                new_blocks.append(new_block)
            tensormaps[k] = TensorMap(
                keys=batch_predictions_split[k][i].keys,
                blocks=new_blocks,
            )
        out_tensormaps.append(tensormaps)
    return out_tensormaps