Skip to content
Snippets Groups Projects
structure.py 7.66 KiB
import struct
import numpy
from ctypes import *


######################################################################
#                                                                    #
#                         SwiftStruct                                #
#                                                                    #
######################################################################
class SwiftStruct(struct.Struct):
    """
    Abstract class constructing a wrapper around a C structure

    Parameters
    ----------

    struct_format: str
        Text defining the data type in the structure (see struct module for
        more informations)

    data: str, list
        Bytes string containing all the data of a structure or
        dictionary of attribute (all the attribute should be present)

    parent: :class:`SwiftStruct`
        Parent containing the data (usefull for :class:`ArrayStruct`)
    """
    def __init__(self, struct_format, data, parent):
        super().__init__(struct_format)
        self.parent = parent # parent for ArrayStruct

        if isinstance(data, bytes):
            self.data = data # bytes string data
        elif isinstance(data, dict):
            tmp = []
            for i in self.struct_name:
                tmp.append(data[i])

            self.data = self.pack(*tmp)
        else:
            raise ValueError("Data should be either bytes or dict, "
                             "received ", type(data))
        
    @property
    def struct_name(self):
        """
        List of the structure attribute names.

        Should be implemented for each sub class
        """
        raise NotImplementedError("SwiftStruct should not be used")

    @property
    def struct_format(self):
        """
        String containing the data type of each attribute
        """
        raise NotImplementedError("SwiftStruct should not be used")


    def _getInfoFromName(self, name):
        """
        Compute index, format and size from an attribute name.

        Parameters
        ----------

        name: str
            Name of an attribute

        Returns
        -------

        i: int, slice
            index of the attribute

        form: char
            data type (see module struct for more information)

        n: int
            number of element in attribute
        """
        struct_size, form = self.struct_size_format
        i = self.struct_name.index(name)
        form = form[i]
        n = struct_size[i]
        i = sum(struct_size[:i])

        if n > 1:
            i = slice(i, i+n)
        
        return i, form, n
        

    @property
    def struct_size_format(self):
        """
        Compute size and format for each field

        Returns
        -------

        out_nber: list
            number of element for each attribute
        
        out_form: list
            format for each attribute
        """
        out_nber = []
        out_form = []
        form = self.struct_format

        N = len(form)
        ii = 0
        while ii < N:
            v = form[ii]
            if v.isdigit():
                out_nber.append(int(v))
                out_form.append(form[ii+1])
                # next value is the type => skip
                ii += 1
            else:
                out_nber.append(1)
                out_form.append(v)
            ii += 1

        return out_nber, out_form


    def __str__(self):
        txt = "%s:\n" % type(self)
        print(len(self.data))
        data = self.unpack(self.data)
        for name in self.struct_name:
            i, form, n = self._getInfoFromName(name)
            d = data[i]
            txt += "\t%s: %s\n" % (name, d)

        return txt

    def __getattr__(self, name):
        # case where the attribute is not in the structure
        if name not in self.struct_name:
            return object.__getattribute__(self, name)
        
        # case where the attribute is in the structure
        else:
            i, form, n = self._getInfoFromName(name)

            data = self.unpack(self.data)
            if n == 1:
                return data[i]

            else:
                # transform scalar -> vector
                nform = str(n) + form
                i = slice(i.start, n)
                data = data[i]

                # compress data and create return struct
                data = struct.pack(nform, *data)
                return ArrayStruct(nform, data, self, name)

        
    def __setattr__(self, name, value):
        # case where the attribute is not in the structure
        if name not in self.struct_name:
            object.__setattr__(self, name, value)
        # case where the attribute is in the structure
        else:
            data = list(self.unpack(self.data))

            i, form, n = self._getInfoFromName(name)

            if isinstance(value, ArrayStruct):
                value = value.getArray()

            data[i] = value
            self.data = self.pack(*data)


######################################################################
#                                                                    #
#                         ArrayStruct                                #
#                                                                    #
######################################################################
class ArrayStruct(SwiftStruct):
    def __init__(self, struct_format, data, parent, name):
        super().__init__(struct_format, data, parent)
        self._name = name
        self._format = struct_format

    def __getitem__(self, ii):
        data = self.unpack(self.data)
        return data[ii]
    
    def __setitem__(self, ii, value):
        data = list(self.unpack(self.data))
        data[ii] = value
        setattr(self.parent, self._name, data)
        
        
    def getArray(self):
        return self.unpack(self.data)

    @property
    def struct_name(self):
        return [
            "array_data"
        ]

    @property
    def struct_format(self):
        return self._format

    
######################################################################
#                                                                    #
#                         UnitSystem                                 #
#                                                                    #
######################################################################
class UnitSystem(SwiftStruct):

    def __init__(self, data, parent=None):
        super().__init__(self.struct_format, data, parent)

    @property
    def struct_format(self):
        return "ddddd"

    @property
    def struct_name(self):
        return  [
            "UnitMass_in_cgs",
            "UnitLength_in_cgs",
            "UnitTime_in_cgs",
            "UnitCurrent_in_cgs",
            "UnitTemperature_in_cgs",
        ]


######################################################################
#                                                                    #
#                          Part                                      #
#                                                                    #
######################################################################
class Part(SwiftStruct):
    def __init__(self, data, parent=None):
        super().__init__(self.struct_format, data, parent)

    @property
    def struct_format(self):
        print("ERROR, need to fix density/time_bin")
        return "qP3d3f3ffffffN7f4c"


    @property
    def struct_name(self):
        return  [
            "id",
            "gpart",
            "pos",
            "vel",
            "a_hydro",
            "h",
            "mass",
            "rho",
            "entropy",
            "entropy_dt",
            "last_offset",
            "density",
            "time_bin"
        ]