#!/bin/python3
'''
    SPDX-FileCopyrightText: 2024 Agata Cacko <cacko.azh@gmail.com>

    This file is part of Fast Sketch Cleanup Plugin  for Krita

    SPDX-License-Identifier: GPL-3.0-or-later
'''

import numpy as np
from PyQt5.QtGui import QImage
from PyQt5.QtCore import QBuffer, QByteArray


from PIL import Image
import sys

def invert(data: np.ndarray) -> np.ndarray:
    return 1.0 - data

def isNormalizedShapeNumpy(array: np.ndarray) -> bool:
    if len(array.shape) != 4:
        return False
    if array.shape[0] != 1:
        return False
    if array.shape[1] != 1:
        return False
    return True

def isNormalizedNumpy(array: np.ndarray) -> bool:
    if len(array.shape) != 4:
        return False
    if array.shape[0] != 1:
        return False
    if array.shape[1] != 1:
        return False
    
    if np.max(array) > 2: # probably range 0-255, not 0.0-1.0
        return False
    if np.min(array) < -1:
        return False
    
    return True
    



def ensureSizeDivisableBy(size, divisableBy) -> int:
    
    howMany : int = int(size/divisableBy)
    rest : int = int(size%divisableBy)
    if rest > 0:
        howMany += 1
    return int(howMany*divisableBy)


def clipOrExtendToSize(numpyArray: np.ndarray, expectedShape: tuple) -> np.ndarray:
    if not isNormalizedNumpy(numpyArray):
        assert(False), f"Only use clipOrExtendToSize on normalized arrays, current shape: {numpyArray.shape}, expectedShape was: {expectedShape}, max = {np.max(numpyArray)}, min = {np.min(numpyArray)}"
        return
    
    if numpyArray.shape != expectedShape:
        widthExpected = expectedShape[2]
        heightExpected = expectedShape[3]
        widthCurrent = numpyArray.shape[2]
        heightCurrent = numpyArray.shape[3]

        if (widthCurrent < widthExpected):
            numpyArray = np.pad(numpyArray, ((0, 0), (0, 0), (0, widthExpected - widthCurrent), (0, 0)), 'edge')
        
        if (heightCurrent < heightExpected):
            numpyArray = np.pad(numpyArray, ((0, 0), (0, 0), (0, 0), (0, heightExpected - heightCurrent)), 'edge')
        
        numpyArray = numpyArray[:, :, 0:widthExpected, 0:heightExpected]

    return numpyArray


def cutToSamples(numpyArray: np.ndarray, expectedShape: tuple, samplesCount: int):
    
    fullExpectedShape = (1, 1, expectedShape[2]*samplesCount, expectedShape[3]*samplesCount)
    numpyArray = clipOrExtendToSize(numpyArray, fullExpectedShape)
    response = []
    width = expectedShape[2]
    height = expectedShape[3]
    for i in range(samplesCount):
        for j in range(samplesCount):
            response.append((i, j, numpyArray[:, :, (i*width):((i + 1)*width), (j*height):((j+1)*height)]))
    return (response, fullExpectedShape)




def extendToBeDivisable(numpyArray: np.ndarray, divisableBy: int) -> np.ndarray:
    if len(numpyArray.shape) == 4 and numpyArray.shape[0] == 1 and numpyArray.shape[1] == 1:
        width = numpyArray.shape[2]
        height = numpyArray.shape[3]

        width = ensureSizeDivisableBy(numpyArray.shape[2], divisableBy)
        height = ensureSizeDivisableBy(numpyArray.shape[3], divisableBy)

        return clipOrExtendToSize(numpyArray, (0, 0, width, height))
    
    return numpyArray
    



def byteArrayToString(data, maxBytes = 20):
    response = ""
    for i in range(maxBytes):
        k: int = data[i]
        response += f"{k}" + " "
    return response

def numpyArrayShortened(data: np.ndarray, maxBytes = 20):
    response = ""
    oneDim = data.reshape((-1))
    for i in range(min(maxBytes, len(oneDim))):
        response += f"{oneDim[i]}" + " "
    return response


def convertImageToNumpy(data, width, height, strides):

    numpyarray = np.frombuffer(data, dtype=np.uint8)
    numpyarray = numpyarray.reshape((strides[1], strides[0], strides[2])) # height, then width
    numpyarray = numpyarray[0:height, 0:width]

    numpyarray = numpyarray.astype(dtype = np.float32)
    data = numpyarray

    data = np.expand_dims(data, 0)
    data = np.expand_dims(data, 0)
    
    data = np.squeeze(data, 4)
    data = data/255.0
    
    return data




def saveAsImage(data, filename):
    im = Image.fromarray(data)
    im.save(filename)

def convertOutputToRGBANumpy(output):

    outputData = output

    outputData = outputData.squeeze(0)
    outputData = outputData*255.0
    outputData = np.clip(outputData, 0, 255)

    # --- everything should be black, and the grey value should be the 1-transparency
    width = output.shape[2]
    height = output.shape[3]

    outputToImage = outputData.squeeze(0)
    outputRGB = np.zeros((width, height, 3), dtype=np.uint8)
    outputRGB[:, :, 0] = outputToImage
    outputRGB[:, :, 1] = outputToImage
    outputRGB[:, :, 2] = outputToImage

    outputRGBA = np.zeros((width, height, 4), dtype=np.uint8)
    outputRGBA[:, :, (0, 1, 2)] = outputRGB
    outputRGBA[:, :, 3] = np.ones((width, height))*255

    return outputRGBA

def convertOutputRGBAToLayerData(outputRGBA):
    return outputRGBA.tobytes()


def convertOutputToLayerData(output):
    return convertOutputRGBAToLayerData(convertOutputToRGBANumpy(output))



def convertNumpyToPillow(numpyInput: np.ndarray) -> Image:
    if (len(numpyInput.shape) == 4):
        squeezed = np.squeeze(np.squeeze(numpyInput, 0), 0)
        squeezed = squeezed*255
        squeezed = squeezed.clip(0, 255)
        squeezed = squeezed.astype(dtype=np.int8)

        pillowImage = Image.fromarray(squeezed, mode = "L")
        return pillowImage
    else:
        print(f"ERROR: convertNumpyToPillow: The shape of numpy array is: {numpyInput.shape}")
        return Image()

def convertPillowToNumpy(pillowImage) -> np.ndarray:
    numpyInput = np.asarray(pillowImage)
    numpyInput = numpyInput.reshape((1, 1, numpyInput.shape[0], numpyInput.shape[1]))
    numpyInput = numpyInput.astype(dtype=float)
    numpyInput = numpyInput/255.0
    numpyInput = numpyInput.clip(0.0, 1.0)
    return numpyInput


def convertQImageToNumpy(image):

    image.convertTo(QImage.Format_Grayscale8)
    
    projection = image
    width = projection.size().width()
    height = projection.size().height()
    
    projection = projection.convertToFormat(QImage.Format_Grayscale8)

    bitsy = projection.bits()
    bitsy.setsize(projection.sizeInBytes())
    
    response = (convertImageToNumpy(bitsy, width, height, [projection.bytesPerLine(), int(projection.sizeInBytes()/projection.bytesPerLine()), 1]), width, height)
    return response


def convertNumpyToQImage(numpyArray) -> QImage:
    channels = numpyArray.shape[2]
    preview = QImage()
    if channels == 4:
        preview = QImage(numpyArray.data, numpyArray.shape[0], numpyArray.shape[1], QImage.Format_RGBA8888)
    elif channels == 1:
        preview = QImage(numpyArray.data, numpyArray.shape[0], numpyArray.shape[1], QImage.Format_Grayscale8)
    elif (numpyArray.shape[0] == 1 and numpyArray.shape[1] == 1 and len(numpyArray.shape) == 4):
        # it means it has four dimentions, (1, 1, x, y)
        converted : np.ndarray = numpyArray.squeeze(0)
        converted = converted.squeeze(0)
        converted = converted.reshape((converted.shape[1], converted.shape[0]))

        converted = np.expand_dims(converted, axis=2)
        
        converted = converted*255
        converted = converted.astype(np.uint8)
        converted = np.clip(converted, 0, 255)

        preview = QImage(bytes(converted.data), converted.shape[0], converted.shape[1], converted.shape[0], QImage.Format_Grayscale8)
        
    else:
        # assert?
        print(f"Error in convertNumpyToQImage: Number of channels: {channels}, shape = {numpyArray.shape}")

    return preview
