diff --git a/e/code.go b/e/code.go index 172e7eb..8d241c9 100644 --- a/e/code.go +++ b/e/code.go @@ -237,6 +237,6 @@ var MsgFlags = map[int]string{ ERR_ALIPAY_ORDER_ERR: "订单创建错误", ERR_PAY_ERR: "未找到支付方式", ERR_SEARCH_ERR: "暂无该分类商品", - ERR_LEVEL_REACH_TOP: "已经是最高等级 ", + ERR_LEVEL_REACH_TOP: "已经是最高等级", ERR_USER_CHECK_ERR: "校验失败", } diff --git a/e/msg.go b/e/msg.go new file mode 100644 index 0000000..0105125 --- /dev/null +++ b/e/msg.go @@ -0,0 +1,124 @@ +package e + +import ( + "code.fnuoos.com/go_rely_warehouse/zyos_go_tool.git/utils" + "code.fnuoos.com/go_rely_warehouse/zyos_go_tool.git/utils/logx" + "encoding/json" + "net/http" + + "github.com/gin-gonic/gin" +) + +// GetMsg get error information based on Code +// 因为这里code是自己控制的, 因此没考虑报错信息 +func GetMsg(code int) (int, string) { + if msg, ok := MsgFlags[code]; ok { + return code / 1000, msg + } + if http.StatusText(code) == "" { + code = 200 + } + return code, MsgFlags[ERR_BAD_REQUEST] +} + +// 成功输出, fields 是额外字段, 与code, msg同级 +func OutSuc(c *gin.Context, data interface{}, fields map[string]interface{}) { + res := gin.H{ + "code": 1, + "msg": "ok", + "data": data, + } + if fields != nil { + for k, v := range fields { + res[k] = v + } + } + jsonData, _ := json.Marshal(res) + strs := string(jsonData) + if c.GetString("translate_open") != "" { + strs = utils.ReadReverse(c, string(jsonData), "") + } + if utils.GetApiVersion(c) > 0 { //加了签名校验只返回加密的字符串 + strs = utils.ResultAes(c, []byte(strs)) + } + c.Writer.WriteString(strs) + +} + +func OutSucPure(c *gin.Context, data interface{}, fields map[string]interface{}) { + res := gin.H{ + "code": 1, + "msg": "ok", + "data": data, + } + if fields != nil { + for k, v := range fields { + res[k] = v + } + } + jsonData, _ := json.Marshal(res) + strs := string(jsonData) + if c.GetString("translate_open") != "" { + strs = utils.ReadReverse(c, string(jsonData), "") + } + if utils.GetApiVersion(c) > 0 { //加了签名校验只返回加密的字符串 + strs = utils.ResultAes(c, []byte(strs)) + } + c.Writer.WriteString(strs) + +} + +// 错误输出 +func OutErr(c *gin.Context, code int, err ...interface{}) { + statusCode, msg := GetMsg(code) + if len(err) > 0 && err[0] != nil { + e := err[0] + switch v := e.(type) { + case E: + statusCode = v.Code / 1000 + msg = v.Error() + logx.Error(v.msg + ": " + v.st) // 记录堆栈信息 + case error: + logx.Error(v) + break + case string: + msg = v + case int: + if _, ok := MsgFlags[v]; ok { + msg = MsgFlags[v] + } + } + } + if c.GetString("translate_open") != "" { + msg = utils.ReadReverse(c, msg, "") + } + if utils.GetApiVersion(c) > 0 { //加了签名校验只返回加密的字符串 + jsonData, _ := json.Marshal(gin.H{ + "code": code, + "msg": msg, + "data": []struct{}{}, + }) + str := utils.ResultAes(c, jsonData) + if code > 100000 { + code = int(utils.FloatFormat(float64(code/1000), 0)) + } + c.Status(500) + c.Writer.WriteString(str) + } else { + c.AbortWithStatusJSON(statusCode, gin.H{ + "code": code, + "msg": msg, + "data": []struct{}{}, + }) + } + +} + +// 重定向 +func OutRedirect(c *gin.Context, code int, loc string) { + if code < 301 || code > 308 { + code = 303 + } + c.Redirect(code, loc) + c.Abort() +} diff --git a/e/set_cache.go b/e/set_cache.go new file mode 100644 index 0000000..45337a1 --- /dev/null +++ b/e/set_cache.go @@ -0,0 +1,8 @@ +package e + +func SetCache(cacheTime int64) map[string]interface{} { + if cacheTime == 0 { + return map[string]interface{}{"cache_time": cacheTime} + } + return map[string]interface{}{"cache_time": cacheTime} +} diff --git a/utils/base64.go b/utils/base64.go new file mode 100644 index 0000000..97ec434 --- /dev/null +++ b/utils/base64.go @@ -0,0 +1,118 @@ +package utils + +import ( + "encoding/base64" + "fmt" + "regexp" +) + +const ( + Base64Std = iota + Base64Url + Base64RawStd + Base64RawUrl +) + +func Base64StdEncode(str interface{}) string { + return Base64Encode(str, Base64Std) +} + +func Base64StdDecode(str interface{}) string { + return Base64Decode(str, Base64Std) +} + +func Base64UrlEncode(str interface{}) string { + return Base64Encode(str, Base64Url) +} + +func Base64UrlDecode(str interface{}) string { + return Base64Decode(str, Base64Url) +} + +func Base64RawStdEncode(str interface{}) string { + return Base64Encode(str, Base64RawStd) +} + +func Base64RawStdDecode(str interface{}) string { + return Base64Decode(str, Base64RawStd) +} + +func Base64RawUrlEncode(str interface{}) string { + return Base64Encode(str, Base64RawUrl) +} + +func Base64RawUrlDecode(str interface{}) string { + return Base64Decode(str, Base64RawUrl) +} + +func Base64Encode(str interface{}, encode int) string { + newEncode := base64Encode(encode) + if newEncode == nil { + return "" + } + switch v := str.(type) { + case string: + return newEncode.EncodeToString([]byte(v)) + case []byte: + return newEncode.EncodeToString(v) + } + return newEncode.EncodeToString([]byte(fmt.Sprint(str))) +} + +func Base64Decode(str interface{}, encode int) string { + var err error + var b []byte + newEncode := base64Encode(encode) + if newEncode == nil { + return "" + } + switch v := str.(type) { + case string: + b, err = newEncode.DecodeString(v) + case []byte: + b, err = newEncode.DecodeString(string(v)) + default: + return "" + } + if err != nil { + return "" + } + return string(b) +} + +func base64Encode(encode int) *base64.Encoding { + switch encode { + case Base64Std: + return base64.StdEncoding + case Base64Url: + return base64.URLEncoding + case Base64RawStd: + return base64.RawStdEncoding + case Base64RawUrl: + return base64.RawURLEncoding + default: + return nil + } +} + + +func JudgeBase64(str string) bool { + pattern := "^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)$" + matched, err := regexp.MatchString(pattern, str) + if err != nil { + return false + } + if !(len(str)%4 == 0 && matched) { + return false + } + unCodeStr, err := base64.StdEncoding.DecodeString(str) + if err != nil { + return false + } + tranStr := base64.StdEncoding.EncodeToString(unCodeStr) + //return str==base64.StdEncoding.EncodeToString(unCodeStr) + if str == tranStr { + return true + } + return false +} diff --git a/utils/boolean.go b/utils/boolean.go new file mode 100644 index 0000000..d64c876 --- /dev/null +++ b/utils/boolean.go @@ -0,0 +1,26 @@ +package utils + +import "reflect" + +// 检验一个值是否为空 +func Empty(val interface{}) bool { + v := reflect.ValueOf(val) + switch v.Kind() { + case reflect.String, reflect.Array: + return v.Len() == 0 + case reflect.Map, reflect.Slice: + return v.Len() == 0 || v.IsNil() + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + + return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface()) +} \ No newline at end of file diff --git a/utils/cache/base.go b/utils/cache/base.go new file mode 100644 index 0000000..64648dd --- /dev/null +++ b/utils/cache/base.go @@ -0,0 +1,421 @@ +package cache + +import ( + "errors" + "fmt" + "strconv" + "time" +) + +const ( + redisDialTTL = 10 * time.Second + redisReadTTL = 3 * time.Second + redisWriteTTL = 3 * time.Second + redisIdleTTL = 10 * time.Second + redisPoolTTL = 10 * time.Second + redisPoolSize int = 512 + redisMaxIdleConn int = 64 + redisMaxActive int = 512 +) + +var ( + ErrNil = errors.New("nil return") + ErrWrongArgsNum = errors.New("args num error") + ErrNegativeInt = errors.New("redis cluster: unexpected value for Uint64") +) + +// 以下为提供类型转换 + +func Int(reply interface{}, err error) (int, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case int: + return reply, nil + case int8: + return int(reply), nil + case int16: + return int(reply), nil + case int32: + return int(reply), nil + case int64: + x := int(reply) + if int64(x) != reply { + return 0, strconv.ErrRange + } + return x, nil + case uint: + n := int(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case uint8: + return int(reply), nil + case uint16: + return int(reply), nil + case uint32: + n := int(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case uint64: + n := int(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case []byte: + data := string(reply) + if len(data) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseInt(data, 10, 0) + return int(n), err + case string: + if len(reply) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseInt(reply, 10, 0) + return int(n), err + case nil: + return 0, ErrNil + case error: + return 0, reply + } + return 0, fmt.Errorf("redis cluster: unexpected type for Int, got type %T", reply) +} + +func Int64(reply interface{}, err error) (int64, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case int: + return int64(reply), nil + case int8: + return int64(reply), nil + case int16: + return int64(reply), nil + case int32: + return int64(reply), nil + case int64: + return reply, nil + case uint: + n := int64(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case uint8: + return int64(reply), nil + case uint16: + return int64(reply), nil + case uint32: + return int64(reply), nil + case uint64: + n := int64(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case []byte: + data := string(reply) + if len(data) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseInt(data, 10, 64) + return n, err + case string: + if len(reply) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseInt(reply, 10, 64) + return n, err + case nil: + return 0, ErrNil + case error: + return 0, reply + } + return 0, fmt.Errorf("redis cluster: unexpected type for Int64, got type %T", reply) +} + +func Uint64(reply interface{}, err error) (uint64, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case uint: + return uint64(reply), nil + case uint8: + return uint64(reply), nil + case uint16: + return uint64(reply), nil + case uint32: + return uint64(reply), nil + case uint64: + return reply, nil + case int: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case int8: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case int16: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case int32: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case int64: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case []byte: + data := string(reply) + if len(data) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseUint(data, 10, 64) + return n, err + case string: + if len(reply) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseUint(reply, 10, 64) + return n, err + case nil: + return 0, ErrNil + case error: + return 0, reply + } + return 0, fmt.Errorf("redis cluster: unexpected type for Uint64, got type %T", reply) +} + +func Float64(reply interface{}, err error) (float64, error) { + if err != nil { + return 0, err + } + + var value float64 + err = nil + switch v := reply.(type) { + case float32: + value = float64(v) + case float64: + value = v + case int: + value = float64(v) + case int8: + value = float64(v) + case int16: + value = float64(v) + case int32: + value = float64(v) + case int64: + value = float64(v) + case uint: + value = float64(v) + case uint8: + value = float64(v) + case uint16: + value = float64(v) + case uint32: + value = float64(v) + case uint64: + value = float64(v) + case []byte: + data := string(v) + if len(data) == 0 { + return 0, ErrNil + } + value, err = strconv.ParseFloat(string(v), 64) + case string: + if len(v) == 0 { + return 0, ErrNil + } + value, err = strconv.ParseFloat(v, 64) + case nil: + err = ErrNil + case error: + err = v + default: + err = fmt.Errorf("redis cluster: unexpected type for Float64, got type %T", v) + } + + return value, err +} + +func Bool(reply interface{}, err error) (bool, error) { + if err != nil { + return false, err + } + switch reply := reply.(type) { + case bool: + return reply, nil + case int64: + return reply != 0, nil + case []byte: + data := string(reply) + if len(data) == 0 { + return false, ErrNil + } + + return strconv.ParseBool(data) + case string: + if len(reply) == 0 { + return false, ErrNil + } + + return strconv.ParseBool(reply) + case nil: + return false, ErrNil + case error: + return false, reply + } + return false, fmt.Errorf("redis cluster: unexpected type for Bool, got type %T", reply) +} + +func Bytes(reply interface{}, err error) ([]byte, error) { + if err != nil { + return nil, err + } + switch reply := reply.(type) { + case []byte: + if len(reply) == 0 { + return nil, ErrNil + } + return reply, nil + case string: + data := []byte(reply) + if len(data) == 0 { + return nil, ErrNil + } + return data, nil + case nil: + return nil, ErrNil + case error: + return nil, reply + } + return nil, fmt.Errorf("redis cluster: unexpected type for Bytes, got type %T", reply) +} + +func String(reply interface{}, err error) (string, error) { + if err != nil { + return "", err + } + + value := "" + err = nil + switch v := reply.(type) { + case string: + if len(v) == 0 { + return "", ErrNil + } + + value = v + case []byte: + if len(v) == 0 { + return "", ErrNil + } + + value = string(v) + case int: + value = strconv.FormatInt(int64(v), 10) + case int8: + value = strconv.FormatInt(int64(v), 10) + case int16: + value = strconv.FormatInt(int64(v), 10) + case int32: + value = strconv.FormatInt(int64(v), 10) + case int64: + value = strconv.FormatInt(v, 10) + case uint: + value = strconv.FormatUint(uint64(v), 10) + case uint8: + value = strconv.FormatUint(uint64(v), 10) + case uint16: + value = strconv.FormatUint(uint64(v), 10) + case uint32: + value = strconv.FormatUint(uint64(v), 10) + case uint64: + value = strconv.FormatUint(v, 10) + case float32: + value = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + value = strconv.FormatFloat(v, 'f', -1, 64) + case bool: + value = strconv.FormatBool(v) + case nil: + err = ErrNil + case error: + err = v + default: + err = fmt.Errorf("redis cluster: unexpected type for String, got type %T", v) + } + + return value, err +} + +func Strings(reply interface{}, err error) ([]string, error) { + if err != nil { + return nil, err + } + switch reply := reply.(type) { + case []interface{}: + result := make([]string, len(reply)) + for i := range reply { + if reply[i] == nil { + continue + } + switch subReply := reply[i].(type) { + case string: + result[i] = subReply + case []byte: + result[i] = string(subReply) + default: + return nil, fmt.Errorf("redis cluster: unexpected element type for String, got type %T", reply[i]) + } + } + return result, nil + case []string: + return reply, nil + case nil: + return nil, ErrNil + case error: + return nil, reply + } + return nil, fmt.Errorf("redis cluster: unexpected type for Strings, got type %T", reply) +} + +func Values(reply interface{}, err error) ([]interface{}, error) { + if err != nil { + return nil, err + } + switch reply := reply.(type) { + case []interface{}: + return reply, nil + case nil: + return nil, ErrNil + case error: + return nil, reply + } + return nil, fmt.Errorf("redis cluster: unexpected type for Values, got type %T", reply) +} diff --git a/utils/cache/cache/cache.go b/utils/cache/cache/cache.go new file mode 100644 index 0000000..e43c5f0 --- /dev/null +++ b/utils/cache/cache/cache.go @@ -0,0 +1,107 @@ +package cache + +import ( + "fmt" + "time" +) + +var c Cache + +type Cache interface { + // get cached value by key. + Get(key string) interface{} + // GetMulti is a batch version of Get. + GetMulti(keys []string) []interface{} + // set cached value with key and expire time. + Put(key string, val interface{}, timeout time.Duration) error + // delete cached value by key. + Delete(key string) error + // increase cached int value by key, as a counter. + Incr(key string) error + // decrease cached int value by key, as a counter. + Decr(key string) error + // check if cached value exists or not. + IsExist(key string) bool + // clear all cache. + ClearAll() error + // start gc routine based on config string settings. + StartAndGC(config string) error +} + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} + +func InitCache(adapterName, config string) (err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + c = instanceFunc() + err = c.StartAndGC(config) + if err != nil { + c = nil + } + return +} + +func Get(key string) interface{} { + return c.Get(key) +} + +func GetMulti(keys []string) []interface{} { + return c.GetMulti(keys) +} +func Put(key string, val interface{}, ttl time.Duration) error { + return c.Put(key, val, ttl) +} +func Delete(key string) error { + return c.Delete(key) +} +func Incr(key string) error { + return c.Incr(key) +} +func Decr(key string) error { + return c.Decr(key) +} +func IsExist(key string) bool { + return c.IsExist(key) +} +func ClearAll() error { + return c.ClearAll() +} +func StartAndGC(cfg string) error { + return c.StartAndGC(cfg) +} diff --git a/utils/cache/cache/conv.go b/utils/cache/cache/conv.go new file mode 100644 index 0000000..6b700ae --- /dev/null +++ b/utils/cache/cache/conv.go @@ -0,0 +1,86 @@ +package cache + +import ( + "fmt" + "strconv" +) + +// GetString convert interface to string. +func GetString(v interface{}) string { + switch result := v.(type) { + case string: + return result + case []byte: + return string(result) + default: + if v != nil { + return fmt.Sprint(result) + } + } + return "" +} + +// GetInt convert interface to int. +func GetInt(v interface{}) int { + switch result := v.(type) { + case int: + return result + case int32: + return int(result) + case int64: + return int(result) + default: + if d := GetString(v); d != "" { + value, _ := strconv.Atoi(d) + return value + } + } + return 0 +} + +// GetInt64 convert interface to int64. +func GetInt64(v interface{}) int64 { + switch result := v.(type) { + case int: + return int64(result) + case int32: + return int64(result) + case int64: + return result + default: + + if d := GetString(v); d != "" { + value, _ := strconv.ParseInt(d, 10, 64) + return value + } + } + return 0 +} + +// GetFloat64 convert interface to float64. +func GetFloat64(v interface{}) float64 { + switch result := v.(type) { + case float64: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseFloat(d, 64) + return value + } + } + return 0 +} + +// GetBool convert interface to bool. +func GetBool(v interface{}) bool { + switch result := v.(type) { + case bool: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseBool(d) + return value + } + } + return false +} diff --git a/utils/cache/cache/file.go b/utils/cache/cache/file.go new file mode 100644 index 0000000..5c4e366 --- /dev/null +++ b/utils/cache/cache/file.go @@ -0,0 +1,241 @@ +package cache + +import ( + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "strconv" + "time" +) + +// FileCacheItem is basic unit of file cache adapter. +// it contains data and expire time. +type FileCacheItem struct { + Data interface{} + LastAccess time.Time + Expired time.Time +} + +// FileCache Config +var ( + FileCachePath = "cache" // cache directory + FileCacheFileSuffix = ".bin" // cache file suffix + FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. + FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. +) + +// FileCache is cache adapter for file storage. +type FileCache struct { + CachePath string + FileSuffix string + DirectoryLevel int + EmbedExpiry int +} + +// NewFileCache Create new file cache with no config. +// the level and expiry need set in method StartAndGC as config string. +func NewFileCache() Cache { + // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} + return &FileCache{} +} + +// StartAndGC will start and begin gc for file cache. +// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0} +func (fc *FileCache) StartAndGC(config string) error { + + var cfg map[string]string + json.Unmarshal([]byte(config), &cfg) + if _, ok := cfg["CachePath"]; !ok { + cfg["CachePath"] = FileCachePath + } + if _, ok := cfg["FileSuffix"]; !ok { + cfg["FileSuffix"] = FileCacheFileSuffix + } + if _, ok := cfg["DirectoryLevel"]; !ok { + cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) + } + if _, ok := cfg["EmbedExpiry"]; !ok { + cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) + } + fc.CachePath = cfg["CachePath"] + fc.FileSuffix = cfg["FileSuffix"] + fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) + fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) + + fc.Init() + return nil +} + +// Init will make new dir for file cache if not exist. +func (fc *FileCache) Init() { + if ok, _ := exists(fc.CachePath); !ok { // todo : error handle + _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle + } +} + +// get cached file name. it's md5 encoded. +func (fc *FileCache) getCacheFileName(key string) string { + m := md5.New() + io.WriteString(m, key) + keyMd5 := hex.EncodeToString(m.Sum(nil)) + cachePath := fc.CachePath + switch fc.DirectoryLevel { + case 2: + cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4]) + case 1: + cachePath = filepath.Join(cachePath, keyMd5[0:2]) + } + + if ok, _ := exists(cachePath); !ok { // todo : error handle + _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle + } + + return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) +} + +// Get value from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) Get(key string) interface{} { + fileData, err := FileGetContents(fc.getCacheFileName(key)) + if err != nil { + return "" + } + var to FileCacheItem + GobDecode(fileData, &to) + if to.Expired.Before(time.Now()) { + return "" + } + return to.Data +} + +// GetMulti gets values from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) GetMulti(keys []string) []interface{} { + var rc []interface{} + for _, key := range keys { + rc = append(rc, fc.Get(key)) + } + return rc +} + +// Put value into file cache. +// timeout means how long to keep this file, unit of ms. +// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever. +func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { + gob.Register(val) + + item := FileCacheItem{Data: val} + if timeout == FileCacheEmbedExpiry { + item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years + } else { + item.Expired = time.Now().Add(timeout) + } + item.LastAccess = time.Now() + data, err := GobEncode(item) + if err != nil { + return err + } + return FilePutContents(fc.getCacheFileName(key), data) +} + +// Delete file cache value. +func (fc *FileCache) Delete(key string) error { + filename := fc.getCacheFileName(key) + if ok, _ := exists(filename); ok { + return os.Remove(filename) + } + return nil +} + +// Incr will increase cached int value. +// fc value is saving forever unless Delete. +func (fc *FileCache) Incr(key string) error { + data := fc.Get(key) + var incr int + if reflect.TypeOf(data).Name() != "int" { + incr = 0 + } else { + incr = data.(int) + 1 + } + fc.Put(key, incr, FileCacheEmbedExpiry) + return nil +} + +// Decr will decrease cached int value. +func (fc *FileCache) Decr(key string) error { + data := fc.Get(key) + var decr int + if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { + decr = 0 + } else { + decr = data.(int) - 1 + } + fc.Put(key, decr, FileCacheEmbedExpiry) + return nil +} + +// IsExist check value is exist. +func (fc *FileCache) IsExist(key string) bool { + ret, _ := exists(fc.getCacheFileName(key)) + return ret +} + +// ClearAll will clean cached files. +// not implemented. +func (fc *FileCache) ClearAll() error { + return nil +} + +// check file exist. +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +// FileGetContents Get bytes to file. +// if non-exist, create this file. +func FileGetContents(filename string) (data []byte, e error) { + return ioutil.ReadFile(filename) +} + +// FilePutContents Put bytes to file. +// if non-exist, create this file. +func FilePutContents(filename string, content []byte) error { + return ioutil.WriteFile(filename, content, os.ModePerm) +} + +// GobEncode Gob encodes file cache item. +func GobEncode(data interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(data) + if err != nil { + return nil, err + } + return buf.Bytes(), err +} + +// GobDecode Gob decodes file cache item. +func GobDecode(data []byte, to *FileCacheItem) error { + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + return dec.Decode(&to) +} + +func init() { + Register("file", NewFileCache) +} diff --git a/utils/cache/cache/memory.go b/utils/cache/cache/memory.go new file mode 100644 index 0000000..0cc5015 --- /dev/null +++ b/utils/cache/cache/memory.go @@ -0,0 +1,239 @@ +package cache + +import ( + "encoding/json" + "errors" + "sync" + "time" +) + +var ( + // DefaultEvery means the clock time of recycling the expired cache items in memory. + DefaultEvery = 60 // 1 minute +) + +// MemoryItem store memory cache item. +type MemoryItem struct { + val interface{} + createdTime time.Time + lifespan time.Duration +} + +func (mi *MemoryItem) isExpire() bool { + // 0 means forever + if mi.lifespan == 0 { + return false + } + return time.Now().Sub(mi.createdTime) > mi.lifespan +} + +// MemoryCache is Memory cache adapter. +// it contains a RW locker for safe map storage. +type MemoryCache struct { + sync.RWMutex + dur time.Duration + items map[string]*MemoryItem + Every int // run an expiration check Every clock time +} + +// NewMemoryCache returns a new MemoryCache. +func NewMemoryCache() Cache { + cache := MemoryCache{items: make(map[string]*MemoryItem)} + return &cache +} + +// Get cache from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) Get(name string) interface{} { + bc.RLock() + defer bc.RUnlock() + if itm, ok := bc.items[name]; ok { + if itm.isExpire() { + return nil + } + return itm.val + } + return nil +} + +// GetMulti gets caches from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) GetMulti(names []string) []interface{} { + var rc []interface{} + for _, name := range names { + rc = append(rc, bc.Get(name)) + } + return rc +} + +// Put cache to memory. +// if lifespan is 0, it will be forever till restart. +func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { + bc.Lock() + defer bc.Unlock() + bc.items[name] = &MemoryItem{ + val: value, + createdTime: time.Now(), + lifespan: lifespan, + } + return nil +} + +// Delete cache in memory. +func (bc *MemoryCache) Delete(name string) error { + bc.Lock() + defer bc.Unlock() + if _, ok := bc.items[name]; !ok { + return errors.New("key not exist") + } + delete(bc.items, name) + if _, ok := bc.items[name]; ok { + return errors.New("delete key error") + } + return nil +} + +// Incr increase cache counter in memory. +// it supports int,int32,int64,uint,uint32,uint64. +func (bc *MemoryCache) Incr(key string) error { + bc.RLock() + defer bc.RUnlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch itm.val.(type) { + case int: + itm.val = itm.val.(int) + 1 + case int32: + itm.val = itm.val.(int32) + 1 + case int64: + itm.val = itm.val.(int64) + 1 + case uint: + itm.val = itm.val.(uint) + 1 + case uint32: + itm.val = itm.val.(uint32) + 1 + case uint64: + itm.val = itm.val.(uint64) + 1 + default: + return errors.New("item val is not (u)int (u)int32 (u)int64") + } + return nil +} + +// Decr decrease counter in memory. +func (bc *MemoryCache) Decr(key string) error { + bc.RLock() + defer bc.RUnlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch itm.val.(type) { + case int: + itm.val = itm.val.(int) - 1 + case int64: + itm.val = itm.val.(int64) - 1 + case int32: + itm.val = itm.val.(int32) - 1 + case uint: + if itm.val.(uint) > 0 { + itm.val = itm.val.(uint) - 1 + } else { + return errors.New("item val is less than 0") + } + case uint32: + if itm.val.(uint32) > 0 { + itm.val = itm.val.(uint32) - 1 + } else { + return errors.New("item val is less than 0") + } + case uint64: + if itm.val.(uint64) > 0 { + itm.val = itm.val.(uint64) - 1 + } else { + return errors.New("item val is less than 0") + } + default: + return errors.New("item val is not int int64 int32") + } + return nil +} + +// IsExist check cache exist in memory. +func (bc *MemoryCache) IsExist(name string) bool { + bc.RLock() + defer bc.RUnlock() + if v, ok := bc.items[name]; ok { + return !v.isExpire() + } + return false +} + +// ClearAll will delete all cache in memory. +func (bc *MemoryCache) ClearAll() error { + bc.Lock() + defer bc.Unlock() + bc.items = make(map[string]*MemoryItem) + return nil +} + +// StartAndGC start memory cache. it will check expiration in every clock time. +func (bc *MemoryCache) StartAndGC(config string) error { + var cf map[string]int + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["interval"]; !ok { + cf = make(map[string]int) + cf["interval"] = DefaultEvery + } + dur := time.Duration(cf["interval"]) * time.Second + bc.Every = cf["interval"] + bc.dur = dur + go bc.vacuum() + return nil +} + +// check expiration. +func (bc *MemoryCache) vacuum() { + bc.RLock() + every := bc.Every + bc.RUnlock() + + if every < 1 { + return + } + for { + <-time.After(bc.dur) + if bc.items == nil { + return + } + if keys := bc.expiredKeys(); len(keys) != 0 { + bc.clearItems(keys) + } + } +} + +// expiredKeys returns key list which are expired. +func (bc *MemoryCache) expiredKeys() (keys []string) { + bc.RLock() + defer bc.RUnlock() + for key, itm := range bc.items { + if itm.isExpire() { + keys = append(keys, key) + } + } + return +} + +// clearItems removes all the items which key in keys. +func (bc *MemoryCache) clearItems(keys []string) { + bc.Lock() + defer bc.Unlock() + for _, key := range keys { + delete(bc.items, key) + } +} + +func init() { + Register("memory", NewMemoryCache) +} diff --git a/utils/cache/redis.go b/utils/cache/redis.go new file mode 100644 index 0000000..4e5f047 --- /dev/null +++ b/utils/cache/redis.go @@ -0,0 +1,403 @@ +package cache + +import ( + "encoding/json" + "errors" + "log" + "strings" + "time" + + redigo "github.com/gomodule/redigo/redis" +) + +// configuration +type Config struct { + Server string + Password string + MaxIdle int // Maximum number of idle connections in the pool. + + // Maximum number of connections allocated by the pool at a given time. + // When zero, there is no limit on the number of connections in the pool. + MaxActive int + + // Close connections after remaining idle for this duration. If the value + // is zero, then idle connections are not closed. Applications should set + // the timeout to a value less than the server's timeout. + IdleTimeout time.Duration + + // If Wait is true and the pool is at the MaxActive limit, then Get() waits + // for a connection to be returned to the pool before returning. + Wait bool + KeyPrefix string // prefix to all keys; example is "dev environment name" + KeyDelimiter string // delimiter to be used while appending keys; example is ":" + KeyPlaceholder string // placeholder to be parsed using given arguments to obtain a final key; example is "?" +} + +var pool *redigo.Pool +var conf *Config + +func NewRedis(addr string) { + if addr == "" { + panic("\nredis connect string cannot be empty\n") + } + pool = &redigo.Pool{ + MaxIdle: redisMaxIdleConn, + IdleTimeout: redisIdleTTL, + MaxActive: redisMaxActive, + // MaxConnLifetime: redisDialTTL, + Wait: true, + Dial: func() (redigo.Conn, error) { + c, err := redigo.Dial("tcp", addr, + redigo.DialConnectTimeout(redisDialTTL), + redigo.DialReadTimeout(redisReadTTL), + redigo.DialWriteTimeout(redisWriteTTL), + ) + if err != nil { + log.Println("Redis Dial failed: ", err) + return nil, err + } + return c, err + }, + TestOnBorrow: func(c redigo.Conn, t time.Time) error { + _, err := c.Do("PING") + if err != nil { + log.Println("Unable to ping to redis server:", err) + } + return err + }, + } + conn := pool.Get() + defer conn.Close() + if conn.Err() != nil { + println("\nredis connect " + addr + " error: " + conn.Err().Error()) + } else { + println("\nredis connect " + addr + " success!\n") + } +} + +func Do(cmd string, args ...interface{}) (reply interface{}, err error) { + conn := pool.Get() + defer conn.Close() + return conn.Do(cmd, args...) +} + +func GetPool() *redigo.Pool { + return pool +} + +func ParseKey(key string, vars []string) (string, error) { + arr := strings.Split(key, conf.KeyPlaceholder) + actualKey := "" + if len(arr) != len(vars)+1 { + return "", errors.New("redis/connection.go: Insufficient arguments to parse key") + } else { + for index, val := range arr { + if index == 0 { + actualKey = arr[index] + } else { + actualKey += vars[index-1] + val + } + } + } + return getPrefixedKey(actualKey), nil +} + +func getPrefixedKey(key string) string { + return conf.KeyPrefix + conf.KeyDelimiter + key +} +func StripEnvKey(key string) string { + return strings.TrimLeft(key, conf.KeyPrefix+conf.KeyDelimiter) +} +func SplitKey(key string) []string { + return strings.Split(key, conf.KeyDelimiter) +} +func Expire(key string, ttl int) (interface{}, error) { + return Do("EXPIRE", key, ttl) +} +func Persist(key string) (interface{}, error) { + return Do("PERSIST", key) +} + +func Del(key string) (interface{}, error) { + return Do("DEL", key) +} +func Set(key string, data interface{}) (interface{}, error) { + // set + return Do("SET", key, data) +} +func SetNX(key string, data interface{}) (interface{}, error) { + return Do("SETNX", key, data) +} +func SetEx(key string, data interface{}, ttl int) (interface{}, error) { + return Do("SETEX", key, ttl, data) +} + +func SetJson(key string, data interface{}, ttl int) bool { + c, err := json.Marshal(data) + if err != nil { + return false + } + if ttl < 1 { + _, err = Set(key, c) + } else { + _, err = SetEx(key, c, ttl) + } + if err != nil { + return false + } + return true +} + +func GetJson(key string, dst interface{}) error { + b, err := GetBytes(key) + if err != nil { + return err + } + if err = json.Unmarshal(b, dst); err != nil { + return err + } + return nil +} + +func Get(key string) (interface{}, error) { + // get + return Do("GET", key) +} +func GetTTL(key string) (time.Duration, error) { + ttl, err := redigo.Int64(Do("TTL", key)) + return time.Duration(ttl) * time.Second, err +} +func GetBytes(key string) ([]byte, error) { + return redigo.Bytes(Do("GET", key)) +} +func GetString(key string) (string, error) { + return redigo.String(Do("GET", key)) +} +func GetStringMap(key string) (map[string]string, error) { + return redigo.StringMap(Do("GET", key)) +} +func GetInt(key string) (int, error) { + return redigo.Int(Do("GET", key)) +} +func GetInt64(key string) (int64, error) { + return redigo.Int64(Do("GET", key)) +} +func GetStringLength(key string) (int, error) { + return redigo.Int(Do("STRLEN", key)) +} +func ZAdd(key string, score float64, data interface{}) (interface{}, error) { + return Do("ZADD", key, score, data) +} +func ZAddNX(key string, score float64, data interface{}) (interface{}, error) { + return Do("ZADD", key, "NX", score, data) +} +func ZRem(key string, data interface{}) (interface{}, error) { + return Do("ZREM", key, data) +} +func ZRange(key string, start int, end int, withScores bool) ([]interface{}, error) { + if withScores { + return redigo.Values(Do("ZRANGE", key, start, end, "WITHSCORES")) + } + return redigo.Values(Do("ZRANGE", key, start, end)) +} +func ZRemRangeByScore(key string, start int64, end int64) ([]interface{}, error) { + return redigo.Values(Do("ZREMRANGEBYSCORE", key, start, end)) +} +func ZCard(setName string) (int64, error) { + return redigo.Int64(Do("ZCARD", setName)) +} +func ZScan(setName string) (int64, error) { + return redigo.Int64(Do("ZCARD", setName)) +} +func SAdd(setName string, data interface{}) (interface{}, error) { + return Do("SADD", setName, data) +} +func SCard(setName string) (int64, error) { + return redigo.Int64(Do("SCARD", setName)) +} +func SIsMember(setName string, data interface{}) (bool, error) { + return redigo.Bool(Do("SISMEMBER", setName, data)) +} +func SMembers(setName string) ([]string, error) { + return redigo.Strings(Do("SMEMBERS", setName)) +} +func SRem(setName string, data interface{}) (interface{}, error) { + return Do("SREM", setName, data) +} +func HSet(key string, HKey string, data interface{}) (interface{}, error) { + return Do("HSET", key, HKey, data) +} + +func HGet(key string, HKey string) (interface{}, error) { + return Do("HGET", key, HKey) +} + +func HMGet(key string, hashKeys ...string) ([]interface{}, error) { + ret, err := Do("HMGET", key, hashKeys) + if err != nil { + return nil, err + } + reta, ok := ret.([]interface{}) + if !ok { + return nil, errors.New("result not an array") + } + return reta, nil +} + +func HMSet(key string, hashKeys []string, vals []interface{}) (interface{}, error) { + if len(hashKeys) == 0 || len(hashKeys) != len(vals) { + var ret interface{} + return ret, errors.New("bad length") + } + input := []interface{}{key} + for i, v := range hashKeys { + input = append(input, v, vals[i]) + } + return Do("HMSET", input...) +} + +func HGetString(key string, HKey string) (string, error) { + return redigo.String(Do("HGET", key, HKey)) +} +func HGetFloat(key string, HKey string) (float64, error) { + f, err := redigo.Float64(Do("HGET", key, HKey)) + return f, err +} +func HGetInt(key string, HKey string) (int, error) { + return redigo.Int(Do("HGET", key, HKey)) +} +func HGetInt64(key string, HKey string) (int64, error) { + return redigo.Int64(Do("HGET", key, HKey)) +} +func HGetBool(key string, HKey string) (bool, error) { + return redigo.Bool(Do("HGET", key, HKey)) +} +func HDel(key string, HKey string) (interface{}, error) { + return Do("HDEL", key, HKey) +} + +func HGetAll(key string) (map[string]interface{}, error) { + vals, err := redigo.Values(Do("HGETALL", key)) + if err != nil { + return nil, err + } + num := len(vals) / 2 + result := make(map[string]interface{}, num) + for i := 0; i < num; i++ { + key, _ := redigo.String(vals[2*i], nil) + result[key] = vals[2*i+1] + } + return result, nil +} + +func FlushAll() bool { + res, _ := redigo.String(Do("FLUSHALL")) + if res == "" { + return false + } + return true +} + +// NOTE: Use this in production environment with extreme care. +// Read more here:https://redigo.io/commands/keys +func Keys(pattern string) ([]string, error) { + return redigo.Strings(Do("KEYS", pattern)) +} + +func HKeys(key string) ([]string, error) { + return redigo.Strings(Do("HKEYS", key)) +} + +func Exists(key string) bool { + count, err := redigo.Int(Do("EXISTS", key)) + if count == 0 || err != nil { + return false + } + return true +} + +func Incr(key string) (int64, error) { + return redigo.Int64(Do("INCR", key)) +} + +func Decr(key string) (int64, error) { + return redigo.Int64(Do("DECR", key)) +} + +func IncrBy(key string, incBy int64) (int64, error) { + return redigo.Int64(Do("INCRBY", key, incBy)) +} + +func DecrBy(key string, decrBy int64) (int64, error) { + return redigo.Int64(Do("DECRBY", key)) +} + +func IncrByFloat(key string, incBy float64) (float64, error) { + return redigo.Float64(Do("INCRBYFLOAT", key, incBy)) +} + +func DecrByFloat(key string, decrBy float64) (float64, error) { + return redigo.Float64(Do("DECRBYFLOAT", key, decrBy)) +} + +// use for message queue +func LPush(key string, data interface{}) (interface{}, error) { + // set + return Do("LPUSH", key, data) +} + +func LPop(key string) (interface{}, error) { + return Do("LPOP", key) +} + +func LPopString(key string) (string, error) { + return redigo.String(Do("LPOP", key)) +} +func LPopFloat(key string) (float64, error) { + f, err := redigo.Float64(Do("LPOP", key)) + return f, err +} +func LPopInt(key string) (int, error) { + return redigo.Int(Do("LPOP", key)) +} +func LPopInt64(key string) (int64, error) { + return redigo.Int64(Do("LPOP", key)) +} + +func RPush(key string, data interface{}) (interface{}, error) { + // set + return Do("RPUSH", key, data) +} + +func RPop(key string) (interface{}, error) { + return Do("RPOP", key) +} + +func RPopString(key string) (string, error) { + return redigo.String(Do("RPOP", key)) +} +func RPopFloat(key string) (float64, error) { + f, err := redigo.Float64(Do("RPOP", key)) + return f, err +} +func RPopInt(key string) (int, error) { + return redigo.Int(Do("RPOP", key)) +} +func RPopInt64(key string) (int64, error) { + return redigo.Int64(Do("RPOP", key)) +} + +func Scan(cursor int64, pattern string, count int64) (int64, []string, error) { + var items []string + var newCursor int64 + + values, err := redigo.Values(Do("SCAN", cursor, "MATCH", pattern, "COUNT", count)) + if err != nil { + return 0, nil, err + } + values, err = redigo.Scan(values, &newCursor, &items) + if err != nil { + return 0, nil, err + } + return newCursor, items, nil +} diff --git a/utils/cache/redis_cluster.go b/utils/cache/redis_cluster.go new file mode 100644 index 0000000..901f30c --- /dev/null +++ b/utils/cache/redis_cluster.go @@ -0,0 +1,622 @@ +package cache + +import ( + "strconv" + "time" + + "github.com/go-redis/redis" +) + +var pools *redis.ClusterClient + +func NewRedisCluster(addrs []string) error { + opt := &redis.ClusterOptions{ + Addrs: addrs, + PoolSize: redisPoolSize, + PoolTimeout: redisPoolTTL, + IdleTimeout: redisIdleTTL, + DialTimeout: redisDialTTL, + ReadTimeout: redisReadTTL, + WriteTimeout: redisWriteTTL, + } + pools = redis.NewClusterClient(opt) + if err := pools.Ping().Err(); err != nil { + return err + } + return nil +} + +func RCGet(key string) (interface{}, error) { + res, err := pools.Get(key).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func RCSet(key string, value interface{}) error { + err := pools.Set(key, value, 0).Err() + return convertError(err) +} +func RCGetSet(key string, value interface{}) (interface{}, error) { + res, err := pools.GetSet(key, value).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func RCSetNx(key string, value interface{}) (int64, error) { + res, err := pools.SetNX(key, value, 0).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func RCSetEx(key string, value interface{}, timeout int64) error { + _, err := pools.Set(key, value, time.Duration(timeout)*time.Second).Result() + if err != nil { + return convertError(err) + } + return nil +} + +// nil表示成功,ErrNil表示数据库内已经存在这个key,其他表示数据库发生错误 +func RCSetNxEx(key string, value interface{}, timeout int64) error { + res, err := pools.SetNX(key, value, time.Duration(timeout)*time.Second).Result() + if err != nil { + return convertError(err) + } + if res { + return nil + } + return ErrNil +} +func RCMGet(keys ...string) ([]interface{}, error) { + res, err := pools.MGet(keys...).Result() + return res, convertError(err) +} + +// 为确保多个key映射到同一个slot,每个key最好加上hash tag,如:{test} +func RCMSet(kvs map[string]interface{}) error { + pairs := make([]string, 0, len(kvs)*2) + for k, v := range kvs { + val, err := String(v, nil) + if err != nil { + return err + } + pairs = append(pairs, k, val) + } + return convertError(pools.MSet(pairs).Err()) +} + +// 为确保多个key映射到同一个slot,每个key最好加上hash tag,如:{test} +func RCMSetNX(kvs map[string]interface{}) (bool, error) { + pairs := make([]string, 0, len(kvs)*2) + for k, v := range kvs { + val, err := String(v, nil) + if err != nil { + return false, err + } + pairs = append(pairs, k, val) + } + res, err := pools.MSetNX(pairs).Result() + return res, convertError(err) +} +func RCExpireAt(key string, timestamp int64) (int64, error) { + res, err := pools.ExpireAt(key, time.Unix(timestamp, 0)).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func RCDel(keys ...string) (int64, error) { + args := make([]interface{}, 0, len(keys)) + for _, key := range keys { + args = append(args, key) + } + res, err := pools.Del(keys...).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCIncr(key string) (int64, error) { + res, err := pools.Incr(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCIncrBy(key string, delta int64) (int64, error) { + res, err := pools.IncrBy(key, delta).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCExpire(key string, duration int64) (int64, error) { + res, err := pools.Expire(key, time.Duration(duration)*time.Second).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func RCExists(key string) (bool, error) { + res, err := pools.Exists(key).Result() + if err != nil { + return false, convertError(err) + } + if res > 0 { + return true, nil + } + return false, nil +} +func RCHGet(key string, field string) (interface{}, error) { + res, err := pools.HGet(key, field).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func RCHLen(key string) (int64, error) { + res, err := pools.HLen(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCHSet(key string, field string, val interface{}) error { + value, err := String(val, nil) + if err != nil && err != ErrNil { + return err + } + _, err = pools.HSet(key, field, value).Result() + if err != nil { + return convertError(err) + } + return nil +} +func RCHDel(key string, fields ...string) (int64, error) { + args := make([]interface{}, 0, len(fields)+1) + args = append(args, key) + for _, field := range fields { + args = append(args, field) + } + res, err := pools.HDel(key, fields...).Result() + if err != nil { + return 0, convertError(err) + } + return res, nil +} + +func RCHMGet(key string, fields ...string) (interface{}, error) { + args := make([]interface{}, 0, len(fields)+1) + args = append(args, key) + for _, field := range fields { + args = append(args, field) + } + if len(fields) == 0 { + return nil, ErrNil + } + res, err := pools.HMGet(key, fields...).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCHMSet(key string, kvs ...interface{}) error { + if len(kvs) == 0 { + return nil + } + if len(kvs)%2 != 0 { + return ErrWrongArgsNum + } + var err error + v := map[string]interface{}{} // todo change + v["field"], err = String(kvs[0], nil) + if err != nil && err != ErrNil { + return err + } + v["value"], err = String(kvs[1], nil) + if err != nil && err != ErrNil { + return err + } + pairs := make([]string, 0, len(kvs)-2) + if len(kvs) > 2 { + for _, kv := range kvs[2:] { + kvString, err := String(kv, nil) + if err != nil && err != ErrNil { + return err + } + pairs = append(pairs, kvString) + } + } + v["paris"] = pairs + _, err = pools.HMSet(key, v).Result() + if err != nil { + return convertError(err) + } + return nil +} + +func RCHKeys(key string) ([]string, error) { + res, err := pools.HKeys(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCHVals(key string) ([]interface{}, error) { + res, err := pools.HVals(key).Result() + if err != nil { + return nil, convertError(err) + } + rs := make([]interface{}, 0, len(res)) + for _, res := range res { + rs = append(rs, res) + } + return rs, nil +} +func RCHGetAll(key string) (map[string]string, error) { + vals, err := pools.HGetAll(key).Result() + if err != nil { + return nil, convertError(err) + } + return vals, nil +} +func RCHIncrBy(key, field string, delta int64) (int64, error) { + res, err := pools.HIncrBy(key, field, delta).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCZAdd(key string, kvs ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(kvs)+1) + args = append(args, key) + args = append(args, kvs...) + if len(kvs) == 0 { + return 0, nil + } + if len(kvs)%2 != 0 { + return 0, ErrWrongArgsNum + } + zs := make([]redis.Z, len(kvs)/2) + for i := 0; i < len(kvs); i += 2 { + idx := i / 2 + score, err := Float64(kvs[i], nil) + if err != nil && err != ErrNil { + return 0, err + } + zs[idx].Score = score + zs[idx].Member = kvs[i+1] + } + res, err := pools.ZAdd(key, zs...).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCZRem(key string, members ...string) (int64, error) { + args := make([]interface{}, 0, len(members)) + args = append(args, key) + for _, member := range members { + args = append(args, member) + } + res, err := pools.ZRem(key, members).Result() + if err != nil { + return res, convertError(err) + } + return res, err +} + +func RCZRange(key string, min, max int64, withScores bool) (interface{}, error) { + res := make([]interface{}, 0) + if withScores { + zs, err := pools.ZRangeWithScores(key, min, max).Result() + if err != nil { + return nil, convertError(err) + } + for _, z := range zs { + res = append(res, z.Member, strconv.FormatFloat(z.Score, 'f', -1, 64)) + } + } else { + ms, err := pools.ZRange(key, min, max).Result() + if err != nil { + return nil, convertError(err) + } + for _, m := range ms { + res = append(res, m) + } + } + return res, nil +} +func RCZRangeByScoreWithScore(key string, min, max int64) (map[string]int64, error) { + opt := new(redis.ZRangeBy) + opt.Min = strconv.FormatInt(int64(min), 10) + opt.Max = strconv.FormatInt(int64(max), 10) + opt.Count = -1 + opt.Offset = 0 + vals, err := pools.ZRangeByScoreWithScores(key, *opt).Result() + if err != nil { + return nil, convertError(err) + } + res := make(map[string]int64, len(vals)) + for _, val := range vals { + key, err := String(val.Member, nil) + if err != nil && err != ErrNil { + return nil, err + } + res[key] = int64(val.Score) + } + return res, nil +} +func RCLRange(key string, start, stop int64) (interface{}, error) { + res, err := pools.LRange(key, start, stop).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCLSet(key string, index int, value interface{}) error { + err := pools.LSet(key, int64(index), value).Err() + return convertError(err) +} +func RCLLen(key string) (int64, error) { + res, err := pools.LLen(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCLRem(key string, count int, value interface{}) (int, error) { + val, _ := value.(string) + res, err := pools.LRem(key, int64(count), val).Result() + if err != nil { + return int(res), convertError(err) + } + return int(res), nil +} +func RCTTl(key string) (int64, error) { + duration, err := pools.TTL(key).Result() + if err != nil { + return int64(duration.Seconds()), convertError(err) + } + return int64(duration.Seconds()), nil +} +func RCLPop(key string) (interface{}, error) { + res, err := pools.LPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCRPop(key string) (interface{}, error) { + res, err := pools.RPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCBLPop(key string, timeout int) (interface{}, error) { + res, err := pools.BLPop(time.Duration(timeout)*time.Second, key).Result() + if err != nil { + // 兼容redis 2.x + if err == redis.Nil { + return nil, ErrNil + } + return nil, err + } + return res[1], nil +} +func RCBRPop(key string, timeout int) (interface{}, error) { + res, err := pools.BRPop(time.Duration(timeout)*time.Second, key).Result() + if err != nil { + // 兼容redis 2.x + if err == redis.Nil { + return nil, ErrNil + } + return nil, convertError(err) + } + return res[1], nil +} +func RCLPush(key string, value ...interface{}) error { + args := make([]interface{}, 0, len(value)+1) + args = append(args, key) + args = append(args, value...) + vals := make([]string, 0, len(value)) + for _, v := range value { + val, err := String(v, nil) + if err != nil && err != ErrNil { + return err + } + vals = append(vals, val) + } + _, err := pools.LPush(key, vals).Result() // todo ... + if err != nil { + return convertError(err) + } + return nil +} +func RCRPush(key string, value ...interface{}) error { + args := make([]interface{}, 0, len(value)+1) + args = append(args, key) + args = append(args, value...) + vals := make([]string, 0, len(value)) + for _, v := range value { + val, err := String(v, nil) + if err != nil && err != ErrNil { + if err == ErrNil { + continue + } + return err + } + if val == "" { + continue + } + vals = append(vals, val) + } + _, err := pools.RPush(key, vals).Result() // todo ... + if err != nil { + return convertError(err) + } + return nil +} + +// 为确保srcKey跟destKey映射到同一个slot,srcKey和destKey需要加上hash tag,如:{test} +func RCBRPopLPush(srcKey string, destKey string, timeout int) (interface{}, error) { + res, err := pools.BRPopLPush(srcKey, destKey, time.Duration(timeout)*time.Second).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} + +// 为确保srcKey跟destKey映射到同一个slot,srcKey和destKey需要加上hash tag,如:{test} +func RCRPopLPush(srcKey string, destKey string) (interface{}, error) { + res, err := pools.RPopLPush(srcKey, destKey).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCSAdd(key string, members ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(members)+1) + args = append(args, key) + args = append(args, members...) + ms := make([]string, 0, len(members)) + for _, member := range members { + m, err := String(member, nil) + if err != nil && err != ErrNil { + return 0, err + } + ms = append(ms, m) + } + res, err := pools.SAdd(key, ms).Result() // todo ... + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCSPop(key string) ([]byte, error) { + res, err := pools.SPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func RCSIsMember(key string, member interface{}) (bool, error) { + m, _ := member.(string) + res, err := pools.SIsMember(key, m).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCSRem(key string, members ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(members)+1) + args = append(args, key) + args = append(args, members...) + ms := make([]string, 0, len(members)) + for _, member := range members { + m, err := String(member, nil) + if err != nil && err != ErrNil { + return 0, err + } + ms = append(ms, m) + } + res, err := pools.SRem(key, ms).Result() // todo ... + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCSMembers(key string) ([]string, error) { + res, err := pools.SMembers(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCScriptLoad(luaScript string) (interface{}, error) { + res, err := pools.ScriptLoad(luaScript).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCEvalSha(sha1 string, numberKeys int, keysArgs ...interface{}) (interface{}, error) { + vals := make([]interface{}, 0, len(keysArgs)+2) + vals = append(vals, sha1, numberKeys) + vals = append(vals, keysArgs...) + keys := make([]string, 0, numberKeys) + args := make([]string, 0, len(keysArgs)-numberKeys) + for i, value := range keysArgs { + val, err := String(value, nil) + if err != nil && err != ErrNil { + return nil, err + } + if i < numberKeys { + keys = append(keys, val) + } else { + args = append(args, val) + } + } + res, err := pools.EvalSha(sha1, keys, args).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCEval(luaScript string, numberKeys int, keysArgs ...interface{}) (interface{}, error) { + vals := make([]interface{}, 0, len(keysArgs)+2) + vals = append(vals, luaScript, numberKeys) + vals = append(vals, keysArgs...) + keys := make([]string, 0, numberKeys) + args := make([]string, 0, len(keysArgs)-numberKeys) + for i, value := range keysArgs { + val, err := String(value, nil) + if err != nil && err != ErrNil { + return nil, err + } + if i < numberKeys { + keys = append(keys, val) + } else { + args = append(args, val) + } + } + res, err := pools.Eval(luaScript, keys, args).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCGetBit(key string, offset int64) (int64, error) { + res, err := pools.GetBit(key, offset).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCSetBit(key string, offset uint32, value int) (int, error) { + res, err := pools.SetBit(key, int64(offset), value).Result() + return int(res), convertError(err) +} +func RCGetClient() *redis.ClusterClient { + return pools +} +func convertError(err error) error { + if err == redis.Nil { + // 为了兼容redis 2.x,这里不返回 ErrNil,ErrNil在调用redis_cluster_reply函数时才返回 + return nil + } + return err +} diff --git a/utils/cache/redis_pool.go b/utils/cache/redis_pool.go new file mode 100644 index 0000000..ca38b3f --- /dev/null +++ b/utils/cache/redis_pool.go @@ -0,0 +1,324 @@ +package cache + +import ( + "errors" + "log" + "strings" + "time" + + redigo "github.com/gomodule/redigo/redis" +) + +type RedisPool struct { + *redigo.Pool +} + +func NewRedisPool(cfg *Config) *RedisPool { + return &RedisPool{&redigo.Pool{ + MaxIdle: cfg.MaxIdle, + IdleTimeout: cfg.IdleTimeout, + MaxActive: cfg.MaxActive, + Wait: cfg.Wait, + Dial: func() (redigo.Conn, error) { + c, err := redigo.Dial("tcp", cfg.Server) + if err != nil { + log.Println("Redis Dial failed: ", err) + return nil, err + } + if cfg.Password != "" { + if _, err := c.Do("AUTH", cfg.Password); err != nil { + c.Close() + log.Println("Redis AUTH failed: ", err) + return nil, err + } + } + return c, err + }, + TestOnBorrow: func(c redigo.Conn, t time.Time) error { + _, err := c.Do("PING") + if err != nil { + log.Println("Unable to ping to redis server:", err) + } + return err + }, + }} +} + +func (p *RedisPool) Do(cmd string, args ...interface{}) (reply interface{}, err error) { + conn := pool.Get() + defer conn.Close() + return conn.Do(cmd, args...) +} + +func (p *RedisPool) GetPool() *redigo.Pool { + return pool +} + +func (p *RedisPool) ParseKey(key string, vars []string) (string, error) { + arr := strings.Split(key, conf.KeyPlaceholder) + actualKey := "" + if len(arr) != len(vars)+1 { + return "", errors.New("redis/connection.go: Insufficient arguments to parse key") + } else { + for index, val := range arr { + if index == 0 { + actualKey = arr[index] + } else { + actualKey += vars[index-1] + val + } + } + } + return getPrefixedKey(actualKey), nil +} + +func (p *RedisPool) getPrefixedKey(key string) string { + return conf.KeyPrefix + conf.KeyDelimiter + key +} +func (p *RedisPool) StripEnvKey(key string) string { + return strings.TrimLeft(key, conf.KeyPrefix+conf.KeyDelimiter) +} +func (p *RedisPool) SplitKey(key string) []string { + return strings.Split(key, conf.KeyDelimiter) +} +func (p *RedisPool) Expire(key string, ttl int) (interface{}, error) { + return Do("EXPIRE", key, ttl) +} +func (p *RedisPool) Persist(key string) (interface{}, error) { + return Do("PERSIST", key) +} + +func (p *RedisPool) Del(key string) (interface{}, error) { + return Do("DEL", key) +} +func (p *RedisPool) Set(key string, data interface{}) (interface{}, error) { + // set + return Do("SET", key, data) +} +func (p *RedisPool) SetNX(key string, data interface{}) (interface{}, error) { + return Do("SETNX", key, data) +} +func (p *RedisPool) SetEx(key string, data interface{}, ttl int) (interface{}, error) { + return Do("SETEX", key, ttl, data) +} +func (p *RedisPool) Get(key string) (interface{}, error) { + // get + return Do("GET", key) +} +func (p *RedisPool) GetStringMap(key string) (map[string]string, error) { + // get + return redigo.StringMap(Do("GET", key)) +} + +func (p *RedisPool) GetTTL(key string) (time.Duration, error) { + ttl, err := redigo.Int64(Do("TTL", key)) + return time.Duration(ttl) * time.Second, err +} +func (p *RedisPool) GetBytes(key string) ([]byte, error) { + return redigo.Bytes(Do("GET", key)) +} +func (p *RedisPool) GetString(key string) (string, error) { + return redigo.String(Do("GET", key)) +} +func (p *RedisPool) GetInt(key string) (int, error) { + return redigo.Int(Do("GET", key)) +} +func (p *RedisPool) GetStringLength(key string) (int, error) { + return redigo.Int(Do("STRLEN", key)) +} +func (p *RedisPool) ZAdd(key string, score float64, data interface{}) (interface{}, error) { + return Do("ZADD", key, score, data) +} +func (p *RedisPool) ZRem(key string, data interface{}) (interface{}, error) { + return Do("ZREM", key, data) +} +func (p *RedisPool) ZRange(key string, start int, end int, withScores bool) ([]interface{}, error) { + if withScores { + return redigo.Values(Do("ZRANGE", key, start, end, "WITHSCORES")) + } + return redigo.Values(Do("ZRANGE", key, start, end)) +} +func (p *RedisPool) SAdd(setName string, data interface{}) (interface{}, error) { + return Do("SADD", setName, data) +} +func (p *RedisPool) SCard(setName string) (int64, error) { + return redigo.Int64(Do("SCARD", setName)) +} +func (p *RedisPool) SIsMember(setName string, data interface{}) (bool, error) { + return redigo.Bool(Do("SISMEMBER", setName, data)) +} +func (p *RedisPool) SMembers(setName string) ([]string, error) { + return redigo.Strings(Do("SMEMBERS", setName)) +} +func (p *RedisPool) SRem(setName string, data interface{}) (interface{}, error) { + return Do("SREM", setName, data) +} +func (p *RedisPool) HSet(key string, HKey string, data interface{}) (interface{}, error) { + return Do("HSET", key, HKey, data) +} + +func (p *RedisPool) HGet(key string, HKey string) (interface{}, error) { + return Do("HGET", key, HKey) +} + +func (p *RedisPool) HMGet(key string, hashKeys ...string) ([]interface{}, error) { + ret, err := Do("HMGET", key, hashKeys) + if err != nil { + return nil, err + } + reta, ok := ret.([]interface{}) + if !ok { + return nil, errors.New("result not an array") + } + return reta, nil +} + +func (p *RedisPool) HMSet(key string, hashKeys []string, vals []interface{}) (interface{}, error) { + if len(hashKeys) == 0 || len(hashKeys) != len(vals) { + var ret interface{} + return ret, errors.New("bad length") + } + input := []interface{}{key} + for i, v := range hashKeys { + input = append(input, v, vals[i]) + } + return Do("HMSET", input...) +} + +func (p *RedisPool) HGetString(key string, HKey string) (string, error) { + return redigo.String(Do("HGET", key, HKey)) +} +func (p *RedisPool) HGetFloat(key string, HKey string) (float64, error) { + f, err := redigo.Float64(Do("HGET", key, HKey)) + return float64(f), err +} +func (p *RedisPool) HGetInt(key string, HKey string) (int, error) { + return redigo.Int(Do("HGET", key, HKey)) +} +func (p *RedisPool) HGetInt64(key string, HKey string) (int64, error) { + return redigo.Int64(Do("HGET", key, HKey)) +} +func (p *RedisPool) HGetBool(key string, HKey string) (bool, error) { + return redigo.Bool(Do("HGET", key, HKey)) +} +func (p *RedisPool) HDel(key string, HKey string) (interface{}, error) { + return Do("HDEL", key, HKey) +} +func (p *RedisPool) HGetAll(key string) (map[string]interface{}, error) { + vals, err := redigo.Values(Do("HGETALL", key)) + if err != nil { + return nil, err + } + num := len(vals) / 2 + result := make(map[string]interface{}, num) + for i := 0; i < num; i++ { + key, _ := redigo.String(vals[2*i], nil) + result[key] = vals[2*i+1] + } + return result, nil +} + +// NOTE: Use this in production environment with extreme care. +// Read more here:https://redigo.io/commands/keys +func (p *RedisPool) Keys(pattern string) ([]string, error) { + return redigo.Strings(Do("KEYS", pattern)) +} + +func (p *RedisPool) HKeys(key string) ([]string, error) { + return redigo.Strings(Do("HKEYS", key)) +} + +func (p *RedisPool) Exists(key string) (bool, error) { + count, err := redigo.Int(Do("EXISTS", key)) + if count == 0 { + return false, err + } else { + return true, err + } +} + +func (p *RedisPool) Incr(key string) (int64, error) { + return redigo.Int64(Do("INCR", key)) +} + +func (p *RedisPool) Decr(key string) (int64, error) { + return redigo.Int64(Do("DECR", key)) +} + +func (p *RedisPool) IncrBy(key string, incBy int64) (int64, error) { + return redigo.Int64(Do("INCRBY", key, incBy)) +} + +func (p *RedisPool) DecrBy(key string, decrBy int64) (int64, error) { + return redigo.Int64(Do("DECRBY", key)) +} + +func (p *RedisPool) IncrByFloat(key string, incBy float64) (float64, error) { + return redigo.Float64(Do("INCRBYFLOAT", key, incBy)) +} + +func (p *RedisPool) DecrByFloat(key string, decrBy float64) (float64, error) { + return redigo.Float64(Do("DECRBYFLOAT", key, decrBy)) +} + +// use for message queue +func (p *RedisPool) LPush(key string, data interface{}) (interface{}, error) { + // set + return Do("LPUSH", key, data) +} + +func (p *RedisPool) LPop(key string) (interface{}, error) { + return Do("LPOP", key) +} + +func (p *RedisPool) LPopString(key string) (string, error) { + return redigo.String(Do("LPOP", key)) +} +func (p *RedisPool) LPopFloat(key string) (float64, error) { + f, err := redigo.Float64(Do("LPOP", key)) + return float64(f), err +} +func (p *RedisPool) LPopInt(key string) (int, error) { + return redigo.Int(Do("LPOP", key)) +} +func (p *RedisPool) LPopInt64(key string) (int64, error) { + return redigo.Int64(Do("LPOP", key)) +} + +func (p *RedisPool) RPush(key string, data interface{}) (interface{}, error) { + // set + return Do("RPUSH", key, data) +} + +func (p *RedisPool) RPop(key string) (interface{}, error) { + return Do("RPOP", key) +} + +func (p *RedisPool) RPopString(key string) (string, error) { + return redigo.String(Do("RPOP", key)) +} +func (p *RedisPool) RPopFloat(key string) (float64, error) { + f, err := redigo.Float64(Do("RPOP", key)) + return float64(f), err +} +func (p *RedisPool) RPopInt(key string) (int, error) { + return redigo.Int(Do("RPOP", key)) +} +func (p *RedisPool) RPopInt64(key string) (int64, error) { + return redigo.Int64(Do("RPOP", key)) +} + +func (p *RedisPool) Scan(cursor int64, pattern string, count int64) (int64, []string, error) { + var items []string + var newCursor int64 + + values, err := redigo.Values(Do("SCAN", cursor, "MATCH", pattern, "COUNT", count)) + if err != nil { + return 0, nil, err + } + values, err = redigo.Scan(values, &newCursor, &items) + if err != nil { + return 0, nil, err + } + + return newCursor, items, nil +} diff --git a/utils/cache/redis_pool_cluster.go b/utils/cache/redis_pool_cluster.go new file mode 100644 index 0000000..cd1911b --- /dev/null +++ b/utils/cache/redis_pool_cluster.go @@ -0,0 +1,617 @@ +package cache + +import ( + "strconv" + "time" + + "github.com/go-redis/redis" +) + +type RedisClusterPool struct { + client *redis.ClusterClient +} + +func NewRedisClusterPool(addrs []string) (*RedisClusterPool, error) { + opt := &redis.ClusterOptions{ + Addrs: addrs, + PoolSize: 512, + PoolTimeout: 10 * time.Second, + IdleTimeout: 10 * time.Second, + DialTimeout: 10 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 3 * time.Second, + } + c := redis.NewClusterClient(opt) + if err := c.Ping().Err(); err != nil { + return nil, err + } + return &RedisClusterPool{client: c}, nil +} + +func (p *RedisClusterPool) Get(key string) (interface{}, error) { + res, err := p.client.Get(key).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func (p *RedisClusterPool) Set(key string, value interface{}) error { + err := p.client.Set(key, value, 0).Err() + return convertError(err) +} +func (p *RedisClusterPool) GetSet(key string, value interface{}) (interface{}, error) { + res, err := p.client.GetSet(key, value).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func (p *RedisClusterPool) SetNx(key string, value interface{}) (int64, error) { + res, err := p.client.SetNX(key, value, 0).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func (p *RedisClusterPool) SetEx(key string, value interface{}, timeout int64) error { + _, err := p.client.Set(key, value, time.Duration(timeout)*time.Second).Result() + if err != nil { + return convertError(err) + } + return nil +} + +// nil表示成功,ErrNil表示数据库内已经存在这个key,其他表示数据库发生错误 +func (p *RedisClusterPool) SetNxEx(key string, value interface{}, timeout int64) error { + res, err := p.client.SetNX(key, value, time.Duration(timeout)*time.Second).Result() + if err != nil { + return convertError(err) + } + if res { + return nil + } + return ErrNil +} +func (p *RedisClusterPool) MGet(keys ...string) ([]interface{}, error) { + res, err := p.client.MGet(keys...).Result() + return res, convertError(err) +} + +// 为确保多个key映射到同一个slot,每个key最好加上hash tag,如:{test} +func (p *RedisClusterPool) MSet(kvs map[string]interface{}) error { + pairs := make([]string, 0, len(kvs)*2) + for k, v := range kvs { + val, err := String(v, nil) + if err != nil { + return err + } + pairs = append(pairs, k, val) + } + return convertError(p.client.MSet(pairs).Err()) +} + +// 为确保多个key映射到同一个slot,每个key最好加上hash tag,如:{test} +func (p *RedisClusterPool) MSetNX(kvs map[string]interface{}) (bool, error) { + pairs := make([]string, 0, len(kvs)*2) + for k, v := range kvs { + val, err := String(v, nil) + if err != nil { + return false, err + } + pairs = append(pairs, k, val) + } + res, err := p.client.MSetNX(pairs).Result() + return res, convertError(err) +} +func (p *RedisClusterPool) ExpireAt(key string, timestamp int64) (int64, error) { + res, err := p.client.ExpireAt(key, time.Unix(timestamp, 0)).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func (p *RedisClusterPool) Del(keys ...string) (int64, error) { + args := make([]interface{}, 0, len(keys)) + for _, key := range keys { + args = append(args, key) + } + res, err := p.client.Del(keys...).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) Incr(key string) (int64, error) { + res, err := p.client.Incr(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) IncrBy(key string, delta int64) (int64, error) { + res, err := p.client.IncrBy(key, delta).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) Expire(key string, duration int64) (int64, error) { + res, err := p.client.Expire(key, time.Duration(duration)*time.Second).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func (p *RedisClusterPool) Exists(key string) (bool, error) { // todo (bool, error) + res, err := p.client.Exists(key).Result() + if err != nil { + return false, convertError(err) + } + if res > 0 { + return true, nil + } + return false, nil +} +func (p *RedisClusterPool) HGet(key string, field string) (interface{}, error) { + res, err := p.client.HGet(key, field).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func (p *RedisClusterPool) HLen(key string) (int64, error) { + res, err := p.client.HLen(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) HSet(key string, field string, val interface{}) error { + value, err := String(val, nil) + if err != nil && err != ErrNil { + return err + } + _, err = p.client.HSet(key, field, value).Result() + if err != nil { + return convertError(err) + } + return nil +} +func (p *RedisClusterPool) HDel(key string, fields ...string) (int64, error) { + args := make([]interface{}, 0, len(fields)+1) + args = append(args, key) + for _, field := range fields { + args = append(args, field) + } + res, err := p.client.HDel(key, fields...).Result() + if err != nil { + return 0, convertError(err) + } + return res, nil +} + +func (p *RedisClusterPool) HMGet(key string, fields ...string) (interface{}, error) { + args := make([]interface{}, 0, len(fields)+1) + args = append(args, key) + for _, field := range fields { + args = append(args, field) + } + if len(fields) == 0 { + return nil, ErrNil + } + res, err := p.client.HMGet(key, fields...).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) HMSet(key string, kvs ...interface{}) error { + if len(kvs) == 0 { + return nil + } + if len(kvs)%2 != 0 { + return ErrWrongArgsNum + } + var err error + v := map[string]interface{}{} // todo change + v["field"], err = String(kvs[0], nil) + if err != nil && err != ErrNil { + return err + } + v["value"], err = String(kvs[1], nil) + if err != nil && err != ErrNil { + return err + } + pairs := make([]string, 0, len(kvs)-2) + if len(kvs) > 2 { + for _, kv := range kvs[2:] { + kvString, err := String(kv, nil) + if err != nil && err != ErrNil { + return err + } + pairs = append(pairs, kvString) + } + } + v["paris"] = pairs + _, err = p.client.HMSet(key, v).Result() + if err != nil { + return convertError(err) + } + return nil +} + +func (p *RedisClusterPool) HKeys(key string) ([]string, error) { + res, err := p.client.HKeys(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) HVals(key string) ([]interface{}, error) { + res, err := p.client.HVals(key).Result() + if err != nil { + return nil, convertError(err) + } + rs := make([]interface{}, 0, len(res)) + for _, res := range res { + rs = append(rs, res) + } + return rs, nil +} +func (p *RedisClusterPool) HGetAll(key string) (map[string]string, error) { + vals, err := p.client.HGetAll(key).Result() + if err != nil { + return nil, convertError(err) + } + return vals, nil +} +func (p *RedisClusterPool) HIncrBy(key, field string, delta int64) (int64, error) { + res, err := p.client.HIncrBy(key, field, delta).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) ZAdd(key string, kvs ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(kvs)+1) + args = append(args, key) + args = append(args, kvs...) + if len(kvs) == 0 { + return 0, nil + } + if len(kvs)%2 != 0 { + return 0, ErrWrongArgsNum + } + zs := make([]redis.Z, len(kvs)/2) + for i := 0; i < len(kvs); i += 2 { + idx := i / 2 + score, err := Float64(kvs[i], nil) + if err != nil && err != ErrNil { + return 0, err + } + zs[idx].Score = score + zs[idx].Member = kvs[i+1] + } + res, err := p.client.ZAdd(key, zs...).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) ZRem(key string, members ...string) (int64, error) { + args := make([]interface{}, 0, len(members)) + args = append(args, key) + for _, member := range members { + args = append(args, member) + } + res, err := p.client.ZRem(key, members).Result() + if err != nil { + return res, convertError(err) + } + return res, err +} + +func (p *RedisClusterPool) ZRange(key string, min, max int64, withScores bool) (interface{}, error) { + res := make([]interface{}, 0) + if withScores { + zs, err := p.client.ZRangeWithScores(key, min, max).Result() + if err != nil { + return nil, convertError(err) + } + for _, z := range zs { + res = append(res, z.Member, strconv.FormatFloat(z.Score, 'f', -1, 64)) + } + } else { + ms, err := p.client.ZRange(key, min, max).Result() + if err != nil { + return nil, convertError(err) + } + for _, m := range ms { + res = append(res, m) + } + } + return res, nil +} +func (p *RedisClusterPool) ZRangeByScoreWithScore(key string, min, max int64) (map[string]int64, error) { + opt := new(redis.ZRangeBy) + opt.Min = strconv.FormatInt(int64(min), 10) + opt.Max = strconv.FormatInt(int64(max), 10) + opt.Count = -1 + opt.Offset = 0 + vals, err := p.client.ZRangeByScoreWithScores(key, *opt).Result() + if err != nil { + return nil, convertError(err) + } + res := make(map[string]int64, len(vals)) + for _, val := range vals { + key, err := String(val.Member, nil) + if err != nil && err != ErrNil { + return nil, err + } + res[key] = int64(val.Score) + } + return res, nil +} +func (p *RedisClusterPool) LRange(key string, start, stop int64) (interface{}, error) { + res, err := p.client.LRange(key, start, stop).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) LSet(key string, index int, value interface{}) error { + err := p.client.LSet(key, int64(index), value).Err() + return convertError(err) +} +func (p *RedisClusterPool) LLen(key string) (int64, error) { + res, err := p.client.LLen(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) LRem(key string, count int, value interface{}) (int, error) { + val, _ := value.(string) + res, err := p.client.LRem(key, int64(count), val).Result() + if err != nil { + return int(res), convertError(err) + } + return int(res), nil +} +func (p *RedisClusterPool) TTl(key string) (int64, error) { + duration, err := p.client.TTL(key).Result() + if err != nil { + return int64(duration.Seconds()), convertError(err) + } + return int64(duration.Seconds()), nil +} +func (p *RedisClusterPool) LPop(key string) (interface{}, error) { + res, err := p.client.LPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) RPop(key string) (interface{}, error) { + res, err := p.client.RPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) BLPop(key string, timeout int) (interface{}, error) { + res, err := p.client.BLPop(time.Duration(timeout)*time.Second, key).Result() + if err != nil { + // 兼容redis 2.x + if err == redis.Nil { + return nil, ErrNil + } + return nil, err + } + return res[1], nil +} +func (p *RedisClusterPool) BRPop(key string, timeout int) (interface{}, error) { + res, err := p.client.BRPop(time.Duration(timeout)*time.Second, key).Result() + if err != nil { + // 兼容redis 2.x + if err == redis.Nil { + return nil, ErrNil + } + return nil, convertError(err) + } + return res[1], nil +} +func (p *RedisClusterPool) LPush(key string, value ...interface{}) error { + args := make([]interface{}, 0, len(value)+1) + args = append(args, key) + args = append(args, value...) + vals := make([]string, 0, len(value)) + for _, v := range value { + val, err := String(v, nil) + if err != nil && err != ErrNil { + return err + } + vals = append(vals, val) + } + _, err := p.client.LPush(key, vals).Result() // todo ... + if err != nil { + return convertError(err) + } + return nil +} +func (p *RedisClusterPool) RPush(key string, value ...interface{}) error { + args := make([]interface{}, 0, len(value)+1) + args = append(args, key) + args = append(args, value...) + vals := make([]string, 0, len(value)) + for _, v := range value { + val, err := String(v, nil) + if err != nil && err != ErrNil { + if err == ErrNil { + continue + } + return err + } + if val == "" { + continue + } + vals = append(vals, val) + } + _, err := p.client.RPush(key, vals).Result() // todo ... + if err != nil { + return convertError(err) + } + return nil +} + +// 为确保srcKey跟destKey映射到同一个slot,srcKey和destKey需要加上hash tag,如:{test} +func (p *RedisClusterPool) BRPopLPush(srcKey string, destKey string, timeout int) (interface{}, error) { + res, err := p.client.BRPopLPush(srcKey, destKey, time.Duration(timeout)*time.Second).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} + +// 为确保srcKey跟destKey映射到同一个slot,srcKey和destKey需要加上hash tag,如:{test} +func (p *RedisClusterPool) RPopLPush(srcKey string, destKey string) (interface{}, error) { + res, err := p.client.RPopLPush(srcKey, destKey).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SAdd(key string, members ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(members)+1) + args = append(args, key) + args = append(args, members...) + ms := make([]string, 0, len(members)) + for _, member := range members { + m, err := String(member, nil) + if err != nil && err != ErrNil { + return 0, err + } + ms = append(ms, m) + } + res, err := p.client.SAdd(key, ms).Result() // todo ... + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SPop(key string) ([]byte, error) { + res, err := p.client.SPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func (p *RedisClusterPool) SIsMember(key string, member interface{}) (bool, error) { + m, _ := member.(string) + res, err := p.client.SIsMember(key, m).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SRem(key string, members ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(members)+1) + args = append(args, key) + args = append(args, members...) + ms := make([]string, 0, len(members)) + for _, member := range members { + m, err := String(member, nil) + if err != nil && err != ErrNil { + return 0, err + } + ms = append(ms, m) + } + res, err := p.client.SRem(key, ms).Result() // todo ... + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SMembers(key string) ([]string, error) { + res, err := p.client.SMembers(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) ScriptLoad(luaScript string) (interface{}, error) { + res, err := p.client.ScriptLoad(luaScript).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) EvalSha(sha1 string, numberKeys int, keysArgs ...interface{}) (interface{}, error) { + vals := make([]interface{}, 0, len(keysArgs)+2) + vals = append(vals, sha1, numberKeys) + vals = append(vals, keysArgs...) + keys := make([]string, 0, numberKeys) + args := make([]string, 0, len(keysArgs)-numberKeys) + for i, value := range keysArgs { + val, err := String(value, nil) + if err != nil && err != ErrNil { + return nil, err + } + if i < numberKeys { + keys = append(keys, val) + } else { + args = append(args, val) + } + } + res, err := p.client.EvalSha(sha1, keys, args).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) Eval(luaScript string, numberKeys int, keysArgs ...interface{}) (interface{}, error) { + vals := make([]interface{}, 0, len(keysArgs)+2) + vals = append(vals, luaScript, numberKeys) + vals = append(vals, keysArgs...) + keys := make([]string, 0, numberKeys) + args := make([]string, 0, len(keysArgs)-numberKeys) + for i, value := range keysArgs { + val, err := String(value, nil) + if err != nil && err != ErrNil { + return nil, err + } + if i < numberKeys { + keys = append(keys, val) + } else { + args = append(args, val) + } + } + res, err := p.client.Eval(luaScript, keys, args).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) GetBit(key string, offset int64) (int64, error) { + res, err := p.client.GetBit(key, offset).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SetBit(key string, offset uint32, value int) (int, error) { + res, err := p.client.SetBit(key, int64(offset), value).Result() + return int(res), convertError(err) +} +func (p *RedisClusterPool) GetClient() *redis.ClusterClient { + return pools +} diff --git a/utils/convert.go b/utils/convert.go new file mode 100644 index 0000000..5fc1076 --- /dev/null +++ b/utils/convert.go @@ -0,0 +1,366 @@ +package utils + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" +) + +func ToString(raw interface{}, e error) (res string) { + if e != nil { + return "" + } + return AnyToString(raw) +} + +func ToInt64(raw interface{}, e error) int64 { + if e != nil { + return 0 + } + return AnyToInt64(raw) +} + +func AnyToBool(raw interface{}) bool { + switch i := raw.(type) { + case float32, float64, int, int64, uint, uint8, uint16, uint32, uint64, int8, int16, int32: + return i != 0 + case []byte: + return i != nil + case string: + if i == "false" { + return false + } + return i != "" + case error: + return false + case nil: + return true + } + val := fmt.Sprint(raw) + val = strings.TrimLeft(val, "&") + if strings.TrimLeft(val, "{}") == "" { + return false + } + if strings.TrimLeft(val, "[]") == "" { + return false + } + // ptr type + b, err := json.Marshal(raw) + if err != nil { + return false + } + if strings.TrimLeft(string(b), "\"\"") == "" { + return false + } + if strings.TrimLeft(string(b), "{}") == "" { + return false + } + return true +} + +func AnyToInt64(raw interface{}) int64 { + switch i := raw.(type) { + case string: + res, _ := strconv.ParseInt(i, 10, 64) + return res + case []byte: + return BytesToInt64(i) + case int: + return int64(i) + case int64: + return i + case uint: + return int64(i) + case uint8: + return int64(i) + case uint16: + return int64(i) + case uint32: + return int64(i) + case uint64: + return int64(i) + case int8: + return int64(i) + case int16: + return int64(i) + case int32: + return int64(i) + case float32: + return int64(i) + case float64: + return int64(i) + case error: + return 0 + case bool: + if i { + return 1 + } + return 0 + } + return 0 +} + +func AnyToString(raw interface{}) string { + switch i := raw.(type) { + case []byte: + return string(i) + case int: + return strconv.FormatInt(int64(i), 10) + case int64: + return strconv.FormatInt(i, 10) + case float32: + return Float64ToStr(float64(i)) + case float64: + return Float64ToStr(i) + case uint: + return strconv.FormatInt(int64(i), 10) + case uint8: + return strconv.FormatInt(int64(i), 10) + case uint16: + return strconv.FormatInt(int64(i), 10) + case uint32: + return strconv.FormatInt(int64(i), 10) + case uint64: + return strconv.FormatInt(int64(i), 10) + case int8: + return strconv.FormatInt(int64(i), 10) + case int16: + return strconv.FormatInt(int64(i), 10) + case int32: + return strconv.FormatInt(int64(i), 10) + case string: + return i + case error: + return i.Error() + case bool: + return strconv.FormatBool(i) + } + return fmt.Sprintf("%#v", raw) +} + +func AnyToFloat64(raw interface{}) float64 { + switch i := raw.(type) { + case []byte: + f, _ := strconv.ParseFloat(string(i), 64) + return f + case int: + return float64(i) + case int64: + return float64(i) + case float32: + return float64(i) + case float64: + return i + case uint: + return float64(i) + case uint8: + return float64(i) + case uint16: + return float64(i) + case uint32: + return float64(i) + case uint64: + return float64(i) + case int8: + return float64(i) + case int16: + return float64(i) + case int32: + return float64(i) + case string: + f, _ := strconv.ParseFloat(i, 64) + return f + case bool: + if i { + return 1 + } + } + return 0 +} + +func ToByte(raw interface{}, e error) []byte { + if e != nil { + return []byte{} + } + switch i := raw.(type) { + case string: + return []byte(i) + case int: + return Int64ToBytes(int64(i)) + case int64: + return Int64ToBytes(i) + case float32: + return Float32ToByte(i) + case float64: + return Float64ToByte(i) + case uint: + return Int64ToBytes(int64(i)) + case uint8: + return Int64ToBytes(int64(i)) + case uint16: + return Int64ToBytes(int64(i)) + case uint32: + return Int64ToBytes(int64(i)) + case uint64: + return Int64ToBytes(int64(i)) + case int8: + return Int64ToBytes(int64(i)) + case int16: + return Int64ToBytes(int64(i)) + case int32: + return Int64ToBytes(int64(i)) + case []byte: + return i + case error: + return []byte(i.Error()) + case bool: + if i { + return []byte("true") + } + return []byte("false") + } + return []byte(fmt.Sprintf("%#v", raw)) +} + +func Int64ToBytes(i int64) []byte { + var buf = make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(i)) + return buf +} + +func BytesToInt64(buf []byte) int64 { + return int64(binary.BigEndian.Uint64(buf)) +} + +func StrToInt(s string) int { + res, _ := strconv.Atoi(s) + return res +} + +func StrToInt64(s string) int64 { + res, _ := strconv.ParseInt(s, 10, 64) + return res +} + +func Float32ToByte(float float32) []byte { + bits := math.Float32bits(float) + bytes := make([]byte, 4) + binary.LittleEndian.PutUint32(bytes, bits) + + return bytes +} + +func ByteToFloat32(bytes []byte) float32 { + bits := binary.LittleEndian.Uint32(bytes) + return math.Float32frombits(bits) +} + +func Float64ToByte(float float64) []byte { + bits := math.Float64bits(float) + bytes := make([]byte, 8) + binary.LittleEndian.PutUint64(bytes, bits) + return bytes +} + +func ByteToFloat64(bytes []byte) float64 { + bits := binary.LittleEndian.Uint64(bytes) + return math.Float64frombits(bits) +} + +func Float64ToStr(f float64) string { + return strconv.FormatFloat(f, 'f', 2, 64) +} +func Float64ToStrPrec1(f float64) string { + return strconv.FormatFloat(f, 'f', 1, 64) +} +func Float64ToStrByPrec(f float64, prec int) string { + return strconv.FormatFloat(f, 'f', prec, 64) +} + +func Float32ToStr(f float32) string { + return Float64ToStr(float64(f)) +} + +func StrToFloat64(s string) float64 { + res, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0 + } + return res +} +func StrToFormat(s string, prec int) string { + ex := strings.Split(s, ".") + if len(ex) == 2 { + if StrToFloat64(ex[1]) == 0 { //小数点后面为空就是不要小数点了 + return ex[0] + } + //看取多少位 + str := ex[1] + str1 := str + if prec < len(str) { + str1 = str[0:prec] + } else { + for i := 0; i < prec-len(str); i++ { + str1 += "0" + } + } + if prec > 0 { + return ex[0] + "." + str1 + } else { + return ex[0] + } + } + return s +} + +func StrToFloat32(s string) float32 { + res, err := strconv.ParseFloat(s, 32) + if err != nil { + return 0 + } + return float32(res) +} + +func StrToBool(s string) bool { + b, _ := strconv.ParseBool(s) + return b +} + +func BoolToStr(b bool) string { + if b { + return "true" + } + return "false" +} + +func FloatToInt64(f float64) int64 { + return int64(f) +} + +func IntToStr(i int) string { + return strconv.Itoa(i) +} + +func Int64ToStr(i int64) string { + return strconv.FormatInt(i, 10) +} + +func IntToFloat64(i int) float64 { + s := strconv.Itoa(i) + res, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0 + } + return res +} +func Int64ToFloat64(i int64) float64 { + s := strconv.FormatInt(i, 10) + res, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0 + } + return res +} diff --git a/utils/crypto.go b/utils/crypto.go new file mode 100644 index 0000000..56289c5 --- /dev/null +++ b/utils/crypto.go @@ -0,0 +1,19 @@ +package utils + +import ( + "crypto/md5" + "encoding/base64" + "fmt" +) + +func GetMd5(raw []byte) string { + h := md5.New() + h.Write(raw) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func GetBase64Md5(raw []byte) string { + h := md5.New() + h.Write(raw) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/utils/curl.go b/utils/curl.go new file mode 100644 index 0000000..0a45607 --- /dev/null +++ b/utils/curl.go @@ -0,0 +1,209 @@ +package utils + +import ( + "bytes" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "sort" + "strings" + "time" +) + +var CurlDebug bool + +func CurlGet(router string, header map[string]string) ([]byte, error) { + return curl(http.MethodGet, router, nil, header) +} +func CurlGetJson(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl_new(http.MethodGet, router, body, header) +} + +// 只支持form 与json 提交, 请留意body的类型, 支持string, []byte, map[string]string +func CurlPost(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl(http.MethodPost, router, body, header) +} + +func CurlPut(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl(http.MethodPut, router, body, header) +} + +// 只支持form 与json 提交, 请留意body的类型, 支持string, []byte, map[string]string +func CurlPatch(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl(http.MethodPatch, router, body, header) +} + +// CurlDelete is curl delete +func CurlDelete(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl(http.MethodDelete, router, body, header) +} + +func curl(method, router string, body interface{}, header map[string]string) ([]byte, error) { + var reqBody io.Reader + contentType := "application/json" + switch v := body.(type) { + case string: + reqBody = strings.NewReader(v) + case []byte: + reqBody = bytes.NewReader(v) + case map[string]string: + val := url.Values{} + for k, v := range v { + val.Set(k, v) + } + reqBody = strings.NewReader(val.Encode()) + contentType = "application/x-www-form-urlencoded" + case map[string]interface{}: + val := url.Values{} + for k, v := range v { + val.Set(k, v.(string)) + } + reqBody = strings.NewReader(val.Encode()) + contentType = "application/x-www-form-urlencoded" + } + if header == nil { + header = map[string]string{"Content-Type": contentType} + } + if _, ok := header["Content-Type"]; !ok { + header["Content-Type"] = contentType + } + resp, er := CurlReq(method, router, reqBody, header) + if er != nil { + return nil, er + } + res, err := ioutil.ReadAll(resp.Body) + if CurlDebug { + blob := SerializeStr(body) + if contentType != "application/json" { + blob = HttpBuild(body) + } + fmt.Printf("\n\n=====================\n[url]: %s\n[time]: %s\n[method]: %s\n[content-type]: %v\n[req_header]: %s\n[req_body]: %#v\n[resp_err]: %v\n[resp_header]: %v\n[resp_body]: %v\n=====================\n\n", + router, + time.Now().Format("2006-01-02 15:04:05.000"), + method, + contentType, + HttpBuildQuery(header), + blob, + err, + SerializeStr(resp.Header), + string(res), + ) + } + resp.Body.Close() + return res, err +} + +func curl_new(method, router string, body interface{}, header map[string]string) ([]byte, error) { + var reqBody io.Reader + contentType := "application/json" + + if header == nil { + header = map[string]string{"Content-Type": contentType} + } + if _, ok := header["Content-Type"]; !ok { + header["Content-Type"] = contentType + } + resp, er := CurlReq(method, router, reqBody, header) + if er != nil { + return nil, er + } + res, err := ioutil.ReadAll(resp.Body) + if CurlDebug { + blob := SerializeStr(body) + if contentType != "application/json" { + blob = HttpBuild(body) + } + fmt.Printf("\n\n=====================\n[url]: %s\n[time]: %s\n[method]: %s\n[content-type]: %v\n[req_header]: %s\n[req_body]: %#v\n[resp_err]: %v\n[resp_header]: %v\n[resp_body]: %v\n=====================\n\n", + router, + time.Now().Format("2006-01-02 15:04:05.000"), + method, + contentType, + HttpBuildQuery(header), + blob, + err, + SerializeStr(resp.Header), + string(res), + ) + } + resp.Body.Close() + return res, err +} + +func CurlReq(method, router string, reqBody io.Reader, header map[string]string) (*http.Response, error) { + req, _ := http.NewRequest(method, router, reqBody) + if header != nil { + for k, v := range header { + req.Header.Set(k, v) + } + } + // 绕过github等可能因为特征码返回503问题 + // https://www.imwzk.com/posts/2021-03-14-why-i-always-get-503-with-golang/ + defaultCipherSuites := []uint16{0xc02f, 0xc030, 0xc02b, 0xc02c, 0xcca8, 0xcca9, 0xc013, 0xc009, + 0xc014, 0xc00a, 0x009c, 0x009d, 0x002f, 0x0035, 0xc012, 0x000a} + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + CipherSuites: append(defaultCipherSuites[8:], defaultCipherSuites[:8]...), + }, + }, + // 获取301重定向 + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + return client.Do(req) +} + +// 组建get请求参数,sortAsc true为小到大,false为大到小,nil不排序 a=123&b=321 +func HttpBuildQuery(args map[string]string, sortAsc ...bool) string { + str := "" + if len(args) == 0 { + return str + } + if len(sortAsc) > 0 { + keys := make([]string, 0, len(args)) + for k := range args { + keys = append(keys, k) + } + if sortAsc[0] { + sort.Strings(keys) + } else { + sort.Sort(sort.Reverse(sort.StringSlice(keys))) + } + for _, k := range keys { + str += "&" + k + "=" + args[k] + } + } else { + for k, v := range args { + str += "&" + k + "=" + v + } + } + return str[1:] +} + +func HttpBuild(body interface{}, sortAsc ...bool) string { + params := map[string]string{} + if args, ok := body.(map[string]interface{}); ok { + for k, v := range args { + params[k] = AnyToString(v) + } + return HttpBuildQuery(params, sortAsc...) + } + if args, ok := body.(map[string]string); ok { + for k, v := range args { + params[k] = AnyToString(v) + } + return HttpBuildQuery(params, sortAsc...) + } + if args, ok := body.(map[string]int); ok { + for k, v := range args { + params[k] = AnyToString(v) + } + return HttpBuildQuery(params, sortAsc...) + } + return AnyToString(body) +} diff --git a/utils/debug.go b/utils/debug.go new file mode 100644 index 0000000..bb2e9d3 --- /dev/null +++ b/utils/debug.go @@ -0,0 +1,25 @@ +package utils + +import ( + "fmt" + "os" + "strconv" + "time" +) + +func Debug(args ...interface{}) { + s := "" + l := len(args) + if l < 1 { + fmt.Println("please input some data") + os.Exit(0) + } + i := 1 + for _, v := range args { + s += fmt.Sprintf("【"+strconv.Itoa(i)+"】: %#v\n", v) + i++ + } + s = "******************** 【DEBUG - " + time.Now().Format("2006-01-02 15:04:05") + "】 ********************\n" + s + "******************** 【DEBUG - END】 ********************\n" + fmt.Println(s) + os.Exit(0) +} diff --git a/utils/distance.go b/utils/distance.go new file mode 100644 index 0000000..2a6923b --- /dev/null +++ b/utils/distance.go @@ -0,0 +1,16 @@ +package utils + +import "math" + +//返回单位为:千米 +func GetDistance(lat1, lat2, lng1, lng2 float64) float64 { + radius := 6371000.0 //6378137.0 + rad := math.Pi / 180.0 + lat1 = lat1 * rad + lng1 = lng1 * rad + lat2 = lat2 * rad + lng2 = lng2 * rad + theta := lng2 - lng1 + dist := math.Acos(math.Sin(lat1)*math.Sin(lat2) + math.Cos(lat1)*math.Cos(lat2)*math.Cos(theta)) + return dist * radius / 1000 +} diff --git a/utils/duplicate.go b/utils/duplicate.go new file mode 100644 index 0000000..17cea88 --- /dev/null +++ b/utils/duplicate.go @@ -0,0 +1,37 @@ +package utils + +func RemoveDuplicateString(elms []string) []string { + res := make([]string, 0, len(elms)) + temp := map[string]struct{}{} + for _, item := range elms { + if _, ok := temp[item]; !ok { + temp[item] = struct{}{} + res = append(res, item) + } + } + return res +} + +func RemoveDuplicateInt(elms []int) []int { + res := make([]int, 0, len(elms)) + temp := map[int]struct{}{} + for _, item := range elms { + if _, ok := temp[item]; !ok { + temp[item] = struct{}{} + res = append(res, item) + } + } + return res +} + +func RemoveDuplicateInt64(elms []int64) []int64 { + res := make([]int64, 0, len(elms)) + temp := map[int64]struct{}{} + for _, item := range elms { + if _, ok := temp[item]; !ok { + temp[item] = struct{}{} + res = append(res, item) + } + } + return res +} diff --git a/utils/file.go b/utils/file.go new file mode 100644 index 0000000..93ed08f --- /dev/null +++ b/utils/file.go @@ -0,0 +1,22 @@ +package utils + +import ( + "os" + "path" + "strings" + "time" +) + +// 获取文件后缀 +func FileExt(fname string) string { + return strings.ToLower(strings.TrimLeft(path.Ext(fname), ".")) +} + +func FilePutContents(fileName string, content string) { + fd, _ := os.OpenFile("./tmp/"+fileName+".log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) + fd_time := time.Now().Format("2006-01-02 15:04:05") + fd_content := strings.Join([]string{"[", fd_time, "] ", content, "\n"}, "") + buf := []byte(fd_content) + fd.Write(buf) + fd.Close() +} diff --git a/utils/file_and_dir.go b/utils/file_and_dir.go new file mode 100644 index 0000000..93141f9 --- /dev/null +++ b/utils/file_and_dir.go @@ -0,0 +1,29 @@ +package utils + +import "os" + +// 判断所给路径文件、文件夹是否存在 +func Exists(path string) bool { + _, err := os.Stat(path) //os.Stat获取文件信息 + if err != nil { + if os.IsExist(err) { + return true + } + return false + } + return true +} + +// 判断所给路径是否为文件夹 +func IsDir(path string) bool { + s, err := os.Stat(path) + if err != nil { + return false + } + return s.IsDir() +} + +// 判断所给路径是否为文件 +func IsFile(path string) bool { + return !IsDir(path) +} diff --git a/utils/format.go b/utils/format.go new file mode 100644 index 0000000..e5e2d40 --- /dev/null +++ b/utils/format.go @@ -0,0 +1,66 @@ +package utils + +import ( + "math" +) + +func CouponFormat(data string) string { + switch data { + case "0.00", "0", "": + return "" + default: + return Int64ToStr(FloatToInt64(StrToFloat64(data))) + } +} +func CommissionFormat(data string) string { + if StrToFloat64(data) > 0 { + return data + } + + return "" +} + +func HideString(src string, hLen int) string { + str := []rune(src) + if hLen == 0 { + hLen = 4 + } + hideStr := "" + for i := 0; i < hLen; i++ { + hideStr += "*" + } + hideLen := len(str) / 2 + showLen := len(str) - hideLen + if hideLen == 0 || showLen == 0 { + return hideStr + } + subLen := showLen / 2 + if subLen == 0 { + return string(str[:showLen]) + hideStr + } + s := string(str[:subLen]) + s += hideStr + s += string(str[len(str)-subLen:]) + return s +} + +//SaleCountFormat is 格式化销量 +func SaleCountFormat(s string) string { + if StrToInt(s) > 0 { + if StrToInt(s) >= 10000 { + num := FloatFormat(StrToFloat64(s)/10000, 2) + s = Float64ToStr(num) + "w" + } + return s + "已售" + } + return "" +} + +// 小数格式化 +func FloatFormat(f float64, i int) float64 { + if i > 14 { + return f + } + p := math.Pow10(i) + return float64(int64((f+0.000000000000009)*p)) / p +} diff --git a/utils/json.go b/utils/json.go new file mode 100644 index 0000000..2905833 --- /dev/null +++ b/utils/json.go @@ -0,0 +1,37 @@ +package utils + +import ( + "bytes" + "encoding/json" + "regexp" +) + +func JsonMarshal(interface{}) { + +} + +// 不科学计数法 +func JsonDecode(data []byte, v interface{}) error { + d := json.NewDecoder(bytes.NewReader(data)) + d.UseNumber() + return d.Decode(v) +} + +// json字符串驼峰命名格式 转为 下划线命名格式 +// c :json字符串 +func MarshalJSONCamelCase2JsonSnakeCase(c string) []byte { + // Regexp definitions + var keyMatchRegex = regexp.MustCompile(`\"(\w+)\":`) + var wordBarrierRegex = regexp.MustCompile(`(\w)([A-Z])`) + marshalled := []byte(c) + converted := keyMatchRegex.ReplaceAllFunc( + marshalled, + func(match []byte) []byte { + return bytes.ToLower(wordBarrierRegex.ReplaceAll( + match, + []byte(`${1}_${2}`), + )) + }, + ) + return converted +} diff --git a/utils/logx/log.go b/utils/logx/log.go new file mode 100644 index 0000000..ca11223 --- /dev/null +++ b/utils/logx/log.go @@ -0,0 +1,245 @@ +package logx + +import ( + "os" + "strings" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type LogConfig struct { + AppName string `yaml:"app_name" json:"app_name" toml:"app_name"` + Level string `yaml:"level" json:"level" toml:"level"` + StacktraceLevel string `yaml:"stacktrace_level" json:"stacktrace_level" toml:"stacktrace_level"` + IsStdOut bool `yaml:"is_stdout" json:"is_stdout" toml:"is_stdout"` + TimeFormat string `yaml:"time_format" json:"time_format" toml:"time_format"` // second, milli, nano, standard, iso, + Encoding string `yaml:"encoding" json:"encoding" toml:"encoding"` // console, json + Skip int `yaml:"skip" json:"skip" toml:"skip"` + + IsFileOut bool `yaml:"is_file_out" json:"is_file_out" toml:"is_file_out"` + FileDir string `yaml:"file_dir" json:"file_dir" toml:"file_dir"` + FileName string `yaml:"file_name" json:"file_name" toml:"file_name"` + FileMaxSize int `yaml:"file_max_size" json:"file_max_size" toml:"file_max_size"` + FileMaxAge int `yaml:"file_max_age" json:"file_max_age" toml:"file_max_age"` +} + +var ( + l *LogX = defaultLogger() + conf *LogConfig +) + +// default logger setting +func defaultLogger() *LogX { + conf = &LogConfig{ + Level: "debug", + StacktraceLevel: "error", + IsStdOut: true, + TimeFormat: "standard", + Encoding: "console", + Skip: 2, + } + writers := []zapcore.WriteSyncer{os.Stdout} + lg, lv := newZapLogger(setLogLevel(conf.Level), setLogLevel(conf.StacktraceLevel), conf.Encoding, conf.TimeFormat, conf.Skip, zapcore.NewMultiWriteSyncer(writers...)) + zap.RedirectStdLog(lg) + return &LogX{logger: lg, atomLevel: lv} +} + +// initial standard log, if you don't init, it will use default logger setting +func InitDefaultLogger(cfg *LogConfig) { + var writers []zapcore.WriteSyncer + if cfg.IsStdOut || (!cfg.IsStdOut && !cfg.IsFileOut) { + writers = append(writers, os.Stdout) + } + if cfg.IsFileOut { + writers = append(writers, NewRollingFile(cfg.FileDir, cfg.FileName, cfg.FileMaxSize, cfg.FileMaxAge)) + } + + lg, lv := newZapLogger(setLogLevel(cfg.Level), setLogLevel(cfg.StacktraceLevel), cfg.Encoding, cfg.TimeFormat, cfg.Skip, zapcore.NewMultiWriteSyncer(writers...)) + zap.RedirectStdLog(lg) + if cfg.AppName != "" { + lg = lg.With(zap.String("app", cfg.AppName)) // 加上应用名称 + } + l = &LogX{logger: lg, atomLevel: lv} +} + +// create a new logger +func NewLogger(cfg *LogConfig) *LogX { + var writers []zapcore.WriteSyncer + if cfg.IsStdOut || (!cfg.IsStdOut && !cfg.IsFileOut) { + writers = append(writers, os.Stdout) + } + if cfg.IsFileOut { + writers = append(writers, NewRollingFile(cfg.FileDir, cfg.FileName, cfg.FileMaxSize, cfg.FileMaxAge)) + } + + lg, lv := newZapLogger(setLogLevel(cfg.Level), setLogLevel(cfg.StacktraceLevel), cfg.Encoding, cfg.TimeFormat, cfg.Skip, zapcore.NewMultiWriteSyncer(writers...)) + zap.RedirectStdLog(lg) + if cfg.AppName != "" { + lg = lg.With(zap.String("app", cfg.AppName)) // 加上应用名称 + } + return &LogX{logger: lg, atomLevel: lv} +} + +// create a new zaplog logger +func newZapLogger(level, stacktrace zapcore.Level, encoding, timeType string, skip int, output zapcore.WriteSyncer) (*zap.Logger, *zap.AtomicLevel) { + encCfg := zapcore.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "N", + CallerKey: "C", + MessageKey: "M", + StacktraceKey: "S", + LineEnding: zapcore.DefaultLineEnding, + EncodeCaller: zapcore.ShortCallerEncoder, + EncodeDuration: zapcore.NanosDurationEncoder, + EncodeLevel: zapcore.LowercaseLevelEncoder, + } + setTimeFormat(timeType, &encCfg) // set time type + atmLvl := zap.NewAtomicLevel() // set level + atmLvl.SetLevel(level) + encoder := zapcore.NewJSONEncoder(encCfg) // 确定encoder格式 + if encoding == "console" { + encoder = zapcore.NewConsoleEncoder(encCfg) + } + return zap.New(zapcore.NewCore(encoder, output, atmLvl), zap.AddCaller(), zap.AddStacktrace(stacktrace), zap.AddCallerSkip(skip)), &atmLvl +} + +// set log level +func setLogLevel(lvl string) zapcore.Level { + switch strings.ToLower(lvl) { + case "panic": + return zapcore.PanicLevel + case "fatal": + return zapcore.FatalLevel + case "error": + return zapcore.ErrorLevel + case "warn", "warning": + return zapcore.WarnLevel + case "info": + return zapcore.InfoLevel + default: + return zapcore.DebugLevel + } +} + +// set time format +func setTimeFormat(timeType string, z *zapcore.EncoderConfig) { + switch strings.ToLower(timeType) { + case "iso": // iso8601 standard + z.EncodeTime = zapcore.ISO8601TimeEncoder + case "sec": // only for unix second, without millisecond + z.EncodeTime = func(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendInt64(t.Unix()) + } + case "second": // unix second, with millisecond + z.EncodeTime = zapcore.EpochTimeEncoder + case "milli", "millisecond": // millisecond + z.EncodeTime = zapcore.EpochMillisTimeEncoder + case "nano", "nanosecond": // nanosecond + z.EncodeTime = zapcore.EpochNanosTimeEncoder + default: // standard format + z.EncodeTime = func(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.Format("2006-01-02 15:04:05.000")) + } + } +} + +func GetLevel() string { + switch l.atomLevel.Level() { + case zapcore.PanicLevel: + return "panic" + case zapcore.FatalLevel: + return "fatal" + case zapcore.ErrorLevel: + return "error" + case zapcore.WarnLevel: + return "warn" + case zapcore.InfoLevel: + return "info" + default: + return "debug" + } +} + +func SetLevel(lvl string) { + l.atomLevel.SetLevel(setLogLevel(lvl)) +} + +// temporary add call skip +func AddCallerSkip(skip int) *LogX { + l.logger.WithOptions(zap.AddCallerSkip(skip)) + return l +} + +// permanent add call skip +func AddDepth(skip int) *LogX { + l.logger = l.logger.WithOptions(zap.AddCallerSkip(skip)) + return l +} + +// permanent add options +func AddOptions(opts ...zap.Option) *LogX { + l.logger = l.logger.WithOptions(opts...) + return l +} + +func AddField(k string, v interface{}) { + l.logger.With(zap.Any(k, v)) +} + +func AddFields(fields map[string]interface{}) *LogX { + for k, v := range fields { + l.logger.With(zap.Any(k, v)) + } + return l +} + +// Normal log +func Debug(e interface{}, args ...interface{}) error { + return l.Debug(e, args...) +} +func Info(e interface{}, args ...interface{}) error { + return l.Info(e, args...) +} +func Warn(e interface{}, args ...interface{}) error { + return l.Warn(e, args...) +} +func Error(e interface{}, args ...interface{}) error { + return l.Error(e, args...) +} +func Panic(e interface{}, args ...interface{}) error { + return l.Panic(e, args...) +} +func Fatal(e interface{}, args ...interface{}) error { + return l.Fatal(e, args...) +} + +// Format logs +func Debugf(format string, args ...interface{}) error { + return l.Debugf(format, args...) +} +func Infof(format string, args ...interface{}) error { + return l.Infof(format, args...) +} +func Warnf(format string, args ...interface{}) error { + return l.Warnf(format, args...) +} +func Errorf(format string, args ...interface{}) error { + return l.Errorf(format, args...) +} +func Panicf(format string, args ...interface{}) error { + return l.Panicf(format, args...) +} +func Fatalf(format string, args ...interface{}) error { + return l.Fatalf(format, args...) +} + +func formatFieldMap(m FieldMap) []Field { + var res []Field + for k, v := range m { + res = append(res, zap.Any(k, v)) + } + return res +} diff --git a/utils/logx/output.go b/utils/logx/output.go new file mode 100644 index 0000000..ef33f0b --- /dev/null +++ b/utils/logx/output.go @@ -0,0 +1,105 @@ +package logx + +import ( + "bytes" + "io" + "os" + "path/filepath" + "time" + + "gopkg.in/natefinch/lumberjack.v2" +) + +// output interface +type WriteSyncer interface { + io.Writer + Sync() error +} + +// split writer +func NewRollingFile(dir, filename string, maxSize, MaxAge int) WriteSyncer { + s, err := os.Stat(dir) + if err != nil || !s.IsDir() { + os.RemoveAll(dir) + if err := os.MkdirAll(dir, 0766); err != nil { + panic(err) + } + } + return newLumberjackWriteSyncer(&lumberjack.Logger{ + Filename: filepath.Join(dir, filename), + MaxSize: maxSize, // megabytes, MB + MaxAge: MaxAge, // days + LocalTime: true, + Compress: false, + }) +} + +type lumberjackWriteSyncer struct { + *lumberjack.Logger + buf *bytes.Buffer + logChan chan []byte + closeChan chan interface{} + maxSize int +} + +func newLumberjackWriteSyncer(l *lumberjack.Logger) *lumberjackWriteSyncer { + ws := &lumberjackWriteSyncer{ + Logger: l, + buf: bytes.NewBuffer([]byte{}), + logChan: make(chan []byte, 5000), + closeChan: make(chan interface{}), + maxSize: 1024, + } + go ws.run() + return ws +} + +func (l *lumberjackWriteSyncer) run() { + ticker := time.NewTicker(1 * time.Second) + + for { + select { + case <-ticker.C: + if l.buf.Len() > 0 { + l.sync() + } + case bs := <-l.logChan: + _, err := l.buf.Write(bs) + if err != nil { + continue + } + if l.buf.Len() > l.maxSize { + l.sync() + } + case <-l.closeChan: + l.sync() + return + } + } +} + +func (l *lumberjackWriteSyncer) Stop() { + close(l.closeChan) +} + +func (l *lumberjackWriteSyncer) Write(bs []byte) (int, error) { + b := make([]byte, len(bs)) + for i, c := range bs { + b[i] = c + } + l.logChan <- b + return 0, nil +} + +func (l *lumberjackWriteSyncer) Sync() error { + return nil +} + +func (l *lumberjackWriteSyncer) sync() error { + defer l.buf.Reset() + _, err := l.Logger.Write(l.buf.Bytes()) + if err != nil { + return err + } + return nil +} diff --git a/utils/logx/sugar.go b/utils/logx/sugar.go new file mode 100644 index 0000000..ab380fc --- /dev/null +++ b/utils/logx/sugar.go @@ -0,0 +1,192 @@ +package logx + +import ( + "errors" + "fmt" + "strconv" + + "go.uber.org/zap" +) + +type LogX struct { + logger *zap.Logger + atomLevel *zap.AtomicLevel +} + +type Field = zap.Field +type FieldMap map[string]interface{} + +// 判断其他类型--start +func getFields(msg string, format bool, args ...interface{}) (string, []Field) { + var str []interface{} + var fields []zap.Field + if len(args) > 0 { + for _, v := range args { + if f, ok := v.(Field); ok { + fields = append(fields, f) + } else if f, ok := v.(FieldMap); ok { + fields = append(fields, formatFieldMap(f)...) + } else { + str = append(str, AnyToString(v)) + } + } + if format { + return fmt.Sprintf(msg, str...), fields + } + str = append([]interface{}{msg}, str...) + return fmt.Sprintln(str...), fields + } + return msg, []Field{} +} + +func (l *LogX) Debug(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Debug(msg, field...) + } + return e +} +func (l *LogX) Info(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Info(msg, field...) + } + return e +} +func (l *LogX) Warn(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Warn(msg, field...) + } + return e +} +func (l *LogX) Error(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Error(msg, field...) + } + return e +} +func (l *LogX) DPanic(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.DPanic(msg, field...) + } + return e +} +func (l *LogX) Panic(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Panic(msg, field...) + } + return e +} +func (l *LogX) Fatal(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Fatal(msg, field...) + } + return e +} + +func checkErr(s interface{}) (string, error) { + switch e := s.(type) { + case error: + return e.Error(), e + case string: + return e, errors.New(e) + case []byte: + return string(e), nil + default: + return "", nil + } +} + +func (l *LogX) LogError(err error) error { + return l.Error(err.Error()) +} + +func (l *LogX) Debugf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Debug(s, f...) + return errors.New(s) +} + +func (l *LogX) Infof(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Info(s, f...) + return errors.New(s) +} + +func (l *LogX) Warnf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Warn(s, f...) + return errors.New(s) +} + +func (l *LogX) Errorf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Error(s, f...) + return errors.New(s) +} + +func (l *LogX) DPanicf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.DPanic(s, f...) + return errors.New(s) +} + +func (l *LogX) Panicf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Panic(s, f...) + return errors.New(s) +} + +func (l *LogX) Fatalf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Fatal(s, f...) + return errors.New(s) +} + +func AnyToString(raw interface{}) string { + switch i := raw.(type) { + case []byte: + return string(i) + case int: + return strconv.FormatInt(int64(i), 10) + case int64: + return strconv.FormatInt(i, 10) + case float32: + return strconv.FormatFloat(float64(i), 'f', 2, 64) + case float64: + return strconv.FormatFloat(i, 'f', 2, 64) + case uint: + return strconv.FormatInt(int64(i), 10) + case uint8: + return strconv.FormatInt(int64(i), 10) + case uint16: + return strconv.FormatInt(int64(i), 10) + case uint32: + return strconv.FormatInt(int64(i), 10) + case uint64: + return strconv.FormatInt(int64(i), 10) + case int8: + return strconv.FormatInt(int64(i), 10) + case int16: + return strconv.FormatInt(int64(i), 10) + case int32: + return strconv.FormatInt(int64(i), 10) + case string: + return i + case error: + return i.Error() + } + return fmt.Sprintf("%#v", raw) +} diff --git a/utils/map_and_struct.go b/utils/map_and_struct.go new file mode 100644 index 0000000..34904ce --- /dev/null +++ b/utils/map_and_struct.go @@ -0,0 +1,341 @@ +package utils + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" +) + +func Map2Struct(vals map[string]interface{}, dst interface{}) (err error) { + return Map2StructByTag(vals, dst, "json") +} + +func Map2StructByTag(vals map[string]interface{}, dst interface{}, structTag string) (err error) { + defer func() { + e := recover() + if e != nil { + if v, ok := e.(error); ok { + err = fmt.Errorf("Panic: %v", v.Error()) + } else { + err = fmt.Errorf("Panic: %v", e) + } + } + }() + + pt := reflect.TypeOf(dst) + pv := reflect.ValueOf(dst) + + if pv.Kind() != reflect.Ptr || pv.Elem().Kind() != reflect.Struct { + return fmt.Errorf("not a pointer of struct") + } + + var f reflect.StructField + var ft reflect.Type + var fv reflect.Value + + for i := 0; i < pt.Elem().NumField(); i++ { + f = pt.Elem().Field(i) + fv = pv.Elem().Field(i) + ft = f.Type + + if f.Anonymous || !fv.CanSet() { + continue + } + + tag := f.Tag.Get(structTag) + + name, option := parseTag(tag) + + if name == "-" { + continue + } + + if name == "" { + name = strings.ToLower(f.Name) + } + val, ok := vals[name] + + if !ok { + if option == "required" { + return fmt.Errorf("'%v' not found", name) + } + if len(option) != 0 { + val = option // default value + } else { + //fv.Set(reflect.Zero(ft)) // TODO set zero value or just ignore it? + continue + } + } + + // convert or set value to field + vv := reflect.ValueOf(val) + vt := reflect.TypeOf(val) + + if vt.Kind() != reflect.String { + // try to assign and convert + if vt.AssignableTo(ft) { + fv.Set(vv) + continue + } + + if vt.ConvertibleTo(ft) { + fv.Set(vv.Convert(ft)) + continue + } + + return fmt.Errorf("value type not match: field=%v(%v) value=%v(%v)", f.Name, ft.Kind(), val, vt.Kind()) + } + s := strings.TrimSpace(vv.String()) + if len(s) == 0 && option == "required" { + return fmt.Errorf("value of required argument can't not be empty") + } + fk := ft.Kind() + + // convert string to value + if fk == reflect.Ptr && ft.Elem().Kind() == reflect.String { + fv.Set(reflect.ValueOf(&s)) + continue + } + if fk == reflect.Ptr || fk == reflect.Struct { + err = convertJsonValue(s, name, fv) + } else if fk == reflect.Slice { + err = convertSlice(s, f.Name, ft, fv) + } else { + err = convertValue(fk, s, f.Name, fv) + } + + if err != nil { + return err + } + continue + } + + return nil +} + +func Struct2Map(s interface{}) map[string]interface{} { + return Struct2MapByTag(s, "json") +} +func Struct2MapByTag(s interface{}, tagName string) map[string]interface{} { + t := reflect.TypeOf(s) + v := reflect.ValueOf(s) + + if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct { + t = t.Elem() + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return nil + } + + m := make(map[string]interface{}) + + for i := 0; i < t.NumField(); i++ { + fv := v.Field(i) + ft := t.Field(i) + + if !fv.CanInterface() { + continue + } + + if ft.PkgPath != "" { // unexported + continue + } + + var name string + var option string + tag := ft.Tag.Get(tagName) + if tag != "" { + ts := strings.Split(tag, ",") + if len(ts) == 1 { + name = ts[0] + } else if len(ts) > 1 { + name = ts[0] + option = ts[1] + } + if name == "-" { + continue // skip this field + } + if name == "" { + name = strings.ToLower(ft.Name) + } + if option == "omitempty" { + if isEmpty(&fv) { + continue // skip empty field + } + } + } else { + name = strings.ToLower(ft.Name) + } + + if ft.Anonymous && fv.Kind() == reflect.Ptr && fv.IsNil() { + continue + } + if (ft.Anonymous && fv.Kind() == reflect.Struct) || + (ft.Anonymous && fv.Kind() == reflect.Ptr && fv.Elem().Kind() == reflect.Struct) { + + // embedded struct + embedded := Struct2MapByTag(fv.Interface(), tagName) + for embName, embValue := range embedded { + m[embName] = embValue + } + } else if option == "string" { + kind := fv.Kind() + if kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || kind == reflect.Int32 || kind == reflect.Int64 { + m[name] = strconv.FormatInt(fv.Int(), 10) + } else if kind == reflect.Uint || kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || kind == reflect.Uint64 { + m[name] = strconv.FormatUint(fv.Uint(), 10) + } else if kind == reflect.Float32 || kind == reflect.Float64 { + m[name] = strconv.FormatFloat(fv.Float(), 'f', 2, 64) + } else { + m[name] = fv.Interface() + } + } else { + m[name] = fv.Interface() + } + } + + return m +} + +func isEmpty(v *reflect.Value) bool { + k := v.Kind() + if k == reflect.Bool { + return v.Bool() == false + } else if reflect.Int < k && k < reflect.Int64 { + return v.Int() == 0 + } else if reflect.Uint < k && k < reflect.Uintptr { + return v.Uint() == 0 + } else if k == reflect.Float32 || k == reflect.Float64 { + return v.Float() == 0 + } else if k == reflect.Array || k == reflect.Map || k == reflect.Slice || k == reflect.String { + return v.Len() == 0 + } else if k == reflect.Interface || k == reflect.Ptr { + return v.IsNil() + } + return false +} + +func convertSlice(s string, name string, ft reflect.Type, fv reflect.Value) error { + var err error + et := ft.Elem() + + if et.Kind() == reflect.Ptr || et.Kind() == reflect.Struct { + return convertJsonValue(s, name, fv) + } + + ss := strings.Split(s, ",") + + if len(s) == 0 || len(ss) == 0 { + return nil + } + + fs := reflect.MakeSlice(ft, 0, len(ss)) + + for _, si := range ss { + ev := reflect.New(et).Elem() + + err = convertValue(et.Kind(), si, name, ev) + if err != nil { + return err + } + fs = reflect.Append(fs, ev) + } + + fv.Set(fs) + + return nil +} + +func convertJsonValue(s string, name string, fv reflect.Value) error { + var err error + d := StringToSlice(s) + + if fv.Kind() == reflect.Ptr { + if fv.IsNil() { + fv.Set(reflect.New(fv.Type().Elem())) + } + } else { + fv = fv.Addr() + } + + err = json.Unmarshal(d, fv.Interface()) + + if err != nil { + return fmt.Errorf("invalid json '%v': %v, %v", name, err.Error(), s) + } + + return nil +} + +func convertValue(kind reflect.Kind, s string, name string, fv reflect.Value) error { + if !fv.CanAddr() { + return fmt.Errorf("can not addr: %v", name) + } + + if kind == reflect.String { + fv.SetString(s) + return nil + } + + if kind == reflect.Bool { + switch s { + case "true": + fv.SetBool(true) + case "false": + fv.SetBool(false) + case "1": + fv.SetBool(true) + case "0": + fv.SetBool(false) + default: + return fmt.Errorf("invalid bool: %v value=%v", name, s) + } + return nil + } + + if reflect.Int <= kind && kind <= reflect.Int64 { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return fmt.Errorf("invalid int: %v value=%v", name, s) + } + fv.SetInt(i) + + } else if reflect.Uint <= kind && kind <= reflect.Uint64 { + i, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return fmt.Errorf("invalid int: %v value=%v", name, s) + } + fv.SetUint(i) + + } else if reflect.Float32 == kind || kind == reflect.Float64 { + i, err := strconv.ParseFloat(s, 64) + + if err != nil { + return fmt.Errorf("invalid float: %v value=%v", name, s) + } + + fv.SetFloat(i) + } else { + // not support or just ignore it? + // return fmt.Errorf("type not support: field=%v(%v) value=%v(%v)", name, ft.Kind(), val, vt.Kind()) + } + return nil +} + +func parseTag(tag string) (string, string) { + tags := strings.Split(tag, ",") + + if len(tags) <= 0 { + return "", "" + } + + if len(tags) == 1 { + return tags[0], "" + } + + return tags[0], tags[1] +} diff --git a/utils/md5.go b/utils/md5.go new file mode 100644 index 0000000..52c108d --- /dev/null +++ b/utils/md5.go @@ -0,0 +1,12 @@ +package utils + +import ( + "crypto/md5" + "encoding/hex" +) + +func Md5(str string) string { + h := md5.New() + h.Write([]byte(str)) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/utils/qrcode/decodeFile.go b/utils/qrcode/decodeFile.go new file mode 100644 index 0000000..f50fb28 --- /dev/null +++ b/utils/qrcode/decodeFile.go @@ -0,0 +1,33 @@ +package qrcode + +import ( + "image" + _ "image/jpeg" + _ "image/png" + "os" + + "github.com/makiuchi-d/gozxing" + "github.com/makiuchi-d/gozxing/qrcode" +) + +func DecodeFile(fi string) (string, error) { + file, err := os.Open(fi) + if err != nil { + return "", err + } + img, _, err := image.Decode(file) + if err != nil { + return "", err + } + // prepare BinaryBitmap + bmp, err := gozxing.NewBinaryBitmapFromImage(img) + if err != nil { + return "", err + } + // decode image + result, err := qrcode.NewQRCodeReader().Decode(bmp, nil) + if err != nil { + return "", err + } + return result.String(), nil +} diff --git a/utils/qrcode/getBase64.go b/utils/qrcode/getBase64.go new file mode 100644 index 0000000..2d0fe75 --- /dev/null +++ b/utils/qrcode/getBase64.go @@ -0,0 +1,55 @@ +package qrcode + +// 生成登录二维码图片, 方便在网页上显示 + +import ( + "bytes" + "encoding/base64" + "image/jpeg" + "image/png" + "io/ioutil" + "net/http" + + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/qr" +) + +func GetJPGBase64(content string, edges ...int) string { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + + emptyBuff := bytes.NewBuffer(nil) // 开辟一个新的空buff缓冲区 + jpeg.Encode(emptyBuff, img, nil) + dist := make([]byte, 50000) // 开辟存储空间 + base64.StdEncoding.Encode(dist, emptyBuff.Bytes()) // buff转成base64 + return "data:image/png;base64," + string(dist) // 输出图片base64(type = []byte) +} + +func GetPNGBase64(content string, edges ...int) string { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + + emptyBuff := bytes.NewBuffer(nil) // 开辟一个新的空buff缓冲区 + png.Encode(emptyBuff, img) + dist := make([]byte, 50000) // 开辟存储空间 + base64.StdEncoding.Encode(dist, emptyBuff.Bytes()) // buff转成base64 + return string(dist) // 输出图片base64(type = []byte) +} +func GetFileBase64(content string) string { + res, err := http.Get(content) + if err != nil { + return "" + } + defer res.Body.Close() + data, _ := ioutil.ReadAll(res.Body) + imageBase64 := base64.StdEncoding.EncodeToString(data) + return imageBase64 +} diff --git a/utils/qrcode/saveFile.go b/utils/qrcode/saveFile.go new file mode 100644 index 0000000..4854783 --- /dev/null +++ b/utils/qrcode/saveFile.go @@ -0,0 +1,85 @@ +package qrcode + +// 生成登录二维码图片 + +import ( + "errors" + "image" + "image/jpeg" + "image/png" + "os" + "path/filepath" + "strings" + + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/qr" +) + +func SaveJpegFile(filePath, content string, edges ...int) error { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + + return writeFile(filePath, img, "jpg") +} + +func SavePngFile(filePath, content string, edges ...int) error { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + + return writeFile(filePath, img, "png") +} + +func writeFile(filePath string, img image.Image, format string) error { + if err := createDir(filePath); err != nil { + return err + } + file, err := os.Create(filePath) + defer file.Close() + if err != nil { + return err + } + switch strings.ToLower(format) { + case "png": + err = png.Encode(file, img) + break + case "jpg": + err = jpeg.Encode(file, img, nil) + default: + return errors.New("format not accept") + } + if err != nil { + return err + } + return nil +} + +func createDir(filePath string) error { + var err error + // filePath, _ = filepath.Abs(filePath) + dirPath := filepath.Dir(filePath) + dirInfo, err := os.Stat(dirPath) + if err != nil { + if !os.IsExist(err) { + err = os.MkdirAll(dirPath, 0777) + if err != nil { + return err + } + } else { + return err + } + } else { + if dirInfo.IsDir() { + return nil + } + return errors.New("directory is a file") + } + return nil +} diff --git a/utils/qrcode/writeWeb.go b/utils/qrcode/writeWeb.go new file mode 100644 index 0000000..57e1e92 --- /dev/null +++ b/utils/qrcode/writeWeb.go @@ -0,0 +1,39 @@ +package qrcode + +import ( + "bytes" + "image/jpeg" + "image/png" + "net/http" + + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/qr" +) + +func WritePng(w http.ResponseWriter, content string, edges ...int) error { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + buff := bytes.NewBuffer(nil) + png.Encode(buff, img) + w.Header().Set("Content-Type", "image/png") + _, err := w.Write(buff.Bytes()) + return err +} + +func WriteJpg(w http.ResponseWriter, content string, edges ...int) error { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + buff := bytes.NewBuffer(nil) + jpeg.Encode(buff, img, nil) + w.Header().Set("Content-Type", "image/jpg") + _, err := w.Write(buff.Bytes()) + return err +} diff --git a/utils/rand.go b/utils/rand.go new file mode 100644 index 0000000..5c52ae3 --- /dev/null +++ b/utils/rand.go @@ -0,0 +1,46 @@ +package utils + +import ( + crand "crypto/rand" + "fmt" + "math/big" + "math/rand" + "time" +) + +func RandString(l int, c ...string) string { + var ( + chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + str string + num *big.Int + ) + if len(c) > 0 { + chars = c[0] + } + chrLen := int64(len(chars)) + for len(str) < l { + num, _ = crand.Int(crand.Reader, big.NewInt(chrLen)) + str += string(chars[num.Int64()]) + } + return str +} + +func RandNum() string { + seed := time.Now().UnixNano() + rand.Int63() + return fmt.Sprintf("%05v", rand.New(rand.NewSource(seed)).Int31n(1000000)) +} + +//x的y次方 +func RandPow(l int) string { + var i = "1" + for j := 0; j < l; j++ { + i += "0" + } + k := StrToInt64(i) + n := rand.New(rand.NewSource(time.Now().UnixNano())).Int63n(k) + ls := "%0" + IntToStr(l) + "v" + str := fmt.Sprintf(ls, n) + //min := int(math.Pow10(l - 1)) + //max := int(math.Pow10(l) - 1) + return str +} diff --git a/utils/rsa.go b/utils/rsa.go new file mode 100644 index 0000000..fb8274a --- /dev/null +++ b/utils/rsa.go @@ -0,0 +1,170 @@ +package utils + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "log" + "os" +) + +// 生成私钥文件 TODO 未指定路径 +func RsaKeyGen(bits int) error { + privateKey, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return err + } + derStream := x509.MarshalPKCS1PrivateKey(privateKey) + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: derStream, + } + priFile, err := os.Create("private.pem") + if err != nil { + return err + } + err = pem.Encode(priFile, block) + priFile.Close() + if err != nil { + return err + } + // 生成公钥文件 + publicKey := &privateKey.PublicKey + derPkix, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return err + } + block = &pem.Block{ + Type: "PUBLIC KEY", + Bytes: derPkix, + } + pubFile, err := os.Create("public.pem") + if err != nil { + return err + } + err = pem.Encode(pubFile, block) + pubFile.Close() + if err != nil { + return err + } + return nil +} + +// 生成私钥文件, 返回 privateKey , publicKey, error +func RsaKeyGenText(bits int) (string, string, error) { // bits 字节位 1024/2048 + privateKey, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return "", "", err + } + derStream := x509.MarshalPKCS1PrivateKey(privateKey) + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: derStream, + } + priBuff := bytes.NewBuffer(nil) + err = pem.Encode(priBuff, block) + if err != nil { + return "", "", err + } + // 生成公钥文件 + publicKey := &privateKey.PublicKey + derPkix, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return "", "", err + } + block = &pem.Block{ + Type: "PUBLIC KEY", + Bytes: derPkix, + } + pubBuff := bytes.NewBuffer(nil) + err = pem.Encode(pubBuff, block) + if err != nil { + return "", "", err + } + return priBuff.String(), pubBuff.String(), nil +} + +// 加密 +func RsaEncrypt(rawData, publicKey []byte) ([]byte, error) { + block, _ := pem.Decode(publicKey) + if block == nil { + return nil, errors.New("public key error") + } + pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, err + } + pub := pubInterface.(*rsa.PublicKey) + return rsa.EncryptPKCS1v15(rand.Reader, pub, rawData) +} + +// 公钥加密 +func RsaEncrypts(data, keyBytes []byte) []byte { + //解密pem格式的公钥 + block, _ := pem.Decode(keyBytes) + if block == nil { + panic(errors.New("public key error")) + } + // 解析公钥 + pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + panic(err) + } + // 类型断言 + pub := pubInterface.(*rsa.PublicKey) + //加密 + ciphertext, err := rsa.EncryptPKCS1v15(rand.Reader, pub, data) + if err != nil { + panic(err) + } + return ciphertext +} + +// 解密 +func RsaDecrypt(cipherText, privateKey []byte) ([]byte, error) { + block, _ := pem.Decode(privateKey) + if block == nil { + return nil, errors.New("private key error") + } + priv, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + return rsa.DecryptPKCS1v15(rand.Reader, priv, cipherText) +} + +// 从证书获取公钥 +func OpensslPemGetPublic(pathOrString string) (interface{}, error) { + var certPem []byte + var err error + if IsFile(pathOrString) && Exists(pathOrString) { + certPem, err = ioutil.ReadFile(pathOrString) + if err != nil { + return nil, err + } + if string(certPem) == "" { + return nil, errors.New("empty pem file") + } + } else { + if pathOrString == "" { + return nil, errors.New("empty pem string") + } + certPem = StringToSlice(pathOrString) + } + block, rest := pem.Decode(certPem) + if block == nil || block.Type != "PUBLIC KEY" { + //log.Fatal("failed to decode PEM block containing public key") + return nil, errors.New("failed to decode PEM block containing public key") + } + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Got a %T, with remaining data: %q", pub, rest) + return pub, nil +} diff --git a/utils/serialize.go b/utils/serialize.go new file mode 100644 index 0000000..1ac4d80 --- /dev/null +++ b/utils/serialize.go @@ -0,0 +1,23 @@ +package utils + +import ( + "encoding/json" +) + +func Serialize(data interface{}) []byte { + res, err := json.Marshal(data) + if err != nil { + return []byte{} + } + return res +} + +func Unserialize(b []byte, dst interface{}) { + if err := json.Unmarshal(b, dst); err != nil { + dst = nil + } +} + +func SerializeStr(data interface{}, arg ...interface{}) string { + return string(Serialize(data)) +} diff --git a/utils/shuffle.go b/utils/shuffle.go new file mode 100644 index 0000000..2c845a8 --- /dev/null +++ b/utils/shuffle.go @@ -0,0 +1,48 @@ +package utils + +import ( + "math/rand" + "time" +) + +// 打乱随机字符串 +func ShuffleString(s *string) { + if len(*s) > 1 { + b := []byte(*s) + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(b), func(x, y int) { + b[x], b[y] = b[y], b[x] + }) + *s = string(b) + } +} + +// 打乱随机slice +func ShuffleSliceBytes(b []byte) { + if len(b) > 1 { + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(b), func(x, y int) { + b[x], b[y] = b[y], b[x] + }) + } +} + +// 打乱slice int +func ShuffleSliceInt(i []int) { + if len(i) > 1 { + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(i), func(x, y int) { + i[x], i[y] = i[y], i[x] + }) + } +} + +// 打乱slice interface +func ShuffleSliceInterface(i []interface{}) { + if len(i) > 1 { + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(i), func(x, y int) { + i[x], i[y] = i[y], i[x] + }) + } +} diff --git a/utils/sign_check.go b/utils/sign_check.go new file mode 100644 index 0000000..7eada16 --- /dev/null +++ b/utils/sign_check.go @@ -0,0 +1,129 @@ +package utils + +import ( + "code.fnuoos.com/go_rely_warehouse/zyos_go_tool.git/utils/logx" + "fmt" + "github.com/forgoer/openssl" + "github.com/gin-gonic/gin" + "github.com/syyongx/php2go" + "strings" +) + +var publicKey = []byte(`-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCFQD7RL2tDNuwdg0jTfV0zjAzh +WoCWfGrcNiucy2XUHZZU2oGhHv1N10qu3XayTDD4pu4sJ73biKwqR6ZN7IS4Sfon +vrzaXGvrTG4kmdo3XrbrkzmyBHDLTsJvv6pyS2HPl9QPSvKDN0iJ66+KN8QjBpw1 +FNIGe7xbDaJPY733/QIDAQAB +-----END PUBLIC KEY-----`) + +var privateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCFQD7RL2tDNuwdg0jTfV0zjAzhWoCWfGrcNiucy2XUHZZU2oGh +Hv1N10qu3XayTDD4pu4sJ73biKwqR6ZN7IS4SfonvrzaXGvrTG4kmdo3Xrbrkzmy +BHDLTsJvv6pyS2HPl9QPSvKDN0iJ66+KN8QjBpw1FNIGe7xbDaJPY733/QIDAQAB +AoGADi14wY8XDY7Bbp5yWDZFfV+QW0Xi2qAgSo/k8gjeK8R+I0cgdcEzWF3oz1Q2 +9d+PclVokAAmfj47e0AmXLImqMCSEzi1jDBUFIRoJk9WE1YstE94mrCgV0FW+N/u ++L6OgZcjmF+9dHKprnpaUGQuUV5fF8j0qp8S2Jfs3Sw+dOECQQCQnHALzFjmXXIR +Ez3VSK4ZoYgDIrrpzNst5Hh6AMDNZcG3CrCxlQrgqjgTzBSr3ZSavvkfYRj42STk +TqyX1tQFAkEA6+O6UENoUTk2lG7iO/ta7cdIULnkTGwQqvkgLIUjk6w8E3sBTIfw +rerTEmquw5F42HHE+FMrRat06ZN57lENmQJAYgUHlZevcoZIePZ35Qfcqpbo4Gc8 +Fpm6vwKr/tZf2Vlt0qo2VkhWFS6L0C92m4AX6EQmDHT+Pj7BWNdS+aCuGQJBAOkq +NKPZvWdr8jNOV3mKvxqB/U0uMigIOYGGtvLKt5vkh42J7ILFbHW8w95UbWMKjDUG +X/hF3WQEUo//Imsa2yECQHSZIpJxiTRueoDiyRt0LH+jdbYFUu/6D0UIYXhFvP/p +EZX+hfCfUnNYX59UVpRjSZ66g0CbCjuBPOhmOD+hDeQ= +-----END RSA PRIVATE KEY-----`) + +func GetApiVersion(c *gin.Context) int { + var apiVersion = c.GetHeader("apiVersion") + if StrToInt(apiVersion) == 0 { //没有版本号先不校验 + apiVersion = c.GetHeader("Apiversion") + } + if StrToInt(apiVersion) == 0 { //没有版本号先不校验 + apiVersion = c.GetHeader("api_version") + } + return StrToInt(apiVersion) +} + +//签名校验 +func SignCheck(c *gin.Context) bool { + var apiVersion = GetApiVersion(c) + if apiVersion == 0 { //没有版本号先不校验 + return true + } + //1.通过rsa 解析出 aes + var key = c.GetHeader("key") + + //拼接对应参数 + var uri = c.Request.RequestURI + var query = GetQueryParam(uri) + fmt.Println(query) + query["timestamp"] = c.GetHeader("timestamp") + query["nonce"] = c.GetHeader("nonce") + query["key"] = key + token := c.GetHeader("Authorization") + if token != "" { + // 按空格分割 + parts := strings.SplitN(token, " ", 2) + if len(parts) == 2 && parts[0] == "Bearer" { + token = parts[1] + } + } + query["token"] = token + //2.query参数按照 ASCII 码从小到大排序 + str := JoinStringsInASCII(query, "&", false, false, "") + //3.拼上密钥 + secret := "" + if InArr(c.GetHeader("platform"), []string{"android", "ios"}) { + secret = c.GetString("app_api_secret_key") + } else if c.GetHeader("platform") == "wap" { + secret = c.GetString("h5_api_secret_key") + } else { + secret = c.GetString("applet_api_secret_key") + } + + str = fmt.Sprintf("%s&secret=%s", str, secret) + fmt.Println(str) + //4.md5加密 转小写 + sign := strings.ToLower(Md5(str)) + //5.判断跟前端传来的sign是否一致 + if sign != c.GetHeader("sign") { + return false + } + return true +} + +func ResultAes(c *gin.Context, raw []byte) string { + var key = c.GetHeader("key") + base, _ := php2go.Base64Decode(key) + aes, err := RsaDecrypt([]byte(base), privateKey) + if err != nil { + logx.Info(err) + return "" + } + fmt.Println("============aes============") + fmt.Println(string(aes)) + fmt.Println(string(raw)) + str, _ := openssl.AesECBEncrypt(raw, aes, openssl.PKCS7_PADDING) + value := php2go.Base64Encode(string(str)) + fmt.Println(value) + + return value +} + +func ResultAesDecrypt(c *gin.Context, raw string) string { + var key = c.GetHeader("key") + base, _ := php2go.Base64Decode(key) + aes, err := RsaDecrypt([]byte(base), privateKey) + if err != nil { + logx.Info(err) + return "" + } + fmt.Println(raw) + value1, _ := php2go.Base64Decode(raw) + if value1 == "" { + return "" + } + str1, _ := openssl.AesECBDecrypt([]byte(value1), aes, openssl.PKCS7_PADDING) + fmt.Println("==========解码=========") + fmt.Println(string(str1)) + return string(str1) +} diff --git a/utils/slice.go b/utils/slice.go new file mode 100644 index 0000000..fd86081 --- /dev/null +++ b/utils/slice.go @@ -0,0 +1,13 @@ +package utils + +// ContainsString is 字符串是否包含在字符串切片里 +func ContainsString(array []string, val string) (index int) { + index = -1 + for i := 0; i < len(array); i++ { + if array[i] == val { + index = i + return + } + } + return +} diff --git a/utils/slice_and_string.go b/utils/slice_and_string.go new file mode 100644 index 0000000..3ae6946 --- /dev/null +++ b/utils/slice_and_string.go @@ -0,0 +1,47 @@ +package utils + +import ( + "fmt" + "reflect" + "strings" + "unsafe" +) + +// string与slice互转,零copy省内存 + +// zero copy to change slice to string +func Slice2String(b []byte) (s string) { + pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pString := (*reflect.StringHeader)(unsafe.Pointer(&s)) + pString.Data = pBytes.Data + pString.Len = pBytes.Len + return +} + +// no copy to change string to slice +func StringToSlice(s string) (b []byte) { + pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pString := (*reflect.StringHeader)(unsafe.Pointer(&s)) + pBytes.Data = pString.Data + pBytes.Len = pString.Len + pBytes.Cap = pString.Len + return +} + +// 任意slice合并 +func SliceJoin(sep string, elems ...interface{}) string { + l := len(elems) + if l == 0 { + return "" + } + if l == 1 { + s := fmt.Sprint(elems[0]) + sLen := len(s) - 1 + if s[0] == '[' && s[sLen] == ']' { + return strings.Replace(s[1:sLen], " ", sep, -1) + } + return s + } + sep = strings.Replace(fmt.Sprint(elems), " ", sep, -1) + return sep[1 : len(sep)-1] +} diff --git a/utils/string.go b/utils/string.go new file mode 100644 index 0000000..fcf9043 --- /dev/null +++ b/utils/string.go @@ -0,0 +1,181 @@ +package utils + +import ( + "fmt" + "github.com/syyongx/php2go" + "reflect" + "regexp" + "sort" + "strings" + "unicode" +) + +func Implode(glue string, args ...interface{}) string { + data := make([]string, len(args)) + for i, s := range args { + data[i] = fmt.Sprint(s) + } + return strings.Join(data, glue) +} + +//字符串是否在数组里 +func InArr(target string, strArray []string) bool { + for _, element := range strArray { + if target == element { + return true + } + } + return false +} + +//把数组的值放到key里 +func ArrayColumn(array interface{}, key string) (result map[string]interface{}, err error) { + result = make(map[string]interface{}) + t := reflect.TypeOf(array) + v := reflect.ValueOf(array) + if t.Kind() != reflect.Slice { + return nil, nil + } + if v.Len() == 0 { + return nil, nil + } + for i := 0; i < v.Len(); i++ { + indexv := v.Index(i) + if indexv.Type().Kind() != reflect.Struct { + return nil, nil + } + mapKeyInterface := indexv.FieldByName(key) + if mapKeyInterface.Kind() == reflect.Invalid { + return nil, nil + } + mapKeyString, err := InterfaceToString(mapKeyInterface.Interface()) + if err != nil { + return nil, err + } + result[mapKeyString] = indexv.Interface() + } + return result, err +} + +//转string +func InterfaceToString(v interface{}) (result string, err error) { + switch reflect.TypeOf(v).Kind() { + case reflect.Int64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + result = fmt.Sprintf("%v", v) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + result = fmt.Sprintf("%v", v) + case reflect.String: + result = v.(string) + default: + err = nil + } + return result, err +} + +func HideTrueName(name string) string { + res := "**" + if name != "" { + runs := []rune(name) + leng := len(runs) + if leng <= 3 { + res = string(runs[0:1]) + res + } else if leng < 5 { + res = string(runs[0:2]) + res + } else if leng < 10 { + res = string(runs[0:2]) + "***" + string(runs[leng-2:leng]) + } else if leng < 16 { + res = string(runs[0:3]) + "****" + string(runs[leng-3:leng]) + } else { + res = string(runs[0:4]) + "*****" + string(runs[leng-4:leng]) + } + } + return res +} + +//是否有中文 +func IsChineseChar(str string) bool { + for _, r := range str { + if unicode.Is(unicode.Scripts["Han"], r) || (regexp.MustCompile("[\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]").MatchString(string(r))) { + return true + } + } + return false +} + +//清空前面是0的 +func LeadZeros(str string) string { + bytes := []byte(str) + var index int + for i, b := range bytes { + if b != byte(48) { + index = i + break + } + } + i := bytes[index:len(bytes)] + return string(i) +} +func GetQueryParam(uri string) map[string]string { + //根据问号分割路由还是query参数 + uriList := strings.Split(uri, "?") + var query = make(map[string]string, 0) + //有参数才处理 + if len(uriList) == 2 { + //分割query参数 + var queryList = strings.Split(uriList[1], "&") + if len(queryList) > 0 { + //key value 分别赋值 + for _, v := range queryList { + var valueList = strings.Split(v, "=") + if len(valueList) == 2 { + value, _ := php2go.URLDecode(valueList[1]) + if value == "" { + value = valueList[1] + } + query[valueList[0]] = value + } + } + } + } + return query +} + +//JoinStringsInASCII 按照规则,参数名ASCII码从小到大排序后拼接 +//data 待拼接的数据 +//sep 连接符 +//onlyValues 是否只包含参数值,true则不包含参数名,否则参数名和参数值均有 +//includeEmpty 是否包含空值,true则包含空值,否则不包含,注意此参数不影响参数名的存在 +//exceptKeys 被排除的参数名,不参与排序及拼接 +func JoinStringsInASCII(data map[string]string, sep string, onlyValues, includeEmpty bool, exceptKeys ...string) string { + var list []string + var keyList []string + m := make(map[string]int) + if len(exceptKeys) > 0 { + for _, except := range exceptKeys { + m[except] = 1 + } + } + for k := range data { + if _, ok := m[k]; ok { + continue + } + value := data[k] + if !includeEmpty && value == "" { + continue + } + if onlyValues { + keyList = append(keyList, k) + } else { + list = append(list, fmt.Sprintf("%s=%s", k, value)) + } + } + if onlyValues { + sort.Strings(keyList) + for _, v := range keyList { + list = append(list, AnyToString(data[v])) + } + } else { + sort.Strings(list) + } + return strings.Join(list, sep) +} diff --git a/utils/struct2UrlParams.go b/utils/struct2UrlParams.go new file mode 100644 index 0000000..7f8ac05 --- /dev/null +++ b/utils/struct2UrlParams.go @@ -0,0 +1,43 @@ +package utils + +import "sort" + +func Struct2UrlParams(obj interface{}) string { + var str = "" + mapVal := Struct2Map(obj) + var keys []string + for key := range mapVal { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + str += key + "=" + AnyToString(mapVal[key]) + "&" + } + /*t := reflect.TypeOf(origin) + v := reflect.ValueOf(origin) + + for i := 0; i < t.NumField(); i++ { + tag := strings.ToLower(t.Field(i).Tag.Get("json")) + if tag != "sign" { + switch v.Field(i).Kind(){ + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str += tag + "=" + utils.Int64ToStr(v.Field(i).Int()) + "&" + break + case reflect.String: + str += tag + "=" + v.Field(i).String() + "&" + break + case reflect.Bool: + str += tag + "=" + utils.BoolToStr(v.Field(i).Bool()) + "&" + break + case reflect.Float32, reflect.Float64: + str += tag + "=" + utils.Float64ToStr(v.Field(i).Float()) + break + case reflect.Array, reflect.Struct, reflect.Slice: + str += ToUrlKeyValue(v.Field(i).Interface()) + default: + break + } + } + }*/ + return str +} diff --git a/utils/time.go b/utils/time.go new file mode 100644 index 0000000..16fb214 --- /dev/null +++ b/utils/time.go @@ -0,0 +1,208 @@ +package utils + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" +) + +func StrToTime(s string) (int64, error) { + // delete all not int characters + if s == "" { + return time.Now().Unix(), nil + } + r := make([]rune, 14) + l := 0 + // 过滤除数字以外的字符 + for _, v := range s { + if '0' <= v && v <= '9' { + r[l] = v + l++ + if l == 14 { + break + } + } + } + for l < 14 { + r[l] = '0' // 补0 + l++ + } + t, err := time.Parse("20060102150405", string(r)) + if err != nil { + return 0, err + } + return t.Unix(), nil +} + +func TimeToStr(unixSecTime interface{}, layout ...string) string { + i := AnyToInt64(unixSecTime) + if i == 0 { + return "" + } + f := "2006-01-02 15:04:05" + if len(layout) > 0 { + f = layout[0] + } + return time.Unix(i, 0).Format(f) +} + +func FormatNanoUnix() string { + return strings.Replace(time.Now().Format("20060102150405.0000000"), ".", "", 1) +} + +func TimeParse(format, src string) (time.Time, error) { + return time.ParseInLocation(format, src, time.Local) +} + +func TimeParseStd(src string) time.Time { + t, _ := TimeParse("2006-01-02 15:04:05", src) + return t +} + +func TimeStdParseUnix(src string) int64 { + t, err := TimeParse("2006-01-02 15:04:05", src) + if err != nil { + return 0 + } + return t.Unix() +} + +// 获取一个当前时间 时间间隔 时间戳 +func GetTimeInterval(unit string, amount int) (startTime, endTime int64) { + t := time.Now() + nowTime := t.Unix() + tmpTime := int64(0) + switch unit { + case "years": + tmpTime = time.Date(t.Year()+amount, t.Month(), t.Day(), t.Hour(), 0, 0, 0, t.Location()).Unix() + case "months": + tmpTime = time.Date(t.Year(), t.Month()+time.Month(amount), t.Day(), t.Hour(), 0, 0, 0, t.Location()).Unix() + case "days": + tmpTime = time.Date(t.Year(), t.Month(), t.Day()+amount, t.Hour(), 0, 0, 0, t.Location()).Unix() + case "hours": + tmpTime = time.Date(t.Year(), t.Month(), t.Day(), t.Hour()+amount, 0, 0, 0, t.Location()).Unix() + } + if amount > 0 { + startTime = nowTime + endTime = tmpTime + } else { + startTime = tmpTime + endTime = nowTime + } + return +} + +// 几天前 +func TimeInterval(newTime int) string { + now := time.Now().Unix() + newTime64 := AnyToInt64(newTime) + if newTime64 >= now { + return "刚刚" + } + interval := now - newTime64 + switch { + case interval < 60: + return AnyToString(interval) + "秒前" + case interval < 60*60: + return AnyToString(interval/60) + "分前" + case interval < 60*60*24: + return AnyToString(interval/60/60) + "小时前" + case interval < 60*60*24*30: + return AnyToString(interval/60/60/24) + "天前" + case interval < 60*60*24*30*12: + return AnyToString(interval/60/60/24/30) + "月前" + default: + return AnyToString(interval/60/60/24/30/12) + "年前" + } +} + +// 时分秒字符串转时间戳,传入示例:8:40 or 8:40:10 +func HmsToUnix(str string) (int64, error) { + t := time.Now() + arr := strings.Split(str, ":") + if len(arr) < 2 { + return 0, errors.New("Time format error") + } + h, _ := strconv.Atoi(arr[0]) + m, _ := strconv.Atoi(arr[1]) + s := 0 + if len(arr) == 3 { + s, _ = strconv.Atoi(arr[3]) + } + formatted1 := fmt.Sprintf("%d%02d%02d%02d%02d%02d", t.Year(), t.Month(), t.Day(), h, m, s) + res, err := time.ParseInLocation("20060102150405", formatted1, time.Local) + if err != nil { + return 0, err + } else { + return res.Unix(), nil + } +} + +// 获取特定时间范围 +func GetTimeRange(s string) map[string]int64 { + t := time.Now() + var stime, etime time.Time + switch s { + case "today": + stime = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month(), t.Day()+1, 0, 0, 0, 0, t.Location()) + case "yesterday": + stime = time.Date(t.Year(), t.Month(), t.Day()-1, 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) + case "within_seven_days": + // 前6天0点 + stime = time.Date(t.Year(), t.Month(), t.Day()-6, 0, 0, 0, 0, t.Location()) + // 明天 0点 + etime = time.Date(t.Year(), t.Month(), t.Day()+1, 0, 0, 0, 0, t.Location()) + case "current_month": + stime = GetFirstDateOfMonth(t) + etime = time.Now() + case "last_month": + etime = GetFirstDateOfMonth(t) + monthTimes := TimeStdParseUnix(etime.Format("2006-01-02 15:04:05")) - 86400 + times, _ := UnixToTime(Int64ToStr(monthTimes)) + stime = GetFirstDateOfMonth(times) + + } + + return map[string]int64{ + "start": stime.Unix(), + "end": etime.Unix(), + } +} + +//时间戳转时间格式 +func UnixToTime(e string) (datatime time.Time, err error) { + data, err := strconv.ParseInt(e, 10, 64) + datatime = time.Unix(data, 0) + return +} + +//获取传入的时间所在月份的第一天,即某月第一天的0点。如传入time.Now(), 返回当前月份的第一天0点时间。 +func GetFirstDateOfMonth(d time.Time) time.Time { + d = d.AddDate(0, 0, -d.Day()+1) + return GetZeroTime(d) +} + +//获取传入的时间所在月份的最后一天,即某月最后一天的0点。如传入time.Now(), 返回当前月份的最后一天0点时间。 +func GetLastDateOfMonth(d time.Time) time.Time { + return GetFirstDateOfMonth(d).AddDate(0, 1, -1) +} + +//获取某一天的0点时间 +func GetZeroTime(d time.Time) time.Time { + return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, d.Location()) +} + +//获取当月某天的某个时间的时间 +func GetDayToTime(day, timeStr string) string { + if timeStr == "" { + timeStr = "00:00:00" + } + year := time.Now().Year() + month := time.Now().Format("01") + times := fmt.Sprintf("%s-%s-%s %s", IntToStr(year), month, day, timeStr) + return times +} diff --git a/utils/translate.go b/utils/translate.go new file mode 100644 index 0000000..3ad5bfe --- /dev/null +++ b/utils/translate.go @@ -0,0 +1,109 @@ +package utils + +import ( + "code.fnuoos.com/go_rely_warehouse/zyos_go_tool.git/utils/cache" + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/go-creed/sat" + "strings" +) + +func ReadReverse(c *gin.Context, str,WxappletFilepathURL string) string { + if c.GetString("translate_open") == "zh_Hant_" { //繁体先不改 + sat.InitDefaultDict(sat.SetPath(WxappletFilepathURL + "/" + "sat.txt")) //使用自定义词库 + sat := sat.DefaultDict() + res := sat.ReadReverse(str) + list := strings.Split(res, "http") + imgList := []string{".png", ".jpg", ".jpeg", ".gif"} + for _, v := range list { + for _, v1 := range imgList { + if strings.Contains(v, v1) { //判断是不是有图片 有图片就截取 替换简繁体 + strs := strings.Split(v, v1) + if len(strs) > 0 { + oldStr := strs[0] + newStr := sat.Read(oldStr) + res = strings.ReplaceAll(res, oldStr, newStr) + } + } + } + } + return res + } + if c.GetString("translate_open") != "zh_Hant_" { //除了繁体,其他都走这里 + //简体---其他语言 + cTouString, err := cache.GetString("multi_language_c_to_" + c.GetString("translate_open")) + if err != nil { + return str + } + var cTou = make(map[string]string) + json.Unmarshal([]byte(cTouString), &cTou) + if len(cTou) == 0 { + return str + } + //其他语言--简体 + getString1, err1 := cache.GetString("multi_language_" + c.GetString("translate_open") + "_to_c") + if err1 != nil { + return str + } + var uToc = make(map[string]string) + json.Unmarshal([]byte(getString1), &uToc) + if len(uToc) == 0 { + return str + } + res := str + for k, v := range cTou { + res = strings.ReplaceAll(res, k, v) + } + list := strings.Split(res, "http") + imgList := []string{".png", ".jpg", ".jpeg", ".gif"} + for _, v := range list { + for _, v1 := range imgList { + if strings.Contains(v, v1) { //判断是不是有图片 有图片就截取 替换简繁体 + strs := strings.Split(v, v1) + if len(strs) > 0 { + oldStr := strs[0] + newStr := oldStr + for k2, v2 := range uToc { + newStr = strings.ReplaceAll(oldStr, k2, v2) + } + res = strings.ReplaceAll(res, oldStr, newStr) + } + } + } + } + return res + } + return str + +} + +func ReadReverse1(str, types string) string { + res := map[string]map[string]string{} + err := cache.GetJson("multi_language", &res) + if err != nil { + return str + } + for k, v := range res { + str = strings.ReplaceAll(str, k, v[types]) + } + resStr := str + list := strings.Split(resStr, "http") + imgList := []string{".png", ".jpg", ".jpeg", ".gif"} + for _, v := range list { + for _, v1 := range imgList { + if strings.Contains(v, v1) { //判断是不是有图片 有图片就截取 替换简繁体 + strs := strings.Split(v, v1) + if len(strs) > 0 { + oldStr := strs[0] + for k2, v2 := range res { + if v2[types] == oldStr { + resStr = strings.ReplaceAll(resStr, oldStr, k2) + } + } + //res = strings.ReplaceAll(res, oldStr, newStr) + } + } + } + } + return resStr +} diff --git a/utils/trim_html.go b/utils/trim_html.go new file mode 100644 index 0000000..f796f32 --- /dev/null +++ b/utils/trim_html.go @@ -0,0 +1,38 @@ +package utils + +import ( + "regexp" + "strings" +) + +func TrimHtml(src string) string { + re, _ := regexp.Compile("<[\\S\\s]+?>") + src = re.ReplaceAllStringFunc(src, strings.ToLower) + + re, _ = regexp.Compile("") + src = re.ReplaceAllString(src, "") + + re, _ = regexp.Compile("") + src = re.ReplaceAllString(src, "") + + re, _ = regexp.Compile("<[\\S\\s]+?>") + src = re.ReplaceAllString(src, "\n") + + re, _ = regexp.Compile("\\s{2,}") + src = re.ReplaceAllString(src, "\n") + + return strings.TrimSpace(src) +} + +func TrimScript(src string) string { + re, _ := regexp.Compile("<[\\S\\s]+?>") + src = re.ReplaceAllStringFunc(src, strings.ToLower) + re, _ = regexp.Compile("") + src = re.ReplaceAllString(src, "") + + re, _ = regexp.Compile("") + + src = re.ReplaceAllString(src, "") + + return src +} diff --git a/utils/uuid.go b/utils/uuid.go new file mode 100644 index 0000000..346b7b2 --- /dev/null +++ b/utils/uuid.go @@ -0,0 +1,61 @@ +package utils + +import ( + "fmt" + "math/rand" + "time" +) + +const ( + KC_RAND_KIND_NUM = 0 // 纯数字 + KC_RAND_KIND_LOWER = 1 // 小写字母 + KC_RAND_KIND_UPPER = 2 // 大写字母 + KC_RAND_KIND_ALL = 3 // 数字、大小写字母 +) + +func newUUID() *[16]byte { + u := &[16]byte{} + rand.Read(u[:16]) + u[8] = (u[8] | 0x80) & 0xBf + u[6] = (u[6] | 0x40) & 0x4f + return u +} + +func UUIDString() string { + u := newUUID() + return fmt.Sprintf("%x-%x-%x-%x-%x", u[:4], u[4:6], u[6:8], u[8:10], u[10:]) +} + +func UUIDHexString() string { + u := newUUID() + return fmt.Sprintf("%x%x%x%x%x", u[:4], u[4:6], u[6:8], u[8:10], u[10:]) +} +func UUIDBinString() string { + u := newUUID() + return fmt.Sprintf("%s", [16]byte(*u)) +} + +func Krand(size int, kind int) []byte { + ikind, kinds, result := kind, [][]int{[]int{10, 48}, []int{26, 97}, []int{26, 65}}, make([]byte, size) + isAll := kind > 2 || kind < 0 + rand.Seed(time.Now().UnixNano()) + for i := 0; i < size; i++ { + if isAll { // random ikind + ikind = rand.Intn(3) + } + scope, base := kinds[ikind][0], kinds[ikind][1] + result[i] = uint8(base + rand.Intn(scope)) + } + return result +} + +// OrderUUID is only num for uuid +func OrderUUID(uid int) string { + ustr := IntToStr(uid) + tstr := Int64ToStr(time.Now().Unix()) + ulen := len(ustr) + tlen := len(tstr) + rlen := 18 - ulen - tlen + krb := Krand(rlen, KC_RAND_KIND_NUM) + return ustr + tstr + string(krb) +}