
from ctypes import *
import numpy as np
import traceback
import os
import sys
import platform
from PyBta.PyStatus import PyBtaStatus
from PyBta.Bta import BtaConfig
from PyBta.Bta import BtaFrame
from PyBta.Bta import InfoEventCb_Type
from PyBta.Frame import PyBtaChannelId
from PyBta.Frame import PyBtaDataFormat
from PyBta.PyFrame import PyBtaFrame
from PyBta.PyFrame import PyBtaChannel

class BtaException(Exception):
    def __init__(self, msg, status):
        # Call the base class constructor with the parameters it needs
        super(BtaException, self).__init__(msg)
        # Custom part: status
        self.status = status


class BtaWrapper(object):
    dll = None
    btaHandle = None

    @classmethod
    def throw(cls, msg, status):
        statusStr = create_string_buffer(b"", 30)
        cls.dll.BTAstatusToString(status, statusStr, len(statusStr))
        raise BtaException(statusStr.value.decode("utf-8") + ": " + msg, status)


    @classmethod
    def init(cls):
        if cls.dll is not None:
            cls.throw("init was already called!", PyBtaStatus.RuntimeError)

        # remember path, then later chdir back to it
        opsys = sys.platform
        if opsys.startswith('linux'):
            file = os.path.dirname(__file__) + "/../../lib/libbta.so"
            if not os.path.isfile(file):
                raise RuntimeError(file + " not found!")
            try:
                cls.dll = CDLL(file)
            except:
                raise RuntimeError(file + " or a dependent could not be loaded!")
        elif opsys.startswith('win32'):
            cwd = os.getcwd()
            plat = "x86" if platform.architecture()[0].startswith("32") else "x64"
            path = os.path.dirname(__file__) + "\\..\\..\\lib\\Win_" + plat
            #path += "_debug"
            file = "BltTofApi.dll"
            os.chdir(path)
            if not os.path.isfile(file):
                raise RuntimeError(file + " not found!")
            try:
                cls.dll = CDLL(file)
            except:
                raise RuntimeError(file + " or a dependent could not be loaded!")

            os.chdir(cwd)
        else:
            raise RuntimeError("Platform " + opsys + " not supported!")


        verMaj = c_uint()
        verMin = c_uint()
        verNonFun = c_uint()
        buildDateTime = create_string_buffer(b"", 128)
        status = cls.dll.BTAgetVersion(byref(verMaj), byref(verMin), byref(verNonFun), buildDateTime, len(buildDateTime), 0, 0)
        if status == PyBtaStatus.Ok:
            return "Loaded BltTofApi v{0}.{1}.{2} built on {3}".format(verMaj.value, verMin.value, verNonFun.value, buildDateTime.value.decode("utf-8"))
        else:
            cls.throw("BTAgetVersion failed!", status)



    @classmethod
    def openBltstream(cls, bltstreamFilename, queueLength, queueMode, infoEventHandler, verbosity):
        if cls.btaHandle is not None:
            raise RuntimeError("Open was already called! Call close before opening another connection!")

        config = BtaConfig()
        cls.dll.BTAinitConfig(byref(config))
        config.deviceType.uint32 = 15

        config.bltstreamFilename.p_char = c_char_p(bltstreamFilename.encode('utf-8'))

        config.frameQueueLength.uint16 = queueLength
        config.frameQueueMode.uint32 = queueMode

        # sometimes causes access violation bug
        #if infoEventHandler is not None:
        #    config.infoEvent.info_event_cb = InfoEventCb_Type(infoEventHandler)
        #    config.verbosity.uint8 = verbosity

        # if in need for debug output, uncomment this:
        #config.verbosity.uint8 = verbosity
        #config.infoEventFilename.p_char = c_char_p("infoEvents.txt".encode('utf-8'))

        cls.btaHandle = c_void_p()
        status = cls.dll.BTAopen(byref(config), byref(cls.btaHandle))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAopen failed!", status)


    @classmethod
    def openEth(cls, udpDataIpAddr, udpDataPort, deviceIpAddr, udpControlPort, tcpControlPort, queueLength, queueMode, infoEventHandler, verbosity):
        if cls.btaHandle is not None:
            raise RuntimeError("Open was already called! Call close before opening another connection!")

        config = BtaConfig()
        cls.dll.BTAinitConfig(byref(config))
        config.deviceType.uint32 = 1

        if udpDataIpAddr is not None and udpDataPort is not None:
            if len(udpDataIpAddr.split('.')) != 4:
                raise RuntimeError("udpDataIpAddr is not in the format ###.###.###.###!")
            config.udpDataIpAddr.p_uint8 = (c_uint8 * 4)(*list(map(int, udpDataIpAddr.split('.'))))
            config.udpDataIpAddrLen.uint8 = 4
            config.udpDataPort.uint16 = udpDataPort
        elif udpDataIpAddr is not None or udpDataPort is not None:
             raise RuntimeError("udpDataIpAddr is given, but noot udpDataPort")

        if deviceIpAddr is not None:
             if udpControlPort is not None and tcpControlPort is not None:
                 raise RuntimeError("Please use at most one of udpControlPort and tcpControlPort")
             if udpControlPort is None and tcpControlPort is None:
                 raise RuntimeError("deviceIpAddr is given, but none of udpControlPort and tcpControlPort")
             if tcpControlPort is not None:
                 if len(deviceIpAddr.split('.')) != 4:
                     raise RuntimeError("deviceIpAddr is not in the format ###.###.###.###!")
                 config.tcpDeviceIpAddr.p_uint8 = (c_uint8 * 4)(*list(map(int, deviceIpAddr.split('.'))))
                 config.tcpDeviceIpAddrLen.uint8 = 4
                 config.tcpControlPort.uint16 = tcpControlPort
             if udpControlPort is not None:
                if len(deviceIpAddr.split('.')) != 4:
                    raise RuntimeError("deviceIpAddr is not in the format ###.###.###.###!")
                config.udpControlOutIpAddr.p_uint8 = (c_uint8 * 4)(*list(map(int, deviceIpAddr.split('.'))))
                config.udpControlOutIpAddrLen.uint8 = 4
                config.udpControlPort.uint16 = udpControlPort
        elif udpControlPort is not None or tcpControlPort is not None:
                 raise RuntimeError("With udpControlPort and tcpControlPort always also specify tcpDeviceIpAddr")

        config.frameQueueLength.uint16 = queueLength
        config.frameQueueMode.uint32 = queueMode

        # sometimes causes access violation bug
        #if infoEventHandler is not None:
        #    config.infoEvent.info_event_cb = InfoEventCb_Type(infoEventHandler)
        #    config.verbosity.uint8 = verbosity

        # if in need for debug output, uncomment this:
        #config.verbosity.uint8 = verbosity
        #config.infoEventFilename.p_char = c_char_p("infoEvents.txt".encode('utf-8'))

        cls.btaHandle = c_void_p()
        status = cls.dll.BTAopen(byref(config), byref(cls.btaHandle))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAopen failed!", status)


    @classmethod
    def close(cls):
        status = cls.dll.BTAclose(byref(cls.btaHandle))
        cls.btaHandle = None
        if status != PyBtaStatus.Ok:
            cls.throw("BTAclose failed!", status)


    @classmethod
    def set_frame_mode(cls, frameMode):
        if isinstance(mode, (Enum,)):
            mode = mode.value
        status = cls.dll.BTAsetFrameMode(cls.btaHandle, c_uint32(frameMode))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAsetFrameMode failed!", status)


    @classmethod
    def read_register(cls, address, count):
        if address < 0:
            raise RuntimeError("address must be positive")
        if count <= 0:
            raise RuntimeError("count must be > 0")
        register_count = c_uint32(count)
        buffer = (c_uint32 * count)()
        status = cls.dll.BTAreadRegister(cls.btaHandle, c_uint32(address), buffer, byref(register_count))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAreadRegister failed!", status)
        return list(buffer)


    @classmethod
    def write_register(cls, address, data):
        """ Write number of registers. Data is supposed to be a single python integer or a list/tuple of python integers """
        if address < 0:
            raise RuntimeError("address must be positive")
        if len(data) == 0:
            raise RuntimeError("data must be provided")
        if isinstance(data, (list, tuple)):
            length = c_uint32(len(data))
            buffer = (c_uint32 * len(data))(*data)
        else:
            length = c_uint32(1)
            buffer = (c_uint32 * 1)(data)
        status = cls.dll.BTAwriteRegister(cls.btaHandle, c_uint32(address), buffer, byref(length))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAwriteRegister failed!", status)
        return length.value


    @classmethod
    def get_lib_param(cls, id):
        value = c_float()
        status = cls.dll.BTAgetLibParam(cls.btaHandle, c_uint32(id), byref(value))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAgetLibParam failed!", status)
        return value.value


    @classmethod
    def set_lib_param(cls, id, value):
        status = cls.dll.BTAsetLibParam(cls.btaHandle, c_uint32(id), c_float(value))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAsetLibParam failed!", status)


    @classmethod
    def get_integration_time(cls):
        integrationTime = c_uint32()
        status = cls.dll.BTAgetIntegrationTime(cls.btaHandle, byref(integrationTime))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAgetIntegrationTime failed!", status)
        return integrationTime.value


    @classmethod
    def set_integration_time(cls, integration_time):
        status = cls.dll.BTAsetIntegrationTime(cls.btaHandle, c_uint32(integration_time))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAsetIntegrationTime failed!", status)


    @classmethod
    def get_modulation_frequency(cls):
        modulationFrequency = c_uint32()
        status = cls.dll.BTAgetModulationFrequency(cls.btaHandle, byref(modulationFrequency))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAgetModulationFrequency failed!", status)
        return modulationFrequency.value


    @classmethod
    def set_modulation_frequency(cls, modulation_frequency):
        status = cls.dll.BTAsetModulationFrequency(cls.btaHandle, c_uint32(modulation_frequency))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAsetModulationFrequency failed!", status)


    @classmethod
    def get_frame_rate(cls):
        frameRate = c_float()
        status = cls.dll.BTAgetFrameRate(cls.btaHandle, byref(frameRate))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAgetFrameRate failed!", status)
        return frameRate.value


    @classmethod
    def set_frame_rate(cls, frame_rate):
        status = cls.dll.BTAsetFrameRate(cls.btaHandle, c_float(frame_rate))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAsetFrameRate failed!", status)


    @classmethod
    def get_frame(cls, timeout):
        btaFrame = POINTER(BtaFrame)()
        status = cls.dll.BTAgetFrame(cls.btaHandle, byref(btaFrame), c_uint32(timeout))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAgetFrame failed!", status)
        pyBtaFrame = to_py_bta_frame(btaFrame)
        cls.free_frame(btaFrame);
        return pyBtaFrame


    @classmethod
    def free_frame(cls, frame):
        status = cls.dll.BTAfreeFrame(byref(frame))
        if status != PyBtaStatus.Ok:
            cls.throw("BTAfreeFrame failed!", status)


    @classmethod
    def generate_planar_view(cls, ch_x, ch_y, ch_z, ch_amp, res_x, res_y, planar_view_res_x, planar_view_res_y, planar_view_scale, planar_view_z, planar_view_amp):
        if not isinstance(ch_x, (list, tuple)):
            raise RuntimeError("must be list")
        if not isinstance(ch_y, (list, tuple)):
            raise RuntimeError("must be list")
        if not isinstance(ch_z, (list, tuple)):
            raise RuntimeError("must be list")
        if not isinstance(ch_amp, (list, tuple)):
            raise RuntimeError("must be list")
        if not isinstance(planar_view_z, (list, tuple)):
            raise RuntimeError("must be list")
        if not isinstance(planar_view_amp, (list, tuple)):
            raise RuntimeError("must be list")
        chX = (c_int16 * len(ch_x))(*ch_x)
        chY = (c_int16 * len(ch_y))(*ch_y)
        chZ = (c_int16 * len(ch_z))(*ch_z)
        chAmp = (c_int16 * len(ch_amp))(*ch_amp)
        resX = c_int32(res_x)
        resY = c_int32(res_y)
        planarViewResX = c_int32(planar_view_res_x)
        planarViewResY = c_int32(planar_view_res_y)
        planarViewZ = (c_int16 * len(planar_view_z))(*planar_view_z)
        planarViewAmp = (c_int16 * len(planar_view_amp))(*planar_view_amp)
        #void BTAgeneratePlanarView(int16_t *chX, int16_t *chY, int16_t *chZ, uint16_t *chAmp, int resX, int resY, int planarViewResX, int planarViewResY, float planarViewScale, int16_t *planarViewZ, uint16_t *planarViewAmp)
        cls.dll.BTAgeneratePlanarView(chX, chY, chZ, chAmp, resX, resY, planarViewResX, planarViewResY, planarViewScale, planarViewZ, planarViewAmp)






def to_py_bta_frame(btaFrame):
    pyBtaFrame = PyBtaFrame()
    pyBtaFrame.firmwareVersionMajor = btaFrame[0].firmwareVersionMajor
    pyBtaFrame.firmwareVersionMinor = btaFrame[0].firmwareVersionMinor
    pyBtaFrame.firmwareVersionNonFunc = btaFrame[0].firmwareVersionNonFunc
    pyBtaFrame.mainTemp = btaFrame[0].mainTemp
    pyBtaFrame.ledTemp = btaFrame[0].ledTemp
    pyBtaFrame.genericTemp = btaFrame[0].genericTemp
    pyBtaFrame.frameCounter = btaFrame[0].frameCounter
    pyBtaFrame.timeStamp = btaFrame[0].timeStamp
    pyBtaFrame.channels = [ ]
    for i in range(btaFrame[0].channelsLen):
        pyBtaChannel = PyBtaChannel()
        pyBtaChannel.id = PyBtaChannelId(btaFrame[0].channels[i][0].id)
        pyBtaChannel.xRes = btaFrame[0].channels[i][0].xRes
        pyBtaChannel.yRes = btaFrame[0].channels[i][0].yRes
        pyBtaChannel.dataFormat = btaFrame[0].channels[i][0].dataFormat
        pyBtaChannel.unit = btaFrame[0].channels[i][0].unit
        pyBtaChannel.integrationTime = btaFrame[0].channels[i][0].integrationTime
        pyBtaChannel.modulationFrequency = btaFrame[0].channels[i][0].modulationFrequency
        pyBtaChannel.data = None
        pyBtaChannel.metadata = None
        pyBtaChannel.lensIndex = btaFrame[0].channels[i][0].lensIndex
        pyBtaChannel.flags = btaFrame[0].channels[i][0].flags
        pyBtaChannel.sequenceCounter = btaFrame[0].channels[i][0].sequenceCounter
        pyBtaChannel.gain = btaFrame[0].channels[i][0].gain

        dataLen = btaFrame[0].channels[i][0].dataLen
        c_pixel_type = c_uint8
        np_pixel_type = np.uint8
        values_per_pixel = 1
        if pyBtaChannel.dataFormat == PyBtaDataFormat.UInt8:
            pass
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.UInt16:
            c_pixel_type = c_uint16
            np_pixel_type = np.uint16
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.UInt16Mlx1C11S:
            c_pixel_type = c_uint16
            np_pixel_type = np.uint16
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.UInt16Mlx12S:
            c_pixel_type = c_uint16
            np_pixel_type = np.uint16
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.UInt16Mlx1C11U:
            c_pixel_type = c_uint16
            np_pixel_type = np.uint16
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.UInt16Mlx12U:
            c_pixel_type = c_uint16
            np_pixel_type = np.uint16
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.UInt32:
            np_pixel_type = np.uint32
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.SInt16:
            c_pixel_type = c_uint16
            np_pixel_type = np.uint16
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.SInt32:
            c_pixel_type = c_uint32
            np_pixel_type = np.uint32
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.Float32:
            c_pixel_type = c_float
            np_pixel_type = np.float
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.Float64:
            c_pixel_type = c_double
            np_pixel_type = np.double
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.Rgb565:
            c_pixel_type = c_uint16
            np_pixel_type = np.uint16
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.Rgb24:
            values_per_pixel = 3
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.YUV422:
            c_pixel_type = c_uint16
            np_pixel_type = np.uint16
        elif pyBtaChannel.dataFormat == PyBtaDataFormat.Jpeg:
            raise NotImplemented()

        # allocate a ctypes array and copy data
        c_array = (c_pixel_type * (pyBtaChannel.xRes * pyBtaChannel.yRes * values_per_pixel))()
        if sizeof(c_array) != dataLen:
            cls.throw("Input buffer size is not equal to output buffer size", PyBtaStatus.RuntimeError)
        memmove(c_array, btaFrame[0].channels[i][0].data, dataLen)
        # wrap the array to a numpy list
        pyBtaChannel.data = np.ctypeslib.as_array(c_array)
        pyBtaFrame.channels.append(pyBtaChannel)
        # TODO: implement Metadata

    return pyBtaFrame