Skip to content
Snippets Groups Projects
pyswiftsim_tools.c 4.09 KiB
#include "pyswiftsim_tools.h"

/* include swift */
#include <part.h>
#include <units.h>
#include <parser.h>
#include <physical_constants.h>
#include <cooling_struct.h>

#include <Python.h>
#include <numpy/arrayobject.h>


const size_t class_size[class_count] = {
  sizeof(struct unit_system),
  sizeof(struct part),
  sizeof(struct swift_params),
  sizeof(struct phys_const),
  sizeof(struct cooling_function_data)
};
  
const char *class_name[class_count] = {
  "UnitSystem",
  "Part",
  "SwiftParams",
  "PhysConst",
  "CoolingFunctionData"
};


PyObject* pytools_import(char* module_name, char* object_name)
{
  /* load module */
  PyObject *module;
  
  module = PyImport_ImportModule(module_name);

  if (module == NULL)
    {
      pyerror("Failed to import module '%s'.", module_name);
    }

  /* get module dictionary */
  PyObject *dict;

  dict = PyModule_GetDict(module);
  Py_DECREF(module);

  if (dict == NULL)
    {
      pyerror("Failed to get module '%s' dictionary", module_name);
    }

  /* get right class */
  PyObject *python_obj = PyDict_GetItemString(dict, object_name);
  Py_DECREF(dict);

  if (python_obj == NULL)
    pyerror("Object %s does not exist in module %s", object_name, module_name);

  return python_obj;
}


PyObject* pytools_return(void *p, int class)
{

  PyObject *python_class;
  size_t nber_bytes;
  char module_name[STRING_SIZE] = "pyswiftsim.structure";
  char *class_pyname;

  /* check class */
  if (class >= class_count)
    pyerror("Class %i does not exists", class);

  /* get class information */
  nber_bytes = class_size[class];
  class_pyname = class_name[class];

  /* import python class */
  python_class = pytools_import(module_name, class_pyname);

  if (python_class == NULL)
    return NULL;
      
  if (!PyCallable_Check(python_class))
    {
      Py_DECREF(python_class);
      pyerror("Unable to import class %s from %s", class_pyname, module_name);
    }

  /* create object */
  PyObject *object, *args;
  
  args = PyTuple_Pack(1, PyBytes_FromStringAndSize((char *) p, nber_bytes));

  object = PyObject_CallObject(python_class, args);

  Py_DECREF(args);
  Py_DECREF(python_class);

  return object;
  
}

char* pytools_get_type_name(PyObject *obj)
{
  /* get object type */
  PyObject *type = PyObject_Type(obj);
  if (type == NULL)
    {
      Py_DECREF(type);
      pyerror("Unable to get type");
    }

  /* get object name */
  PyObject* recv = PyObject_Str(type);
  Py_DECREF(type);

  if (recv == NULL)
    {
      Py_DECREF(recv);
      pyerror("Unable to get string representation");
    }

  /* transform to C */
  size_t size;
  char *name = PyUnicode_AsUTF8AndSize(recv, size);
  Py_DECREF(recv);

  if (name == NULL)
    {
      pyerror("Unable to convert string to char");
    }

  return name;
}

char* pytools_construct(PyObject* obj, int class)
{
  char *module_name = "pyswiftsim.structure";
  char *class_pyname;

  /* check python class */
  if (class >= class_count)
    pyerror("Class %i does not exists", class);

  /* get class information */
  class_pyname = class_name[class];

  /* import class */
  PyObject *pyclass = pytools_import(module_name, class_pyname);

  /* check if classes correspond */
  int test = !PyObject_IsInstance(obj, pyclass);
  Py_DECREF(pyclass);
  if (test)
    {
      char *recv = pytools_get_type_name(obj);
      if (recv == NULL)
	return NULL;
      pyerror("Expecting class %s, received %s", class_pyname, recv);
    }


  /* copy python class' data to C */
  PyObject* data = PyObject_GetAttrString(obj, "data");

  if (data == NULL)
    pyerror("Unable to get the attribute 'data'");

  char *ret = PyBytes_AsString(data);

  Py_DECREF(data);
  return ret;
}


int pytools_check_array(PyArrayObject *obj, int dim, int type)
{
  /* ensure to have numpy arrays */
  IMPORT_ARRAY();

  /* check if array */
  if (!PyArray_Check(obj))
    {
      pyerror("Expecting a numpy array");
    }

  /* check if required dim */
  if (PyArray_NDIM(obj) != dim)
    {
      pyerror("Array should be a %i dimensional object", dim);
    }

  /* check data type */
  if (PyArray_TYPE(obj) != type)
    {
      pyerror("Wrong array type");
    }

  return SUCCESS;

}