From aefd2b198f531cccdc026b1c4a9552511eb86e54 Mon Sep 17 00:00:00 2001 From: dengbiao Date: Tue, 19 Nov 2024 21:16:28 +0800 Subject: [PATCH] =?UTF-8?q?add=20aes=20=E5=8A=A0=E5=AF=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/lib/aes/check.go | 26 +++++++ app/lib/aes/get_ip.go | 84 +++++++++++++++++++++ app/lib/aes/md/md_key.go | 6 ++ app/lib/aes/sign.go | 26 +++++++ app/lib/aes/utils.go | 74 +++++++++++++++++++ app/mw/mw_check_sign.go | 37 ++++++++++ app/svc/svc_common.go | 10 +++ app/svc/svc_go_routine.go | 72 ++++++++++++++++++ app/svc/svc_redis_mutex_lock.go | 100 +++++++++++++++++++++++++ app/svc/svc_user_invitecode.go | 70 ++++++++++++++++++ app/utils/sign_check.go | 125 -------------------------------- 11 files changed, 505 insertions(+), 125 deletions(-) create mode 100644 app/lib/aes/check.go create mode 100644 app/lib/aes/get_ip.go create mode 100644 app/lib/aes/md/md_key.go create mode 100644 app/lib/aes/sign.go create mode 100644 app/lib/aes/utils.go create mode 100644 app/mw/mw_check_sign.go create mode 100644 app/svc/svc_common.go create mode 100644 app/svc/svc_go_routine.go create mode 100644 app/svc/svc_redis_mutex_lock.go create mode 100644 app/svc/svc_user_invitecode.go delete mode 100644 app/utils/sign_check.go diff --git a/app/lib/aes/check.go b/app/lib/aes/check.go new file mode 100644 index 0000000..33e800b --- /dev/null +++ b/app/lib/aes/check.go @@ -0,0 +1,26 @@ +package aes + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" +) + +func CheckBody(request *http.Request, key string) { + body, _ := ioutil.ReadAll(request.Body) + if string(body) != "" { + fmt.Println("check_", string(body)) + str := AesDecryptByECB(key, string(body)) + if str != "" { + request.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(str))) + } + } + return +} + +func CheckError(err error) { + if err != nil { + panic(err) + } +} diff --git a/app/lib/aes/get_ip.go b/app/lib/aes/get_ip.go new file mode 100644 index 0000000..be456d2 --- /dev/null +++ b/app/lib/aes/get_ip.go @@ -0,0 +1,84 @@ +package aes + +import ( + "applet/app/cfg" + "applet/app/md" + "applet/app/utils" + "code.fnuoos.com/go_rely_warehouse/zyos_go_mq.git/rabbit" + "fmt" + "net" + "net/http" + "strings" +) + +func GetIpAddress(request *http.Request, masterId, pvd string) { + //geoIp2db, _ := geoip2db.NewGeoipDbByStatik() + //defer geoIp2db.Close() + ip := GetRemoteClientIp(request) + if ip != "" { + //record, _ := geoIp2db.City(net.ParseIP(ip)) + //if record.Country.Names != nil && record.Subdivisions != nil && record.City.Names != nil { + + // aesData := fmt.Sprintf( + // `{"country":"%s","province":"%s","city":"%s","ip":"%s", "master_id":"%s", "url":"%s", "pvd":"%s"}`, + // record.Country.Names["zh-CN"], record.Subdivisions[0].Names["zh-CN"], record.City.Names["zh-CN"], ip, masterId, request.Host, pvd) + + aesData := fmt.Sprintf( + `{"ip":"%s", "master_id":"%s", "url":"%s", "pvd":"%s"}`, + ip, masterId, request.Host, pvd) + fmt.Println(aesData) + if cfg.Prd { + //TODO::正式环境,推入rabbitMq + ch, _ := rabbit.Cfg.Pool.GetChannel() + defer ch.Release() + ch.Publish(md.UserVisitIpAddress, aesData, "") + } + //} + } + + return +} + +//func GetIp() (ip string) { +// conn, _ := net.Dial("udp", "google.com:80") +// if err != nil { +// return +// } +// defer conn.Close() +// ip = strings.Split(conn.LocalAddr().String(), ":")[0] +// return +//} + +// GetRemoteClientIp 获取远程客户端IP +func GetRemoteClientIp(r *http.Request) string { + remoteIp := r.RemoteAddr + if ip := r.Header.Get("X-Real-IP"); ip != "" { + if strings.Contains(ip, "172.20") { + ips := strings.Split(r.Header.Get("X-Forwarded-For"), ",") + if len(ips) == 2 { + remoteIp = ips[0] + } + } else { + remoteIp = ip + } + } else if ip = r.Header.Get("X-Forwarded-For"); ip != "" { + remoteIp = ip + } else { + remoteIp, _, _ = net.SplitHostPort(remoteIp) + //remoteIp = "" + } + + //本地ip + if remoteIp == "127.0.0.1" || remoteIp == "::1" { + remoteIp = "221.4.210.1651" + } + + if r.Host == "h5.99813608.zhiyingos.com" { + utils.FilePutContents("GetRemoteClientIp", utils.SerializeStr(map[string]interface{}{ + "header": r.Header, + "remoteIp": remoteIp, + })) + } + fmt.Println(">>>>>>>>>>>>>>>>>>>>GetRemoteClientIp<<<<<<<<<<<<<<<<<<", remoteIp) + return remoteIp +} diff --git a/app/lib/aes/md/md_key.go b/app/lib/aes/md/md_key.go new file mode 100644 index 0000000..9b13116 --- /dev/null +++ b/app/lib/aes/md/md_key.go @@ -0,0 +1,6 @@ +package md + +const AesKey = "zhiyingos@qq.com" + +const LimiterRequestIdPrefix = "request_domain_limiter_request_id:%s" +const LimiterLock = "request_domain_limiter_lock:%s" // 限流器锁 diff --git a/app/lib/aes/sign.go b/app/lib/aes/sign.go new file mode 100644 index 0000000..33e800b --- /dev/null +++ b/app/lib/aes/sign.go @@ -0,0 +1,26 @@ +package aes + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" +) + +func CheckBody(request *http.Request, key string) { + body, _ := ioutil.ReadAll(request.Body) + if string(body) != "" { + fmt.Println("check_", string(body)) + str := AesDecryptByECB(key, string(body)) + if str != "" { + request.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(str))) + } + } + return +} + +func CheckError(err error) { + if err != nil { + panic(err) + } +} diff --git a/app/lib/aes/utils.go b/app/lib/aes/utils.go new file mode 100644 index 0000000..dcb648c --- /dev/null +++ b/app/lib/aes/utils.go @@ -0,0 +1,74 @@ +package aes + +import ( + "bytes" + "crypto/aes" + "encoding/base64" +) + +// 加密 +func AesEncryptByECB(key, data string) string { + // 判断key长度 + keyLenMap := map[int]struct{}{16: {}, 24: {}, 32: {}} + if _, ok := keyLenMap[len(key)]; !ok { + panic("key长度必须是 16、24、32 其中一个") + } + // 密钥和待加密数据转成[]byte + originByte := []byte(data) + keyByte := []byte(key) + // 创建密码组,长度只能是16、24、32 字节 + block, _ := aes.NewCipher(keyByte) + // 获取密钥长度 + blockSize := block.BlockSize() + // 补码 + originByte = PKCS7Padding(originByte, blockSize) + // 创建保存加密变量 + encryptResult := make([]byte, len(originByte)) + // CEB是把整个明文分成若干段相同的小段,然后对每一小段进行加密 + for bs, be := 0, blockSize; bs < len(originByte); bs, be = bs+blockSize, be+blockSize { + block.Encrypt(encryptResult[bs:be], originByte[bs:be]) + } + res := base64.StdEncoding.EncodeToString(encryptResult) + return res +} + +// 补码 +func PKCS7Padding(originByte []byte, blockSize int) []byte { + // 计算补码长度 + padding := blockSize - len(originByte)%blockSize + // 生成补码 + padText := bytes.Repeat([]byte{byte(padding)}, padding) + // 追加补码 + return append(originByte, padText...) +} + +// 解密 +func AesDecryptByECB(key, data string) string { + // 判断key长度 + keyLenMap := map[int]struct{}{16: {}, 24: {}, 32: {}} + if _, ok := keyLenMap[len(key)]; !ok { + panic("key长度必须是 16、24、32 其中一个") + } + // 反解密码base64 + originByte, _ := base64.StdEncoding.DecodeString(data) + // 密钥和待加密数据转成[]byte + keyByte := []byte(key) + // 创建密码组,长度只能是16、24、32字节 + block, _ := aes.NewCipher(keyByte) + // 获取密钥长度 + blockSize := block.BlockSize() + // 创建保存解密变量 + decrypted := make([]byte, len(originByte)) + for bs, be := 0, blockSize; bs < len(originByte); bs, be = bs+blockSize, be+blockSize { + block.Decrypt(decrypted[bs:be], originByte[bs:be]) + } + // 解码 + return string(PKCS7UNPadding(decrypted)) +} + +// 解码 +func PKCS7UNPadding(originDataByte []byte) []byte { + length := len(originDataByte) + unpadding := int(originDataByte[length-1]) + return originDataByte[:(length - unpadding)] +} diff --git a/app/mw/mw_check_sign.go b/app/mw/mw_check_sign.go new file mode 100644 index 0000000..599c512 --- /dev/null +++ b/app/mw/mw_check_sign.go @@ -0,0 +1,37 @@ +package mw + +import ( + "applet/app/e" + "applet/app/utils" + "bytes" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io/ioutil" +) + +// CheckSign is 中间件 用来检查签名 +func CheckSign(c *gin.Context) { + if utils.SignCheck(c) == false { + e.OutErr(c, 400, errors.New("请求失败~~")) + return + } + c.Next() +} + +func CheckBody(c *gin.Context) { + c.Set("api_version", "1") + if utils.GetApiVersion(c) > 0 { + body, _ := ioutil.ReadAll(c.Request.Body) + fmt.Println("check_", c.GetString("mid"), string(body)) + if string(body) != "" { + str := utils.ResultAesDecrypt(c, string(body)) + fmt.Println("check_de", c.GetString("mid"), str) + if str != "" { + c.Set("body_str", str) + c.Request.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(str))) + } + } + } + c.Next() +} diff --git a/app/svc/svc_common.go b/app/svc/svc_common.go new file mode 100644 index 0000000..358a00e --- /dev/null +++ b/app/svc/svc_common.go @@ -0,0 +1,10 @@ +package svc + +import "applet/app/utils/logx" + +// 简单的recover +func Rev() { + if err := recover(); err != nil { + _ = logx.Error(err) + } +} diff --git a/app/svc/svc_go_routine.go b/app/svc/svc_go_routine.go new file mode 100644 index 0000000..1f3c083 --- /dev/null +++ b/app/svc/svc_go_routine.go @@ -0,0 +1,72 @@ +package svc + +import ( + "applet/app/db" + "applet/app/db/model" + "applet/app/utils/logx" + "fmt" + "time" +) + +// RoutineInsertUserRelate is 协程 当关联上级用户时,需要查询当前用户的所有关联下级,并新增关联上级与当前用户下级关系 +func RoutineInsertUserRelate(puid, uid, addlv int) { + defer Rev() + urs, err := db.UserRelatesByPuid(db.Db, uid, 0, 0) + if err != nil { + logx.Warn(err) + } + // fmt.Println(*urs) + for _, item := range *urs { + _, err = db.UserRelateInsert(db.Db, + &model.UserRelate{ + ParentUid: puid, + Uid: item.Uid, + Level: item.Level + addlv, + InviteTime: time.Now(), + }) + if err != nil { + continue + } + logx.Info(fmt.Sprintf("关联pid(%v) -> uid(%v) ,lv:%v", puid, item.Uid, item.Level+addlv)) + + } +} + +// RoutineMultiRelate is 多级关联 +func RoutineMultiRelate(pid int, uid int, lv int) { + userDb := db.UserDb{} + userDb.Set() + + defer Rev() + for { + if pid == 0 { + break + } + m, err := userDb.GetUser(pid) + if err != nil || m == nil { + logx.Warn(err) + break + } + if m.ParentUid == 0 { + break + } + lv++ + ur := new(model.UserRelate) + ur.ParentUid = m.ParentUid + ur.Uid = uid + ur.Level = lv + ur.InviteTime = time.Now() + _, err = db.UserRelateInsert(db.Db, ur) + if err != nil { + logx.Warn(err) + break + } + // 还要关联当前的用户的所有下级,注意关联等级 + RoutineInsertUserRelate(m.ParentUid, uid, lv) + // 下级关联上上级 + // 继续查询 + logx.Info(fmt.Sprintf("关联pid(%v) -> uid(%v),lv:%v", ur.ParentUid, ur.Uid, lv)) + logx.Info("继续查询") + pid = m.ParentUid + } +} diff --git a/app/svc/svc_redis_mutex_lock.go b/app/svc/svc_redis_mutex_lock.go new file mode 100644 index 0000000..f35e0f9 --- /dev/null +++ b/app/svc/svc_redis_mutex_lock.go @@ -0,0 +1,100 @@ +package svc + +import ( + "applet/app/md" + "applet/app/utils" + "applet/app/utils/cache" + "errors" + "fmt" + "math/rand" + "reflect" + "time" +) + +const redisMutexLockExpTime = 15 + +// TryGetDistributedLock 分布式锁获取 +// requestId 用于标识请求客户端,可以是随机字符串,需确保唯一 +func TryGetDistributedLock(lockKey, requestId string, isNegative bool) bool { + if isNegative { // 多次尝试获取 + retry := 1 + for { + ok, err := cache.Do("SET", lockKey, requestId, "EX", redisMutexLockExpTime, "NX") + // 获取锁成功 + if err == nil && ok == "OK" { + return true + } + // 尝试多次没获取成功 + if retry > 10 { + return false + } + time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000))) + retry += 1 + } + } else { // 只尝试一次 + ok, err := cache.Do("SET", lockKey, requestId, "EX", redisMutexLockExpTime, "NX") + // 获取锁成功 + if err == nil && ok == "OK" { + return true + } + + return false + } +} + +// ReleaseDistributedLock 释放锁,通过比较requestId,用于确保客户端只释放自己的锁,使用lua脚本保证操作的原子型 +func ReleaseDistributedLock(lockKey, requestId string) (bool, error) { + luaScript := ` + if redis.call("get",KEYS[1]) == ARGV[1] + then + return redis.call("del",KEYS[1]) + else + return 0 + end` + + do, err := cache.Do("eval", luaScript, 1, lockKey, requestId) + fmt.Println(reflect.TypeOf(do)) + fmt.Println(do) + + if utils.AnyToInt64(do) == 1 { + return true, err + } else { + return false, err + } +} + +func GetDistributedLockRequestId(prefix string) string { + return prefix + utils.IntToStr(rand.Intn(100000000)) +} + +// HandleBalanceDistributedLock 处理余额更新时获取锁和释放锁 如果加锁成功,使用语句 ` defer cb() ` 释放锁 +func HandleBalanceDistributedLock(masterId, uid, requestIdPrefix string) (cb func(), err error) { + // 获取余额更新锁 + balanceLockKey := fmt.Sprintf(md.UserFinValidUpdateLock, masterId, uid) + requestId := GetDistributedLockRequestId(requestIdPrefix) + balanceLockOk := TryGetDistributedLock(balanceLockKey, requestId, true) + if !balanceLockOk { + return nil, errors.New("系统繁忙,请稍后再试") + } + + cb = func() { + _, _ = ReleaseDistributedLock(balanceLockKey, requestId) + } + + return cb, nil +} + +func HandleLimiterDistributedLock(masterId, ip, requestIdPrefix string) (cb func(), err error) { + balanceLockKey := fmt.Sprintf(md.AppLimiterLock, masterId, ip) + requestId := GetDistributedLockRequestId(requestIdPrefix) + balanceLockOk := TryGetDistributedLock(balanceLockKey, requestId, true) + if !balanceLockOk { + return nil, errors.New("系统繁忙,请稍后再试") + } + + cb = func() { + _, _ = ReleaseDistributedLock(balanceLockKey, requestId) + } + + return cb, nil +} diff --git a/app/svc/svc_user_invitecode.go b/app/svc/svc_user_invitecode.go new file mode 100644 index 0000000..ef3103e --- /dev/null +++ b/app/svc/svc_user_invitecode.go @@ -0,0 +1,70 @@ +package svc + +import ( + "applet/app/db" + "math/rand" + "unicode" +) + +func ReturnCode(l, types, num int) string { + if num > 5 { + return "" + } + //循环3次判断是否存在该邀请码 + var code string + var ( + codes []string + ) + for i := 0; i < 3; i++ { + oneCode := GetRandomString(l, types) + codes = append(codes, oneCode) + } + + //判断是不是存在邀请码了 + tmp, _ := db.UserProfileFindByInviteCodes(db.Db, codes...) + + //循环生成的邀请码 判断tmp里有没有这个邀请码 如果邀请码没有就赋值 再判断是否存在 存在就清空 + for _, v := range codes { + if code != "" { //如果存在并且数据库没有就跳过 + continue + } + code = v + for _, v1 := range *tmp { + //如果存在就清空 + if v1.InviteCode == v { + code = "" + } + } + } + //如果都没有就继续加一位继续查 + if code == "" { + return ReturnCode(l+1, types, num+1) + } + return code +} + +// 随机生成指定位数的大写字母和数字的组合 +func GetRandomString(l, isLetter int) string { + str := "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" + if isLetter != 1 { + str = "0123456789" + } + strs := []rune(str) + result := make([]rune, l) + for i := range result { + result[i] = strs[rand.Intn(len(strs))] + } + if IsLetter(string(result)) && isLetter == 1 { + return GetRandomString(l, isLetter) + } + return string(result) +} + +func IsLetter(s string) bool { + for _, r := range s { + if !unicode.IsLetter(r) { + return false + } + } + return true +} diff --git a/app/utils/sign_check.go b/app/utils/sign_check.go deleted file mode 100644 index 798f63d..0000000 --- a/app/utils/sign_check.go +++ /dev/null @@ -1,125 +0,0 @@ -package utils - -import ( - "applet/app/utils/logx" - "fmt" - "github.com/forgoer/openssl" - "github.com/gin-gonic/gin" - "github.com/syyongx/php2go" - "strings" -) - -var publicKey = []byte(`-----BEGIN PUBLIC KEY----- -MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCFQD7RL2tDNuwdg0jTfV0zjAzh -WoCWfGrcNiucy2XUHZZU2oGhHv1N10qu3XayTDD4pu4sJ73biKwqR6ZN7IS4Sfon -vrzaXGvrTG4kmdo3XrbrkzmyBHDLTsJvv6pyS2HPl9QPSvKDN0iJ66+KN8QjBpw1 -FNIGe7xbDaJPY733/QIDAQAB ------END PUBLIC KEY-----`) - -var privateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- -MIICXAIBAAKBgQCFQD7RL2tDNuwdg0jTfV0zjAzhWoCWfGrcNiucy2XUHZZU2oGh -Hv1N10qu3XayTDD4pu4sJ73biKwqR6ZN7IS4SfonvrzaXGvrTG4kmdo3Xrbrkzmy -BHDLTsJvv6pyS2HPl9QPSvKDN0iJ66+KN8QjBpw1FNIGe7xbDaJPY733/QIDAQAB -AoGADi14wY8XDY7Bbp5yWDZFfV+QW0Xi2qAgSo/k8gjeK8R+I0cgdcEzWF3oz1Q2 -9d+PclVokAAmfj47e0AmXLImqMCSEzi1jDBUFIRoJk9WE1YstE94mrCgV0FW+N/u -+L6OgZcjmF+9dHKprnpaUGQuUV5fF8j0qp8S2Jfs3Sw+dOECQQCQnHALzFjmXXIR -Ez3VSK4ZoYgDIrrpzNst5Hh6AMDNZcG3CrCxlQrgqjgTzBSr3ZSavvkfYRj42STk -TqyX1tQFAkEA6+O6UENoUTk2lG7iO/ta7cdIULnkTGwQqvkgLIUjk6w8E3sBTIfw -rerTEmquw5F42HHE+FMrRat06ZN57lENmQJAYgUHlZevcoZIePZ35Qfcqpbo4Gc8 -Fpm6vwKr/tZf2Vlt0qo2VkhWFS6L0C92m4AX6EQmDHT+Pj7BWNdS+aCuGQJBAOkq -NKPZvWdr8jNOV3mKvxqB/U0uMigIOYGGtvLKt5vkh42J7ILFbHW8w95UbWMKjDUG -X/hF3WQEUo//Imsa2yECQHSZIpJxiTRueoDiyRt0LH+jdbYFUu/6D0UIYXhFvP/p -EZX+hfCfUnNYX59UVpRjSZ66g0CbCjuBPOhmOD+hDeQ= ------END RSA PRIVATE KEY-----`) - -func GetApiVersion(c *gin.Context) int { - var apiVersion = c.GetHeader("apiVersion") - if StrToInt(apiVersion) == 0 { //没有版本号先不校验 - apiVersion = c.GetHeader("Apiversion") - } - if StrToInt(apiVersion) == 0 { //没有版本号先不校验 - apiVersion = c.GetHeader("api_version") - } - return StrToInt(apiVersion) -} - -//签名校验 -func SignCheck(c *gin.Context) bool { - var apiVersion = GetApiVersion(c) - if apiVersion == 0 { //没有版本号先不校验 - return true - } - //1.通过rsa 解析出 aes - var key = c.GetHeader("key") - - //拼接对应参数 - var uri = c.Request.RequestURI - var query = GetQueryParam(uri) - fmt.Println(query) - query["timestamp"] = c.GetHeader("timestamp") - query["nonce"] = c.GetHeader("nonce") - query["key"] = key - token := c.GetHeader("Authorization") - if token != "" { - // 按空格分割 - parts := strings.SplitN(token, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - token = parts[1] - } - } - query["token"] = token - //2.query参数按照 ASCII 码从小到大排序 - str := JoinStringsInASCII(query, "&", false, false, "") - //3.拼上密钥 - secret := "" - if InArr(c.GetHeader("platform"), []string{"android", "ios"}) { - secret = c.GetString("app_api_secret_key") - } else if c.GetHeader("platform") == "wap" { - secret = c.GetString("h5_api_secret_key") - } else { - secret = c.GetString("applet_api_secret_key") - } - str = fmt.Sprintf("%s&secret=%s", str, secret) - fmt.Println(str) - //4.md5加密 转小写 - sign := strings.ToLower(Md5(str)) - //5.判断跟前端传来的sign是否一致 - if sign != c.GetHeader("sign") { - return false - } - return true -} - -func ResultAes(c *gin.Context, raw []byte) string { - var key = c.GetHeader("key") - base, _ := php2go.Base64Decode(key) - aes, err := RsaDecrypt([]byte(base), privateKey) - if err != nil { - logx.Info(err) - return "" - } - - str, _ := openssl.AesECBEncrypt(raw, aes, openssl.PKCS7_PADDING) - value := php2go.Base64Encode(string(str)) - fmt.Println(value) - - return value -} - -func ResultAesDecrypt(c *gin.Context, raw string) string { - var key = c.GetHeader("key") - base, _ := php2go.Base64Decode(key) - aes, err := RsaDecrypt([]byte(base), privateKey) - if err != nil { - logx.Info(err) - return "" - } - fmt.Println(raw) - value1, _ := php2go.Base64Decode(raw) - if value1 == "" { - return "" - } - str1, _ := openssl.AesECBDecrypt([]byte(value1), aes, openssl.PKCS7_PADDING) - - return string(str1) -}