package utils

import (
	"bytes"
	"crypto/aes"
	"crypto/cipher"
	"fmt"
)

func AesEncrypt(rawData, key []byte) ([]byte, error) {
	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}
	blockSize := block.BlockSize()
	rawData = PKCS5Padding(rawData, blockSize)
	// rawData = ZeroPadding(rawData, block.BlockSize())
	blockMode := cipher.NewCBCEncrypter(block, key[:blockSize])
	encrypted := make([]byte, len(rawData))
	// 根据CryptBlocks方法的说明,如下方式初始化encrypted也可以
	// encrypted := rawData
	blockMode.CryptBlocks(encrypted, rawData)
	return encrypted, nil
}

func AesDecrypt(encrypted, key []byte) ([]byte, error) {
	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}
	blockSize := block.BlockSize()
	blockMode := cipher.NewCBCDecrypter(block, key[:blockSize])
	rawData := make([]byte, len(encrypted))
	// rawData := encrypted
	blockMode.CryptBlocks(rawData, encrypted)
	rawData = PKCS5UnPadding(rawData)
	// rawData = ZeroUnPadding(rawData)
	return rawData, nil
}

func ZeroPadding(cipherText []byte, blockSize int) []byte {
	padding := blockSize - len(cipherText)%blockSize
	padText := bytes.Repeat([]byte{0}, padding)
	return append(cipherText, padText...)
}

func ZeroUnPadding(rawData []byte) []byte {
	length := len(rawData)
	unPadding := int(rawData[length-1])
	return rawData[:(length - unPadding)]
}

func PKCS5Padding(cipherText []byte, blockSize int) []byte {
	padding := blockSize - len(cipherText)%blockSize
	padText := bytes.Repeat([]byte{byte(padding)}, padding)
	return append(cipherText, padText...)
}

func PKCS5UnPadding(rawData []byte) []byte {
	length := len(rawData)
	// 去掉最后一个字节 unPadding 次
	unPadding := int(rawData[length-1])
	return rawData[:(length - unPadding)]
}

// 填充0
func zeroFill(key *string) {
	l := len(*key)
	if l != 16 && l != 24 && l != 32 {
		if l < 16 {
			*key = *key + fmt.Sprintf("%0*d", 16-l, 0)
		} else if l < 24 {
			*key = *key + fmt.Sprintf("%0*d", 24-l, 0)
		} else if l < 32 {
			*key = *key + fmt.Sprintf("%0*d", 32-l, 0)
		} else {
			*key = string([]byte(*key)[:32])
		}
	}
}

type AesCrypt struct {
	Key []byte
	Iv  []byte
}

func (a *AesCrypt) Encrypt(data []byte) ([]byte, error) {
	aesBlockEncrypt, err := aes.NewCipher(a.Key)
	if err != nil {
		println(err.Error())
		return nil, err
	}

	content := pKCS5Padding(data, aesBlockEncrypt.BlockSize())
	cipherBytes := make([]byte, len(content))
	aesEncrypt := cipher.NewCBCEncrypter(aesBlockEncrypt, a.Iv)
	aesEncrypt.CryptBlocks(cipherBytes, content)
	return cipherBytes, nil
}

func (a *AesCrypt) Decrypt(src []byte) (data []byte, err error) {
	decrypted := make([]byte, len(src))
	var aesBlockDecrypt cipher.Block
	aesBlockDecrypt, err = aes.NewCipher(a.Key)
	if err != nil {
		println(err.Error())
		return nil, err
	}
	aesDecrypt := cipher.NewCBCDecrypter(aesBlockDecrypt, a.Iv)
	aesDecrypt.CryptBlocks(decrypted, src)
	return pKCS5Trimming(decrypted), nil
}

func pKCS5Padding(cipherText []byte, blockSize int) []byte {
	padding := blockSize - len(cipherText)%blockSize
	padText := bytes.Repeat([]byte{byte(padding)}, padding)
	return append(cipherText, padText...)
}

func pKCS5Trimming(encrypt []byte) []byte {
	padding := encrypt[len(encrypt)-1]
	return encrypt[:len(encrypt)-int(padding)]
}