commit 860a7b8be8f50cf96e7660c9d18c98f1f1aedabd Author: DengBiao <2319963317@qq.com> Date: Mon Jun 26 13:55:31 2023 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9a8f686 --- /dev/null +++ b/.gitignore @@ -0,0 +1,46 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib +# Test binary, built with `go test -c` +*.test +# Output of the go coverage tool, specifically when used with LiteIDE +*.out +.idea +.vscode +*.log +.DS_Store +Thumbs.db +*.swp +*.swn +*.swo +*.swm +*.7z +*.zip +*.rar +*.tar +*.tar.gz +go.sum +/etc/cfg.yaml +images +test/test.json +etc/cfg.yml +t.json +t1.json +t2.json +t3.json +t.go +wait-for-it.sh +test.go +xorm +test.csv +nginx.conf +.devcontainer +.devcontainer/Dockerfile +.devcontainer/sources.list +/t1.go +/tmp/* +.idea/* +/.idea/modules.xml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c5905d1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +# 多重构建,减少镜像大小 +# 构建:使用golang:1.15版本 +FROM golang:1.15 as build + +# 容器环境变量添加,会覆盖默认的变量值 +ENV GO111MODULE=on +ENV GOPROXY=https://goproxy.cn,direct +ENV TZ="Asia/Shanghai" +# 设置工作区 +WORKDIR /go/release + +# 把全部文件添加到/go/release目录 +ADD . . + +# 编译:把main.go编译成可执行的二进制文件,命名为zyos +RUN GOOS=linux CGO_ENABLED=0 GOARCH=amd64 go build -tags netgo -ldflags="-s -w" -installsuffix cgo -o zyos main.go + +FROM ubuntu:xenial as prod +LABEL maintainer="wuhanqin" +ENV TZ="Asia/Shanghai" + +COPY static/html static/html +# 时区纠正 +RUN rm -f /etc/localtime \ + && ln -sv /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ + && echo "Asia/Shanghai" > /etc/timezone +# 在build阶段复制可执行的go二进制文件app +COPY --from=build /go/release/zyos ./zyos + +COPY --from=build /go/release/etc/cfg.yml /var/zyos/cfg.yml + +# 启动服务 +CMD ["./zyos","-c","/var/zyos/cfg.yml"] + diff --git a/Dockerfile-task b/Dockerfile-task new file mode 100644 index 0000000..905efe7 --- /dev/null +++ b/Dockerfile-task @@ -0,0 +1,34 @@ +# 多重构建,减少镜像大小 +# 构建:使用golang:1.15版本 +FROM golang:1.15 as build + +# 容器环境变量添加,会覆盖默认的变量值 +ENV GO111MODULE=on +ENV GOPROXY=https://goproxy.cn,direct +ENV TZ="Asia/Shanghai" +# 设置工作区 +WORKDIR /go/release + +# 把全部文件添加到/go/release目录 +ADD . . + +# 编译:把main.go编译成可执行的二进制文件,命名为zyos +RUN GOOS=linux CGO_ENABLED=0 GOARCH=amd64 go build -tags netgo -ldflags="-s -w" -installsuffix cgo -o zyos_mall_task cmd/task/main.go + +FROM ubuntu:xenial as prod +LABEL maintainer="wuhanqin" +ENV TZ="Asia/Shanghai" + +COPY static/html static/html +# 时区纠正 +RUN rm -f /etc/localtime \ + && ln -sv /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ + && echo "Asia/Shanghai" > /etc/timezone +# 在build阶段复制可执行的go二进制文件app +COPY --from=build /go/release/zyos_mall_task ./zyos_mall_task + +COPY --from=build /go/release/etc/task.yml /var/zyos/task.yml + +# 启动服务 +CMD ["./zyos_mall_task","-c","/var/zyos/task.yml"] + diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e7e30c2 --- /dev/null +++ b/Makefile @@ -0,0 +1,32 @@ +.PHONY: build clean tool lint help + +APP=applet + +all: build + +build: + go build -o ./bin/$(APP) ./cmd/main.go + +lite: + go build -ldflags "-s -w" -o ./bin/$(APP) ./cmd/main.go + +install: + #@go build -v . + go install ./cmd/... + +tool: + go vet ./...; true + gofmt -w . + +lint: + golint ./... + +clean: + rm -rf go-gin-example + go clean -i . + +help: + @echo "make: compile packages and dependencies" + @echo "make tool: run specified go tool" + @echo "make lint: golint ./..." + @echo "make clean: remove object files and cached files" \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..ca65483 --- /dev/null +++ b/README.md @@ -0,0 +1,48 @@ +# applet + +## 要看 nginx.conf 和 wap conf + +## 层级介绍 + +- hdl 做接收数据的报错, 数据校验 +- svc 做数据处理的报错, 数据转换 +- lib 只抛出错误给hdl或者svc进行处理, 不做数据校验 +- db 可以处理db错误,其它错误返回给svc进行处理 +- mw 中间件 +- md 结构体 + +#### 介绍 +基于gin的接口小程序 + +#### 软件架构 + +软件架构说明 + +#### 安装教程 + +1. xxxx +2. xxxx +3. xxxx + +#### 使用说明 + +1. xxxx +2. xxxx +3. xxxx + +#### 参与贡献 + +1. Fork 本仓库 +2. 新建 Feat_xxx 分支 +3. 提交代码 +4. 新建 Pull Request + +## swagger + +``` +// 参考:https://segmentfault.com/a/1190000013808421 +// 安装命令行 +go get -u github.com/swaggo/swag/cmd/swag +// 生成 +swag init +``` \ No newline at end of file diff --git a/app/cfg/cfg_app.go b/app/cfg/cfg_app.go new file mode 100644 index 0000000..53d6113 --- /dev/null +++ b/app/cfg/cfg_app.go @@ -0,0 +1,42 @@ +package cfg + +import ( + "time" +) + +type Config struct { + Debug bool `yaml:"debug"` + Prd bool `yaml:"prd"` + SrvAddr string `yaml:"srv_addr"` + RedisAddr string `yaml:"redis_addr"` + DB DBCfg `yaml:"db"` + Log LogCfg `yaml:"log"` +} + +//数据库配置结构体 +type DBCfg struct { + Host string `yaml:"host"` //ip及端口 + Name string `yaml:"name"` //库名 + User string `yaml:"user"` //用户 + Psw string `yaml:"psw"` //密码 + ShowLog bool `yaml:"show_log"` //是否显示SQL语句 + MaxLifetime time.Duration `yaml:"max_lifetime"` + MaxOpenConns int `yaml:"max_open_conns"` + MaxIdleConns int `yaml:"max_idle_conns"` + Path string `yaml:"path"` //日志文件存放路径 +} + +//日志配置结构体 +type LogCfg struct { + AppName string `yaml:"app_name" ` + Level string `yaml:"level"` + IsStdOut bool `yaml:"is_stdout"` + TimeFormat string `yaml:"time_format"` // second, milli, nano, standard, iso, + Encoding string `yaml:"encoding"` // console, json + + IsFileOut bool `yaml:"is_file_out"` + FileDir string `yaml:"file_dir"` + FileName string `yaml:"file_name"` + FileMaxSize int `yaml:"file_max_size"` + FileMaxAge int `yaml:"file_max_age"` +} diff --git a/app/cfg/cfg_cache_key.go b/app/cfg/cfg_cache_key.go new file mode 100644 index 0000000..c091909 --- /dev/null +++ b/app/cfg/cfg_cache_key.go @@ -0,0 +1,3 @@ +package cfg + +// 统一管理缓存 diff --git a/app/cfg/init_cache.go b/app/cfg/init_cache.go new file mode 100644 index 0000000..873657f --- /dev/null +++ b/app/cfg/init_cache.go @@ -0,0 +1,9 @@ +package cfg + +import ( + "applet/app/utils/cache" +) + +func InitCache() { + cache.NewRedis(RedisAddr) +} diff --git a/app/cfg/init_cfg.go b/app/cfg/init_cfg.go new file mode 100644 index 0000000..4fddb6c --- /dev/null +++ b/app/cfg/init_cfg.go @@ -0,0 +1,46 @@ +package cfg + +import ( + "flag" + "io/ioutil" + + "gopkg.in/yaml.v2" +) + +//配置文件数据,全局变量 +var ( + Debug bool + Prd bool + SrvAddr string + RedisAddr string + DB *DBCfg + Log *LogCfg +) + +//初始化配置文件,将cfg.yml读入到内存 +func InitCfg() { + //用指定的名称、默认值、使用信息注册一个string类型flag。 + path := flag.String("c", "etc/cfg.yml", "config file") + //解析命令行参数写入注册的flag里。 + //解析之后,flag的值可以直接使用。 + flag.Parse() + var ( + c []byte + err error + conf *Config + ) + if c, err = ioutil.ReadFile(*path); err != nil { + panic(err) + } + //yaml.Unmarshal反序列化映射到Config + if err = yaml.Unmarshal(c, &conf); err != nil { + panic(err) + } + //数据读入内存 + Prd = conf.Prd + Debug = conf.Debug + DB = &conf.DB + Log = &conf.Log + RedisAddr = conf.RedisAddr + SrvAddr = conf.SrvAddr +} diff --git a/app/cfg/init_log.go b/app/cfg/init_log.go new file mode 100644 index 0000000..0f31eb5 --- /dev/null +++ b/app/cfg/init_log.go @@ -0,0 +1,20 @@ +package cfg + +import "applet/app/utils/logx" + +func InitLog() { + logx.InitDefaultLogger(&logx.LogConfig{ + AppName: Log.AppName, + Level: Log.Level, + StacktraceLevel: "error", + IsStdOut: Log.IsStdOut, + TimeFormat: Log.TimeFormat, + Encoding: Log.Encoding, + IsFileOut: Log.IsFileOut, + FileDir: Log.FileDir, + FileName: Log.FileName, + FileMaxSize: Log.FileMaxSize, + FileMaxAge: Log.FileMaxAge, + Skip: 2, + }) +} diff --git a/app/cfg/init_task.go b/app/cfg/init_task.go new file mode 100644 index 0000000..0eec20e --- /dev/null +++ b/app/cfg/init_task.go @@ -0,0 +1,42 @@ +package cfg + +import ( + "flag" + "io/ioutil" + + "gopkg.in/yaml.v2" + + mc "applet/app/utils/cache/cache" + "applet/app/utils/logx" +) + +func InitTaskCfg() { + path := flag.String("c", "etc/task.yml", "config file") + flag.Parse() + var ( + c []byte + err error + conf *Config + ) + if c, err = ioutil.ReadFile(*path); err != nil { + panic(err) + } + if err = yaml.Unmarshal(c, &conf); err != nil { + panic(err) + } + Prd = conf.Prd + Debug = conf.Debug + DB = &conf.DB + Log = &conf.Log + RedisAddr = conf.RedisAddr +} + +var MemCache mc.Cache + +func InitMemCache() { + var err error + MemCache, err = mc.NewCache("memory", `{"interval":60}`) + if err != nil { + logx.Fatal(err.Error()) + } +} diff --git a/app/db/db.go b/app/db/db.go new file mode 100644 index 0000000..ea6f235 --- /dev/null +++ b/app/db/db.go @@ -0,0 +1,114 @@ +package db + +import ( + "database/sql" + "fmt" + "os" + + _ "github.com/go-sql-driver/mysql" + "xorm.io/xorm" + "xorm.io/xorm/log" + + "applet/app/cfg" + "applet/app/utils/logx" +) + +var Db *xorm.Engine + +func InitDB(c *cfg.DBCfg) error { + var err error + if Db, err = xorm.NewEngine("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4", c.User, c.Psw, c.Host, c.Name)); err != nil { + return err + } + Db.SetConnMaxLifetime(c.MaxLifetime) + Db.SetMaxOpenConns(c.MaxOpenConns) + Db.SetMaxIdleConns(c.MaxIdleConns) + if err = Db.Ping(); err != nil { + return err + } + if c.ShowLog { + Db.ShowSQL(true) + Db.Logger().SetLevel(0) + f, err := os.OpenFile(c.Path, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0777) + if err != nil { + os.RemoveAll(c.Path) + if f, err = os.OpenFile(c.Path, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0777); err != nil { + return err + } + } + logger := log.NewSimpleLogger(f) + logger.ShowSQL(true) + Db.SetLogger(logger) + } + return nil +} + +/********************************************* 公用方法 *********************************************/ + +// DbInsertBatch 数据批量插入 +func DbInsertBatch(Db *xorm.Engine, m ...interface{}) error { + if len(m) == 0 { + return nil + } + id, err := Db.Insert(m...) + if id == 0 || err != nil { + return logx.Warn("cannot insert data :", err) + } + return nil +} + +// QueryNativeString 查询原生sql +func QueryNativeString(Db *xorm.Engine, sql string, args ...interface{}) ([]map[string]string, error) { + results, err := Db.SQL(sql, args...).QueryString() + return results, err +} +func QueryNativeStringSess(sess *xorm.Session, sql string, args ...interface{}) ([]map[string]string, error) { + results, err := sess.SQL(sql, args...).QueryString() + return results, err +} + +// CommonInsert 插入一条或多条数据 +func CommonInsert(Db *xorm.Engine, data interface{}) (int64, error) { + row, err := Db.Insert(data) + return row, err +} + +// UpdateComm 根据主键更新 +func UpdateComm(Db *xorm.Engine, id interface{}, model interface{}) (int64, error) { + row, err := Db.ID(id).Update(model) + return row, err +} + +// InsertOneComm 插入一条数据 +func InsertOneComm(Db *xorm.Engine, model interface{}) (int64, error) { + row, err := Db.InsertOne(model) + return row, err +} + +// GetComm 获取一条数据 +// payload *model +// return *model,has,err +func GetComm(Db *xorm.Engine, model interface{}) (interface{}, bool, error) { + has, err := Db.Get(model) + if err != nil { + _ = logx.Warn(err) + return nil, false, err + } + return model, has, nil +} + +// ExecuteOriginalSql 执行原生sql +func ExecuteOriginalSql(Db *xorm.Engine, sql string) (sql.Result, error) { + result, err := Db.Exec(sql) + if err != nil { + _ = logx.Warn(err) + return nil, err + } + return result, nil +} + +// InsertCommWithSession common insert +func InsertCommWithSession(session *xorm.Session, model interface{}) (int64, error) { + row, err := session.InsertOne(model) + return row, err +} diff --git a/app/db/db_admin.go b/app/db/db_admin.go new file mode 100644 index 0000000..22dc828 --- /dev/null +++ b/app/db/db_admin.go @@ -0,0 +1,39 @@ +package db + +import ( + "applet/app/db/model" + "applet/app/utils/logx" + "xorm.io/xorm" +) + +type AdminDb struct { + Db *xorm.Engine `json:"db"` +} + +func (adminDb *AdminDb) Set() { // set方法 + adminDb.Db = Db +} + +func (adminDb *AdminDb) GetAdmin(id int) (m *model.Admin, err error) { + m = new(model.Admin) + has, err := adminDb.Db.Where("adm_id =?", id).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} + +func (adminDb *AdminDb) GetAdminByUserName(userName string) (m *model.Admin, err error) { + m = new(model.Admin) + has, err := adminDb.Db.Where("username =?", userName).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} diff --git a/app/db/db_qrcode.go b/app/db/db_qrcode.go new file mode 100644 index 0000000..c84e098 --- /dev/null +++ b/app/db/db_qrcode.go @@ -0,0 +1,63 @@ +package db + +import ( + "applet/app/db/model" + "applet/app/enum" + "applet/app/utils/logx" + "xorm.io/xorm" +) + +type QrcodeDb struct { + Db *xorm.Engine `json:"db"` +} + +func (qrcodeDb *QrcodeDb) Set() { // set方法 + qrcodeDb.Db = Db +} + +func (qrcodeDb *QrcodeDb) GetQrcode(id int) (m *model.Qrcode, err error) { + m = new(model.Qrcode) + has, err := qrcodeDb.Db.Where("id =?", id).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} + +func (qrcodeDb *QrcodeDb) GetQrcodeForAllowUse() (m *model.Qrcode, err error) { + m = new(model.Qrcode) + has, err := qrcodeDb.Db.Where("state =?", enum.QrcodeSateAllowUse).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} + +func (qrcodeDb *QrcodeDb) FindQrcodeForAllowUse() (m []*model.Qrcode, total int64, err error) { + total, err = qrcodeDb.Db.Where("state =?", enum.QrcodeSateAllowUse).FindAndCount(&m) + return +} + +func (qrcodeDb *QrcodeDb) BatchAddQrcode(data []*model.Qrcode) (int64, error) { + affected, err := qrcodeDb.Db.Insert(data) + if err != nil { + return 0, err + } + return affected, nil +} + +func (qrcodeDb *QrcodeDb) BatchUpdateQrcodeBySession(session *xorm.Session, ids []int, state int32) (int64, error) { + m := new(model.Qrcode) + m.State = state + affected, err := session.In("id", ids).Cols("state").Update(m) + if err != nil { + return 0, err + } + return affected, nil +} diff --git a/app/db/db_qrcode_batch.go b/app/db/db_qrcode_batch.go new file mode 100644 index 0000000..cb6dc90 --- /dev/null +++ b/app/db/db_qrcode_batch.go @@ -0,0 +1,67 @@ +package db + +import ( + "applet/app/db/model" + "applet/app/utils/logx" + "xorm.io/xorm" +) + +type QrcodeBatchDb struct { + Db *xorm.Engine `json:"db"` +} + +func (qrcodeBatchDb *QrcodeBatchDb) Set() { // set方法 + qrcodeBatchDb.Db = Db +} + +func (qrcodeBatchDb *QrcodeBatchDb) GetQrcodeBatchById(id int) (m *model.QrcodeBatch, err error) { + m = new(model.QrcodeBatch) + has, err := qrcodeBatchDb.Db.Where("id =?", id).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} + +func (qrcodeBatchDb *QrcodeBatchDb) DeleteQrcodeBatchBySession(session *xorm.Session, id int) (delResult int64, err error) { + m := new(model.QrcodeBatch) + delResult, err = session.Where("id =?", id).Delete(m) + return +} + +func (qrcodeBatchDb *QrcodeBatchDb) GeLastId() (m *model.QrcodeBatch, err error) { + m = new(model.QrcodeBatch) + has, err := qrcodeBatchDb.Db.OrderBy("id Desc").Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} + +func (qrcodeBatchDb *QrcodeBatchDb) GetQrcodeBatchByName(name string) (m *model.QrcodeBatch, err error) { + m = new(model.QrcodeBatch) + has, err := qrcodeBatchDb.Db.Where("name =?", name).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} + +func (qrcodeBatchDb *QrcodeBatchDb) List(page, limit int) (m []*model.QrcodeBatch, total int64, err error) { + total, err = qrcodeBatchDb.Db.Desc("id").Limit(limit, (page-1)*limit).FindAndCount(&m) + return +} + +func (qrcodeBatchDb *QrcodeBatchDb) AddBySession(session *xorm.Session, m *model.QrcodeBatch) (err error) { + _, err = session.InsertOne(m) + return +} diff --git a/app/db/db_qrcode_with_batch_records.go b/app/db/db_qrcode_with_batch_records.go new file mode 100644 index 0000000..732eb40 --- /dev/null +++ b/app/db/db_qrcode_with_batch_records.go @@ -0,0 +1,66 @@ +package db + +import ( + "applet/app/db/model" + "applet/app/utils/logx" + "xorm.io/xorm" +) + +type QrcodeWithBatchRecordsDb struct { + Db *xorm.Engine `json:"db"` +} + +func (qrcodeWithBatchRecordsDb *QrcodeWithBatchRecordsDb) Set() { // set方法 + qrcodeWithBatchRecordsDb.Db = Db +} + +func (qrcodeWithBatchRecordsDb *QrcodeWithBatchRecordsDb) GetQrcodeWithBatchRecordsById(id int) (m *model.QrcodeWithBatchRecords, err error) { + m = new(model.QrcodeWithBatchRecords) + has, err := qrcodeWithBatchRecordsDb.Db.Where("id =?", id).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} + +func (qrcodeWithBatchRecordsDb *QrcodeWithBatchRecordsDb) FindQrcodeWithBatchRecordsByState(state int32) (m []*model.QrcodeWithBatchRecords, err error) { + err = qrcodeWithBatchRecordsDb.Db.Where("state =?", state).Find(&m) + if err != nil { + return nil, logx.Error(err) + } + return m, nil +} + +func (qrcodeWithBatchRecordsDb *QrcodeWithBatchRecordsDb) BatchAddQrcodeWithBatchRecordsBySession(session *xorm.Session, data []*model.QrcodeWithBatchRecords) (int64, error) { + affected, err := session.Insert(data) + if err != nil { + return 0, err + } + return affected, nil +} + +func (qrcodeWithBatchRecordsDb *QrcodeWithBatchRecordsDb) FindQrcodeWithBatchRecordsById(batchId int) (m []*model.QrcodeWithBatchRecords, total int64, err error) { + total, err = qrcodeWithBatchRecordsDb.Db.Where("batch_id =?", batchId).FindAndCount(&m) + return +} + +func (qrcodeWithBatchRecordsDb *QrcodeWithBatchRecordsDb) DeleteQrcodeWithBatchRecordsBySession(session *xorm.Session, batchId int) (delResult int64, err error) { + m := new(model.QrcodeWithBatchRecords) + delResult, err = session.Where("batch_id =?", batchId).Delete(m) + return +} + +type QrcodeWithBatchRecords struct { + model.QrcodeWithBatchRecords `xorm:"extends"` + model.Qrcode `xorm:"extends"` +} + +func (qrcodeWithBatchRecordsDb *QrcodeWithBatchRecordsDb) FindQrcodeWithBatchRecordsLeftJoinQrcode(batchId int) (m []*QrcodeWithBatchRecords, total int64, err error) { + total, err = qrcodeWithBatchRecordsDb.Db.Where("batch_id =?", batchId). + Join("LEFT", "qrcode", "qrcode_with_batch_records.qrcode_id = qrcode.id"). + FindAndCount(&m) + return +} diff --git a/app/db/db_sys_cfg.go b/app/db/db_sys_cfg.go new file mode 100644 index 0000000..0a71e32 --- /dev/null +++ b/app/db/db_sys_cfg.go @@ -0,0 +1,119 @@ +package db + +import ( + "applet/app/db/model" + "applet/app/md" + "applet/app/utils/cache" + "applet/app/utils/logx" + "fmt" + "xorm.io/xorm" +) + +type SysCfgDb struct { + Db *xorm.Engine `json:"db"` +} + +func (sysCfgDb *SysCfgDb) Set() { // set方法 + sysCfgDb.Db = Db +} + +func (sysCfgDb *SysCfgDb) SysCfgGetAll() (*[]model.SysCfg, error) { + var cfgList []model.SysCfg + if err := Db.Cols("key,val,memo").Find(&cfgList); err != nil { + return nil, logx.Error(err) + } + return &cfgList, nil +} + +func (sysCfgDb *SysCfgDb) SysCfgGetOneNoDataNoErr(key string) (*model.SysCfg, error) { + var cfgList model.SysCfg + has, err := Db.Where("`key`=?", key).Get(&cfgList) + if err != nil { + return nil, logx.Error(err) + } + if !has { + return nil, nil + } + return &cfgList, nil +} + +func (sysCfgDb *SysCfgDb) SysCfgGetOne(key string) (*model.SysCfg, error) { + var cfgList model.SysCfg + if has, err := Db.Where("`key`=?", key).Get(&cfgList); err != nil || has == false { + return nil, logx.Error(err) + } + return &cfgList, nil +} + +func (sysCfgDb *SysCfgDb) SysCfgInsert(key, val, memo string) bool { + cfg := model.SysCfg{Key: key, Val: val, Memo: memo} + _, err := Db.InsertOne(&cfg) + if err != nil { + logx.Error(err) + return false + } + return true +} + +func (sysCfgDb *SysCfgDb) SysCfgUpdate(key, val string) bool { + cfg := model.SysCfg{Key: key, Val: val} + _, err := Db.Where("`key`=?", key).Cols("val").Update(&cfg) + if err != nil { + logx.Error(err) + return false + } + sysCfgDb.SysCfgDel(key) + return true +} + +func (sysCfgDb *SysCfgDb) SysCfgGetWithDb(HKey string) string { + cacheKey := fmt.Sprintf(md.AppCfgCacheKey, HKey[0:1]) + get, err := cache.HGetString(cacheKey, HKey) + if err != nil || get == "" { + cfg, err := sysCfgDb.SysCfgGetOne(HKey) + if err != nil || cfg == nil { + _ = logx.Error(err) + return "" + } + + // key是否存在 + cacheKeyExist := false + if cache.Exists(cacheKey) { + cacheKeyExist = true + } + + // 设置缓存 + _, err = cache.HSet(cacheKey, HKey, cfg.Val) + if err != nil { + _ = logx.Error(err) + return "" + } + if !cacheKeyExist { // 如果是首次设置 设置过期时间 + _, err := cache.Expire(cacheKey, md.CfgCacheTime) + if err != nil { + _ = logx.Error(err) + return "" + } + } + return cfg.Val + } + return get +} + +func (sysCfgDb *SysCfgDb) SysCfgDel(HKey string) error { + cacheKey := fmt.Sprintf(md.AppCfgCacheKey, HKey[0:1]) + _, err := cache.HDel(cacheKey, HKey) + if err != nil { + return err + } + return nil +} + +func (sysCfgDb *SysCfgDb) SysCfgFindWithDb(keys ...string) map[string]string { + res := map[string]string{} + for _, v := range keys { + val := sysCfgDb.SysCfgGetWithDb(v) + res[v] = val + } + return res +} diff --git a/app/db/db_user_follow_wx_official_account.go b/app/db/db_user_follow_wx_official_account.go new file mode 100644 index 0000000..5dc29ea --- /dev/null +++ b/app/db/db_user_follow_wx_official_account.go @@ -0,0 +1,27 @@ +package db + +import ( + "applet/app/db/model" + "applet/app/utils/logx" + "xorm.io/xorm" +) + +type UserFollowWxOfficialAccountDb struct { + Db *xorm.Engine `json:"db"` +} + +func (userFollowWxOfficialAccountDb *UserFollowWxOfficialAccountDb) Set() { // set方法 + userFollowWxOfficialAccountDb.Db = Db +} + +func (userFollowWxOfficialAccountDb *UserFollowWxOfficialAccountDb) GetUserFollowWxOfficialAccountByOpenId(openId string) (m *model.UserFollowWxOfficialAccount, err error) { + m = new(model.UserFollowWxOfficialAccount) + has, err := userFollowWxOfficialAccountDb.Db.Where("user_wx_open_id =?", openId).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} diff --git a/app/db/db_user_use_qrcode_records.go b/app/db/db_user_use_qrcode_records.go new file mode 100644 index 0000000..2d91e00 --- /dev/null +++ b/app/db/db_user_use_qrcode_records.go @@ -0,0 +1,47 @@ +package db + +import ( + "applet/app/db/model" + "applet/app/utils/logx" + "xorm.io/xorm" +) + +type UserUseQrcodeRecordsDb struct { + Db *xorm.Engine `json:"db"` +} + +func (userUseQrcodeRecordsDb *UserUseQrcodeRecordsDb) Set() { // set方法 + userUseQrcodeRecordsDb.Db = Db +} + +func (userUseQrcodeRecordsDb *UserUseQrcodeRecordsDb) GetUserUseQrcodeRecordsById(recordsId int) (m *model.UserUseQrcodeRecords, err error) { + m = new(model.UserUseQrcodeRecords) + has, err := userUseQrcodeRecordsDb.Db.Where("records_id =?", recordsId).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} + +func (userUseQrcodeRecordsDb *UserUseQrcodeRecordsDb) GetUserUseQrcodeRecordsByOpenId(openId string) (m *model.UserUseQrcodeRecords, err error) { + m = new(model.UserUseQrcodeRecords) + has, err := userUseQrcodeRecordsDb.Db.Where("user_wx_open_id =?", openId).Get(m) + if err != nil { + return nil, logx.Error(err) + } + if has == false { + return nil, nil + } + return m, nil +} + +func (userUseQrcodeRecordsDb *UserUseQrcodeRecordsDb) InsertUserUseQrcodeRecords(m *model.UserUseQrcodeRecords) (int64, error) { + _, err := userUseQrcodeRecordsDb.Db.InsertOne(m) + if err != nil { + return 0, err + } + return m.Id, nil +} diff --git a/app/db/model/admin.go b/app/db/model/admin.go new file mode 100644 index 0000000..bff047a --- /dev/null +++ b/app/db/model/admin.go @@ -0,0 +1,10 @@ +package model + +type Admin struct { + AdmId int `json:"adm_id" xorm:"not null comment('管理员id') INT(11)"` + Username string `json:"username" xorm:"not null default '' comment('用户名') VARCHAR(255)"` + Password string `json:"password" xorm:"not null default '' comment('密码') VARCHAR(255)"` + State int32 `json:"state" xorm:"not null default 1 comment('状态') TINYINT(1)"` + CreateAt string `json:"create_at" xorm:"not null default CURRENT_TIMESTAMP comment('创建时间') TIMESTAMP"` + UpdateAt string `json:"update_at" xorm:"not null default CURRENT_TIMESTAMP comment('更新时间') TIMESTAMP"` +} diff --git a/app/db/model/qrcode.go b/app/db/model/qrcode.go new file mode 100644 index 0000000..b01be06 --- /dev/null +++ b/app/db/model/qrcode.go @@ -0,0 +1,10 @@ +package model + +type Qrcode struct { + Id int `json:"id" xorm:"not null pk autoincr INT(11)"` + Url string `json:"url" xorm:"not null default '' comment('url地址') VARCHAR(255)"` + State int32 `json:"state" xorm:"not null default 1 comment('状态(1:可用 2:不可用) ') TINYINT(1)"` + Index string `json:"index" xorm:"not null default '' comment('唯一标识符(随机6位字母+数字)') CHAR(50)"` + CreateAt string `json:"create_at" xorm:"not null default CURRENT_TIMESTAMP comment('创建时间') TIMESTAMP"` + UpdateAt string `json:"update_at" xorm:"not null default CURRENT_TIMESTAMP comment('更新时间') TIMESTAMP"` +} diff --git a/app/db/model/qrcode_batch.go b/app/db/model/qrcode_batch.go new file mode 100644 index 0000000..2e5597f --- /dev/null +++ b/app/db/model/qrcode_batch.go @@ -0,0 +1,13 @@ +package model + +type QrcodeBatch struct { + Id int `json:"id" xorm:"not null pk autoincr INT(11)"` + Name string `json:"name" xorm:"not null default '' comment('名称') VARCHAR(255)"` + TotalNum int `json:"total_num" xorm:"not null default 0' comment('总数量') INT(11)"` + TotalAmount string `json:"total_amount" xorm:"not null default 0.00 comment('总金额') DECIMAL(6,2)"` + State int32 `json:"state" xorm:"not null default 1 comment('状态(1:使用中 2:使用完 3:已过期 4:已作废)') TINYINT(1)"` + ExpireDate string `json:"expire_date" xorm:"not null default 0000-00-00 comment('截止日期') CHAR(50)"` + Memo string `json:"memo" xorm:"not null default '' comment('备注') VARCHAR(255)"` + CreateAt string `json:"create_at" xorm:"not null default CURRENT_TIMESTAMP comment('创建时间') TIMESTAMP"` + UpdateAt string `json:"update_at" xorm:"not null default CURRENT_TIMESTAMP comment('更新时间') TIMESTAMP"` +} diff --git a/app/db/model/qrcode_with_batch_records.go b/app/db/model/qrcode_with_batch_records.go new file mode 100644 index 0000000..7490eb9 --- /dev/null +++ b/app/db/model/qrcode_with_batch_records.go @@ -0,0 +1,11 @@ +package model + +type QrcodeWithBatchRecords struct { + Id int64 `json:"id" xorm:"not null pk autoincr BIGINT(32)"` + QrcodeId int `json:"qrcode_id" xorm:"not null default 0' comment('二维码id') INT(11)"` + BatchId int `json:"batch_id" xorm:"not null default 0' comment('批次id') INT(11)"` + Amount string `json:"amount" xorm:"not null default 0.00 comment('金额') DECIMAL(6,2)"` + State int32 `json:"state" xorm:"not null default 1 comment('状态(1:待使用 2:已使用 3:已过期 4:已作废)') TINYINT(1)"` + CreateAt string `json:"create_at" xorm:"not null default CURRENT_TIMESTAMP comment('创建时间') TIMESTAMP"` + UpdateAt string `json:"update_at" xorm:"not null default CURRENT_TIMESTAMP comment('更新时间') TIMESTAMP"` +} diff --git a/app/db/model/sys_cfg.go b/app/db/model/sys_cfg.go new file mode 100644 index 0000000..22d906b --- /dev/null +++ b/app/db/model/sys_cfg.go @@ -0,0 +1,7 @@ +package model + +type SysCfg struct { + Key string `json:"key" xorm:"not null pk comment('键') VARCHAR(127)"` + Val string `json:"val" xorm:"comment('值') TEXT"` + Memo string `json:"memo" xorm:"not null default '' comment('备注') VARCHAR(255)"` +} diff --git a/app/db/model/user_follow_wx_official_account.go b/app/db/model/user_follow_wx_official_account.go new file mode 100644 index 0000000..b536314 --- /dev/null +++ b/app/db/model/user_follow_wx_official_account.go @@ -0,0 +1,8 @@ +package model + +type UserFollowWxOfficialAccount struct { + Id int64 `json:"id" xorm:"not null pk autoincr BIGINT(32)"` + UserWxOpenId string `json:"user_wx_open_id" xorm:"not null default '' comment('用户微信open_id') VARCHAR(255)"` + CreateAt string `json:"create_at" xorm:"not null default CURRENT_TIMESTAMP comment('创建时间') TIMESTAMP"` + UpdateAt string `json:"update_at" xorm:"not null default CURRENT_TIMESTAMP comment('更新时间') TIMESTAMP"` +} diff --git a/app/db/model/user_use_qrcode_records.go b/app/db/model/user_use_qrcode_records.go new file mode 100644 index 0000000..b7f4342 --- /dev/null +++ b/app/db/model/user_use_qrcode_records.go @@ -0,0 +1,10 @@ +package model + +type UserUseQrcodeRecords struct { + Id int64 `json:"id" xorm:"not null pk autoincr BIGINT(32)"` + UserWxOpenId string `json:"user_wx_open_id" xorm:"not null default '' comment('用户微信open_id') VARCHAR(255)"` + RecordsId int64 `json:"records_id" xorm:"not null default 0 comment('二维码记录id') BIGINT(32)"` + State int32 `json:"state" xorm:"not null default 1 comment('状态(0:未发送 1:已发送 2:发送失败)') TINYINT(1)"` + CreateAt string `json:"create_at" xorm:"not null default CURRENT_TIMESTAMP comment('创建时间') TIMESTAMP"` + UpdateAt string `json:"update_at" xorm:"not null default CURRENT_TIMESTAMP comment('更新时间') TIMESTAMP"` +} diff --git a/app/e/code.go b/app/e/code.go new file mode 100644 index 0000000..cc8be46 --- /dev/null +++ b/app/e/code.go @@ -0,0 +1,236 @@ +package e + +const ( + // 200 因为部分第三方接口不能返回错误头,因此在此定义部分错误 + ERR_FILE_SAVE = 200001 + // 400 系列 + ERR_BAD_REQUEST = 400000 + ERR_INVALID_ARGS = 400001 + ERR_API_RESPONSE = 400002 + ERR_NO_DATA = 400003 + ERR_MOBILE_NIL = 400004 + ERR_MOBILE_MATH = 400005 + ERR_FILE_EXT = 400006 + ERR_FILE_MAX_SIZE = 400007 + ERR_SIGN = 400008 + ERR_PASSWORD_MATH = 400009 + ERR_PROVIDER_RESPONSE = 400010 + ERR_AES_ENCODE = 400011 + ERR_ADMIN_API = 400012 + ERR_QINIUAPI_RESPONSE = 400013 + ERR_URL_TURNCHAIN = 400014 + + // 401 未授权 + ERR_UNAUTHORIZED = 401000 + ERR_NOT_AUTH = 401001 + ERR_SMS_AUTH = 401002 + ERR_TOKEN_AUTH = 401003 + ERR_TOKEN_FORMAT = 401004 + ERR_TOKEN_GEN = 401005 + ERR_CACHE_SET = 401006 + // 403 禁止 + ERR_FORBIDEN = 403000 + ERR_PLATFORM = 403001 + ERR_MOBILE_EXIST = 403002 + ERR_USER_NO_EXIST = 403003 + ERR_MOBILE_NO_EXIST = 403004 + ERR_FORBIDEN_VALID = 403005 + ERR_RELATE_ERR = 403006 + ERR_REPEAT_RELATE = 403007 + ERR_MOB_FORBIDEN = 403008 + ERR_MOB_SMS_NO_AVA = 403009 + ERR_USER_IS_REG = 403010 + ERR_MASTER_ID = 403011 + ERR_CASH_OUT_TIME = 403012 + ERR_CASH_OUT_FEE = 403013 + ERR_CASH_OUT_USER_NOT_FOUND = 403014 + ERR_CASH_OUT_FAIL = 403015 + ERR_CASH_OUT_TIMES = 403016 + ERR_CASH_OUT_MINI = 403017 + ERR_CASH_OUT_MUT = 403018 + ERR_CASH_OUT_NOT_DECIMAL = 403019 + ERR_CASH_OUT_NOT_DAY_AVA = 403020 + ERR_USER_LEVEL_PAY_CHECK_TASK_NO_DONE = 403021 + ERR_USER_LEVEL_PAY_CHECK_NO_CROSS = 403022 + ERR_USER_LEVEL_ORD_EXP = 403023 + ERR_IS_BIND_THIRDPARTY = 403024 + ERR_USER_LEVEL_UPDATE_CHECK_TASK_NO_DONE = 403025 + ERR_USER_LEVEL_UPDATE_CHECK_NOT_FOUND_ORDER = 403026 + ERR_USER_LEVEL_UPDATE_REPEAT = 403027 + ERR_USER_NO_ACTIVE = 403028 + ERR_USER_IS_BAN = 403029 + ERR_ALIPAY_SETTING = 403030 + ERR_ALIPAY_ORDERTYPE = 403031 + ERR_CLIPBOARD_UNSUP = 403032 + ERR_SYSUNION_CONFIG = 403033 + ERR_WECAHT_MINI = 403034 + ERR_WECAHT_MINI_CACHE = 403035 + ERR_WECAHT_MINI_DECODE = 403036 + ERR_WECHAT_MINI_ACCESSTOKEN = 403037 + ERR_CURRENT_VIP_LEVEL_AUDITING = 403038 + ERR_LEVEL_RENEW_SHOULD_KEEP_CURRENT = 403039 + ERR_LEVEL_UPGRADE_APPLY_AUDITTING = 403040 + ERR_LEVEL_TASK_PAY_TYPE = 403041 + ERR_BALANCE_NOT_ENOUGH = 403042 + ERR_ADMIN_PUSH = 403043 + ERR_PLAN = 403044 + ERR_MOB_CONFIG = 403045 + ERR_BAlANCE_PAY_ORDERTYPE = 403046 + ERR_PHONE_EXISTED = 403047 + ERR_NOT_RESULT = 403048 + ERR_REVIEW = 403049 + ERR_USER_LEVEL_HAS_PAID = 403050 + ERR_USER_BIND_OWN = 403051 + ERR_PARENTUID_ERR = 403052 + ERR_USER_DEL = 403053 + ERR_SEARCH_ERR = 403054 + ERR_LEVEL_REACH_TOP = 403055 + ERR_USER_CHECK_ERR = 403056 + ERR_PASSWORD_ERR = 403057 + // 404 + ERR_USER_NOTFOUND = 404001 + ERR_SUP_NOTFOUND = 404002 + ERR_LEVEL_MAP = 404003 + ERR_MOD_NOTFOUND = 404004 + ERR_CLIPBOARD_PARSE = 404005 + ERR_NOT_FAN = 404006 + ERR_USER_LEVEL = 404007 + ERR_LACK_PAY_CFG = 404008 + ERR_NOT_LEVEL_TASK = 404009 + ERR_ITEM_NOT_FOUND = 404010 + ERR_WX_CHECKFILE_NOTFOUND = 404011 + + // 429 请求频繁 + ERR_TOO_MANY_REQUESTS = 429000 + // 500 系列 + ERR = 500000 + ERR_UNMARSHAL = 500001 + ERR_UNKNOWN = 500002 + ERR_SMS = 500003 + ERR_ARKID_REGISTER = 500004 + ERR_ARKID_WHITELIST = 500005 + ERR_ARKID_LOGIN = 500006 + ERR_CFG = 500007 + ERR_DB_ORM = 500008 + ERR_CFG_CACHE = 500009 + ERR_ZHIMENG_CONVERT_ERR = 500010 + ERR_ALIPAY_ERR = 500011 + ERR_ALIPAY_ORDER_ERR = 500012 + ERR_PAY_ERR = 500013 + ERR_IS_BIND_THIRDOTHER = 500014 +) + +var MsgFlags = map[int]string{ + // 200 + ERR_FILE_SAVE: "文件保存失败", + // 400 + ERR_BAD_REQUEST: "请求失败", + ERR_INVALID_ARGS: "请求参数错误", + ERR_API_RESPONSE: "API错误", + ERR_QINIUAPI_RESPONSE: "七牛请求API错误", + ERR_URL_TURNCHAIN: "转链失败", + ERR_NO_DATA: "暂无数据", + ERR_MOBILE_NIL: "电话号码不能为空", + ERR_MOBILE_MATH: "电话号码输入有误", + ERR_FILE_MAX_SIZE: "文件上传大小超限", + ERR_FILE_EXT: "文件类型不支持", + ERR_SIGN: "签名校验失败", + ERR_PROVIDER_RESPONSE: "提供商接口错误", + ERR_AES_ENCODE: "加解密错误", + ERR_ADMIN_API: "后台接口请求失败", + // 401 + ERR_NOT_AUTH: "请登录后操作", + ERR_SMS_AUTH: "验证码过期或无效", + ERR_UNAUTHORIZED: "验证用户失败", + ERR_TOKEN_FORMAT: "Token格式不对", + ERR_TOKEN_GEN: "生成Token失败", + ERR_CACHE_SET: "生成缓存失败", + // 403 + ERR_FORBIDEN: "禁止访问", + ERR_PLATFORM: "平台不支持", + ERR_MOBILE_EXIST: "该号码已注册过", + ERR_USER_NO_EXIST: "用户没有注册或账号密码不正确", + ERR_PASSWORD_ERR: "输入两次密码不一致", + ERR_RELATE_ERR: "推荐人不能是自己的粉丝", + ERR_PARENTUID_ERR: "推荐人不存在", + ERR_TOKEN_AUTH: "登录信息失效,请重新登录", + ERR_MOB_SMS_NO_AVA: "短信余额不足或智盟短信配置失败", + ERR_USER_IS_REG: "用户已注册", + ERR_MASTER_ID: "找不到对应站长的数据库", + ERR_CASH_OUT_TIME: "非可提现时间段", + ERR_CASH_OUT_USER_NOT_FOUND: "收款账号不存在", + ERR_CASH_OUT_FAIL: "提现失败", + ERR_CASH_OUT_FEE: "提现金额必须大于手续费", + ERR_CASH_OUT_TIMES: "当日提现次数已达上线", + ERR_CASH_OUT_MINI: "申请提现金额未达到最低金额要求", + ERR_CASH_OUT_MUT: "申请提现金额未达到整数倍要求", + ERR_CASH_OUT_NOT_DECIMAL: "提现申请金额只能是整数", + ERR_CASH_OUT_NOT_DAY_AVA: "不在可提现日期范围内", + ERR_USER_LEVEL_PAY_CHECK_TASK_NO_DONE: "请先完成其他任务", + ERR_USER_LEVEL_PAY_CHECK_NO_CROSS: "无法跨越升级", + ERR_USER_LEVEL_ORD_EXP: "付费订单已失效", + ERR_IS_BIND_THIRDPARTY: "该用户已经绑定了", + ERR_IS_BIND_THIRDOTHER: "该账号已经被绑定了", + ERR_USER_LEVEL_UPDATE_CHECK_TASK_NO_DONE: "请完成指定任务", + ERR_USER_LEVEL_UPDATE_CHECK_NOT_FOUND_ORDER: "没有找到对应的订单", + ERR_USER_LEVEL_UPDATE_REPEAT: "不允许重复升级", + ERR_USER_NO_ACTIVE: "账户没激活", + ERR_USER_IS_BAN: "账户已被冻结", + ERR_SYSUNION_CONFIG: "联盟设置错误,请检查配置", + ERR_WECAHT_MINI: "小程序响应错误,请检查小程序配置", + ERR_WECAHT_MINI_CACHE: "获取小程序缓存失败", + ERR_WECAHT_MINI_DECODE: "小程序解密失败", + ERR_WECHAT_MINI_ACCESSTOKEN: "无法获取accesstoekn", + ERR_CURRENT_VIP_LEVEL_AUDITING: "当前等级正在审核中", + ERR_LEVEL_RENEW_SHOULD_KEEP_CURRENT: "续费只能在当前等级续费", + ERR_LEVEL_UPGRADE_APPLY_AUDITTING: "已有申请正在审核中,暂时不能申请", + ERR_LEVEL_TASK_PAY_TYPE: "任务付费类型错误", + ERR_BALANCE_NOT_ENOUGH: "余额不足", + ERR_ADMIN_PUSH: "后台MOB推送错误", + ERR_PLAN: "分拥方案出错", + ERR_MOB_CONFIG: "Mob 配置错误", + ERR_BAlANCE_PAY_ORDERTYPE: "无效余额支付订单类型", + ERR_PHONE_EXISTED: "手机号码已存在", + ERR_NOT_RESULT: "已加载完毕", + ERR_REVIEW: "审核模板错误", + ERR_USER_LEVEL_HAS_PAID: "该等级已经付过款", + // 404 + ERR_USER_NOTFOUND: "用户不存在", + ERR_USER_DEL: "账号被删除,如有疑问请联系客服", + ERR_SUP_NOTFOUND: "上级用户不存在", + ERR_LEVEL_MAP: "无等级映射关系", + ERR_MOD_NOTFOUND: "没有找到对应模块", + ERR_CLIPBOARD_PARSE: "无法解析剪切板内容", + ERR_NOT_FAN: "没有粉丝", + ERR_CLIPBOARD_UNSUP: "不支持该平台", + ERR_USER_LEVEL: "该等级已不存在", + ERR_LACK_PAY_CFG: "支付配置不完整", + ERR_NOT_LEVEL_TASK: "等级任务查找错误", + ERR_ITEM_NOT_FOUND: "找不到对应商品", + ERR_WX_CHECKFILE_NOTFOUND: "找不到微信校验文件", + ERR_USER_BIND_OWN: "不能填写自己的邀请码", + // 429 + ERR_TOO_MANY_REQUESTS: "请求频繁,请稍后重试", + // 500 内部错误 + ERR: "接口错误", + ERR_SMS: "短信发送出错", + ERR_CFG: "服务器配置错误", + ERR_UNMARSHAL: "JSON解码错误", + ERR_UNKNOWN: "未知错误", + ERR_ARKID_LOGIN: "登录失败", + ERR_MOBILE_NO_EXIST: "该用户未设定手机号", + ERR_FORBIDEN_VALID: "验证码错误", + ERR_CFG_CACHE: "获取配置缓存失败", + ERR_DB_ORM: "数据操作失败", + ERR_REPEAT_RELATE: "重复关联", + ERR_ZHIMENG_CONVERT_ERR: "智盟转链失败", + ERR_MOB_FORBIDEN: "Mob调用失败", + ERR_ALIPAY_ERR: "支付宝参数错误", + ERR_ALIPAY_SETTING: "请在后台正确配置支付宝", + ERR_ALIPAY_ORDERTYPE: "无效支付宝订单类型", + ERR_ALIPAY_ORDER_ERR: "订单创建错误", + ERR_PAY_ERR: "未找到支付方式", + ERR_SEARCH_ERR: "暂无该分类商品", + ERR_LEVEL_REACH_TOP: "已经是最高等级", + ERR_USER_CHECK_ERR: "校验失败", +} diff --git a/app/e/error.go b/app/e/error.go new file mode 100644 index 0000000..2564174 --- /dev/null +++ b/app/e/error.go @@ -0,0 +1,72 @@ +package e + +import ( + "fmt" + "path" + "runtime" +) + +type E struct { + Code int // 错误码 + msg string // 报错代码 + st string // 堆栈信息 +} + +func NewErrCode(code int) error { + if msg, ok := MsgFlags[code]; ok { + return E{code, msg, stack(3)} + } + return E{ERR_UNKNOWN, "unknown", stack(3)} +} + +func NewErr(code int, msg string) error { + return E{code, msg, stack(3)} +} + +func NewErrf(code int, msg string, args ...interface{}) error { + return E{code, fmt.Sprintf(msg, args), stack(3)} +} + +func (e E) Error() string { + return e.msg +} + +func stack(skip int) string { + stk := make([]uintptr, 32) + str := "" + l := runtime.Callers(skip, stk[:]) + for i := 0; i < l; i++ { + f := runtime.FuncForPC(stk[i]) + name := f.Name() + file, line := f.FileLine(stk[i]) + str += fmt.Sprintf("\n%-30s[%s:%d]", name, path.Base(file), line) + } + return str +} + +// ErrorIsAccountBan is 检查这个账号是否被禁用的错误 +func ErrorIsAccountBan(e error) bool { + err, ok := e.(E) + if ok && err.Code == 403029 { + return true + } + return false +} + +// ErrorIsAccountNoActive is 检查这个账号是否被禁用的错误 +func ErrorIsAccountNoActive(e error) bool { + err, ok := e.(E) + if ok && err.Code == 403028 { + return true + } + return false +} + +// ErrorIsUserDel is 检查这个账号是否被删除 +func ErrorIsUserDel(e error) bool { + err, ok := e.(E) + if ok && err.Code == 403053 { + return true + } + return false +} diff --git a/app/e/msg.go b/app/e/msg.go new file mode 100644 index 0000000..ed226ae --- /dev/null +++ b/app/e/msg.go @@ -0,0 +1,110 @@ +package e + +import ( + "applet/app/utils" + "encoding/json" + "net/http" + + "github.com/gin-gonic/gin" + + "applet/app/utils/logx" +) + +// GetMsg get error information based on Code +// 因为这里code是自己控制的, 因此没考虑报错信息 +func GetMsg(code int) (int, string) { + if msg, ok := MsgFlags[code]; ok { + return code / 1000, msg + } + if http.StatusText(code) == "" { + code = 200 + } + return code, MsgFlags[ERR_BAD_REQUEST] +} + +// 成功输出, fields 是额外字段, 与code, msg同级 +func OutSuc(c *gin.Context, data interface{}, fields map[string]interface{}) { + res := gin.H{ + "code": 1, + "msg": "ok", + "data": data, + } + if fields != nil { + for k, v := range fields { + res[k] = v + } + } + if utils.GetApiVersion(c) > 0 { //加了签名校验只返回加密的字符串 + jsonData, _ := json.Marshal(res) + str := utils.ResultAes(c, jsonData) + c.Writer.WriteString(str) + } else { + c.AbortWithStatusJSON(200, res) + } +} + +func OutSucPure(c *gin.Context, data interface{}, fields map[string]interface{}) { + res := gin.H{ + "code": 1, + "msg": "ok", + "data": data, + } + if fields != nil { + for k, v := range fields { + res[k] = v + } + } + c.Abort() + c.PureJSON(200, res) +} + +// 错误输出 +func OutErr(c *gin.Context, code int, err ...interface{}) { + statusCode, msg := GetMsg(code) + if len(err) > 0 && err[0] != nil { + e := err[0] + switch v := e.(type) { + case E: + statusCode = v.Code / 1000 + msg = v.Error() + logx.Error(v.msg + ": " + v.st) // 记录堆栈信息 + case error: + logx.Error(v) + break + case string: + msg = v + case int: + if _, ok := MsgFlags[v]; ok { + msg = MsgFlags[v] + } + } + } + if utils.GetApiVersion(c) > 0 { //加了签名校验只返回加密的字符串 + jsonData, _ := json.Marshal(gin.H{ + "code": code, + "msg": msg, + "data": []struct{}{}, + }) + str := utils.ResultAes(c, jsonData) + if code > 100000 { + code = int(utils.FloatFormat(float64(code/1000), 0)) + } + c.Status(500) + c.Writer.WriteString(str) + } else { + c.AbortWithStatusJSON(statusCode, gin.H{ + "code": code, + "msg": msg, + "data": []struct{}{}, + }) + } +} + +// 重定向 +func OutRedirect(c *gin.Context, code int, loc string) { + if code < 301 || code > 308 { + code = 303 + } + c.Redirect(code, loc) + c.Abort() +} diff --git a/app/e/set_cache.go b/app/e/set_cache.go new file mode 100644 index 0000000..45337a1 --- /dev/null +++ b/app/e/set_cache.go @@ -0,0 +1,8 @@ +package e + +func SetCache(cacheTime int64) map[string]interface{} { + if cacheTime == 0 { + return map[string]interface{}{"cache_time": cacheTime} + } + return map[string]interface{}{"cache_time": cacheTime} +} diff --git a/app/enum/enum_qrcode.go b/app/enum/enum_qrcode.go new file mode 100644 index 0000000..12ab106 --- /dev/null +++ b/app/enum/enum_qrcode.go @@ -0,0 +1,67 @@ +package enum + +type QrcodeBatchState int32 + +const ( + QrcodeBatchStateForUseIng = 1 + QrcodeBatchStateForUseAlready = 2 + QrcodeBatchStateForExpire = 3 + QrcodeBatchStateForCancel = 4 +) + +func (gt QrcodeBatchState) String() string { + switch gt { + case QrcodeBatchStateForUseIng: + return "使用中" + case QrcodeBatchStateForUseAlready: + return "使用完" + case QrcodeBatchStateForExpire: + return "已过期" + case QrcodeBatchStateForCancel: + return "已作废" + default: + return "未知" + } +} + +type QrcodeWithBatchRecordsSate int32 + +const ( + QrcodeWithBatchRecordsStateForWait = 1 + QrcodeWithBatchRecordsStateForAlready = 2 + QrcodeWithBatchRecordsStateForExpire = 3 + QrcodeWithBatchRecordsStateForCancel = 4 +) + +func (gt QrcodeWithBatchRecordsSate) String() string { + switch gt { + case QrcodeWithBatchRecordsStateForWait: + return "待使用" + case QrcodeWithBatchRecordsStateForAlready: + return "已使用" + case QrcodeWithBatchRecordsStateForExpire: + return "已过期" + case QrcodeWithBatchRecordsStateForCancel: + return "已作废" + default: + return "未知" + } +} + +type QrcodeSate int32 + +const ( + QrcodeSateAllowUse = 1 + QrcodeSateAllowNotUse = 2 +) + +func (gt QrcodeSate) String() string { + switch gt { + case QrcodeSateAllowUse: + return "可使用" + case QrcodeSateAllowNotUse: + return "不可用" + default: + return "未知" + } +} diff --git a/app/enum/enum_sys_cfg.go b/app/enum/enum_sys_cfg.go new file mode 100644 index 0000000..b951c85 --- /dev/null +++ b/app/enum/enum_sys_cfg.go @@ -0,0 +1,28 @@ +package enum + +type SysCfg string + +const ( + WxMchApiV3Key = "wx_mch_api_v3_key" + WxMchCertificateSerialNumber = "wx_mch_certificate_serial_number" + WxMchId = "wx_mch_id" + WxOfficialAccountAppId = "wx_official_account_app_id" + WxOfficialAccountAppSecret = "wx_official_account_app_secret" +) + +func (gt SysCfg) String() string { + switch gt { + case WxMchApiV3Key: + return "微信商户APIv3密钥" + case WxMchCertificateSerialNumber: + return "微信商户证书序列号" + case WxMchId: + return "微信商户号" + case WxOfficialAccountAppId: + return "微信公众号appId" + case WxOfficialAccountAppSecret: + return "微信公众号appSecret" + default: + return "未知" + } +} diff --git a/app/enum/enum_wx_official_account.go b/app/enum/enum_wx_official_account.go new file mode 100644 index 0000000..e982771 --- /dev/null +++ b/app/enum/enum_wx_official_account.go @@ -0,0 +1,19 @@ +package enum + +type WxOfficialAccountRequest string + +const ( + GetAccessToken = "cgi-bin/token" + QrcodeCreate = "cgi-bin/qrcode/create" +) + +func (gt WxOfficialAccountRequest) String() string { + switch gt { + case GetAccessToken: + return "获取 Access token" + case QrcodeCreate: + return "生成带参二维码" + default: + return "未知" + } +} diff --git a/app/hdl/hdl_admin.go b/app/hdl/hdl_admin.go new file mode 100644 index 0000000..49caca5 --- /dev/null +++ b/app/hdl/hdl_admin.go @@ -0,0 +1,13 @@ +package hdl + +import ( + "applet/app/e" + "applet/app/svc" + "github.com/gin-gonic/gin" +) + +func UserInfo(c *gin.Context) { + admInfo := svc.GetUser(c) + e.OutSuc(c, admInfo, nil) + return +} diff --git a/app/hdl/hdl_demo.go b/app/hdl/hdl_demo.go new file mode 100644 index 0000000..dfe4f95 --- /dev/null +++ b/app/hdl/hdl_demo.go @@ -0,0 +1,81 @@ +package hdl + +import ( + "applet/app/e" + //"applet/app/utils" + "applet/app/utils/logx" + "fmt" + "github.com/gin-gonic/gin" +) + +// Demo 测试 +func Demo(c *gin.Context) { + str := `{"appid":"wx598aaef252cd78e4","bank_type":"OTHERS","cash_fee":"1","fee_type":"CNY","is_subscribe":"N","master_id":"22255132","mch_id":"1534243971","nonce_str":"xiUZXdrEkpY9UdfCGEcBSE2jy7yWmQsk","openid":"odmKs6kNQBnujHv_S8YyME8g0-6c","order_type":"mall_goods","out_trade_no":"570761162512383595","pay_method":"wxpay","result_code":"SUCCESS","return_code":"SUCCESS","sign":"A5C7B43A8437E6AD72BB4FDAA8532A59","time_end":"20210701151722","total_fee":"1","trade_type":"APP","transaction_id":"4200001143202107010591333162"}` + c.Set("data", str) + var tmp map[string]interface{} + err := c.ShouldBindJSON(&tmp) + if err != nil { + _ = logx.Error(err) + return + } + fmt.Println(tmp["master_id"]) + + e.OutSuc(c, "hello mall", nil) +} + +func Demo1(c *gin.Context) { + //eg := commDb.DBs[c.GetString("mid")] + //sess := eg.NewSession() + ////r, err := eg.Table("user_profile").Where("uid=21699").Incr("fin_valid", 10).Exec() + //sql := "update user_profile set fin_valid=fin_valid+? WHERE uid=?" + //r, err := sess.Exec(sql, 10, 21699) + //if err != nil { + // return + //} + //sess.Commit() + // + //fmt.Println("res",utils.SerializeStr(r)) + + + + + /*engine := commDb.DBs[c.GetString("mid")] + now := time.Now() //获取当前时间 + var startDate = now.Format("2006-01-02 15:00:00") + var endDate = now.Add(time.Hour * 2).Format("2006-01-02 15:00:00") + res := svc2.HandleSecondsKillForDate(engine, c.GetString("mid"), startDate, endDate) + startTime := utils.AnyToString(now.Hour()) + endTime := utils.AnyToString(now.Add(time.Hour * 2).Hour()) + res = svc2.HandleSecondsKillForTime(engine, c.GetString("mid"), startDate, endDate) + + res = svc2.HandleSecondsKillForDateTime(engine, c.GetString("mid"), startDate, endDate, startTime, endTime)*/ + //reqList := make([]*md.CommissionReq, 0, 10) + // + //req := md.CommissionReq{ + // CommissionParam: md.CommissionParam{Commission: "10.00"}, + // Uid: "21699", + // IsShare: 0, + // Provider: "mall_goods", + // IsAllLevelReturn: 0, + // GoodsId: "3", + //} + // + //for i := 0; i < 10; i++ { + // req := req + // req.GoodsId = utils.AnyToString(i + 1) + // reqList = append(reqList, &req) + //} + // + //fmt.Println(utils.SerializeStr(reqList)) + // + //api, err := svc.BatchGetCommissionByCommApi("123456", reqList) + //if err != nil { + // _ = logx.Error(err) + // fmt.Println(err) + // e.OutErr(c, e.ERR, err) + // return + //} + + //e.OutSuc(c, res, nil) + +} diff --git a/app/hdl/hdl_login.go b/app/hdl/hdl_login.go new file mode 100644 index 0000000..f402741 --- /dev/null +++ b/app/hdl/hdl_login.go @@ -0,0 +1,45 @@ +package hdl + +import ( + "applet/app/db" + "applet/app/e" + "applet/app/lib/validate" + "applet/app/md" + "applet/app/svc" + "applet/app/utils" + "fmt" + "github.com/gin-gonic/gin" +) + +func Login(c *gin.Context) { + var req md.LoginReq + err := c.ShouldBindJSON(&req) + if err != nil { + err = validate.HandleValidateErr(err) + err1 := err.(e.E) + e.OutErr(c, err1.Code, err1.Error()) + return + } + adminDb := db.AdminDb{} + adminDb.Set() + admin, err := adminDb.GetAdminByUserName(req.UserName) + if err != nil { + e.OutErr(c, e.ERR_DB_ORM, err) + return + } + if utils.Md5(req.PassWord) != admin.Password { + e.OutErr(c, e.ERR_INVALID_ARGS, "密码错误") + return + } + ip := utils.GetIP(c.Request) + key := fmt.Sprintf(md.AdminJwtTokenKey, ip, utils.AnyToString(admin.AdmId)) + token, err := svc.HandleLoginToken(key, admin) + if err != nil { + e.OutErr(c, e.ERR, err.Error()) + return + } + e.OutSuc(c, md.LoginResponse{ + Token: token, + }, nil) + return +} diff --git a/app/hdl/hdl_qrcode.go b/app/hdl/hdl_qrcode.go new file mode 100644 index 0000000..5cfb322 --- /dev/null +++ b/app/hdl/hdl_qrcode.go @@ -0,0 +1,326 @@ +package hdl + +import ( + "applet/app/db" + "applet/app/db/model" + "applet/app/e" + "applet/app/enum" + "applet/app/lib/validate" + "applet/app/md" + "applet/app/svc" + "applet/app/utils" + "github.com/360EntSecGroup-Skylar/excelize" + "github.com/gin-gonic/gin" + "github.com/shopspring/decimal" + "strconv" + "time" +) + +func QrcodeBatchList(c *gin.Context) { + var req md.QrcodeBatchListReq + err := c.ShouldBindJSON(&req) + if err != nil { + err = validate.HandleValidateErr(err) + err1 := err.(e.E) + e.OutErr(c, err1.Code, err1.Error()) + return + } + qrcodeBatchDb := db.QrcodeBatchDb{} + qrcodeBatchDb.Set() + list, total, err := qrcodeBatchDb.List(req.Page, req.Limit) + if err != nil { + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + qrcodeTotalNums, waitUseQrcodeNums, alreadyUseQrcodeNums, allowCreateQrcodeNums, err := svc.StatisticsQrcodeData() + if err != nil { + e.OutErr(c, e.ERR, err.Error()) + return + } + e.OutSuc(c, map[string]interface{}{ + "list": list, + "total": total, + "batch_state_list": []map[string]interface{}{ + { + "name": enum.QrcodeBatchState(enum.QrcodeBatchStateForUseIng).String(), + "value": enum.QrcodeBatchStateForUseIng, + }, + { + "name": enum.QrcodeBatchState(enum.QrcodeBatchStateForUseAlready).String(), + "value": enum.QrcodeBatchStateForUseAlready, + }, + { + "name": enum.QrcodeBatchState(enum.QrcodeBatchStateForExpire).String(), + "value": enum.QrcodeBatchStateForExpire, + }, + { + "name": enum.QrcodeBatchState(enum.QrcodeBatchStateForCancel).String(), + "value": enum.QrcodeBatchStateForCancel, + }, + }, + "statistics_qrcode_data": map[string]interface{}{ + "qrcode_total_nums": qrcodeTotalNums, + "wait_use_qrcode_nums": waitUseQrcodeNums, + "already_use_qrcode_nums": alreadyUseQrcodeNums, + "allow_create_qrcode_nums": allowCreateQrcodeNums, + }, + }, nil) + return +} + +func QrcodeBatchAdd(c *gin.Context) { + var req md.QrcodeBatchAddReq + err := c.ShouldBindJSON(&req) + if err != nil { + err = validate.HandleValidateErr(err) + err1 := err.(e.E) + e.OutErr(c, err1.Code, err1.Error()) + return + } + + var totalNum int + var totalAmount decimal.Decimal + for _, v := range req.List { + totalNum += v.Num + amount, _ := decimal.NewFromString(v.Amount) + num := decimal.NewFromInt(int64(v.Num)) + totalAmount = totalAmount.Add(amount.Mul(num)) + } + session := db.Db.NewSession() + defer session.Close() + session.Begin() + now := time.Now() + + //1、新增批次数据 `qrcode_batch` + var qrcodeBatch = model.QrcodeBatch{ + Name: req.Name, + TotalNum: totalNum, + TotalAmount: totalAmount.String(), + State: enum.QrcodeBatchStateForUseIng, + ExpireDate: req.ExpireDate, + Memo: req.Memo, + CreateAt: now.Format("2006-01-02 15:04:05"), + UpdateAt: now.Format("2006-01-02 15:04:05"), + } + qrcodeBatchDb := db.QrcodeBatchDb{} + err = qrcodeBatchDb.AddBySession(session, &qrcodeBatch) + if err != nil { + _ = session.Rollback() + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + + //2、获取 qrcode 表中是否有可用二维码 + qrcodeDb := db.QrcodeDb{} + qrcodeDb.Set() + _, allowUseQrcodeTotal, err := qrcodeDb.FindQrcodeForAllowUse() + if err != nil { + _ = session.Rollback() + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + diffQrcodeNum := totalNum - int(allowUseQrcodeTotal) + if diffQrcodeNum > 0 { + //TODO::为避免频繁请求微信二维码接口 + if diffQrcodeNum > 1000 { + e.OutErr(c, e.ERR, "为保证二维码数据准确性,每批次新增二维码不宜操过1000张") + return + } + //3、不够用,新增二维码 + err := svc.CreateQrcode(diffQrcodeNum) + if err != nil { + _ = session.Rollback() + e.OutErr(c, e.ERR, err.Error()) + return + } + } + + //4、生成 "二维码-批次" 记录 + err = svc.OperateQrcode(qrcodeBatch.Id, totalNum, req, session) + if err != nil { + _ = session.Rollback() + e.OutErr(c, e.ERR, err.Error()) + return + } + + err = session.Commit() + if err != nil { + _ = session.Rollback() + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + e.OutSuc(c, "success", nil) + return +} + +func GetBatchAddName(c *gin.Context) { + var name = "第【1】批" + qrcodeBatchDb := db.QrcodeBatchDb{} + qrcodeBatchDb.Set() + qrcodeBatch, err := qrcodeBatchDb.GeLastId() + if err != nil { + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + if qrcodeBatch != nil { + name = "第【" + utils.IntToStr(qrcodeBatch.Id+1) + "】批" + } + e.OutSuc(c, map[string]string{ + "name": name, + }, nil) + return +} + +func QrcodeBatchDetail(c *gin.Context) { + batchId := c.DefaultQuery("id", "") + qrcodeBatchDb := db.QrcodeBatchDb{} + qrcodeBatchDb.Set() + qrcodeBatch, err := qrcodeBatchDb.GetQrcodeBatchById(utils.StrToInt(batchId)) + if err != nil { + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + if qrcodeBatch == nil { + e.OutErr(c, e.ERR_NO_DATA, "未查询到对应的批次记录") + return + } + + qrcodeWithBatchRecordsDb := db.QrcodeWithBatchRecordsDb{} + qrcodeWithBatchRecordsDb.Set() + data, _, err := qrcodeWithBatchRecordsDb.FindQrcodeWithBatchRecordsById(utils.StrToInt(batchId)) + if err != nil { + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + var list = map[string]*md.QrcodeBatchAddReqListDetail{} + for _, v := range data { + if list[v.Amount] == nil { + list[v.Amount] = &md.QrcodeBatchAddReqListDetail{} + } + list[v.Amount].Num++ + list[v.Amount].Amount = v.Amount + switch v.State { + case enum.QrcodeWithBatchRecordsStateForWait: + list[v.Amount].WaitUseNum++ + break + case enum.QrcodeWithBatchRecordsStateForAlready: + list[v.Amount].UsedNum++ + break + case enum.QrcodeWithBatchRecordsStateForExpire: + list[v.Amount].ExpiredNum++ + break + case enum.QrcodeWithBatchRecordsStateForCancel: + list[v.Amount].CancelNum++ + break + } + } + var resultList []*md.QrcodeBatchAddReqListDetail + for _, v := range list { + resultList = append(resultList, v) + } + + e.OutSuc(c, map[string]interface{}{ + "info": qrcodeBatch, + "list": resultList, + }, nil) + return +} + +func QrcodeBatchDelete(c *gin.Context) { + batchId := c.Param("id") + session := db.Db.NewSession() + defer session.Close() + session.Begin() + + //1、删除 `qrcode_batch` 记录 + qrcodeBatchDb := db.QrcodeBatchDb{} + qrcodeBatchDb.Set() + _, err := qrcodeBatchDb.DeleteQrcodeBatchBySession(session, utils.StrToInt(batchId)) + if err != nil { + _ = session.Rollback() + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + + //2、将所关联的 `qrcode` 状态改为 "可用" + qrcodeWithBatchRecordsDb := db.QrcodeWithBatchRecordsDb{} + qrcodeWithBatchRecordsDb.Set() + data, _, err := qrcodeWithBatchRecordsDb.FindQrcodeWithBatchRecordsById(utils.StrToInt(batchId)) + if err != nil { + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + var updateQrcodeIds []int + for _, v := range data { + updateQrcodeIds = append(updateQrcodeIds, v.QrcodeId) + } + qrcodeDb := db.QrcodeDb{} + qrcodeDb.Set() + _, err = qrcodeDb.BatchUpdateQrcodeBySession(session, updateQrcodeIds, enum.QrcodeSateAllowUse) + if err != nil { + _ = session.Rollback() + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + + //3、删除 `qrcode_with_batch_records` 记录 + _, err = qrcodeWithBatchRecordsDb.DeleteQrcodeWithBatchRecordsBySession(session, utils.StrToInt(batchId)) + if err != nil { + _ = session.Rollback() + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + + err = session.Commit() + if err != nil { + _ = session.Rollback() + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + e.OutSuc(c, "success", nil) + return +} + +func QrcodeBatchDownload(c *gin.Context) { + batchId := c.DefaultQuery("id", "") + qrcodeBatchDb := db.QrcodeBatchDb{} + qrcodeBatchDb.Set() + qrcodeBatch, err := qrcodeBatchDb.GetQrcodeBatchById(utils.StrToInt(batchId)) + if err != nil { + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + if qrcodeBatch == nil { + e.OutErr(c, e.ERR_NO_DATA, "未查询到对应的批次记录") + return + } + + qrcodeWithBatchRecordsDb := db.QrcodeWithBatchRecordsDb{} + qrcodeWithBatchRecordsDb.Set() + data, _, err := qrcodeWithBatchRecordsDb.FindQrcodeWithBatchRecordsLeftJoinQrcode(utils.StrToInt(batchId)) + if err != nil { + e.OutErr(c, e.ERR_DB_ORM, err.Error()) + return + } + + titleList := []string{"批次", "有效期", "金额", "二维码地址"} + xlsx := excelize.NewFile() + xlsx.SetSheetRow("Sheet1", "A1", &titleList) + //表头被第一行用了,只能从第二行开始 + j := 2 + for _, vv := range data { + xlsx.SetSheetRow("Sheet1", "A"+strconv.Itoa(j), &[]interface{}{qrcodeBatch.Name, qrcodeBatch.ExpireDate, vv.Amount, vv.Url}) + j++ + } + + //if err := xlsx.SaveAs(qrcodeBatch.Name + ".xlsx"); err != nil { + // e.OutErr(c, e.ERR, err.Error()) + // return + //} + c.Header("Content-Type", "application/octet-stream") + c.Header("Content-Disposition", "attachment; filename="+qrcodeBatch.Name+".xlsx") + c.Header("Content-Transfer-Encoding", "binary") + //回写到web 流媒体 形成下载 + _ = xlsx.Write(c.Writer) + return +} diff --git a/app/hdl/hdl_sys_cfg.go b/app/hdl/hdl_sys_cfg.go new file mode 100644 index 0000000..8a9e952 --- /dev/null +++ b/app/hdl/hdl_sys_cfg.go @@ -0,0 +1,39 @@ +package hdl + +import ( + "applet/app/db" + "applet/app/e" + "applet/app/enum" + "applet/app/lib/validate" + "applet/app/md" + "github.com/gin-gonic/gin" +) + +func GetSysCfg(c *gin.Context) { + sysCfgDb := db.SysCfgDb{} + sysCfgDb.Set() + res := sysCfgDb.SysCfgFindWithDb(enum.WxMchApiV3Key, enum.WxMchCertificateSerialNumber, enum.WxMchId, enum.WxOfficialAccountAppId, enum.WxOfficialAccountAppSecret) + e.OutSuc(c, res, nil) + return +} + +func SetSysCfg(c *gin.Context) { + var req md.SetSysCfgReq + err := c.ShouldBindJSON(&req) + if err != nil { + err = validate.HandleValidateErr(err) + err1 := err.(e.E) + e.OutErr(c, err1.Code, err1.Error()) + return + } + sysCfgDb := db.SysCfgDb{} + sysCfgDb.Set() + sysCfgDb.SysCfgUpdate(enum.WxMchApiV3Key, req.WxMchApiV3Key) + sysCfgDb.SysCfgUpdate(enum.WxMchCertificateSerialNumber, req.WxMchCertificateSerialNumber) + sysCfgDb.SysCfgUpdate(enum.WxMchId, req.WxMchId) + sysCfgDb.SysCfgUpdate(enum.WxOfficialAccountAppId, req.WxOfficialAccountAppId) + sysCfgDb.SysCfgUpdate(enum.WxOfficialAccountAppSecret, req.WxOfficialAccountAppSecret) + //res := sysCfgDb.SysCfgFindWithDb(enum.WxMchApiV3Key, enum.WxMchCertificateSerialNumber, enum.WxMchId, enum.WxOfficialAccountAppId, enum.WxOfficialAccountAppSecret) + e.OutSuc(c, nil, nil) + return +} diff --git a/app/hdl/hdl_wx.go b/app/hdl/hdl_wx.go new file mode 100644 index 0000000..9440535 --- /dev/null +++ b/app/hdl/hdl_wx.go @@ -0,0 +1,148 @@ +package hdl + +import ( + "applet/app/utils" + "encoding/xml" + "fmt" + "log" + "time" + + "github.com/gin-gonic/gin" +) + +const Token = "temptoken" + +// WXCheckSignature 微信接入校验 +func WXCheckSignature(c *gin.Context) { + signature := c.Query("signature") + timestamp := c.Query("timestamp") + nonce := c.Query("nonce") + echostr := c.Query("echostr") + + ok := utils.CheckSignature(signature, timestamp, nonce, Token) + if !ok { + log.Println("[微信接入] - 微信公众号接入校验失败!") + return + } + + log.Println("[微信接入] - 微信公众号接入校验成功!") + _, _ = c.Writer.WriteString(echostr) +} + +// WXMsg 微信消息结构体 +type WXMsg struct { + ToUserName string + FromUserName string + CreateTime int64 + MsgType string +} + +// WXTextMsg 微信文本消息结构体 +type WXTextMsg struct { + ToUserName string + FromUserName string + CreateTime int64 + MsgType string + Content string + MsgId int64 +} + +// WXEventForSubscribeMsg 扫描带参数二维码事件消息结构体(用户未关注时,进行关注后的事件推送) +type WXEventForSubscribeMsg struct { + ToUserName string //开发者微信号 + FromUserName string //发送方帐号(一个OpenID) + CreateTime int64 //消息创建时间 (整型) + MsgType string //消息类型,event + Event string //事件类型,subscribe + EventKey string //事件KEY值,qrscene_为前缀,后面为二维码的参数值 + Ticket string //二维码的ticket,可用来换取二维码图片 +} + +// WXEventForScanMsg 扫描带参数二维码事件消息结构体(用户已关注时的事件推送) +type WXEventForScanMsg struct { + ToUserName string //开发者微信号 + FromUserName string //发送方帐号(一个OpenID) + CreateTime int64 //消息创建时间 (整型) + MsgType string //消息类型,event + Event string //事件类型,subscribe + EventKey string //事件KEY值,是一个32位无符号整数,即创建二维码时的二维码scene_id + Ticket string //二维码的ticket,可用来换取二维码图片 +} + +// WXMsgReceive 微信消息接收 +func WXMsgReceive(c *gin.Context) { + var msg WXMsg + err := c.ShouldBindXML(&msg) + if err != nil { + log.Printf("[消息接收] - XML数据包解析失败: %v\n", err) + return + } + log.Printf("[消息接收] - 收到消息, 消息类型为: %s", msg.MsgType) + if msg.MsgType == "event" { + //事件类型消息 + var eventMsg WXEventForSubscribeMsg + err := c.ShouldBindXML(&eventMsg) + if err != nil { + log.Printf("[事件类型-消息接收] - XML数据包解析失败: %v\n", err) + return + } + log.Printf("[事件类型]-收到消息, 事件类型为: %s, 事件KEY值为: %s\n, 二维码的ticket值为: %s\n", eventMsg.Event, eventMsg.EventKey, eventMsg.Ticket) + if eventMsg.Event == "subscribe" { + //用户未关注时,进行关注后的事件推送 + //userUseQrcodeRecordsDb := db.UserUseQrcodeRecordsDb{} + //userUseQrcodeRecordsDb.Set() + //userUseQrcodeRecordsDb.InsertUserUseQrcodeRecords(model.UserUseQrcodeRecords{ + // UserWxOpenId: eventMsg.FromUserName, + // RecordsId: 0, + // State: 0, + // CreateAt: "", + // UpdateAt: "", + //}) + } + if eventMsg.Event == "SCAN" { + //用户已关注时的事件推送 + + } + } + if msg.MsgType == "text" { + //事件类型消息 + var textMsg WXTextMsg + err := c.ShouldBindXML(&textMsg) + if err != nil { + log.Printf("[文本消息-消息接收] - XML数据包解析失败: %v\n", err) + return + } + log.Printf("[文本消息]-收到消息, 消息内容为: %s", textMsg.Content) + WXMsgReply(c, textMsg.ToUserName, textMsg.FromUserName) + } + +} + +// WXRepTextMsg 微信回复文本消息结构体 +type WXRepTextMsg struct { + ToUserName string + FromUserName string + CreateTime int64 + MsgType string + Content string + // 若不标记XMLName, 则解析后的xml名为该结构体的名称 + XMLName xml.Name `xml:"xml"` +} + +// WXMsgReply 微信消息回复 +func WXMsgReply(c *gin.Context, fromUser, toUser string) { + repTextMsg := WXRepTextMsg{ + ToUserName: toUser, + FromUserName: fromUser, + CreateTime: time.Now().Unix(), + MsgType: "text", + Content: fmt.Sprintf("[消息回复] - %s\n,hello, world!", time.Now().Format("2006-01-02 15:04:05")), + } + + msg, err := xml.Marshal(&repTextMsg) + if err != nil { + log.Printf("[消息回复] - 将对象进行XML编码出错: %v\n", err) + return + } + _, _ = c.Writer.Write(msg) +} diff --git a/app/lib/auth/base.go b/app/lib/auth/base.go new file mode 100644 index 0000000..d771802 --- /dev/null +++ b/app/lib/auth/base.go @@ -0,0 +1,19 @@ +package auth + +import ( + "time" + + "github.com/dgrijalva/jwt-go" +) + +// TokenExpireDuration is jwt 过期时间 +const TokenExpireDuration = time.Hour * 4380 + +var Secret = []byte("zyos") + +// JWTUser 如果想要保存更多信息,都可以添加到这个结构体中 +type JWTUser struct { + AdmId int `json:"adm_id"` + Username string `json:"username"` + jwt.StandardClaims +} diff --git a/app/lib/validate/validate_comm.go b/app/lib/validate/validate_comm.go new file mode 100644 index 0000000..9305d9e --- /dev/null +++ b/app/lib/validate/validate_comm.go @@ -0,0 +1,33 @@ +package validate + +import ( + "applet/app/e" + "applet/app/utils" + "applet/app/utils/logx" + "encoding/json" + "fmt" + "github.com/go-playground/validator/v10" +) + +func HandleValidateErr(err error) error { + switch err.(type) { + case *json.UnmarshalTypeError: + return e.NewErr(e.ERR_UNMARSHAL, "参数格式错误") + case validator.ValidationErrors: + errs := err.(validator.ValidationErrors) + transMsgMap := errs.Translate(utils.ValidatorTrans) // utils.ValidatorTrans \app\utils\validator_err_trans.go::ValidatorTransInit初始化获得 + transMsgOne := transMsgMap[GetOneKeyOfMapString(transMsgMap)] + return e.NewErr(e.ERR_INVALID_ARGS, transMsgOne) + default: + _ = logx.Error(err) + return e.NewErr(e.ERR, fmt.Sprintf("validate request params, err:%v\n", err)) + } +} + +// GetOneKeyOfMapString 取出Map的一个key +func GetOneKeyOfMapString(collection map[string]string) string { + for k := range collection { + return k + } + return "" +} diff --git a/app/lib/wx/wx_official_account.go b/app/lib/wx/wx_official_account.go new file mode 100644 index 0000000..c77c8ac --- /dev/null +++ b/app/lib/wx/wx_official_account.go @@ -0,0 +1,85 @@ +package wx + +import ( + "applet/app/db" + "applet/app/enum" + "applet/app/md" + "applet/app/utils" + "applet/app/utils/cache" + "encoding/json" + "errors" +) + +type OfficialAccount struct { + AccessToken string `json:"access_token"` + Appid string `json:"appid"` + Secret string `json:"secret"` +} + +func (officialAccount *OfficialAccount) Set() { // set方法 + sysCfgDb := db.SysCfgDb{} + sysCfgDb.Set() + officialAccount.Appid = sysCfgDb.SysCfgGetWithDb(enum.WxOfficialAccountAppId) + officialAccount.Secret = sysCfgDb.SysCfgGetWithDb(enum.WxOfficialAccountAppSecret) + officialAccount.AccessToken = officialAccount.createToken() +} + +func (officialAccount *OfficialAccount) createToken() (accessToken string) { + cacheKey := md.WxOfficialAccountCacheKey + accessToken, _ = cache.GetString(cacheKey) + if accessToken != "" { + return + } + + url := md.WxOfficialAccountRequestBaseUrl + enum.GetAccessToken + post, err := utils.CurlPost(url, map[string]string{ + "appid": officialAccount.Appid, + "secret": officialAccount.Secret, + "grant_type": "client_credential", + }, nil) + + utils.FilePutContents("wx_official_account_create_token", "resp"+string(post)) + var data md.CreateTokenResp + err = json.Unmarshal(post, &data) + if err != nil { + return + } + if data.AccessToken == "" { + panic(errors.New("获取 access_token 失败")) + } + + accessToken = data.AccessToken + cache.SetEx(cacheKey, accessToken, int(data.ExpiresIn-3600)) + return +} + +func (officialAccount *OfficialAccount) QrcodeCreate(sceneStr string) (qrcodeUrl string, err error) { + url := md.WxOfficialAccountRequestBaseUrl + enum.QrcodeCreate + "?access_token=" + officialAccount.AccessToken + //post, err := utils.CurlPost(url, map[string]interface{}{ + // "action_name": "QR_LIMIT_STR_SCENE", + // "action_info": map[string]interface{}{ + // "scene": map[string]string{ + // "scene_str": sceneStr, + // }, + // }, + //}, nil) + requestBody, _ := json.Marshal(map[string]interface{}{ + "action_name": "QR_STR_SCENE", + "expire_seconds": "6000", + "action_info": map[string]interface{}{ + "scene": map[string]string{ + "scene_str": sceneStr, + }, + }, + }) + post, err := utils.CurlPost(url, requestBody, nil) + + utils.FilePutContents("wx_official_account_qrcode_create", "resp"+string(post)) + var data md.CreateQrcodeResp + err = json.Unmarshal(post, &data) + if err != nil { + return + } + qrcodeUrl = "https://mp.weixin.qq.com/cgi-bin/showqrcode?ticket=" + data.Ticket + return +} diff --git a/app/md/md_app_redis_key.go b/app/md/md_app_redis_key.go new file mode 100644 index 0000000..4f93079 --- /dev/null +++ b/app/md/md_app_redis_key.go @@ -0,0 +1,10 @@ +package md + +// 缓存key统一管理 +const ( + AdminJwtTokenKey = "%s:admin_jwt_token:%s" // jwt, 占位符:ip, admin:id + JwtTokenCacheTime = 3600 * 24 * 365 + CfgCacheTime = 86400 + AppCfgCacheKey = "one_item_one_code:%s" // 占位符: key的第一个字母 + WxOfficialAccountCacheKey = "wx_official_account" // 占位符: key的第一个字母 +) diff --git a/app/md/md_login.go b/app/md/md_login.go new file mode 100644 index 0000000..16b0897 --- /dev/null +++ b/app/md/md_login.go @@ -0,0 +1,10 @@ +package md + +type LoginReq struct { + UserName string `json:"username" binding:"required" label:"登录账号"` + PassWord string `json:"password" binding:"required" label:"登录密码"` +} + +type LoginResponse struct { + Token string `json:"token"` +} diff --git a/app/md/md_qrcode.go b/app/md/md_qrcode.go new file mode 100644 index 0000000..170893e --- /dev/null +++ b/app/md/md_qrcode.go @@ -0,0 +1,31 @@ +package md + +const ( + QrcodeTotalNums = 100000 +) + +type QrcodeBatchListReq struct { + Page int `json:"page"` + Limit int `json:"limit"` +} + +type QrcodeBatchAddReq struct { + Name string `json:"name"` + ExpireDate string `json:"expire_date"` + List []QrcodeBatchAddReqList `json:"list"` + Memo string `json:"memo"` +} + +type QrcodeBatchAddReqList struct { + Num int `json:"num"` + Amount string `json:"amount"` +} + +type QrcodeBatchAddReqListDetail struct { + Num int `json:"num"` + WaitUseNum int `json:"wait_use_num"` + UsedNum int `json:"used_num"` + ExpiredNum int `json:"expired_num"` + CancelNum int `json:"cancel_num"` + Amount string `json:"amount"` +} diff --git a/app/md/md_sys_cfg.go b/app/md/md_sys_cfg.go new file mode 100644 index 0000000..ff0ca45 --- /dev/null +++ b/app/md/md_sys_cfg.go @@ -0,0 +1,9 @@ +package md + +type SetSysCfgReq struct { + WxMchApiV3Key string `json:"wx_mch_api_v3_key" label:"微信商户APIv3密钥"` + WxMchCertificateSerialNumber string `json:"wx_mch_certificate_serial_number" label:"微信商户证书序列号"` + WxMchId string `json:"wx_mch_id" label:"微信商户号"` + WxOfficialAccountAppId string `json:"wx_official_account_app_id" label:"微信公众号appId"` + WxOfficialAccountAppSecret string `json:"wx_official_account_app_secret" label:"微信公众号appSecret"` +} diff --git a/app/md/md_wx_official_account.go b/app/md/md_wx_official_account.go new file mode 100644 index 0000000..fdf8ecf --- /dev/null +++ b/app/md/md_wx_official_account.go @@ -0,0 +1,14 @@ +package md + +const WxOfficialAccountRequestBaseUrl = "https://api.weixin.qq.com/" + +type CreateTokenResp struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` +} + +type CreateQrcodeResp struct { + Ticket string `json:"ticket"` + ExpireSeconds int64 `json:"expire_seconds"` + Url string `json:"url"` +} diff --git a/app/mw/mw_access_log.go b/app/mw/mw_access_log.go new file mode 100644 index 0000000..84f6b52 --- /dev/null +++ b/app/mw/mw_access_log.go @@ -0,0 +1,31 @@ +package mw + +import ( + "time" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" + + "applet/app/utils/logx" +) + +// access log +func AccessLog(c *gin.Context) { + start := time.Now() + c.Next() + cost := time.Since(start) + + logx.Info(c.Request.URL.Path) + + logger := &zap.Logger{} + logger.Info(c.Request.URL.Path, + zap.Int("status", c.Writer.Status()), + zap.String("method", c.Request.Method), + zap.String("path", c.Request.URL.Path), + zap.String("query", c.Request.URL.RawQuery), + zap.String("ip", c.ClientIP()), + zap.String("user-agent", c.Request.UserAgent()), + zap.String("errors", c.Errors.ByType(gin.ErrorTypePrivate).String()), + zap.Duration("cost", cost), + ) +} diff --git a/app/mw/mw_auth.go b/app/mw/mw_auth.go new file mode 100644 index 0000000..2bf0a3c --- /dev/null +++ b/app/mw/mw_auth.go @@ -0,0 +1,27 @@ +package mw + +import ( + "applet/app/e" + "applet/app/svc" + "github.com/gin-gonic/gin" +) + +// 检查权限, 签名等等 +func Auth(c *gin.Context) { + admin, err := svc.CheckUser(c) + if err != nil { + switch err.(type) { + case e.E: + err1 := err.(e.E) + e.OutErr(c, err1.Code, err1.Error()) + return + default: + e.OutErr(c, e.ERR, err.Error()) + return + } + } + + // 将当前请求的username信息保存到请求的上下文c上 + c.Set("admin", admin) + c.Next() +} diff --git a/app/mw/mw_breaker.go b/app/mw/mw_breaker.go new file mode 100644 index 0000000..fefc078 --- /dev/null +++ b/app/mw/mw_breaker.go @@ -0,0 +1,30 @@ +package mw + +import ( + "errors" + "net/http" + "strconv" + + "github.com/afex/hystrix-go/hystrix" + "github.com/gin-gonic/gin" +) + +// 熔断器, 此组件需要在gin.Recovery中间之前进行调用, 否则可能会导致panic时候, 无法recovery, 正确顺序如下 +//r.Use(BreakerWrapper) +//r.Use(gin.Recovery()) +func Breaker(c *gin.Context) { + name := c.Request.Method + "-" + c.Request.RequestURI + hystrix.Do(name, func() error { + c.Next() + statusCode := c.Writer.Status() + if statusCode >= http.StatusInternalServerError { + return errors.New("status code " + strconv.Itoa(statusCode)) + } + return nil + }, func(e error) error { + if e == hystrix.ErrCircuitOpen { + c.String(http.StatusAccepted, "请稍后重试") //todo 修改报错方法 + } + return e + }) +} diff --git a/app/mw/mw_change_header.go b/app/mw/mw_change_header.go new file mode 100644 index 0000000..4a5aefa --- /dev/null +++ b/app/mw/mw_change_header.go @@ -0,0 +1,17 @@ +package mw + +import ( + "github.com/gin-gonic/gin" +) + +// 修改传过来的头部字段 +func ChangeHeader(c *gin.Context) { + appvserison := c.GetHeader("AppVersionName") + if appvserison == "" { + appvserison = c.GetHeader("app_version_name") + } + if appvserison != "" { + c.Request.Header.Add("app_version_name", appvserison) + } + c.Next() +} diff --git a/app/mw/mw_check_sign.go b/app/mw/mw_check_sign.go new file mode 100644 index 0000000..e3bf3c2 --- /dev/null +++ b/app/mw/mw_check_sign.go @@ -0,0 +1,34 @@ +package mw + +import ( + "applet/app/e" + "applet/app/utils" + "bytes" + "fmt" + "github.com/gin-gonic/gin" + "io/ioutil" +) + +// CheckSign is 中间件 用来检查签名 +func CheckSign(c *gin.Context) { + + bools := utils.SignCheck(c) + if bools == false { + e.OutErr(c, 400, e.NewErr(400, "签名校验错误,请求失败")) + return + } + c.Next() +} +func CheckBody(c *gin.Context) { + if utils.GetApiVersion(c) > 0 { + body, _ := ioutil.ReadAll(c.Request.Body) + fmt.Println(string(body)) + if string(body) != "" { + str := utils.ResultAesDecrypt(c, string(body)) + if str != "" { + c.Request.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(str))) + } + } + } + c.Next() +} diff --git a/app/mw/mw_checker.go b/app/mw/mw_checker.go new file mode 100644 index 0000000..84d3c65 --- /dev/null +++ b/app/mw/mw_checker.go @@ -0,0 +1,11 @@ +package mw + +import ( + "github.com/gin-gonic/gin" +) + +// 检查设备等, 把头部信息下放到hdl可以获取 +func Checker(c *gin.Context) { + // 校验平台支持 + c.Next() +} diff --git a/app/mw/mw_cors.go b/app/mw/mw_cors.go new file mode 100644 index 0000000..3433553 --- /dev/null +++ b/app/mw/mw_cors.go @@ -0,0 +1,29 @@ +package mw + +import ( + "github.com/gin-gonic/gin" +) + +// cors跨域 +func Cors(c *gin.Context) { + // 放行所有OPTIONS方法 + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + + origin := c.Request.Header.Get("Origin") // 请求头部 + if origin != "" { + c.Header("Access-Control-Allow-Origin", origin) // 这是允许访问来源域 + c.Header("Access-Control-Allow-Methods", "POST,GET,OPTIONS,PUT,DELETE,UPDATE") // 服务器支持的所有跨域请求的方法,为了避免浏览次请求的多次'预检'请求 + // header的类型 + c.Header("Access-Control-Allow-Headers", "Authorization,Content-Length,X-CSRF-Token,Token,session,X_Requested_With,Accept,Origin,Host,Connection,Accept-Encoding,Accept-Language,DNT,X-CustomHeader,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Pragma,X-Mx-ReqToken") + // 允许跨域设置,可以返回其他子段 + // 跨域关键设置 让浏览器可以解析 + c.Header("Access-Control-Expose-Headers", "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma,FooBar") + c.Header("Access-Control-Max-Age", "172800") // 缓存请求信息 单位为秒 + c.Header("Access-Control-Allow-Credentials", "false") // 跨域请求是否需要带cookie信息 默认设置为true + c.Set("Content-Type", "Application/json") // 设置返回格式是json + } + c.Next() +} diff --git a/app/mw/mw_csrf.go b/app/mw/mw_csrf.go new file mode 100644 index 0000000..b15619b --- /dev/null +++ b/app/mw/mw_csrf.go @@ -0,0 +1,136 @@ +package mw + +import ( + "crypto/sha1" + "encoding/base64" + "errors" + "io" + + "github.com/dchest/uniuri" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +// csrf,xsrf检查 +const ( + csrfSecret = "csrfSecret" + csrfSalt = "csrfSalt" + csrfToken = "csrfToken" +) + +var defaultIgnoreMethods = []string{"GET", "HEAD", "OPTIONS"} + +var defaultErrorFunc = func(c *gin.Context) { + panic(errors.New("CSRF token mismatch")) +} + +var defaultTokenGetter = func(c *gin.Context) string { + r := c.Request + + if t := r.FormValue("_csrf"); len(t) > 0 { + return t + } else if t := r.URL.Query().Get("_csrf"); len(t) > 0 { + return t + } else if t := r.Header.Get("X-CSRF-TOKEN"); len(t) > 0 { + return t + } else if t := r.Header.Get("X-XSRF-TOKEN"); len(t) > 0 { + return t + } + + return "" +} + +// Options stores configurations for a CSRF middleware. +type Options struct { + Secret string + IgnoreMethods []string + ErrorFunc gin.HandlerFunc + TokenGetter func(c *gin.Context) string +} + +func tokenize(secret, salt string) string { + h := sha1.New() + io.WriteString(h, salt+"-"+secret) + hash := base64.URLEncoding.EncodeToString(h.Sum(nil)) + + return hash +} + +func inArray(arr []string, value string) bool { + inarr := false + + for _, v := range arr { + if v == value { + inarr = true + break + } + } + + return inarr +} + +// Middleware validates CSRF token. +func Middleware(options Options) gin.HandlerFunc { + ignoreMethods := options.IgnoreMethods + errorFunc := options.ErrorFunc + tokenGetter := options.TokenGetter + + if ignoreMethods == nil { + ignoreMethods = defaultIgnoreMethods + } + + if errorFunc == nil { + errorFunc = defaultErrorFunc + } + + if tokenGetter == nil { + tokenGetter = defaultTokenGetter + } + + return func(c *gin.Context) { + session := sessions.Default(c) + c.Set(csrfSecret, options.Secret) + + if inArray(ignoreMethods, c.Request.Method) { + c.Next() + return + } + + salt, ok := session.Get(csrfSalt).(string) + + if !ok || len(salt) == 0 { + errorFunc(c) + return + } + + token := tokenGetter(c) + + if tokenize(options.Secret, salt) != token { + errorFunc(c) + return + } + + c.Next() + } +} + +// GetToken returns a CSRF token. +func GetToken(c *gin.Context) string { + session := sessions.Default(c) + secret := c.MustGet(csrfSecret).(string) + + if t, ok := c.Get(csrfToken); ok { + return t.(string) + } + + salt, ok := session.Get(csrfSalt).(string) + if !ok { + salt = uniuri.New() + session.Set(csrfSalt, salt) + session.Save() + } + token := tokenize(secret, salt) + c.Set(csrfToken, token) + + return token +} diff --git a/app/mw/mw_db.go b/app/mw/mw_db.go new file mode 100644 index 0000000..9e6f8ee --- /dev/null +++ b/app/mw/mw_db.go @@ -0,0 +1,24 @@ +package mw + +import ( + "fmt" + "github.com/gin-gonic/gin" +) + +// DB is 中间件 用来检查master_id是否有对应的数据库engine +func DB(c *gin.Context) { + fmt.Println(c.Request.Header) + masterID := c.GetHeader("master_id") + fmt.Println("master_id", masterID) + if masterID == "" { + fmt.Println("not found master_id found MasterId start") + masterID = c.GetHeader("MasterId") + fmt.Println("MasterId", masterID) + // if masterID still emtpy + + } + + fmt.Println("master_id", masterID) + c.Set("mid", masterID) + c.Next() +} diff --git a/app/mw/mw_limiter.go b/app/mw/mw_limiter.go new file mode 100644 index 0000000..4eb5299 --- /dev/null +++ b/app/mw/mw_limiter.go @@ -0,0 +1,58 @@ +package mw + +import ( + "bytes" + "io/ioutil" + + "github.com/gin-gonic/gin" + + "applet/app/utils" + "applet/app/utils/cache" +) + +// 限流器 +func Limiter(c *gin.Context) { + limit := 100 // 限流次数 + ttl := 1 // 限流过期时间 + ip := c.ClientIP() + // 读取token或者ip + token := c.GetHeader("Authorization") + // 判断是否已经超出限额次数 + method := c.Request.Method + host := c.Request.Host + uri := c.Request.URL.String() + + buf := make([]byte, 2048) + num, _ := c.Request.Body.Read(buf) + body := buf[:num] + // Write body back + c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + Md5 := utils.Md5(ip + token + method + host + uri + string(body)) + if cache.Exists(Md5) { + c.AbortWithStatusJSON(429, gin.H{ + "code": 429, + "msg": "don't repeat the request", + "data": struct{}{}, + }) + return + } + // 2s后没返回自动释放 + go cache.SetEx(Md5, "0", ttl) + key := "LIMITER_" + ip + reqs, _ := cache.GetInt(key) + if reqs >= limit { + c.AbortWithStatusJSON(429, gin.H{ + "code": 429, + "msg": "too many requests", + "data": struct{}{}, + }) + return + } + if reqs > 0 { + go cache.Incr(key) + } else { + go cache.SetEx(key, 1, ttl) + } + c.Next() + go cache.Del(Md5) +} diff --git a/app/mw/mw_recovery.go b/app/mw/mw_recovery.go new file mode 100644 index 0000000..b32cc82 --- /dev/null +++ b/app/mw/mw_recovery.go @@ -0,0 +1,57 @@ +package mw + +import ( + "net" + "net/http" + "net/http/httputil" + "os" + "runtime/debug" + "strings" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func Recovery(logger *zap.Logger, stack bool) gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + var brokenPipe bool + if ne, ok := err.(*net.OpError); ok { + if se, ok := ne.Err.(*os.SyscallError); ok { + if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { + brokenPipe = true + } + } + } + + httpRequest, _ := httputil.DumpRequest(c.Request, false) + if brokenPipe { + logger.Error(c.Request.URL.Path, + zap.Any("error", err), + zap.String("request", string(httpRequest)), + ) + // If the connection is dead, we can't write a status to it. + c.Error(err.(error)) + c.Abort() + return + } + + if stack { + logger.Error("[Recovery from panic]", + zap.Any("error", err), + zap.String("request", string(httpRequest)), + zap.String("stack", string(debug.Stack())), + ) + } else { + logger.Error("[Recovery from panic]", + zap.Any("error", err), + zap.String("request", string(httpRequest)), + ) + } + c.AbortWithStatus(http.StatusInternalServerError) + } + }() + c.Next() + } +} diff --git a/app/router/router.go b/app/router/router.go new file mode 100644 index 0000000..e540dc7 --- /dev/null +++ b/app/router/router.go @@ -0,0 +1,77 @@ +package router + +import ( + "applet/app/cfg" + "applet/app/hdl" + "applet/app/mw" + "github.com/gin-gonic/gin" +) + +//初始化路由 +func Init() *gin.Engine { + // debug, release, test 项目阶段 + mode := "release" + if cfg.Debug { + mode = "debug" + } + gin.SetMode(mode) + //创建一个新的启动器 + r := gin.New() + r.Use(mw.ChangeHeader) + + // 是否打印访问日志, 在非正式环境都打印 + if mode != "release" { + r.Use(gin.Logger()) + } + r.Use(gin.Recovery()) + // r.Use(mw.Limiter) + //r.LoadHTMLGlob("static/html/*") + + r.GET("/favicon.ico", func(c *gin.Context) { + c.Status(204) + }) + r.NoRoute(func(c *gin.Context) { + c.JSON(404, gin.H{"code": 404, "msg": "page not found", "data": []struct{}{}}) + }) + r.NoMethod(func(c *gin.Context) { + c.JSON(405, gin.H{"code": 405, "msg": "method not allowed", "data": []struct{}{}}) + }) + r.Use(mw.Cors) + route(r.Group("/api/v1")) + return r +} + +func route(r *gin.RouterGroup) { + r.Any("/demo", hdl.Demo) + r.POST("/login", hdl.Login) + + r.Group("/wx") + { + r.Use(mw.DB) + // 微信公众号消息通知 + r.GET("/msgReceive", hdl.WXCheckSignature) + r.POST("/msgReceive", hdl.WXMsgReceive) + + } + r.Use(mw.DB) // 以下接口需要用到数据库 + { + r.GET("/demo1", hdl.Demo1) + } + + r.Use(mw.Checker) // 以下接口需要检查Header: platform + { + } + + r.GET("/qrcodeBatchDownload", hdl.QrcodeBatchDownload) //二维码批次-下载 + r.Use(mw.Auth) // 以下接口需要JWT验证 + { + r.GET("/userInfo", hdl.UserInfo) //用户信息 + r.GET("/sysCfg", hdl.GetSysCfg) //基础配置-获取 + r.POST("/sysCfg", hdl.SetSysCfg) //基础配置-设置 + r.POST("/qrcodeBatchList", hdl.QrcodeBatchList) //二维码批次-列表 + r.GET("/getBatchAddName", hdl.GetBatchAddName) //二维码批次-自动获取添加时名称 + r.POST("/qrcodeBatchAdd", hdl.QrcodeBatchAdd) //二维码批次-添加 + r.GET("/qrcodeBatchDetail", hdl.QrcodeBatchDetail) //二维码批次-详情 + r.DELETE("/qrcodeBatchDelete/:id", hdl.QrcodeBatchDelete) //二维码批次-删除 + } +} diff --git a/app/svc/svc_auth.go b/app/svc/svc_auth.go new file mode 100644 index 0000000..47b93d6 --- /dev/null +++ b/app/svc/svc_auth.go @@ -0,0 +1,51 @@ +package svc + +import ( + "applet/app/db" + "applet/app/db/model" + "applet/app/utils" + "errors" + "github.com/gin-gonic/gin" + "strings" +) + +func GetUser(c *gin.Context) *model.Admin { + user, _ := c.Get("admin") + if user == nil { + return &model.Admin{ + AdmId: 0, + Username: "", + Password: "", + State: 0, + CreateAt: "", + UpdateAt: "", + } + } + return user.(*model.Admin) +} + +func CheckUser(c *gin.Context) (*model.Admin, error) { + token := c.GetHeader("Authorization") + if token == "" { + return nil, errors.New("token not exist") + } + // 按空格分割 + parts := strings.SplitN(token, " ", 2) + if !(len(parts) == 2 && parts[0] == "Bearer") { + return nil, errors.New("token format error") + } + // parts[1]是获取到的tokenString,我们使用之前定义好的解析JWT的函数来解析它 + mc, err := utils.ParseToken(parts[1]) + if err != nil { + return nil, err + } + + // 获取admin + adminDb := db.AdminDb{} + adminDb.Set() + admin, err := adminDb.GetAdmin(mc.AdmId) + if err != nil { + return nil, err + } + return admin, nil +} diff --git a/app/svc/svc_login.go b/app/svc/svc_login.go new file mode 100644 index 0000000..961c9e2 --- /dev/null +++ b/app/svc/svc_login.go @@ -0,0 +1,33 @@ +package svc + +import ( + "applet/app/db/model" + "applet/app/md" + "applet/app/utils" + "applet/app/utils/cache" + "applet/app/utils/logx" +) + +func HandleLoginToken(cacheKey string, admin *model.Admin) (string, error) { + // 获取之前生成的token + token, err := cache.GetString(cacheKey) + if err != nil { + _ = logx.Error(err) + } + // 没有获取到 + if err != nil || token == "" { + // 生成token + token, err = utils.GenToken(admin.AdmId, admin.Username) + if err != nil { + return "", err + } + // 缓存token + _, err = cache.SetEx(cacheKey, token, md.JwtTokenCacheTime) + if err != nil { + return "", err + } + return token, nil + } + + return token, nil +} diff --git a/app/svc/svc_qrcode.go b/app/svc/svc_qrcode.go new file mode 100644 index 0000000..f5a56cd --- /dev/null +++ b/app/svc/svc_qrcode.go @@ -0,0 +1,108 @@ +package svc + +import ( + "applet/app/db" + "applet/app/db/model" + "applet/app/enum" + "applet/app/lib/wx" + "applet/app/md" + "applet/app/utils" + "errors" + "time" + "xorm.io/xorm" +) + +func StatisticsQrcodeData() (qrcodeTotalNums, waitUseQrcodeNums, alreadyUseQrcodeNums, allowCreateQrcodeNums int64, err error) { + qrcodeTotalNums = md.QrcodeTotalNums //二维码总量 + + qrcodeWithBatchRecordsDb := db.QrcodeWithBatchRecordsDb{} + qrcodeWithBatchRecordsDb.Set() + qrcodeWithBatchRecordsForUseWait, err := qrcodeWithBatchRecordsDb.FindQrcodeWithBatchRecordsByState(enum.QrcodeWithBatchRecordsStateForWait) + if err != nil { + return + } + waitUseQrcodeNums = int64(len(qrcodeWithBatchRecordsForUseWait)) //待使用二维码数量 + + qrcodeWithBatchRecordsForUseAlready, err := qrcodeWithBatchRecordsDb.FindQrcodeWithBatchRecordsByState(enum.QrcodeWithBatchRecordsStateForAlready) + if err != nil { + return + } + alreadyUseQrcodeNums = int64(len(qrcodeWithBatchRecordsForUseAlready)) //已使用二维码数量 + + allowCreateQrcodeNums = qrcodeTotalNums - waitUseQrcodeNums //可生成二维码数量 + return +} + +func createQrcodeIndex() string { + date := utils.Int64ToStr(time.Now().UnixMicro()) + sceneStr := date + "_" + utils.RandString(6) //根据当前时间戳(微秒)+ 随机6位字符串 作为唯一标识符 + return sceneStr +} + +func CreateQrcode(createNums int) (err error) { + now := time.Now() + var insertData []*model.Qrcode + //1、调用微信 `cgi-bin/qrcode/create` 生成带参的永久二维码 + wxOfficial := wx.OfficialAccount{} + wxOfficial.Set() + for i := 0; i < createNums; i++ { + sceneStr := createQrcodeIndex() + qrcodeUrl, err1 := wxOfficial.QrcodeCreate(sceneStr) + if err1 != nil { + return err1 + } + insertData = append(insertData, &model.Qrcode{ + Url: qrcodeUrl, + State: enum.QrcodeSateAllowUse, + Index: sceneStr, + CreateAt: now.Format("2006-01-02 15:00:00"), + UpdateAt: now.Format("2006-01-02 15:00:00"), + }) + } + //2、批量新增二维码 + qrcodeDb := db.QrcodeDb{} + qrcodeDb.Set() + _, err = qrcodeDb.BatchAddQrcode(insertData) + return +} + +func OperateQrcode(batchId, totalNums int, args md.QrcodeBatchAddReq, session *xorm.Session) (err error) { + qrcodeDb := db.QrcodeDb{} + qrcodeDb.Set() + //1、获取当前可用二维码 + allowUseQrcodeList, allowUseQrcodeTotal, err := qrcodeDb.FindQrcodeForAllowUse() + if int(allowUseQrcodeTotal) < totalNums { + err = errors.New("可用二维码不足") + return + } + + now := time.Now() + var insertData []*model.QrcodeWithBatchRecords + var updateQrcodeIds []int + var k = 0 + for _, v := range args.List { + for i := 0; i < v.Num; i++ { + insertData = append(insertData, &model.QrcodeWithBatchRecords{ + QrcodeId: allowUseQrcodeList[k].Id, + BatchId: batchId, + Amount: v.Amount, + State: enum.QrcodeWithBatchRecordsStateForWait, + CreateAt: now.Format("2006-01-02 15:00:00"), + UpdateAt: now.Format("2006-01-02 15:00:00"), + }) + updateQrcodeIds = append(updateQrcodeIds, allowUseQrcodeList[k].Id) + k++ + } + } + + //2、新增“二维码-批次”记录 + qrcodeWithBatchRecordsDb := db.QrcodeWithBatchRecordsDb{} + qrcodeWithBatchRecordsDb.Set() + if _, err = qrcodeWithBatchRecordsDb.BatchAddQrcodeWithBatchRecordsBySession(session, insertData); err != nil { + return + } + + //3、修改"二维码状态"为不可用 + _, err = qrcodeDb.BatchUpdateQrcodeBySession(session, updateQrcodeIds, enum.QrcodeSateAllowNotUse) + return +} diff --git a/app/task/init.go b/app/task/init.go new file mode 100644 index 0000000..bd11346 --- /dev/null +++ b/app/task/init.go @@ -0,0 +1,92 @@ +package task + +import ( + taskMd "applet/app/task/md" + "time" + + "applet/app/db/model" + "applet/app/utils/logx" + "github.com/robfig/cron/v3" + "xorm.io/xorm" +) + +var ( + timer *cron.Cron + jobs = map[string]func(*xorm.Engine, string){} + baseEntryId cron.EntryID + entryIds []cron.EntryID + taskCfgList map[string]*[]model.SysCfg + ch = make(chan int, 30) + workerNum = 15 // 智盟跟单并发数量 + otherCh = make(chan int, 30) + otherWorkerNum = 18 // 淘宝, 苏宁, 考拉并发量 +) + +func Init() { + // 初始化任务列表 + initTasks() + var err error + timer = cron.New() + // reload为初始化数据库方法 + if baseEntryId, err = timer.AddFunc("@every 15m", reload); err != nil { + _ = logx.Fatal(err) + } +} + +func Run() { + reload() + timer.Start() + _ = logx.Info("auto tasks running...") +} + +func reload() { + // 重新初始化数据库 + + if len(taskCfgList) > 0 { + // 删除原有所有任务 + if len(entryIds) > 0 { + for _, v := range entryIds { + if v != baseEntryId { + timer.Remove(v) + } + } + entryIds = nil + } + var ( + entryId cron.EntryID + err error + ) + // 添加任务 + for dbName, v := range taskCfgList { + for _, vv := range *v { + if _, ok := jobs[vv.Key]; ok && vv.Val != "" { + // fmt.Println(vv.Val) + if entryId, err = timer.AddFunc(vv.Val, doTask(dbName, vv.Key)); err == nil { + entryIds = append(entryIds, entryId) + } + } + } + } + + } +} + +func doTask(dbName, fnName string) func() { + return func() { + begin := time.Now().Local() + end := time.Now().Local() + logx.Infof( + "[%s] AutoTask <%s> started at <%s>, ended at <%s> duration <%s>", + dbName, + fnName, + begin.Format("2006-01-02 15:04:05.000"), + end.Format("2006-01-02 15:04:05.000"), + time.Duration(end.UnixNano()-begin.UnixNano()).String(), + ) + } +} + +// 增加自动任务队列 +func initTasks() { + jobs[taskMd.MallCronOrderCancel] = taskCancelOrder // 取消订单 +} diff --git a/app/task/md/cron_key.go b/app/task/md/cron_key.go new file mode 100644 index 0000000..b38ccc8 --- /dev/null +++ b/app/task/md/cron_key.go @@ -0,0 +1,5 @@ +package md + +const ( + MallCronOrderCancel = "mall_cron_order_cancel" // 取消订单任务 +) diff --git a/app/task/svc/svc_cancel_order.go b/app/task/svc/svc_cancel_order.go new file mode 100644 index 0000000..0c35b5f --- /dev/null +++ b/app/task/svc/svc_cancel_order.go @@ -0,0 +1,64 @@ +package svc + +import ( + "applet/app/db" + "applet/app/utils" + "applet/app/utils/logx" + "errors" + "fmt" + "time" + "xorm.io/xorm" +) + +func CancelOrder(eg *xorm.Engine, dbName string) { + fmt.Println("cancel order...") + defer func() { + if err := recover(); err != nil { + _ = logx.Error(err) + } + }() + + timeStr, err := getCancelCfg(eg, dbName) + if err != nil { + fmt.Println(err.Error()) + return + } + + now := time.Now() + // x 分钟后取消订单 + expTime := now.Add(-time.Hour * time.Duration(utils.StrToInt64(timeStr))) + expTimeStr := utils.Time2String(expTime, "") + + page := 1 + + for { + isEmpty, err := handleOnePage(eg, dbName, expTimeStr) + if err != nil { + _ = logx.Error(err) + break + } + if isEmpty { + break + } + + if page > 100 { + break + } + + page += 1 + + } +} + +func handleOnePage(eg *xorm.Engine, dbName, expTimeStr string) (isEmpty bool, err error) { + return false, nil +} + +func getCancelCfg(eg *xorm.Engine, masterId string) (string, error) { + cfg := db.SysCfgGetWithDb(eg, masterId, "order_expiration_time") + + if cfg == "" { + return "", errors.New("order_expiration_time no found") + } + return cfg, nil +} diff --git a/app/task/task_cancel_order.go b/app/task/task_cancel_order.go new file mode 100644 index 0000000..2e45bbb --- /dev/null +++ b/app/task/task_cancel_order.go @@ -0,0 +1,23 @@ +package task + +import ( + "applet/app/task/svc" + "math/rand" + "time" + "xorm.io/xorm" +) + +// 取消订单 +func taskCancelOrder(eg *xorm.Engine, dbName string) { + for { + if len(ch) > workerNum { + time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000))) + } else { + goto START + } + } +START: + ch <- 1 + svc.CancelOrder(eg, dbName) + <-ch +} diff --git a/app/utils/aes.go b/app/utils/aes.go new file mode 100644 index 0000000..8f5aaac --- /dev/null +++ b/app/utils/aes.go @@ -0,0 +1,123 @@ +package utils + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "fmt" +) + +func AesEncrypt(rawData, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + rawData = PKCS5Padding(rawData, blockSize) + // rawData = ZeroPadding(rawData, block.BlockSize()) + blockMode := cipher.NewCBCEncrypter(block, key[:blockSize]) + encrypted := make([]byte, len(rawData)) + // 根据CryptBlocks方法的说明,如下方式初始化encrypted也可以 + // encrypted := rawData + blockMode.CryptBlocks(encrypted, rawData) + return encrypted, nil +} + +func AesDecrypt(encrypted, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + blockMode := cipher.NewCBCDecrypter(block, key[:blockSize]) + rawData := make([]byte, len(encrypted)) + // rawData := encrypted + blockMode.CryptBlocks(rawData, encrypted) + rawData = PKCS5UnPadding(rawData) + // rawData = ZeroUnPadding(rawData) + return rawData, nil +} + +func ZeroPadding(cipherText []byte, blockSize int) []byte { + padding := blockSize - len(cipherText)%blockSize + padText := bytes.Repeat([]byte{0}, padding) + return append(cipherText, padText...) +} + +func ZeroUnPadding(rawData []byte) []byte { + length := len(rawData) + unPadding := int(rawData[length-1]) + return rawData[:(length - unPadding)] +} + +func PKCS5Padding(cipherText []byte, blockSize int) []byte { + padding := blockSize - len(cipherText)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(cipherText, padText...) +} + +func PKCS5UnPadding(rawData []byte) []byte { + length := len(rawData) + // 去掉最后一个字节 unPadding 次 + unPadding := int(rawData[length-1]) + return rawData[:(length - unPadding)] +} + +// 填充0 +func zeroFill(key *string) { + l := len(*key) + if l != 16 && l != 24 && l != 32 { + if l < 16 { + *key = *key + fmt.Sprintf("%0*d", 16-l, 0) + } else if l < 24 { + *key = *key + fmt.Sprintf("%0*d", 24-l, 0) + } else if l < 32 { + *key = *key + fmt.Sprintf("%0*d", 32-l, 0) + } else { + *key = string([]byte(*key)[:32]) + } + } +} + +type AesCrypt struct { + Key []byte + Iv []byte +} + +func (a *AesCrypt) Encrypt(data []byte) ([]byte, error) { + aesBlockEncrypt, err := aes.NewCipher(a.Key) + if err != nil { + println(err.Error()) + return nil, err + } + + content := pKCS5Padding(data, aesBlockEncrypt.BlockSize()) + cipherBytes := make([]byte, len(content)) + aesEncrypt := cipher.NewCBCEncrypter(aesBlockEncrypt, a.Iv) + aesEncrypt.CryptBlocks(cipherBytes, content) + return cipherBytes, nil +} + +func (a *AesCrypt) Decrypt(src []byte) (data []byte, err error) { + decrypted := make([]byte, len(src)) + var aesBlockDecrypt cipher.Block + aesBlockDecrypt, err = aes.NewCipher(a.Key) + if err != nil { + println(err.Error()) + return nil, err + } + aesDecrypt := cipher.NewCBCDecrypter(aesBlockDecrypt, a.Iv) + aesDecrypt.CryptBlocks(decrypted, src) + return pKCS5Trimming(decrypted), nil +} + +func pKCS5Padding(cipherText []byte, blockSize int) []byte { + padding := blockSize - len(cipherText)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(cipherText, padText...) +} + +func pKCS5Trimming(encrypt []byte) []byte { + padding := encrypt[len(encrypt)-1] + return encrypt[:len(encrypt)-int(padding)] +} diff --git a/app/utils/auth.go b/app/utils/auth.go new file mode 100644 index 0000000..7b95682 --- /dev/null +++ b/app/utils/auth.go @@ -0,0 +1,41 @@ +package utils + +import ( + "applet/app/lib/auth" + "errors" + "time" + + "github.com/dgrijalva/jwt-go" +) + +// GenToken 生成JWT +func GenToken(admId int, username string) (string, error) { + // 创建一个我们自己的声明 + c := auth.JWTUser{ + AdmId: admId, + Username: username, + StandardClaims: jwt.StandardClaims{ + ExpiresAt: time.Now().Add(auth.TokenExpireDuration).Unix(), // 过期时间 + Issuer: "zyos", // 签发人 + }, + } + // 使用指定的签名方法创建签名对象 + token := jwt.NewWithClaims(jwt.SigningMethodHS256, c) + // 使用指定的secret签名并获得完整的编码后的字符串token + return token.SignedString(auth.Secret) +} + +// ParseToken 解析JWT +func ParseToken(tokenString string) (*auth.JWTUser, error) { + // 解析token + token, err := jwt.ParseWithClaims(tokenString, &auth.JWTUser{}, func(token *jwt.Token) (i interface{}, err error) { + return auth.Secret, nil + }) + if err != nil { + return nil, err + } + if claims, ok := token.Claims.(*auth.JWTUser); ok && token.Valid { // 校验token + return claims, nil + } + return nil, errors.New("invalid token") +} diff --git a/app/utils/base64.go b/app/utils/base64.go new file mode 100644 index 0000000..ee16553 --- /dev/null +++ b/app/utils/base64.go @@ -0,0 +1,95 @@ +package utils + +import ( + "encoding/base64" + "fmt" +) + +const ( + Base64Std = iota + Base64Url + Base64RawStd + Base64RawUrl +) + +func Base64StdEncode(str interface{}) string { + return Base64Encode(str, Base64Std) +} + +func Base64StdDecode(str interface{}) string { + return Base64Decode(str, Base64Std) +} + +func Base64UrlEncode(str interface{}) string { + return Base64Encode(str, Base64Url) +} + +func Base64UrlDecode(str interface{}) string { + return Base64Decode(str, Base64Url) +} + +func Base64RawStdEncode(str interface{}) string { + return Base64Encode(str, Base64RawStd) +} + +func Base64RawStdDecode(str interface{}) string { + return Base64Decode(str, Base64RawStd) +} + +func Base64RawUrlEncode(str interface{}) string { + return Base64Encode(str, Base64RawUrl) +} + +func Base64RawUrlDecode(str interface{}) string { + return Base64Decode(str, Base64RawUrl) +} + +func Base64Encode(str interface{}, encode int) string { + newEncode := base64Encode(encode) + if newEncode == nil { + return "" + } + switch v := str.(type) { + case string: + return newEncode.EncodeToString([]byte(v)) + case []byte: + return newEncode.EncodeToString(v) + } + return newEncode.EncodeToString([]byte(fmt.Sprint(str))) +} + +func Base64Decode(str interface{}, encode int) string { + var err error + var b []byte + newEncode := base64Encode(encode) + if newEncode == nil { + return "" + } + switch v := str.(type) { + case string: + b, err = newEncode.DecodeString(v) + case []byte: + b, err = newEncode.DecodeString(string(v)) + default: + return "" + } + if err != nil { + return "" + } + return string(b) +} + +func base64Encode(encode int) *base64.Encoding { + switch encode { + case Base64Std: + return base64.StdEncoding + case Base64Url: + return base64.URLEncoding + case Base64RawStd: + return base64.RawStdEncoding + case Base64RawUrl: + return base64.RawURLEncoding + default: + return nil + } +} diff --git a/app/utils/boolean.go b/app/utils/boolean.go new file mode 100644 index 0000000..d64c876 --- /dev/null +++ b/app/utils/boolean.go @@ -0,0 +1,26 @@ +package utils + +import "reflect" + +// 检验一个值是否为空 +func Empty(val interface{}) bool { + v := reflect.ValueOf(val) + switch v.Kind() { + case reflect.String, reflect.Array: + return v.Len() == 0 + case reflect.Map, reflect.Slice: + return v.Len() == 0 || v.IsNil() + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + + return reflect.DeepEqual(val, reflect.Zero(v.Type()).Interface()) +} \ No newline at end of file diff --git a/app/utils/cache/base.go b/app/utils/cache/base.go new file mode 100644 index 0000000..9e5b7fe --- /dev/null +++ b/app/utils/cache/base.go @@ -0,0 +1,422 @@ +package cache + +import ( + "errors" + "fmt" + "strconv" + "time" +) + +const ( + redisPassword = "sanhu" + redisDialTTL = 10 * time.Second + redisReadTTL = 3 * time.Second + redisWriteTTL = 3 * time.Second + redisIdleTTL = 10 * time.Second + redisPoolTTL = 10 * time.Second + redisPoolSize int = 512 + redisMaxIdleConn int = 64 + redisMaxActive int = 512 +) + +var ( + ErrNil = errors.New("nil return") + ErrWrongArgsNum = errors.New("args num error") + ErrNegativeInt = errors.New("redis cluster: unexpected value for Uint64") +) + +// 以下为提供类型转换 + +func Int(reply interface{}, err error) (int, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case int: + return reply, nil + case int8: + return int(reply), nil + case int16: + return int(reply), nil + case int32: + return int(reply), nil + case int64: + x := int(reply) + if int64(x) != reply { + return 0, strconv.ErrRange + } + return x, nil + case uint: + n := int(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case uint8: + return int(reply), nil + case uint16: + return int(reply), nil + case uint32: + n := int(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case uint64: + n := int(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case []byte: + data := string(reply) + if len(data) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseInt(data, 10, 0) + return int(n), err + case string: + if len(reply) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseInt(reply, 10, 0) + return int(n), err + case nil: + return 0, ErrNil + case error: + return 0, reply + } + return 0, fmt.Errorf("redis cluster: unexpected type for Int, got type %T", reply) +} + +func Int64(reply interface{}, err error) (int64, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case int: + return int64(reply), nil + case int8: + return int64(reply), nil + case int16: + return int64(reply), nil + case int32: + return int64(reply), nil + case int64: + return reply, nil + case uint: + n := int64(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case uint8: + return int64(reply), nil + case uint16: + return int64(reply), nil + case uint32: + return int64(reply), nil + case uint64: + n := int64(reply) + if n < 0 { + return 0, strconv.ErrRange + } + return n, nil + case []byte: + data := string(reply) + if len(data) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseInt(data, 10, 64) + return n, err + case string: + if len(reply) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseInt(reply, 10, 64) + return n, err + case nil: + return 0, ErrNil + case error: + return 0, reply + } + return 0, fmt.Errorf("redis cluster: unexpected type for Int64, got type %T", reply) +} + +func Uint64(reply interface{}, err error) (uint64, error) { + if err != nil { + return 0, err + } + switch reply := reply.(type) { + case uint: + return uint64(reply), nil + case uint8: + return uint64(reply), nil + case uint16: + return uint64(reply), nil + case uint32: + return uint64(reply), nil + case uint64: + return reply, nil + case int: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case int8: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case int16: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case int32: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case int64: + if reply < 0 { + return 0, ErrNegativeInt + } + return uint64(reply), nil + case []byte: + data := string(reply) + if len(data) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseUint(data, 10, 64) + return n, err + case string: + if len(reply) == 0 { + return 0, ErrNil + } + + n, err := strconv.ParseUint(reply, 10, 64) + return n, err + case nil: + return 0, ErrNil + case error: + return 0, reply + } + return 0, fmt.Errorf("redis cluster: unexpected type for Uint64, got type %T", reply) +} + +func Float64(reply interface{}, err error) (float64, error) { + if err != nil { + return 0, err + } + + var value float64 + err = nil + switch v := reply.(type) { + case float32: + value = float64(v) + case float64: + value = v + case int: + value = float64(v) + case int8: + value = float64(v) + case int16: + value = float64(v) + case int32: + value = float64(v) + case int64: + value = float64(v) + case uint: + value = float64(v) + case uint8: + value = float64(v) + case uint16: + value = float64(v) + case uint32: + value = float64(v) + case uint64: + value = float64(v) + case []byte: + data := string(v) + if len(data) == 0 { + return 0, ErrNil + } + value, err = strconv.ParseFloat(string(v), 64) + case string: + if len(v) == 0 { + return 0, ErrNil + } + value, err = strconv.ParseFloat(v, 64) + case nil: + err = ErrNil + case error: + err = v + default: + err = fmt.Errorf("redis cluster: unexpected type for Float64, got type %T", v) + } + + return value, err +} + +func Bool(reply interface{}, err error) (bool, error) { + if err != nil { + return false, err + } + switch reply := reply.(type) { + case bool: + return reply, nil + case int64: + return reply != 0, nil + case []byte: + data := string(reply) + if len(data) == 0 { + return false, ErrNil + } + + return strconv.ParseBool(data) + case string: + if len(reply) == 0 { + return false, ErrNil + } + + return strconv.ParseBool(reply) + case nil: + return false, ErrNil + case error: + return false, reply + } + return false, fmt.Errorf("redis cluster: unexpected type for Bool, got type %T", reply) +} + +func Bytes(reply interface{}, err error) ([]byte, error) { + if err != nil { + return nil, err + } + switch reply := reply.(type) { + case []byte: + if len(reply) == 0 { + return nil, ErrNil + } + return reply, nil + case string: + data := []byte(reply) + if len(data) == 0 { + return nil, ErrNil + } + return data, nil + case nil: + return nil, ErrNil + case error: + return nil, reply + } + return nil, fmt.Errorf("redis cluster: unexpected type for Bytes, got type %T", reply) +} + +func String(reply interface{}, err error) (string, error) { + if err != nil { + return "", err + } + + value := "" + err = nil + switch v := reply.(type) { + case string: + if len(v) == 0 { + return "", ErrNil + } + + value = v + case []byte: + if len(v) == 0 { + return "", ErrNil + } + + value = string(v) + case int: + value = strconv.FormatInt(int64(v), 10) + case int8: + value = strconv.FormatInt(int64(v), 10) + case int16: + value = strconv.FormatInt(int64(v), 10) + case int32: + value = strconv.FormatInt(int64(v), 10) + case int64: + value = strconv.FormatInt(v, 10) + case uint: + value = strconv.FormatUint(uint64(v), 10) + case uint8: + value = strconv.FormatUint(uint64(v), 10) + case uint16: + value = strconv.FormatUint(uint64(v), 10) + case uint32: + value = strconv.FormatUint(uint64(v), 10) + case uint64: + value = strconv.FormatUint(v, 10) + case float32: + value = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + value = strconv.FormatFloat(v, 'f', -1, 64) + case bool: + value = strconv.FormatBool(v) + case nil: + err = ErrNil + case error: + err = v + default: + err = fmt.Errorf("redis cluster: unexpected type for String, got type %T", v) + } + + return value, err +} + +func Strings(reply interface{}, err error) ([]string, error) { + if err != nil { + return nil, err + } + switch reply := reply.(type) { + case []interface{}: + result := make([]string, len(reply)) + for i := range reply { + if reply[i] == nil { + continue + } + switch subReply := reply[i].(type) { + case string: + result[i] = subReply + case []byte: + result[i] = string(subReply) + default: + return nil, fmt.Errorf("redis cluster: unexpected element type for String, got type %T", reply[i]) + } + } + return result, nil + case []string: + return reply, nil + case nil: + return nil, ErrNil + case error: + return nil, reply + } + return nil, fmt.Errorf("redis cluster: unexpected type for Strings, got type %T", reply) +} + +func Values(reply interface{}, err error) ([]interface{}, error) { + if err != nil { + return nil, err + } + switch reply := reply.(type) { + case []interface{}: + return reply, nil + case nil: + return nil, ErrNil + case error: + return nil, reply + } + return nil, fmt.Errorf("redis cluster: unexpected type for Values, got type %T", reply) +} diff --git a/app/utils/cache/cache/cache.go b/app/utils/cache/cache/cache.go new file mode 100644 index 0000000..e43c5f0 --- /dev/null +++ b/app/utils/cache/cache/cache.go @@ -0,0 +1,107 @@ +package cache + +import ( + "fmt" + "time" +) + +var c Cache + +type Cache interface { + // get cached value by key. + Get(key string) interface{} + // GetMulti is a batch version of Get. + GetMulti(keys []string) []interface{} + // set cached value with key and expire time. + Put(key string, val interface{}, timeout time.Duration) error + // delete cached value by key. + Delete(key string) error + // increase cached int value by key, as a counter. + Incr(key string) error + // decrease cached int value by key, as a counter. + Decr(key string) error + // check if cached value exists or not. + IsExist(key string) bool + // clear all cache. + ClearAll() error + // start gc routine based on config string settings. + StartAndGC(config string) error +} + +// Instance is a function create a new Cache Instance +type Instance func() Cache + +var adapters = make(map[string]Instance) + +// Register makes a cache adapter available by the adapter name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, adapter Instance) { + if adapter == nil { + panic("cache: Register adapter is nil") + } + if _, ok := adapters[name]; ok { + panic("cache: Register called twice for adapter " + name) + } + adapters[name] = adapter +} + +// NewCache Create a new cache driver by adapter name and config string. +// config need to be correct JSON as string: {"interval":360}. +// it will start gc automatically. +func NewCache(adapterName, config string) (adapter Cache, err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + adapter = instanceFunc() + err = adapter.StartAndGC(config) + if err != nil { + adapter = nil + } + return +} + +func InitCache(adapterName, config string) (err error) { + instanceFunc, ok := adapters[adapterName] + if !ok { + err = fmt.Errorf("cache: unknown adapter name %q (forgot to import?)", adapterName) + return + } + c = instanceFunc() + err = c.StartAndGC(config) + if err != nil { + c = nil + } + return +} + +func Get(key string) interface{} { + return c.Get(key) +} + +func GetMulti(keys []string) []interface{} { + return c.GetMulti(keys) +} +func Put(key string, val interface{}, ttl time.Duration) error { + return c.Put(key, val, ttl) +} +func Delete(key string) error { + return c.Delete(key) +} +func Incr(key string) error { + return c.Incr(key) +} +func Decr(key string) error { + return c.Decr(key) +} +func IsExist(key string) bool { + return c.IsExist(key) +} +func ClearAll() error { + return c.ClearAll() +} +func StartAndGC(cfg string) error { + return c.StartAndGC(cfg) +} diff --git a/app/utils/cache/cache/conv.go b/app/utils/cache/cache/conv.go new file mode 100644 index 0000000..6b700ae --- /dev/null +++ b/app/utils/cache/cache/conv.go @@ -0,0 +1,86 @@ +package cache + +import ( + "fmt" + "strconv" +) + +// GetString convert interface to string. +func GetString(v interface{}) string { + switch result := v.(type) { + case string: + return result + case []byte: + return string(result) + default: + if v != nil { + return fmt.Sprint(result) + } + } + return "" +} + +// GetInt convert interface to int. +func GetInt(v interface{}) int { + switch result := v.(type) { + case int: + return result + case int32: + return int(result) + case int64: + return int(result) + default: + if d := GetString(v); d != "" { + value, _ := strconv.Atoi(d) + return value + } + } + return 0 +} + +// GetInt64 convert interface to int64. +func GetInt64(v interface{}) int64 { + switch result := v.(type) { + case int: + return int64(result) + case int32: + return int64(result) + case int64: + return result + default: + + if d := GetString(v); d != "" { + value, _ := strconv.ParseInt(d, 10, 64) + return value + } + } + return 0 +} + +// GetFloat64 convert interface to float64. +func GetFloat64(v interface{}) float64 { + switch result := v.(type) { + case float64: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseFloat(d, 64) + return value + } + } + return 0 +} + +// GetBool convert interface to bool. +func GetBool(v interface{}) bool { + switch result := v.(type) { + case bool: + return result + default: + if d := GetString(v); d != "" { + value, _ := strconv.ParseBool(d) + return value + } + } + return false +} diff --git a/app/utils/cache/cache/file.go b/app/utils/cache/cache/file.go new file mode 100644 index 0000000..5c4e366 --- /dev/null +++ b/app/utils/cache/cache/file.go @@ -0,0 +1,241 @@ +package cache + +import ( + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "strconv" + "time" +) + +// FileCacheItem is basic unit of file cache adapter. +// it contains data and expire time. +type FileCacheItem struct { + Data interface{} + LastAccess time.Time + Expired time.Time +} + +// FileCache Config +var ( + FileCachePath = "cache" // cache directory + FileCacheFileSuffix = ".bin" // cache file suffix + FileCacheDirectoryLevel = 2 // cache file deep level if auto generated cache files. + FileCacheEmbedExpiry time.Duration // cache expire time, default is no expire forever. +) + +// FileCache is cache adapter for file storage. +type FileCache struct { + CachePath string + FileSuffix string + DirectoryLevel int + EmbedExpiry int +} + +// NewFileCache Create new file cache with no config. +// the level and expiry need set in method StartAndGC as config string. +func NewFileCache() Cache { + // return &FileCache{CachePath:FileCachePath, FileSuffix:FileCacheFileSuffix} + return &FileCache{} +} + +// StartAndGC will start and begin gc for file cache. +// the config need to be like {CachePath:"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0} +func (fc *FileCache) StartAndGC(config string) error { + + var cfg map[string]string + json.Unmarshal([]byte(config), &cfg) + if _, ok := cfg["CachePath"]; !ok { + cfg["CachePath"] = FileCachePath + } + if _, ok := cfg["FileSuffix"]; !ok { + cfg["FileSuffix"] = FileCacheFileSuffix + } + if _, ok := cfg["DirectoryLevel"]; !ok { + cfg["DirectoryLevel"] = strconv.Itoa(FileCacheDirectoryLevel) + } + if _, ok := cfg["EmbedExpiry"]; !ok { + cfg["EmbedExpiry"] = strconv.FormatInt(int64(FileCacheEmbedExpiry.Seconds()), 10) + } + fc.CachePath = cfg["CachePath"] + fc.FileSuffix = cfg["FileSuffix"] + fc.DirectoryLevel, _ = strconv.Atoi(cfg["DirectoryLevel"]) + fc.EmbedExpiry, _ = strconv.Atoi(cfg["EmbedExpiry"]) + + fc.Init() + return nil +} + +// Init will make new dir for file cache if not exist. +func (fc *FileCache) Init() { + if ok, _ := exists(fc.CachePath); !ok { // todo : error handle + _ = os.MkdirAll(fc.CachePath, os.ModePerm) // todo : error handle + } +} + +// get cached file name. it's md5 encoded. +func (fc *FileCache) getCacheFileName(key string) string { + m := md5.New() + io.WriteString(m, key) + keyMd5 := hex.EncodeToString(m.Sum(nil)) + cachePath := fc.CachePath + switch fc.DirectoryLevel { + case 2: + cachePath = filepath.Join(cachePath, keyMd5[0:2], keyMd5[2:4]) + case 1: + cachePath = filepath.Join(cachePath, keyMd5[0:2]) + } + + if ok, _ := exists(cachePath); !ok { // todo : error handle + _ = os.MkdirAll(cachePath, os.ModePerm) // todo : error handle + } + + return filepath.Join(cachePath, fmt.Sprintf("%s%s", keyMd5, fc.FileSuffix)) +} + +// Get value from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) Get(key string) interface{} { + fileData, err := FileGetContents(fc.getCacheFileName(key)) + if err != nil { + return "" + } + var to FileCacheItem + GobDecode(fileData, &to) + if to.Expired.Before(time.Now()) { + return "" + } + return to.Data +} + +// GetMulti gets values from file cache. +// if non-exist or expired, return empty string. +func (fc *FileCache) GetMulti(keys []string) []interface{} { + var rc []interface{} + for _, key := range keys { + rc = append(rc, fc.Get(key)) + } + return rc +} + +// Put value into file cache. +// timeout means how long to keep this file, unit of ms. +// if timeout equals FileCacheEmbedExpiry(default is 0), cache this item forever. +func (fc *FileCache) Put(key string, val interface{}, timeout time.Duration) error { + gob.Register(val) + + item := FileCacheItem{Data: val} + if timeout == FileCacheEmbedExpiry { + item.Expired = time.Now().Add((86400 * 365 * 10) * time.Second) // ten years + } else { + item.Expired = time.Now().Add(timeout) + } + item.LastAccess = time.Now() + data, err := GobEncode(item) + if err != nil { + return err + } + return FilePutContents(fc.getCacheFileName(key), data) +} + +// Delete file cache value. +func (fc *FileCache) Delete(key string) error { + filename := fc.getCacheFileName(key) + if ok, _ := exists(filename); ok { + return os.Remove(filename) + } + return nil +} + +// Incr will increase cached int value. +// fc value is saving forever unless Delete. +func (fc *FileCache) Incr(key string) error { + data := fc.Get(key) + var incr int + if reflect.TypeOf(data).Name() != "int" { + incr = 0 + } else { + incr = data.(int) + 1 + } + fc.Put(key, incr, FileCacheEmbedExpiry) + return nil +} + +// Decr will decrease cached int value. +func (fc *FileCache) Decr(key string) error { + data := fc.Get(key) + var decr int + if reflect.TypeOf(data).Name() != "int" || data.(int)-1 <= 0 { + decr = 0 + } else { + decr = data.(int) - 1 + } + fc.Put(key, decr, FileCacheEmbedExpiry) + return nil +} + +// IsExist check value is exist. +func (fc *FileCache) IsExist(key string) bool { + ret, _ := exists(fc.getCacheFileName(key)) + return ret +} + +// ClearAll will clean cached files. +// not implemented. +func (fc *FileCache) ClearAll() error { + return nil +} + +// check file exist. +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +// FileGetContents Get bytes to file. +// if non-exist, create this file. +func FileGetContents(filename string) (data []byte, e error) { + return ioutil.ReadFile(filename) +} + +// FilePutContents Put bytes to file. +// if non-exist, create this file. +func FilePutContents(filename string, content []byte) error { + return ioutil.WriteFile(filename, content, os.ModePerm) +} + +// GobEncode Gob encodes file cache item. +func GobEncode(data interface{}) ([]byte, error) { + buf := bytes.NewBuffer(nil) + enc := gob.NewEncoder(buf) + err := enc.Encode(data) + if err != nil { + return nil, err + } + return buf.Bytes(), err +} + +// GobDecode Gob decodes file cache item. +func GobDecode(data []byte, to *FileCacheItem) error { + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + return dec.Decode(&to) +} + +func init() { + Register("file", NewFileCache) +} diff --git a/app/utils/cache/cache/memory.go b/app/utils/cache/cache/memory.go new file mode 100644 index 0000000..0cc5015 --- /dev/null +++ b/app/utils/cache/cache/memory.go @@ -0,0 +1,239 @@ +package cache + +import ( + "encoding/json" + "errors" + "sync" + "time" +) + +var ( + // DefaultEvery means the clock time of recycling the expired cache items in memory. + DefaultEvery = 60 // 1 minute +) + +// MemoryItem store memory cache item. +type MemoryItem struct { + val interface{} + createdTime time.Time + lifespan time.Duration +} + +func (mi *MemoryItem) isExpire() bool { + // 0 means forever + if mi.lifespan == 0 { + return false + } + return time.Now().Sub(mi.createdTime) > mi.lifespan +} + +// MemoryCache is Memory cache adapter. +// it contains a RW locker for safe map storage. +type MemoryCache struct { + sync.RWMutex + dur time.Duration + items map[string]*MemoryItem + Every int // run an expiration check Every clock time +} + +// NewMemoryCache returns a new MemoryCache. +func NewMemoryCache() Cache { + cache := MemoryCache{items: make(map[string]*MemoryItem)} + return &cache +} + +// Get cache from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) Get(name string) interface{} { + bc.RLock() + defer bc.RUnlock() + if itm, ok := bc.items[name]; ok { + if itm.isExpire() { + return nil + } + return itm.val + } + return nil +} + +// GetMulti gets caches from memory. +// if non-existed or expired, return nil. +func (bc *MemoryCache) GetMulti(names []string) []interface{} { + var rc []interface{} + for _, name := range names { + rc = append(rc, bc.Get(name)) + } + return rc +} + +// Put cache to memory. +// if lifespan is 0, it will be forever till restart. +func (bc *MemoryCache) Put(name string, value interface{}, lifespan time.Duration) error { + bc.Lock() + defer bc.Unlock() + bc.items[name] = &MemoryItem{ + val: value, + createdTime: time.Now(), + lifespan: lifespan, + } + return nil +} + +// Delete cache in memory. +func (bc *MemoryCache) Delete(name string) error { + bc.Lock() + defer bc.Unlock() + if _, ok := bc.items[name]; !ok { + return errors.New("key not exist") + } + delete(bc.items, name) + if _, ok := bc.items[name]; ok { + return errors.New("delete key error") + } + return nil +} + +// Incr increase cache counter in memory. +// it supports int,int32,int64,uint,uint32,uint64. +func (bc *MemoryCache) Incr(key string) error { + bc.RLock() + defer bc.RUnlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch itm.val.(type) { + case int: + itm.val = itm.val.(int) + 1 + case int32: + itm.val = itm.val.(int32) + 1 + case int64: + itm.val = itm.val.(int64) + 1 + case uint: + itm.val = itm.val.(uint) + 1 + case uint32: + itm.val = itm.val.(uint32) + 1 + case uint64: + itm.val = itm.val.(uint64) + 1 + default: + return errors.New("item val is not (u)int (u)int32 (u)int64") + } + return nil +} + +// Decr decrease counter in memory. +func (bc *MemoryCache) Decr(key string) error { + bc.RLock() + defer bc.RUnlock() + itm, ok := bc.items[key] + if !ok { + return errors.New("key not exist") + } + switch itm.val.(type) { + case int: + itm.val = itm.val.(int) - 1 + case int64: + itm.val = itm.val.(int64) - 1 + case int32: + itm.val = itm.val.(int32) - 1 + case uint: + if itm.val.(uint) > 0 { + itm.val = itm.val.(uint) - 1 + } else { + return errors.New("item val is less than 0") + } + case uint32: + if itm.val.(uint32) > 0 { + itm.val = itm.val.(uint32) - 1 + } else { + return errors.New("item val is less than 0") + } + case uint64: + if itm.val.(uint64) > 0 { + itm.val = itm.val.(uint64) - 1 + } else { + return errors.New("item val is less than 0") + } + default: + return errors.New("item val is not int int64 int32") + } + return nil +} + +// IsExist check cache exist in memory. +func (bc *MemoryCache) IsExist(name string) bool { + bc.RLock() + defer bc.RUnlock() + if v, ok := bc.items[name]; ok { + return !v.isExpire() + } + return false +} + +// ClearAll will delete all cache in memory. +func (bc *MemoryCache) ClearAll() error { + bc.Lock() + defer bc.Unlock() + bc.items = make(map[string]*MemoryItem) + return nil +} + +// StartAndGC start memory cache. it will check expiration in every clock time. +func (bc *MemoryCache) StartAndGC(config string) error { + var cf map[string]int + json.Unmarshal([]byte(config), &cf) + if _, ok := cf["interval"]; !ok { + cf = make(map[string]int) + cf["interval"] = DefaultEvery + } + dur := time.Duration(cf["interval"]) * time.Second + bc.Every = cf["interval"] + bc.dur = dur + go bc.vacuum() + return nil +} + +// check expiration. +func (bc *MemoryCache) vacuum() { + bc.RLock() + every := bc.Every + bc.RUnlock() + + if every < 1 { + return + } + for { + <-time.After(bc.dur) + if bc.items == nil { + return + } + if keys := bc.expiredKeys(); len(keys) != 0 { + bc.clearItems(keys) + } + } +} + +// expiredKeys returns key list which are expired. +func (bc *MemoryCache) expiredKeys() (keys []string) { + bc.RLock() + defer bc.RUnlock() + for key, itm := range bc.items { + if itm.isExpire() { + keys = append(keys, key) + } + } + return +} + +// clearItems removes all the items which key in keys. +func (bc *MemoryCache) clearItems(keys []string) { + bc.Lock() + defer bc.Unlock() + for _, key := range keys { + delete(bc.items, key) + } +} + +func init() { + Register("memory", NewMemoryCache) +} diff --git a/app/utils/cache/redis.go b/app/utils/cache/redis.go new file mode 100644 index 0000000..de3be89 --- /dev/null +++ b/app/utils/cache/redis.go @@ -0,0 +1,409 @@ +package cache + +import ( + "encoding/json" + "errors" + "log" + "strings" + "time" + + redigo "github.com/gomodule/redigo/redis" +) + +// configuration +type Config struct { + Server string + Password string + MaxIdle int // Maximum number of idle connections in the pool. + + // Maximum number of connections allocated by the pool at a given time. + // When zero, there is no limit on the number of connections in the pool. + MaxActive int + + // Close connections after remaining idle for this duration. If the value + // is zero, then idle connections are not closed. Applications should set + // the timeout to a value less than the server's timeout. + IdleTimeout time.Duration + + // If Wait is true and the pool is at the MaxActive limit, then Get() waits + // for a connection to be returned to the pool before returning. + Wait bool + KeyPrefix string // prefix to all keys; example is "dev environment name" + KeyDelimiter string // delimiter to be used while appending keys; example is ":" + KeyPlaceholder string // placeholder to be parsed using given arguments to obtain a final key; example is "?" +} + +var pool *redigo.Pool +var conf *Config + +func NewRedis(addr string) { + if addr == "" { + panic("\nredis connect string cannot be empty\n") + } + pool = &redigo.Pool{ + MaxIdle: redisMaxIdleConn, + IdleTimeout: redisIdleTTL, + MaxActive: redisMaxActive, + // MaxConnLifetime: redisDialTTL, + Wait: true, + Dial: func() (redigo.Conn, error) { + c, err := redigo.Dial("tcp", addr, + redigo.DialPassword(redisPassword), + redigo.DialConnectTimeout(redisDialTTL), + redigo.DialReadTimeout(redisReadTTL), + redigo.DialWriteTimeout(redisWriteTTL), + ) + if err != nil { + log.Println("Redis Dial failed: ", err) + return nil, err + } + return c, err + }, + TestOnBorrow: func(c redigo.Conn, t time.Time) error { + _, err := c.Do("PING") + if err != nil { + log.Println("Unable to ping to redis server:", err) + } + return err + }, + } + conn := pool.Get() + defer conn.Close() + if conn.Err() != nil { + println("\nredis connect " + addr + " error: " + conn.Err().Error()) + } else { + println("\nredis connect " + addr + " success!\n") + } +} + +func Do(cmd string, args ...interface{}) (reply interface{}, err error) { + conn := pool.Get() + defer conn.Close() + return conn.Do(cmd, args...) +} + +func GetPool() *redigo.Pool { + return pool +} + +func ParseKey(key string, vars []string) (string, error) { + arr := strings.Split(key, conf.KeyPlaceholder) + actualKey := "" + if len(arr) != len(vars)+1 { + return "", errors.New("redis/connection.go: Insufficient arguments to parse key") + } else { + for index, val := range arr { + if index == 0 { + actualKey = arr[index] + } else { + actualKey += vars[index-1] + val + } + } + } + return getPrefixedKey(actualKey), nil +} + +func getPrefixedKey(key string) string { + return conf.KeyPrefix + conf.KeyDelimiter + key +} +func StripEnvKey(key string) string { + return strings.TrimLeft(key, conf.KeyPrefix+conf.KeyDelimiter) +} +func SplitKey(key string) []string { + return strings.Split(key, conf.KeyDelimiter) +} +func Expire(key string, ttl int) (interface{}, error) { + return Do("EXPIRE", key, ttl) +} +func Persist(key string) (interface{}, error) { + return Do("PERSIST", key) +} + +func Del(key string) (interface{}, error) { + return Do("DEL", key) +} +func Set(key string, data interface{}) (interface{}, error) { + // set + return Do("SET", key, data) +} +func SetNX(key string, data interface{}) (interface{}, error) { + return Do("SETNX", key, data) +} +func SetEx(key string, data interface{}, ttl int) (interface{}, error) { + return Do("SETEX", key, ttl, data) +} + +func SetJson(key string, data interface{}, ttl int) bool { + c, err := json.Marshal(data) + if err != nil { + return false + } + if ttl < 1 { + _, err = Set(key, c) + } else { + _, err = SetEx(key, c, ttl) + } + if err != nil { + return false + } + return true +} + +func GetJson(key string, dst interface{}) error { + b, err := GetBytes(key) + if err != nil { + return err + } + if err = json.Unmarshal(b, dst); err != nil { + return err + } + return nil +} + +func Get(key string) (interface{}, error) { + // get + return Do("GET", key) +} +func GetTTL(key string) (time.Duration, error) { + ttl, err := redigo.Int64(Do("TTL", key)) + return time.Duration(ttl) * time.Second, err +} +func GetBytes(key string) ([]byte, error) { + return redigo.Bytes(Do("GET", key)) +} +func GetString(key string) (string, error) { + return redigo.String(Do("GET", key)) +} +func GetStringMap(key string) (map[string]string, error) { + return redigo.StringMap(Do("GET", key)) +} +func GetInt(key string) (int, error) { + return redigo.Int(Do("GET", key)) +} +func GetInt64(key string) (int64, error) { + return redigo.Int64(Do("GET", key)) +} +func GetStringLength(key string) (int, error) { + return redigo.Int(Do("STRLEN", key)) +} +func ZAdd(key string, score float64, data interface{}) (interface{}, error) { + return Do("ZADD", key, score, data) +} +func ZAddNX(key string, score float64, data interface{}) (interface{}, error) { + return Do("ZADD", key, "NX", score, data) +} +func ZRem(key string, data interface{}) (interface{}, error) { + return Do("ZREM", key, data) +} +func ZRange(key string, start int, end int, withScores bool) ([]interface{}, error) { + if withScores { + return redigo.Values(Do("ZRANGE", key, start, end, "WITHSCORES")) + } + return redigo.Values(Do("ZRANGE", key, start, end)) +} +func ZRemRangeByScore(key string, start int64, end int64) ([]interface{}, error) { + return redigo.Values(Do("ZREMRANGEBYSCORE", key, start, end)) +} +func ZCard(setName string) (int64, error) { + return redigo.Int64(Do("ZCARD", setName)) +} +func ZScan(setName string) (int64, error) { + return redigo.Int64(Do("ZCARD", setName)) +} +func SAdd(setName string, data interface{}) (interface{}, error) { + return Do("SADD", setName, data) +} +func SCard(setName string) (int64, error) { + return redigo.Int64(Do("SCARD", setName)) +} +func SIsMember(setName string, data interface{}) (bool, error) { + return redigo.Bool(Do("SISMEMBER", setName, data)) +} +func SMembers(setName string) ([]string, error) { + return redigo.Strings(Do("SMEMBERS", setName)) +} +func SRem(setName string, data interface{}) (interface{}, error) { + return Do("SREM", setName, data) +} +func HSet(key string, HKey string, data interface{}) (interface{}, error) { + return Do("HSET", key, HKey, data) +} + +func HGet(key string, HKey string) (interface{}, error) { + return Do("HGET", key, HKey) +} + +func HMGet(key string, hashKeys ...string) ([]interface{}, error) { + ret, err := Do("HMGET", key, hashKeys) + if err != nil { + return nil, err + } + reta, ok := ret.([]interface{}) + if !ok { + return nil, errors.New("result not an array") + } + return reta, nil +} + +func HMSet(key string, hashKeys []string, vals []interface{}) (interface{}, error) { + if len(hashKeys) == 0 || len(hashKeys) != len(vals) { + var ret interface{} + return ret, errors.New("bad length") + } + input := []interface{}{key} + for i, v := range hashKeys { + input = append(input, v, vals[i]) + } + return Do("HMSET", input...) +} + +func HGetString(key string, HKey string) (string, error) { + return redigo.String(Do("HGET", key, HKey)) +} +func HGetFloat(key string, HKey string) (float64, error) { + f, err := redigo.Float64(Do("HGET", key, HKey)) + return f, err +} +func HGetInt(key string, HKey string) (int, error) { + return redigo.Int(Do("HGET", key, HKey)) +} +func HGetInt64(key string, HKey string) (int64, error) { + return redigo.Int64(Do("HGET", key, HKey)) +} +func HGetBool(key string, HKey string) (bool, error) { + return redigo.Bool(Do("HGET", key, HKey)) +} +func HDel(key string, HKey string) (interface{}, error) { + return Do("HDEL", key, HKey) +} + +func HGetAll(key string) (map[string]interface{}, error) { + vals, err := redigo.Values(Do("HGETALL", key)) + if err != nil { + return nil, err + } + num := len(vals) / 2 + result := make(map[string]interface{}, num) + for i := 0; i < num; i++ { + key, _ := redigo.String(vals[2*i], nil) + result[key] = vals[2*i+1] + } + return result, nil +} + +func FlushAll() bool { + res, _ := redigo.String(Do("FLUSHALL")) + if res == "" { + return false + } + return true +} + +// NOTE: Use this in production environment with extreme care. +// Read more here:https://redigo.io/commands/keys +func Keys(pattern string) ([]string, error) { + return redigo.Strings(Do("KEYS", pattern)) +} + +func HKeys(key string) ([]string, error) { + return redigo.Strings(Do("HKEYS", key)) +} + +func Exists(key string) bool { + count, err := redigo.Int(Do("EXISTS", key)) + if count == 0 || err != nil { + return false + } + return true +} + +func Incr(key string) (int64, error) { + return redigo.Int64(Do("INCR", key)) +} + +func Decr(key string) (int64, error) { + return redigo.Int64(Do("DECR", key)) +} + +func IncrBy(key string, incBy int64) (int64, error) { + return redigo.Int64(Do("INCRBY", key, incBy)) +} + +func DecrBy(key string, decrBy int64) (int64, error) { + return redigo.Int64(Do("DECRBY", key)) +} + +func IncrByFloat(key string, incBy float64) (float64, error) { + return redigo.Float64(Do("INCRBYFLOAT", key, incBy)) +} + +func DecrByFloat(key string, decrBy float64) (float64, error) { + return redigo.Float64(Do("DECRBYFLOAT", key, decrBy)) +} + +// use for message queue +func LPush(key string, data interface{}) (interface{}, error) { + // set + return Do("LPUSH", key, data) +} + +func LPop(key string) (interface{}, error) { + return Do("LPOP", key) +} + +func LPopString(key string) (string, error) { + return redigo.String(Do("LPOP", key)) +} +func LPopFloat(key string) (float64, error) { + f, err := redigo.Float64(Do("LPOP", key)) + return f, err +} +func LPopInt(key string) (int, error) { + return redigo.Int(Do("LPOP", key)) +} +func LPopInt64(key string) (int64, error) { + return redigo.Int64(Do("LPOP", key)) +} + +func RPush(key string, data interface{}) (interface{}, error) { + // set + return Do("RPUSH", key, data) +} + +func RPop(key string) (interface{}, error) { + return Do("RPOP", key) +} + +func RPopString(key string) (string, error) { + return redigo.String(Do("RPOP", key)) +} +func RPopFloat(key string) (float64, error) { + f, err := redigo.Float64(Do("RPOP", key)) + return f, err +} +func RPopInt(key string) (int, error) { + return redigo.Int(Do("RPOP", key)) +} +func RPopInt64(key string) (int64, error) { + return redigo.Int64(Do("RPOP", key)) +} + +func Scan(cursor int64, pattern string, count int64) (int64, []string, error) { + var items []string + var newCursor int64 + + values, err := redigo.Values(Do("SCAN", cursor, "MATCH", pattern, "COUNT", count)) + if err != nil { + return 0, nil, err + } + values, err = redigo.Scan(values, &newCursor, &items) + if err != nil { + return 0, nil, err + } + return newCursor, items, nil +} + +func LPushMax(key string, data ...interface{}) (interface{}, error) { + // set + return Do("LPUSH", key, data) +} diff --git a/app/utils/cache/redis_cluster.go b/app/utils/cache/redis_cluster.go new file mode 100644 index 0000000..901f30c --- /dev/null +++ b/app/utils/cache/redis_cluster.go @@ -0,0 +1,622 @@ +package cache + +import ( + "strconv" + "time" + + "github.com/go-redis/redis" +) + +var pools *redis.ClusterClient + +func NewRedisCluster(addrs []string) error { + opt := &redis.ClusterOptions{ + Addrs: addrs, + PoolSize: redisPoolSize, + PoolTimeout: redisPoolTTL, + IdleTimeout: redisIdleTTL, + DialTimeout: redisDialTTL, + ReadTimeout: redisReadTTL, + WriteTimeout: redisWriteTTL, + } + pools = redis.NewClusterClient(opt) + if err := pools.Ping().Err(); err != nil { + return err + } + return nil +} + +func RCGet(key string) (interface{}, error) { + res, err := pools.Get(key).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func RCSet(key string, value interface{}) error { + err := pools.Set(key, value, 0).Err() + return convertError(err) +} +func RCGetSet(key string, value interface{}) (interface{}, error) { + res, err := pools.GetSet(key, value).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func RCSetNx(key string, value interface{}) (int64, error) { + res, err := pools.SetNX(key, value, 0).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func RCSetEx(key string, value interface{}, timeout int64) error { + _, err := pools.Set(key, value, time.Duration(timeout)*time.Second).Result() + if err != nil { + return convertError(err) + } + return nil +} + +// nil表示成功,ErrNil表示数据库内已经存在这个key,其他表示数据库发生错误 +func RCSetNxEx(key string, value interface{}, timeout int64) error { + res, err := pools.SetNX(key, value, time.Duration(timeout)*time.Second).Result() + if err != nil { + return convertError(err) + } + if res { + return nil + } + return ErrNil +} +func RCMGet(keys ...string) ([]interface{}, error) { + res, err := pools.MGet(keys...).Result() + return res, convertError(err) +} + +// 为确保多个key映射到同一个slot,每个key最好加上hash tag,如:{test} +func RCMSet(kvs map[string]interface{}) error { + pairs := make([]string, 0, len(kvs)*2) + for k, v := range kvs { + val, err := String(v, nil) + if err != nil { + return err + } + pairs = append(pairs, k, val) + } + return convertError(pools.MSet(pairs).Err()) +} + +// 为确保多个key映射到同一个slot,每个key最好加上hash tag,如:{test} +func RCMSetNX(kvs map[string]interface{}) (bool, error) { + pairs := make([]string, 0, len(kvs)*2) + for k, v := range kvs { + val, err := String(v, nil) + if err != nil { + return false, err + } + pairs = append(pairs, k, val) + } + res, err := pools.MSetNX(pairs).Result() + return res, convertError(err) +} +func RCExpireAt(key string, timestamp int64) (int64, error) { + res, err := pools.ExpireAt(key, time.Unix(timestamp, 0)).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func RCDel(keys ...string) (int64, error) { + args := make([]interface{}, 0, len(keys)) + for _, key := range keys { + args = append(args, key) + } + res, err := pools.Del(keys...).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCIncr(key string) (int64, error) { + res, err := pools.Incr(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCIncrBy(key string, delta int64) (int64, error) { + res, err := pools.IncrBy(key, delta).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCExpire(key string, duration int64) (int64, error) { + res, err := pools.Expire(key, time.Duration(duration)*time.Second).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func RCExists(key string) (bool, error) { + res, err := pools.Exists(key).Result() + if err != nil { + return false, convertError(err) + } + if res > 0 { + return true, nil + } + return false, nil +} +func RCHGet(key string, field string) (interface{}, error) { + res, err := pools.HGet(key, field).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func RCHLen(key string) (int64, error) { + res, err := pools.HLen(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCHSet(key string, field string, val interface{}) error { + value, err := String(val, nil) + if err != nil && err != ErrNil { + return err + } + _, err = pools.HSet(key, field, value).Result() + if err != nil { + return convertError(err) + } + return nil +} +func RCHDel(key string, fields ...string) (int64, error) { + args := make([]interface{}, 0, len(fields)+1) + args = append(args, key) + for _, field := range fields { + args = append(args, field) + } + res, err := pools.HDel(key, fields...).Result() + if err != nil { + return 0, convertError(err) + } + return res, nil +} + +func RCHMGet(key string, fields ...string) (interface{}, error) { + args := make([]interface{}, 0, len(fields)+1) + args = append(args, key) + for _, field := range fields { + args = append(args, field) + } + if len(fields) == 0 { + return nil, ErrNil + } + res, err := pools.HMGet(key, fields...).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCHMSet(key string, kvs ...interface{}) error { + if len(kvs) == 0 { + return nil + } + if len(kvs)%2 != 0 { + return ErrWrongArgsNum + } + var err error + v := map[string]interface{}{} // todo change + v["field"], err = String(kvs[0], nil) + if err != nil && err != ErrNil { + return err + } + v["value"], err = String(kvs[1], nil) + if err != nil && err != ErrNil { + return err + } + pairs := make([]string, 0, len(kvs)-2) + if len(kvs) > 2 { + for _, kv := range kvs[2:] { + kvString, err := String(kv, nil) + if err != nil && err != ErrNil { + return err + } + pairs = append(pairs, kvString) + } + } + v["paris"] = pairs + _, err = pools.HMSet(key, v).Result() + if err != nil { + return convertError(err) + } + return nil +} + +func RCHKeys(key string) ([]string, error) { + res, err := pools.HKeys(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCHVals(key string) ([]interface{}, error) { + res, err := pools.HVals(key).Result() + if err != nil { + return nil, convertError(err) + } + rs := make([]interface{}, 0, len(res)) + for _, res := range res { + rs = append(rs, res) + } + return rs, nil +} +func RCHGetAll(key string) (map[string]string, error) { + vals, err := pools.HGetAll(key).Result() + if err != nil { + return nil, convertError(err) + } + return vals, nil +} +func RCHIncrBy(key, field string, delta int64) (int64, error) { + res, err := pools.HIncrBy(key, field, delta).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCZAdd(key string, kvs ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(kvs)+1) + args = append(args, key) + args = append(args, kvs...) + if len(kvs) == 0 { + return 0, nil + } + if len(kvs)%2 != 0 { + return 0, ErrWrongArgsNum + } + zs := make([]redis.Z, len(kvs)/2) + for i := 0; i < len(kvs); i += 2 { + idx := i / 2 + score, err := Float64(kvs[i], nil) + if err != nil && err != ErrNil { + return 0, err + } + zs[idx].Score = score + zs[idx].Member = kvs[i+1] + } + res, err := pools.ZAdd(key, zs...).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCZRem(key string, members ...string) (int64, error) { + args := make([]interface{}, 0, len(members)) + args = append(args, key) + for _, member := range members { + args = append(args, member) + } + res, err := pools.ZRem(key, members).Result() + if err != nil { + return res, convertError(err) + } + return res, err +} + +func RCZRange(key string, min, max int64, withScores bool) (interface{}, error) { + res := make([]interface{}, 0) + if withScores { + zs, err := pools.ZRangeWithScores(key, min, max).Result() + if err != nil { + return nil, convertError(err) + } + for _, z := range zs { + res = append(res, z.Member, strconv.FormatFloat(z.Score, 'f', -1, 64)) + } + } else { + ms, err := pools.ZRange(key, min, max).Result() + if err != nil { + return nil, convertError(err) + } + for _, m := range ms { + res = append(res, m) + } + } + return res, nil +} +func RCZRangeByScoreWithScore(key string, min, max int64) (map[string]int64, error) { + opt := new(redis.ZRangeBy) + opt.Min = strconv.FormatInt(int64(min), 10) + opt.Max = strconv.FormatInt(int64(max), 10) + opt.Count = -1 + opt.Offset = 0 + vals, err := pools.ZRangeByScoreWithScores(key, *opt).Result() + if err != nil { + return nil, convertError(err) + } + res := make(map[string]int64, len(vals)) + for _, val := range vals { + key, err := String(val.Member, nil) + if err != nil && err != ErrNil { + return nil, err + } + res[key] = int64(val.Score) + } + return res, nil +} +func RCLRange(key string, start, stop int64) (interface{}, error) { + res, err := pools.LRange(key, start, stop).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCLSet(key string, index int, value interface{}) error { + err := pools.LSet(key, int64(index), value).Err() + return convertError(err) +} +func RCLLen(key string) (int64, error) { + res, err := pools.LLen(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCLRem(key string, count int, value interface{}) (int, error) { + val, _ := value.(string) + res, err := pools.LRem(key, int64(count), val).Result() + if err != nil { + return int(res), convertError(err) + } + return int(res), nil +} +func RCTTl(key string) (int64, error) { + duration, err := pools.TTL(key).Result() + if err != nil { + return int64(duration.Seconds()), convertError(err) + } + return int64(duration.Seconds()), nil +} +func RCLPop(key string) (interface{}, error) { + res, err := pools.LPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCRPop(key string) (interface{}, error) { + res, err := pools.RPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCBLPop(key string, timeout int) (interface{}, error) { + res, err := pools.BLPop(time.Duration(timeout)*time.Second, key).Result() + if err != nil { + // 兼容redis 2.x + if err == redis.Nil { + return nil, ErrNil + } + return nil, err + } + return res[1], nil +} +func RCBRPop(key string, timeout int) (interface{}, error) { + res, err := pools.BRPop(time.Duration(timeout)*time.Second, key).Result() + if err != nil { + // 兼容redis 2.x + if err == redis.Nil { + return nil, ErrNil + } + return nil, convertError(err) + } + return res[1], nil +} +func RCLPush(key string, value ...interface{}) error { + args := make([]interface{}, 0, len(value)+1) + args = append(args, key) + args = append(args, value...) + vals := make([]string, 0, len(value)) + for _, v := range value { + val, err := String(v, nil) + if err != nil && err != ErrNil { + return err + } + vals = append(vals, val) + } + _, err := pools.LPush(key, vals).Result() // todo ... + if err != nil { + return convertError(err) + } + return nil +} +func RCRPush(key string, value ...interface{}) error { + args := make([]interface{}, 0, len(value)+1) + args = append(args, key) + args = append(args, value...) + vals := make([]string, 0, len(value)) + for _, v := range value { + val, err := String(v, nil) + if err != nil && err != ErrNil { + if err == ErrNil { + continue + } + return err + } + if val == "" { + continue + } + vals = append(vals, val) + } + _, err := pools.RPush(key, vals).Result() // todo ... + if err != nil { + return convertError(err) + } + return nil +} + +// 为确保srcKey跟destKey映射到同一个slot,srcKey和destKey需要加上hash tag,如:{test} +func RCBRPopLPush(srcKey string, destKey string, timeout int) (interface{}, error) { + res, err := pools.BRPopLPush(srcKey, destKey, time.Duration(timeout)*time.Second).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} + +// 为确保srcKey跟destKey映射到同一个slot,srcKey和destKey需要加上hash tag,如:{test} +func RCRPopLPush(srcKey string, destKey string) (interface{}, error) { + res, err := pools.RPopLPush(srcKey, destKey).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCSAdd(key string, members ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(members)+1) + args = append(args, key) + args = append(args, members...) + ms := make([]string, 0, len(members)) + for _, member := range members { + m, err := String(member, nil) + if err != nil && err != ErrNil { + return 0, err + } + ms = append(ms, m) + } + res, err := pools.SAdd(key, ms).Result() // todo ... + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCSPop(key string) ([]byte, error) { + res, err := pools.SPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func RCSIsMember(key string, member interface{}) (bool, error) { + m, _ := member.(string) + res, err := pools.SIsMember(key, m).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCSRem(key string, members ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(members)+1) + args = append(args, key) + args = append(args, members...) + ms := make([]string, 0, len(members)) + for _, member := range members { + m, err := String(member, nil) + if err != nil && err != ErrNil { + return 0, err + } + ms = append(ms, m) + } + res, err := pools.SRem(key, ms).Result() // todo ... + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCSMembers(key string) ([]string, error) { + res, err := pools.SMembers(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCScriptLoad(luaScript string) (interface{}, error) { + res, err := pools.ScriptLoad(luaScript).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCEvalSha(sha1 string, numberKeys int, keysArgs ...interface{}) (interface{}, error) { + vals := make([]interface{}, 0, len(keysArgs)+2) + vals = append(vals, sha1, numberKeys) + vals = append(vals, keysArgs...) + keys := make([]string, 0, numberKeys) + args := make([]string, 0, len(keysArgs)-numberKeys) + for i, value := range keysArgs { + val, err := String(value, nil) + if err != nil && err != ErrNil { + return nil, err + } + if i < numberKeys { + keys = append(keys, val) + } else { + args = append(args, val) + } + } + res, err := pools.EvalSha(sha1, keys, args).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCEval(luaScript string, numberKeys int, keysArgs ...interface{}) (interface{}, error) { + vals := make([]interface{}, 0, len(keysArgs)+2) + vals = append(vals, luaScript, numberKeys) + vals = append(vals, keysArgs...) + keys := make([]string, 0, numberKeys) + args := make([]string, 0, len(keysArgs)-numberKeys) + for i, value := range keysArgs { + val, err := String(value, nil) + if err != nil && err != ErrNil { + return nil, err + } + if i < numberKeys { + keys = append(keys, val) + } else { + args = append(args, val) + } + } + res, err := pools.Eval(luaScript, keys, args).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func RCGetBit(key string, offset int64) (int64, error) { + res, err := pools.GetBit(key, offset).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func RCSetBit(key string, offset uint32, value int) (int, error) { + res, err := pools.SetBit(key, int64(offset), value).Result() + return int(res), convertError(err) +} +func RCGetClient() *redis.ClusterClient { + return pools +} +func convertError(err error) error { + if err == redis.Nil { + // 为了兼容redis 2.x,这里不返回 ErrNil,ErrNil在调用redis_cluster_reply函数时才返回 + return nil + } + return err +} diff --git a/app/utils/cache/redis_pool.go b/app/utils/cache/redis_pool.go new file mode 100644 index 0000000..ca38b3f --- /dev/null +++ b/app/utils/cache/redis_pool.go @@ -0,0 +1,324 @@ +package cache + +import ( + "errors" + "log" + "strings" + "time" + + redigo "github.com/gomodule/redigo/redis" +) + +type RedisPool struct { + *redigo.Pool +} + +func NewRedisPool(cfg *Config) *RedisPool { + return &RedisPool{&redigo.Pool{ + MaxIdle: cfg.MaxIdle, + IdleTimeout: cfg.IdleTimeout, + MaxActive: cfg.MaxActive, + Wait: cfg.Wait, + Dial: func() (redigo.Conn, error) { + c, err := redigo.Dial("tcp", cfg.Server) + if err != nil { + log.Println("Redis Dial failed: ", err) + return nil, err + } + if cfg.Password != "" { + if _, err := c.Do("AUTH", cfg.Password); err != nil { + c.Close() + log.Println("Redis AUTH failed: ", err) + return nil, err + } + } + return c, err + }, + TestOnBorrow: func(c redigo.Conn, t time.Time) error { + _, err := c.Do("PING") + if err != nil { + log.Println("Unable to ping to redis server:", err) + } + return err + }, + }} +} + +func (p *RedisPool) Do(cmd string, args ...interface{}) (reply interface{}, err error) { + conn := pool.Get() + defer conn.Close() + return conn.Do(cmd, args...) +} + +func (p *RedisPool) GetPool() *redigo.Pool { + return pool +} + +func (p *RedisPool) ParseKey(key string, vars []string) (string, error) { + arr := strings.Split(key, conf.KeyPlaceholder) + actualKey := "" + if len(arr) != len(vars)+1 { + return "", errors.New("redis/connection.go: Insufficient arguments to parse key") + } else { + for index, val := range arr { + if index == 0 { + actualKey = arr[index] + } else { + actualKey += vars[index-1] + val + } + } + } + return getPrefixedKey(actualKey), nil +} + +func (p *RedisPool) getPrefixedKey(key string) string { + return conf.KeyPrefix + conf.KeyDelimiter + key +} +func (p *RedisPool) StripEnvKey(key string) string { + return strings.TrimLeft(key, conf.KeyPrefix+conf.KeyDelimiter) +} +func (p *RedisPool) SplitKey(key string) []string { + return strings.Split(key, conf.KeyDelimiter) +} +func (p *RedisPool) Expire(key string, ttl int) (interface{}, error) { + return Do("EXPIRE", key, ttl) +} +func (p *RedisPool) Persist(key string) (interface{}, error) { + return Do("PERSIST", key) +} + +func (p *RedisPool) Del(key string) (interface{}, error) { + return Do("DEL", key) +} +func (p *RedisPool) Set(key string, data interface{}) (interface{}, error) { + // set + return Do("SET", key, data) +} +func (p *RedisPool) SetNX(key string, data interface{}) (interface{}, error) { + return Do("SETNX", key, data) +} +func (p *RedisPool) SetEx(key string, data interface{}, ttl int) (interface{}, error) { + return Do("SETEX", key, ttl, data) +} +func (p *RedisPool) Get(key string) (interface{}, error) { + // get + return Do("GET", key) +} +func (p *RedisPool) GetStringMap(key string) (map[string]string, error) { + // get + return redigo.StringMap(Do("GET", key)) +} + +func (p *RedisPool) GetTTL(key string) (time.Duration, error) { + ttl, err := redigo.Int64(Do("TTL", key)) + return time.Duration(ttl) * time.Second, err +} +func (p *RedisPool) GetBytes(key string) ([]byte, error) { + return redigo.Bytes(Do("GET", key)) +} +func (p *RedisPool) GetString(key string) (string, error) { + return redigo.String(Do("GET", key)) +} +func (p *RedisPool) GetInt(key string) (int, error) { + return redigo.Int(Do("GET", key)) +} +func (p *RedisPool) GetStringLength(key string) (int, error) { + return redigo.Int(Do("STRLEN", key)) +} +func (p *RedisPool) ZAdd(key string, score float64, data interface{}) (interface{}, error) { + return Do("ZADD", key, score, data) +} +func (p *RedisPool) ZRem(key string, data interface{}) (interface{}, error) { + return Do("ZREM", key, data) +} +func (p *RedisPool) ZRange(key string, start int, end int, withScores bool) ([]interface{}, error) { + if withScores { + return redigo.Values(Do("ZRANGE", key, start, end, "WITHSCORES")) + } + return redigo.Values(Do("ZRANGE", key, start, end)) +} +func (p *RedisPool) SAdd(setName string, data interface{}) (interface{}, error) { + return Do("SADD", setName, data) +} +func (p *RedisPool) SCard(setName string) (int64, error) { + return redigo.Int64(Do("SCARD", setName)) +} +func (p *RedisPool) SIsMember(setName string, data interface{}) (bool, error) { + return redigo.Bool(Do("SISMEMBER", setName, data)) +} +func (p *RedisPool) SMembers(setName string) ([]string, error) { + return redigo.Strings(Do("SMEMBERS", setName)) +} +func (p *RedisPool) SRem(setName string, data interface{}) (interface{}, error) { + return Do("SREM", setName, data) +} +func (p *RedisPool) HSet(key string, HKey string, data interface{}) (interface{}, error) { + return Do("HSET", key, HKey, data) +} + +func (p *RedisPool) HGet(key string, HKey string) (interface{}, error) { + return Do("HGET", key, HKey) +} + +func (p *RedisPool) HMGet(key string, hashKeys ...string) ([]interface{}, error) { + ret, err := Do("HMGET", key, hashKeys) + if err != nil { + return nil, err + } + reta, ok := ret.([]interface{}) + if !ok { + return nil, errors.New("result not an array") + } + return reta, nil +} + +func (p *RedisPool) HMSet(key string, hashKeys []string, vals []interface{}) (interface{}, error) { + if len(hashKeys) == 0 || len(hashKeys) != len(vals) { + var ret interface{} + return ret, errors.New("bad length") + } + input := []interface{}{key} + for i, v := range hashKeys { + input = append(input, v, vals[i]) + } + return Do("HMSET", input...) +} + +func (p *RedisPool) HGetString(key string, HKey string) (string, error) { + return redigo.String(Do("HGET", key, HKey)) +} +func (p *RedisPool) HGetFloat(key string, HKey string) (float64, error) { + f, err := redigo.Float64(Do("HGET", key, HKey)) + return float64(f), err +} +func (p *RedisPool) HGetInt(key string, HKey string) (int, error) { + return redigo.Int(Do("HGET", key, HKey)) +} +func (p *RedisPool) HGetInt64(key string, HKey string) (int64, error) { + return redigo.Int64(Do("HGET", key, HKey)) +} +func (p *RedisPool) HGetBool(key string, HKey string) (bool, error) { + return redigo.Bool(Do("HGET", key, HKey)) +} +func (p *RedisPool) HDel(key string, HKey string) (interface{}, error) { + return Do("HDEL", key, HKey) +} +func (p *RedisPool) HGetAll(key string) (map[string]interface{}, error) { + vals, err := redigo.Values(Do("HGETALL", key)) + if err != nil { + return nil, err + } + num := len(vals) / 2 + result := make(map[string]interface{}, num) + for i := 0; i < num; i++ { + key, _ := redigo.String(vals[2*i], nil) + result[key] = vals[2*i+1] + } + return result, nil +} + +// NOTE: Use this in production environment with extreme care. +// Read more here:https://redigo.io/commands/keys +func (p *RedisPool) Keys(pattern string) ([]string, error) { + return redigo.Strings(Do("KEYS", pattern)) +} + +func (p *RedisPool) HKeys(key string) ([]string, error) { + return redigo.Strings(Do("HKEYS", key)) +} + +func (p *RedisPool) Exists(key string) (bool, error) { + count, err := redigo.Int(Do("EXISTS", key)) + if count == 0 { + return false, err + } else { + return true, err + } +} + +func (p *RedisPool) Incr(key string) (int64, error) { + return redigo.Int64(Do("INCR", key)) +} + +func (p *RedisPool) Decr(key string) (int64, error) { + return redigo.Int64(Do("DECR", key)) +} + +func (p *RedisPool) IncrBy(key string, incBy int64) (int64, error) { + return redigo.Int64(Do("INCRBY", key, incBy)) +} + +func (p *RedisPool) DecrBy(key string, decrBy int64) (int64, error) { + return redigo.Int64(Do("DECRBY", key)) +} + +func (p *RedisPool) IncrByFloat(key string, incBy float64) (float64, error) { + return redigo.Float64(Do("INCRBYFLOAT", key, incBy)) +} + +func (p *RedisPool) DecrByFloat(key string, decrBy float64) (float64, error) { + return redigo.Float64(Do("DECRBYFLOAT", key, decrBy)) +} + +// use for message queue +func (p *RedisPool) LPush(key string, data interface{}) (interface{}, error) { + // set + return Do("LPUSH", key, data) +} + +func (p *RedisPool) LPop(key string) (interface{}, error) { + return Do("LPOP", key) +} + +func (p *RedisPool) LPopString(key string) (string, error) { + return redigo.String(Do("LPOP", key)) +} +func (p *RedisPool) LPopFloat(key string) (float64, error) { + f, err := redigo.Float64(Do("LPOP", key)) + return float64(f), err +} +func (p *RedisPool) LPopInt(key string) (int, error) { + return redigo.Int(Do("LPOP", key)) +} +func (p *RedisPool) LPopInt64(key string) (int64, error) { + return redigo.Int64(Do("LPOP", key)) +} + +func (p *RedisPool) RPush(key string, data interface{}) (interface{}, error) { + // set + return Do("RPUSH", key, data) +} + +func (p *RedisPool) RPop(key string) (interface{}, error) { + return Do("RPOP", key) +} + +func (p *RedisPool) RPopString(key string) (string, error) { + return redigo.String(Do("RPOP", key)) +} +func (p *RedisPool) RPopFloat(key string) (float64, error) { + f, err := redigo.Float64(Do("RPOP", key)) + return float64(f), err +} +func (p *RedisPool) RPopInt(key string) (int, error) { + return redigo.Int(Do("RPOP", key)) +} +func (p *RedisPool) RPopInt64(key string) (int64, error) { + return redigo.Int64(Do("RPOP", key)) +} + +func (p *RedisPool) Scan(cursor int64, pattern string, count int64) (int64, []string, error) { + var items []string + var newCursor int64 + + values, err := redigo.Values(Do("SCAN", cursor, "MATCH", pattern, "COUNT", count)) + if err != nil { + return 0, nil, err + } + values, err = redigo.Scan(values, &newCursor, &items) + if err != nil { + return 0, nil, err + } + + return newCursor, items, nil +} diff --git a/app/utils/cache/redis_pool_cluster.go b/app/utils/cache/redis_pool_cluster.go new file mode 100644 index 0000000..cd1911b --- /dev/null +++ b/app/utils/cache/redis_pool_cluster.go @@ -0,0 +1,617 @@ +package cache + +import ( + "strconv" + "time" + + "github.com/go-redis/redis" +) + +type RedisClusterPool struct { + client *redis.ClusterClient +} + +func NewRedisClusterPool(addrs []string) (*RedisClusterPool, error) { + opt := &redis.ClusterOptions{ + Addrs: addrs, + PoolSize: 512, + PoolTimeout: 10 * time.Second, + IdleTimeout: 10 * time.Second, + DialTimeout: 10 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 3 * time.Second, + } + c := redis.NewClusterClient(opt) + if err := c.Ping().Err(); err != nil { + return nil, err + } + return &RedisClusterPool{client: c}, nil +} + +func (p *RedisClusterPool) Get(key string) (interface{}, error) { + res, err := p.client.Get(key).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func (p *RedisClusterPool) Set(key string, value interface{}) error { + err := p.client.Set(key, value, 0).Err() + return convertError(err) +} +func (p *RedisClusterPool) GetSet(key string, value interface{}) (interface{}, error) { + res, err := p.client.GetSet(key, value).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func (p *RedisClusterPool) SetNx(key string, value interface{}) (int64, error) { + res, err := p.client.SetNX(key, value, 0).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func (p *RedisClusterPool) SetEx(key string, value interface{}, timeout int64) error { + _, err := p.client.Set(key, value, time.Duration(timeout)*time.Second).Result() + if err != nil { + return convertError(err) + } + return nil +} + +// nil表示成功,ErrNil表示数据库内已经存在这个key,其他表示数据库发生错误 +func (p *RedisClusterPool) SetNxEx(key string, value interface{}, timeout int64) error { + res, err := p.client.SetNX(key, value, time.Duration(timeout)*time.Second).Result() + if err != nil { + return convertError(err) + } + if res { + return nil + } + return ErrNil +} +func (p *RedisClusterPool) MGet(keys ...string) ([]interface{}, error) { + res, err := p.client.MGet(keys...).Result() + return res, convertError(err) +} + +// 为确保多个key映射到同一个slot,每个key最好加上hash tag,如:{test} +func (p *RedisClusterPool) MSet(kvs map[string]interface{}) error { + pairs := make([]string, 0, len(kvs)*2) + for k, v := range kvs { + val, err := String(v, nil) + if err != nil { + return err + } + pairs = append(pairs, k, val) + } + return convertError(p.client.MSet(pairs).Err()) +} + +// 为确保多个key映射到同一个slot,每个key最好加上hash tag,如:{test} +func (p *RedisClusterPool) MSetNX(kvs map[string]interface{}) (bool, error) { + pairs := make([]string, 0, len(kvs)*2) + for k, v := range kvs { + val, err := String(v, nil) + if err != nil { + return false, err + } + pairs = append(pairs, k, val) + } + res, err := p.client.MSetNX(pairs).Result() + return res, convertError(err) +} +func (p *RedisClusterPool) ExpireAt(key string, timestamp int64) (int64, error) { + res, err := p.client.ExpireAt(key, time.Unix(timestamp, 0)).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func (p *RedisClusterPool) Del(keys ...string) (int64, error) { + args := make([]interface{}, 0, len(keys)) + for _, key := range keys { + args = append(args, key) + } + res, err := p.client.Del(keys...).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) Incr(key string) (int64, error) { + res, err := p.client.Incr(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) IncrBy(key string, delta int64) (int64, error) { + res, err := p.client.IncrBy(key, delta).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) Expire(key string, duration int64) (int64, error) { + res, err := p.client.Expire(key, time.Duration(duration)*time.Second).Result() + if err != nil { + return 0, convertError(err) + } + if res { + return 1, nil + } + return 0, nil +} +func (p *RedisClusterPool) Exists(key string) (bool, error) { // todo (bool, error) + res, err := p.client.Exists(key).Result() + if err != nil { + return false, convertError(err) + } + if res > 0 { + return true, nil + } + return false, nil +} +func (p *RedisClusterPool) HGet(key string, field string) (interface{}, error) { + res, err := p.client.HGet(key, field).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func (p *RedisClusterPool) HLen(key string) (int64, error) { + res, err := p.client.HLen(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) HSet(key string, field string, val interface{}) error { + value, err := String(val, nil) + if err != nil && err != ErrNil { + return err + } + _, err = p.client.HSet(key, field, value).Result() + if err != nil { + return convertError(err) + } + return nil +} +func (p *RedisClusterPool) HDel(key string, fields ...string) (int64, error) { + args := make([]interface{}, 0, len(fields)+1) + args = append(args, key) + for _, field := range fields { + args = append(args, field) + } + res, err := p.client.HDel(key, fields...).Result() + if err != nil { + return 0, convertError(err) + } + return res, nil +} + +func (p *RedisClusterPool) HMGet(key string, fields ...string) (interface{}, error) { + args := make([]interface{}, 0, len(fields)+1) + args = append(args, key) + for _, field := range fields { + args = append(args, field) + } + if len(fields) == 0 { + return nil, ErrNil + } + res, err := p.client.HMGet(key, fields...).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) HMSet(key string, kvs ...interface{}) error { + if len(kvs) == 0 { + return nil + } + if len(kvs)%2 != 0 { + return ErrWrongArgsNum + } + var err error + v := map[string]interface{}{} // todo change + v["field"], err = String(kvs[0], nil) + if err != nil && err != ErrNil { + return err + } + v["value"], err = String(kvs[1], nil) + if err != nil && err != ErrNil { + return err + } + pairs := make([]string, 0, len(kvs)-2) + if len(kvs) > 2 { + for _, kv := range kvs[2:] { + kvString, err := String(kv, nil) + if err != nil && err != ErrNil { + return err + } + pairs = append(pairs, kvString) + } + } + v["paris"] = pairs + _, err = p.client.HMSet(key, v).Result() + if err != nil { + return convertError(err) + } + return nil +} + +func (p *RedisClusterPool) HKeys(key string) ([]string, error) { + res, err := p.client.HKeys(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) HVals(key string) ([]interface{}, error) { + res, err := p.client.HVals(key).Result() + if err != nil { + return nil, convertError(err) + } + rs := make([]interface{}, 0, len(res)) + for _, res := range res { + rs = append(rs, res) + } + return rs, nil +} +func (p *RedisClusterPool) HGetAll(key string) (map[string]string, error) { + vals, err := p.client.HGetAll(key).Result() + if err != nil { + return nil, convertError(err) + } + return vals, nil +} +func (p *RedisClusterPool) HIncrBy(key, field string, delta int64) (int64, error) { + res, err := p.client.HIncrBy(key, field, delta).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) ZAdd(key string, kvs ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(kvs)+1) + args = append(args, key) + args = append(args, kvs...) + if len(kvs) == 0 { + return 0, nil + } + if len(kvs)%2 != 0 { + return 0, ErrWrongArgsNum + } + zs := make([]redis.Z, len(kvs)/2) + for i := 0; i < len(kvs); i += 2 { + idx := i / 2 + score, err := Float64(kvs[i], nil) + if err != nil && err != ErrNil { + return 0, err + } + zs[idx].Score = score + zs[idx].Member = kvs[i+1] + } + res, err := p.client.ZAdd(key, zs...).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) ZRem(key string, members ...string) (int64, error) { + args := make([]interface{}, 0, len(members)) + args = append(args, key) + for _, member := range members { + args = append(args, member) + } + res, err := p.client.ZRem(key, members).Result() + if err != nil { + return res, convertError(err) + } + return res, err +} + +func (p *RedisClusterPool) ZRange(key string, min, max int64, withScores bool) (interface{}, error) { + res := make([]interface{}, 0) + if withScores { + zs, err := p.client.ZRangeWithScores(key, min, max).Result() + if err != nil { + return nil, convertError(err) + } + for _, z := range zs { + res = append(res, z.Member, strconv.FormatFloat(z.Score, 'f', -1, 64)) + } + } else { + ms, err := p.client.ZRange(key, min, max).Result() + if err != nil { + return nil, convertError(err) + } + for _, m := range ms { + res = append(res, m) + } + } + return res, nil +} +func (p *RedisClusterPool) ZRangeByScoreWithScore(key string, min, max int64) (map[string]int64, error) { + opt := new(redis.ZRangeBy) + opt.Min = strconv.FormatInt(int64(min), 10) + opt.Max = strconv.FormatInt(int64(max), 10) + opt.Count = -1 + opt.Offset = 0 + vals, err := p.client.ZRangeByScoreWithScores(key, *opt).Result() + if err != nil { + return nil, convertError(err) + } + res := make(map[string]int64, len(vals)) + for _, val := range vals { + key, err := String(val.Member, nil) + if err != nil && err != ErrNil { + return nil, err + } + res[key] = int64(val.Score) + } + return res, nil +} +func (p *RedisClusterPool) LRange(key string, start, stop int64) (interface{}, error) { + res, err := p.client.LRange(key, start, stop).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) LSet(key string, index int, value interface{}) error { + err := p.client.LSet(key, int64(index), value).Err() + return convertError(err) +} +func (p *RedisClusterPool) LLen(key string) (int64, error) { + res, err := p.client.LLen(key).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) LRem(key string, count int, value interface{}) (int, error) { + val, _ := value.(string) + res, err := p.client.LRem(key, int64(count), val).Result() + if err != nil { + return int(res), convertError(err) + } + return int(res), nil +} +func (p *RedisClusterPool) TTl(key string) (int64, error) { + duration, err := p.client.TTL(key).Result() + if err != nil { + return int64(duration.Seconds()), convertError(err) + } + return int64(duration.Seconds()), nil +} +func (p *RedisClusterPool) LPop(key string) (interface{}, error) { + res, err := p.client.LPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) RPop(key string) (interface{}, error) { + res, err := p.client.RPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) BLPop(key string, timeout int) (interface{}, error) { + res, err := p.client.BLPop(time.Duration(timeout)*time.Second, key).Result() + if err != nil { + // 兼容redis 2.x + if err == redis.Nil { + return nil, ErrNil + } + return nil, err + } + return res[1], nil +} +func (p *RedisClusterPool) BRPop(key string, timeout int) (interface{}, error) { + res, err := p.client.BRPop(time.Duration(timeout)*time.Second, key).Result() + if err != nil { + // 兼容redis 2.x + if err == redis.Nil { + return nil, ErrNil + } + return nil, convertError(err) + } + return res[1], nil +} +func (p *RedisClusterPool) LPush(key string, value ...interface{}) error { + args := make([]interface{}, 0, len(value)+1) + args = append(args, key) + args = append(args, value...) + vals := make([]string, 0, len(value)) + for _, v := range value { + val, err := String(v, nil) + if err != nil && err != ErrNil { + return err + } + vals = append(vals, val) + } + _, err := p.client.LPush(key, vals).Result() // todo ... + if err != nil { + return convertError(err) + } + return nil +} +func (p *RedisClusterPool) RPush(key string, value ...interface{}) error { + args := make([]interface{}, 0, len(value)+1) + args = append(args, key) + args = append(args, value...) + vals := make([]string, 0, len(value)) + for _, v := range value { + val, err := String(v, nil) + if err != nil && err != ErrNil { + if err == ErrNil { + continue + } + return err + } + if val == "" { + continue + } + vals = append(vals, val) + } + _, err := p.client.RPush(key, vals).Result() // todo ... + if err != nil { + return convertError(err) + } + return nil +} + +// 为确保srcKey跟destKey映射到同一个slot,srcKey和destKey需要加上hash tag,如:{test} +func (p *RedisClusterPool) BRPopLPush(srcKey string, destKey string, timeout int) (interface{}, error) { + res, err := p.client.BRPopLPush(srcKey, destKey, time.Duration(timeout)*time.Second).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} + +// 为确保srcKey跟destKey映射到同一个slot,srcKey和destKey需要加上hash tag,如:{test} +func (p *RedisClusterPool) RPopLPush(srcKey string, destKey string) (interface{}, error) { + res, err := p.client.RPopLPush(srcKey, destKey).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SAdd(key string, members ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(members)+1) + args = append(args, key) + args = append(args, members...) + ms := make([]string, 0, len(members)) + for _, member := range members { + m, err := String(member, nil) + if err != nil && err != ErrNil { + return 0, err + } + ms = append(ms, m) + } + res, err := p.client.SAdd(key, ms).Result() // todo ... + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SPop(key string) ([]byte, error) { + res, err := p.client.SPop(key).Result() + if err != nil { + return nil, convertError(err) + } + return []byte(res), nil +} +func (p *RedisClusterPool) SIsMember(key string, member interface{}) (bool, error) { + m, _ := member.(string) + res, err := p.client.SIsMember(key, m).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SRem(key string, members ...interface{}) (int64, error) { + args := make([]interface{}, 0, len(members)+1) + args = append(args, key) + args = append(args, members...) + ms := make([]string, 0, len(members)) + for _, member := range members { + m, err := String(member, nil) + if err != nil && err != ErrNil { + return 0, err + } + ms = append(ms, m) + } + res, err := p.client.SRem(key, ms).Result() // todo ... + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SMembers(key string) ([]string, error) { + res, err := p.client.SMembers(key).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) ScriptLoad(luaScript string) (interface{}, error) { + res, err := p.client.ScriptLoad(luaScript).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) EvalSha(sha1 string, numberKeys int, keysArgs ...interface{}) (interface{}, error) { + vals := make([]interface{}, 0, len(keysArgs)+2) + vals = append(vals, sha1, numberKeys) + vals = append(vals, keysArgs...) + keys := make([]string, 0, numberKeys) + args := make([]string, 0, len(keysArgs)-numberKeys) + for i, value := range keysArgs { + val, err := String(value, nil) + if err != nil && err != ErrNil { + return nil, err + } + if i < numberKeys { + keys = append(keys, val) + } else { + args = append(args, val) + } + } + res, err := p.client.EvalSha(sha1, keys, args).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) Eval(luaScript string, numberKeys int, keysArgs ...interface{}) (interface{}, error) { + vals := make([]interface{}, 0, len(keysArgs)+2) + vals = append(vals, luaScript, numberKeys) + vals = append(vals, keysArgs...) + keys := make([]string, 0, numberKeys) + args := make([]string, 0, len(keysArgs)-numberKeys) + for i, value := range keysArgs { + val, err := String(value, nil) + if err != nil && err != ErrNil { + return nil, err + } + if i < numberKeys { + keys = append(keys, val) + } else { + args = append(args, val) + } + } + res, err := p.client.Eval(luaScript, keys, args).Result() + if err != nil { + return nil, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) GetBit(key string, offset int64) (int64, error) { + res, err := p.client.GetBit(key, offset).Result() + if err != nil { + return res, convertError(err) + } + return res, nil +} +func (p *RedisClusterPool) SetBit(key string, offset uint32, value int) (int, error) { + res, err := p.client.SetBit(key, int64(offset), value).Result() + return int(res), convertError(err) +} +func (p *RedisClusterPool) GetClient() *redis.ClusterClient { + return pools +} diff --git a/app/utils/convert.go b/app/utils/convert.go new file mode 100644 index 0000000..a638d37 --- /dev/null +++ b/app/utils/convert.go @@ -0,0 +1,322 @@ +package utils + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "math" + "strconv" + "strings" +) + +func ToString(raw interface{}, e error) (res string) { + if e != nil { + return "" + } + return AnyToString(raw) +} + +func ToInt64(raw interface{}, e error) int64 { + if e != nil { + return 0 + } + return AnyToInt64(raw) +} + +func AnyToBool(raw interface{}) bool { + switch i := raw.(type) { + case float32, float64, int, int64, uint, uint8, uint16, uint32, uint64, int8, int16, int32: + return i != 0 + case []byte: + return i != nil + case string: + if i == "false" { + return false + } + return i != "" + case error: + return false + case nil: + return true + } + val := fmt.Sprint(raw) + val = strings.TrimLeft(val, "&") + if strings.TrimLeft(val, "{}") == "" { + return false + } + if strings.TrimLeft(val, "[]") == "" { + return false + } + // ptr type + b, err := json.Marshal(raw) + if err != nil { + return false + } + if strings.TrimLeft(string(b), "\"\"") == "" { + return false + } + if strings.TrimLeft(string(b), "{}") == "" { + return false + } + return true +} + +func AnyToInt64(raw interface{}) int64 { + switch i := raw.(type) { + case string: + res, _ := strconv.ParseInt(i, 10, 64) + return res + case []byte: + return BytesToInt64(i) + case int: + return int64(i) + case int64: + return i + case uint: + return int64(i) + case uint8: + return int64(i) + case uint16: + return int64(i) + case uint32: + return int64(i) + case uint64: + return int64(i) + case int8: + return int64(i) + case int16: + return int64(i) + case int32: + return int64(i) + case float32: + return int64(i) + case float64: + return int64(i) + case error: + return 0 + case bool: + if i { + return 1 + } + return 0 + } + return 0 +} + +func AnyToString(raw interface{}) string { + switch i := raw.(type) { + case []byte: + return string(i) + case int: + return strconv.FormatInt(int64(i), 10) + case int64: + return strconv.FormatInt(i, 10) + case float32: + return Float64ToStr(float64(i)) + case float64: + return Float64ToStr(i) + case uint: + return strconv.FormatInt(int64(i), 10) + case uint8: + return strconv.FormatInt(int64(i), 10) + case uint16: + return strconv.FormatInt(int64(i), 10) + case uint32: + return strconv.FormatInt(int64(i), 10) + case uint64: + return strconv.FormatInt(int64(i), 10) + case int8: + return strconv.FormatInt(int64(i), 10) + case int16: + return strconv.FormatInt(int64(i), 10) + case int32: + return strconv.FormatInt(int64(i), 10) + case string: + return i + case error: + return i.Error() + case bool: + return strconv.FormatBool(i) + } + return fmt.Sprintf("%#v", raw) +} + +func AnyToFloat64(raw interface{}) float64 { + switch i := raw.(type) { + case []byte: + f, _ := strconv.ParseFloat(string(i), 64) + return f + case int: + return float64(i) + case int64: + return float64(i) + case float32: + return float64(i) + case float64: + return i + case uint: + return float64(i) + case uint8: + return float64(i) + case uint16: + return float64(i) + case uint32: + return float64(i) + case uint64: + return float64(i) + case int8: + return float64(i) + case int16: + return float64(i) + case int32: + return float64(i) + case string: + f, _ := strconv.ParseFloat(i, 64) + return f + case bool: + if i { + return 1 + } + } + return 0 +} + +func ToByte(raw interface{}, e error) []byte { + if e != nil { + return []byte{} + } + switch i := raw.(type) { + case string: + return []byte(i) + case int: + return Int64ToBytes(int64(i)) + case int64: + return Int64ToBytes(i) + case float32: + return Float32ToByte(i) + case float64: + return Float64ToByte(i) + case uint: + return Int64ToBytes(int64(i)) + case uint8: + return Int64ToBytes(int64(i)) + case uint16: + return Int64ToBytes(int64(i)) + case uint32: + return Int64ToBytes(int64(i)) + case uint64: + return Int64ToBytes(int64(i)) + case int8: + return Int64ToBytes(int64(i)) + case int16: + return Int64ToBytes(int64(i)) + case int32: + return Int64ToBytes(int64(i)) + case []byte: + return i + case error: + return []byte(i.Error()) + case bool: + if i { + return []byte("true") + } + return []byte("false") + } + return []byte(fmt.Sprintf("%#v", raw)) +} + +func Int64ToBytes(i int64) []byte { + var buf = make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(i)) + return buf +} + +func BytesToInt64(buf []byte) int64 { + return int64(binary.BigEndian.Uint64(buf)) +} + +func StrToInt(s string) int { + res, _ := strconv.Atoi(s) + return res +} + +func StrToInt64(s string) int64 { + res, _ := strconv.ParseInt(s, 10, 64) + return res +} + +func Float32ToByte(float float32) []byte { + bits := math.Float32bits(float) + bytes := make([]byte, 4) + binary.LittleEndian.PutUint32(bytes, bits) + + return bytes +} + +func ByteToFloat32(bytes []byte) float32 { + bits := binary.LittleEndian.Uint32(bytes) + return math.Float32frombits(bits) +} + +func Float64ToByte(float float64) []byte { + bits := math.Float64bits(float) + bytes := make([]byte, 8) + binary.LittleEndian.PutUint64(bytes, bits) + return bytes +} + +func ByteToFloat64(bytes []byte) float64 { + bits := binary.LittleEndian.Uint64(bytes) + return math.Float64frombits(bits) +} + +func Float64ToStr(f float64) string { + return strconv.FormatFloat(f, 'f', 2, 64) +} +func Float64ToStrPrec1(f float64) string { + return strconv.FormatFloat(f, 'f', 1, 64) +} + +func Float32ToStr(f float32) string { + return Float64ToStr(float64(f)) +} + +func StrToFloat64(s string) float64 { + res, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0 + } + return res +} + +func StrToFloat32(s string) float32 { + res, err := strconv.ParseFloat(s, 32) + if err != nil { + return 0 + } + return float32(res) +} + +func StrToBool(s string) bool { + b, _ := strconv.ParseBool(s) + return b +} + +func BoolToStr(b bool) string { + if b { + return "true" + } + return "false" +} + +func FloatToInt64(f float64) int64 { + return int64(f) +} + +func IntToStr(i int) string { + return strconv.Itoa(i) +} + +func Int64ToStr(i int64) string { + return strconv.FormatInt(i, 10) +} diff --git a/app/utils/crypto.go b/app/utils/crypto.go new file mode 100644 index 0000000..56289c5 --- /dev/null +++ b/app/utils/crypto.go @@ -0,0 +1,19 @@ +package utils + +import ( + "crypto/md5" + "encoding/base64" + "fmt" +) + +func GetMd5(raw []byte) string { + h := md5.New() + h.Write(raw) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func GetBase64Md5(raw []byte) string { + h := md5.New() + h.Write(raw) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/app/utils/curl.go b/app/utils/curl.go new file mode 100644 index 0000000..0a45607 --- /dev/null +++ b/app/utils/curl.go @@ -0,0 +1,209 @@ +package utils + +import ( + "bytes" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "sort" + "strings" + "time" +) + +var CurlDebug bool + +func CurlGet(router string, header map[string]string) ([]byte, error) { + return curl(http.MethodGet, router, nil, header) +} +func CurlGetJson(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl_new(http.MethodGet, router, body, header) +} + +// 只支持form 与json 提交, 请留意body的类型, 支持string, []byte, map[string]string +func CurlPost(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl(http.MethodPost, router, body, header) +} + +func CurlPut(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl(http.MethodPut, router, body, header) +} + +// 只支持form 与json 提交, 请留意body的类型, 支持string, []byte, map[string]string +func CurlPatch(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl(http.MethodPatch, router, body, header) +} + +// CurlDelete is curl delete +func CurlDelete(router string, body interface{}, header map[string]string) ([]byte, error) { + return curl(http.MethodDelete, router, body, header) +} + +func curl(method, router string, body interface{}, header map[string]string) ([]byte, error) { + var reqBody io.Reader + contentType := "application/json" + switch v := body.(type) { + case string: + reqBody = strings.NewReader(v) + case []byte: + reqBody = bytes.NewReader(v) + case map[string]string: + val := url.Values{} + for k, v := range v { + val.Set(k, v) + } + reqBody = strings.NewReader(val.Encode()) + contentType = "application/x-www-form-urlencoded" + case map[string]interface{}: + val := url.Values{} + for k, v := range v { + val.Set(k, v.(string)) + } + reqBody = strings.NewReader(val.Encode()) + contentType = "application/x-www-form-urlencoded" + } + if header == nil { + header = map[string]string{"Content-Type": contentType} + } + if _, ok := header["Content-Type"]; !ok { + header["Content-Type"] = contentType + } + resp, er := CurlReq(method, router, reqBody, header) + if er != nil { + return nil, er + } + res, err := ioutil.ReadAll(resp.Body) + if CurlDebug { + blob := SerializeStr(body) + if contentType != "application/json" { + blob = HttpBuild(body) + } + fmt.Printf("\n\n=====================\n[url]: %s\n[time]: %s\n[method]: %s\n[content-type]: %v\n[req_header]: %s\n[req_body]: %#v\n[resp_err]: %v\n[resp_header]: %v\n[resp_body]: %v\n=====================\n\n", + router, + time.Now().Format("2006-01-02 15:04:05.000"), + method, + contentType, + HttpBuildQuery(header), + blob, + err, + SerializeStr(resp.Header), + string(res), + ) + } + resp.Body.Close() + return res, err +} + +func curl_new(method, router string, body interface{}, header map[string]string) ([]byte, error) { + var reqBody io.Reader + contentType := "application/json" + + if header == nil { + header = map[string]string{"Content-Type": contentType} + } + if _, ok := header["Content-Type"]; !ok { + header["Content-Type"] = contentType + } + resp, er := CurlReq(method, router, reqBody, header) + if er != nil { + return nil, er + } + res, err := ioutil.ReadAll(resp.Body) + if CurlDebug { + blob := SerializeStr(body) + if contentType != "application/json" { + blob = HttpBuild(body) + } + fmt.Printf("\n\n=====================\n[url]: %s\n[time]: %s\n[method]: %s\n[content-type]: %v\n[req_header]: %s\n[req_body]: %#v\n[resp_err]: %v\n[resp_header]: %v\n[resp_body]: %v\n=====================\n\n", + router, + time.Now().Format("2006-01-02 15:04:05.000"), + method, + contentType, + HttpBuildQuery(header), + blob, + err, + SerializeStr(resp.Header), + string(res), + ) + } + resp.Body.Close() + return res, err +} + +func CurlReq(method, router string, reqBody io.Reader, header map[string]string) (*http.Response, error) { + req, _ := http.NewRequest(method, router, reqBody) + if header != nil { + for k, v := range header { + req.Header.Set(k, v) + } + } + // 绕过github等可能因为特征码返回503问题 + // https://www.imwzk.com/posts/2021-03-14-why-i-always-get-503-with-golang/ + defaultCipherSuites := []uint16{0xc02f, 0xc030, 0xc02b, 0xc02c, 0xcca8, 0xcca9, 0xc013, 0xc009, + 0xc014, 0xc00a, 0x009c, 0x009d, 0x002f, 0x0035, 0xc012, 0x000a} + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + CipherSuites: append(defaultCipherSuites[8:], defaultCipherSuites[:8]...), + }, + }, + // 获取301重定向 + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + return client.Do(req) +} + +// 组建get请求参数,sortAsc true为小到大,false为大到小,nil不排序 a=123&b=321 +func HttpBuildQuery(args map[string]string, sortAsc ...bool) string { + str := "" + if len(args) == 0 { + return str + } + if len(sortAsc) > 0 { + keys := make([]string, 0, len(args)) + for k := range args { + keys = append(keys, k) + } + if sortAsc[0] { + sort.Strings(keys) + } else { + sort.Sort(sort.Reverse(sort.StringSlice(keys))) + } + for _, k := range keys { + str += "&" + k + "=" + args[k] + } + } else { + for k, v := range args { + str += "&" + k + "=" + v + } + } + return str[1:] +} + +func HttpBuild(body interface{}, sortAsc ...bool) string { + params := map[string]string{} + if args, ok := body.(map[string]interface{}); ok { + for k, v := range args { + params[k] = AnyToString(v) + } + return HttpBuildQuery(params, sortAsc...) + } + if args, ok := body.(map[string]string); ok { + for k, v := range args { + params[k] = AnyToString(v) + } + return HttpBuildQuery(params, sortAsc...) + } + if args, ok := body.(map[string]int); ok { + for k, v := range args { + params[k] = AnyToString(v) + } + return HttpBuildQuery(params, sortAsc...) + } + return AnyToString(body) +} diff --git a/app/utils/debug.go b/app/utils/debug.go new file mode 100644 index 0000000..bb2e9d3 --- /dev/null +++ b/app/utils/debug.go @@ -0,0 +1,25 @@ +package utils + +import ( + "fmt" + "os" + "strconv" + "time" +) + +func Debug(args ...interface{}) { + s := "" + l := len(args) + if l < 1 { + fmt.Println("please input some data") + os.Exit(0) + } + i := 1 + for _, v := range args { + s += fmt.Sprintf("【"+strconv.Itoa(i)+"】: %#v\n", v) + i++ + } + s = "******************** 【DEBUG - " + time.Now().Format("2006-01-02 15:04:05") + "】 ********************\n" + s + "******************** 【DEBUG - END】 ********************\n" + fmt.Println(s) + os.Exit(0) +} diff --git a/app/utils/duplicate.go b/app/utils/duplicate.go new file mode 100644 index 0000000..17cea88 --- /dev/null +++ b/app/utils/duplicate.go @@ -0,0 +1,37 @@ +package utils + +func RemoveDuplicateString(elms []string) []string { + res := make([]string, 0, len(elms)) + temp := map[string]struct{}{} + for _, item := range elms { + if _, ok := temp[item]; !ok { + temp[item] = struct{}{} + res = append(res, item) + } + } + return res +} + +func RemoveDuplicateInt(elms []int) []int { + res := make([]int, 0, len(elms)) + temp := map[int]struct{}{} + for _, item := range elms { + if _, ok := temp[item]; !ok { + temp[item] = struct{}{} + res = append(res, item) + } + } + return res +} + +func RemoveDuplicateInt64(elms []int64) []int64 { + res := make([]int64, 0, len(elms)) + temp := map[int64]struct{}{} + for _, item := range elms { + if _, ok := temp[item]; !ok { + temp[item] = struct{}{} + res = append(res, item) + } + } + return res +} diff --git a/app/utils/file.go b/app/utils/file.go new file mode 100644 index 0000000..93ed08f --- /dev/null +++ b/app/utils/file.go @@ -0,0 +1,22 @@ +package utils + +import ( + "os" + "path" + "strings" + "time" +) + +// 获取文件后缀 +func FileExt(fname string) string { + return strings.ToLower(strings.TrimLeft(path.Ext(fname), ".")) +} + +func FilePutContents(fileName string, content string) { + fd, _ := os.OpenFile("./tmp/"+fileName+".log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) + fd_time := time.Now().Format("2006-01-02 15:04:05") + fd_content := strings.Join([]string{"[", fd_time, "] ", content, "\n"}, "") + buf := []byte(fd_content) + fd.Write(buf) + fd.Close() +} diff --git a/app/utils/file_and_dir.go b/app/utils/file_and_dir.go new file mode 100644 index 0000000..93141f9 --- /dev/null +++ b/app/utils/file_and_dir.go @@ -0,0 +1,29 @@ +package utils + +import "os" + +// 判断所给路径文件、文件夹是否存在 +func Exists(path string) bool { + _, err := os.Stat(path) //os.Stat获取文件信息 + if err != nil { + if os.IsExist(err) { + return true + } + return false + } + return true +} + +// 判断所给路径是否为文件夹 +func IsDir(path string) bool { + s, err := os.Stat(path) + if err != nil { + return false + } + return s.IsDir() +} + +// 判断所给路径是否为文件 +func IsFile(path string) bool { + return !IsDir(path) +} diff --git a/app/utils/format.go b/app/utils/format.go new file mode 100644 index 0000000..997fe80 --- /dev/null +++ b/app/utils/format.go @@ -0,0 +1,59 @@ +package utils + +import ( + "math" +) + +func CouponFormat(data string) string { + switch data { + case "0.00", "0", "": + return "" + default: + return Int64ToStr(FloatToInt64(StrToFloat64(data))) + } +} +func CommissionFormat(data string) string { + if StrToFloat64(data) > 0 { + return data + } + + return "" +} + +func HideString(src string, hLen int) string { + str := []rune(src) + if hLen == 0 { + hLen = 4 + } + hideStr := "" + for i := 0; i < hLen; i++ { + hideStr += "*" + } + hideLen := len(str) / 2 + showLen := len(str) - hideLen + if hideLen == 0 || showLen == 0 { + return hideStr + } + subLen := showLen / 2 + if subLen == 0 { + return string(str[:showLen]) + hideStr + } + s := string(str[:subLen]) + s += hideStr + s += string(str[len(str)-subLen:]) + return s +} + +//SaleCountFormat is 格式化销量 +func SaleCountFormat(s string) string { + return s + "已售" +} + +// 小数格式化 +func FloatFormat(f float64, i int) float64 { + if i > 14 { + return f + } + p := math.Pow10(i) + return float64(int64((f+0.000000000000009)*p)) / p +} diff --git a/app/utils/ip.go b/app/utils/ip.go new file mode 100644 index 0000000..6ed8286 --- /dev/null +++ b/app/utils/ip.go @@ -0,0 +1,146 @@ +package utils + +import ( + "errors" + "math" + "net" + "net/http" + "strings" +) + +func GetIP(r *http.Request) string { + ip := ClientPublicIP(r) + if ip == "" { + ip = ClientIP(r) + } + if ip == "" { + ip = "0000" + } + return ip +} + +// HasLocalIPddr 检测 IP 地址字符串是否是内网地址 +// Deprecated: 此为一个错误名称错误拼写的函数,计划在将来移除,请使用 HasLocalIPAddr 函数 +func HasLocalIPddr(ip string) bool { + return HasLocalIPAddr(ip) +} + +// HasLocalIPAddr 检测 IP 地址字符串是否是内网地址 +func HasLocalIPAddr(ip string) bool { + return HasLocalIP(net.ParseIP(ip)) +} + +// HasLocalIP 检测 IP 地址是否是内网地址 +// 通过直接对比ip段范围效率更高,详见:https://github.com/thinkeridea/go-extend/issues/2 +func HasLocalIP(ip net.IP) bool { + if ip.IsLoopback() { + return true + } + + ip4 := ip.To4() + if ip4 == nil { + return false + } + + return ip4[0] == 10 || // 10.0.0.0/8 + (ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) || // 172.16.0.0/12 + (ip4[0] == 169 && ip4[1] == 254) || // 169.254.0.0/16 + (ip4[0] == 192 && ip4[1] == 168) // 192.168.0.0/16 +} + +// ClientIP 尽最大努力实现获取客户端 IP 的算法。 +// 解析 X-Real-IP 和 X-Forwarded-For 以便于反向代理(nginx 或 haproxy)可以正常工作。 +func ClientIP(r *http.Request) string { + ip := strings.TrimSpace(strings.Split(r.Header.Get("X-Forwarded-For"), ",")[0]) + if ip != "" { + return ip + } + + ip = strings.TrimSpace(r.Header.Get("X-Real-Ip")) + if ip != "" { + return ip + } + + if ip, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)); err == nil { + return ip + } + + return "" +} + +// ClientPublicIP 尽最大努力实现获取客户端公网 IP 的算法。 +// 解析 X-Real-IP 和 X-Forwarded-For 以便于反向代理(nginx 或 haproxy)可以正常工作。 +func ClientPublicIP(r *http.Request) string { + var ip string + for _, ip = range strings.Split(r.Header.Get("X-Forwarded-For"), ",") { + if ip = strings.TrimSpace(ip); ip != "" && !HasLocalIPAddr(ip) { + return ip + } + } + + if ip = strings.TrimSpace(r.Header.Get("X-Real-Ip")); ip != "" && !HasLocalIPAddr(ip) { + return ip + } + + if ip = RemoteIP(r); !HasLocalIPAddr(ip) { + return ip + } + + return "" +} + +// RemoteIP 通过 RemoteAddr 获取 IP 地址, 只是一个快速解析方法。 +func RemoteIP(r *http.Request) string { + ip, _, _ := net.SplitHostPort(r.RemoteAddr) + return ip +} + +// IPString2Long 把ip字符串转为数值 +func IPString2Long(ip string) (uint, error) { + b := net.ParseIP(ip).To4() + if b == nil { + return 0, errors.New("invalid ipv4 format") + } + + return uint(b[3]) | uint(b[2])<<8 | uint(b[1])<<16 | uint(b[0])<<24, nil +} + +// Long2IPString 把数值转为ip字符串 +func Long2IPString(i uint) (string, error) { + if i > math.MaxUint32 { + return "", errors.New("beyond the scope of ipv4") + } + + ip := make(net.IP, net.IPv4len) + ip[0] = byte(i >> 24) + ip[1] = byte(i >> 16) + ip[2] = byte(i >> 8) + ip[3] = byte(i) + + return ip.String(), nil +} + +// IP2Long 把net.IP转为数值 +func IP2Long(ip net.IP) (uint, error) { + b := ip.To4() + if b == nil { + return 0, errors.New("invalid ipv4 format") + } + + return uint(b[3]) | uint(b[2])<<8 | uint(b[1])<<16 | uint(b[0])<<24, nil +} + +// Long2IP 把数值转为net.IP +func Long2IP(i uint) (net.IP, error) { + if i > math.MaxUint32 { + return nil, errors.New("beyond the scope of ipv4") + } + + ip := make(net.IP, net.IPv4len) + ip[0] = byte(i >> 24) + ip[1] = byte(i >> 16) + ip[2] = byte(i >> 8) + ip[3] = byte(i) + + return ip, nil +} diff --git a/app/utils/json.go b/app/utils/json.go new file mode 100644 index 0000000..998bcec --- /dev/null +++ b/app/utils/json.go @@ -0,0 +1,17 @@ +package utils + +import ( + "bytes" + "encoding/json" +) + +func JsonMarshal(interface{}) { + +} + +// 不科学计数法 +func JsonDecode(data []byte, v interface{}) error { + d := json.NewDecoder(bytes.NewReader(data)) + d.UseNumber() + return d.Decode(v) +} diff --git a/app/utils/logx/log.go b/app/utils/logx/log.go new file mode 100644 index 0000000..ca11223 --- /dev/null +++ b/app/utils/logx/log.go @@ -0,0 +1,245 @@ +package logx + +import ( + "os" + "strings" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type LogConfig struct { + AppName string `yaml:"app_name" json:"app_name" toml:"app_name"` + Level string `yaml:"level" json:"level" toml:"level"` + StacktraceLevel string `yaml:"stacktrace_level" json:"stacktrace_level" toml:"stacktrace_level"` + IsStdOut bool `yaml:"is_stdout" json:"is_stdout" toml:"is_stdout"` + TimeFormat string `yaml:"time_format" json:"time_format" toml:"time_format"` // second, milli, nano, standard, iso, + Encoding string `yaml:"encoding" json:"encoding" toml:"encoding"` // console, json + Skip int `yaml:"skip" json:"skip" toml:"skip"` + + IsFileOut bool `yaml:"is_file_out" json:"is_file_out" toml:"is_file_out"` + FileDir string `yaml:"file_dir" json:"file_dir" toml:"file_dir"` + FileName string `yaml:"file_name" json:"file_name" toml:"file_name"` + FileMaxSize int `yaml:"file_max_size" json:"file_max_size" toml:"file_max_size"` + FileMaxAge int `yaml:"file_max_age" json:"file_max_age" toml:"file_max_age"` +} + +var ( + l *LogX = defaultLogger() + conf *LogConfig +) + +// default logger setting +func defaultLogger() *LogX { + conf = &LogConfig{ + Level: "debug", + StacktraceLevel: "error", + IsStdOut: true, + TimeFormat: "standard", + Encoding: "console", + Skip: 2, + } + writers := []zapcore.WriteSyncer{os.Stdout} + lg, lv := newZapLogger(setLogLevel(conf.Level), setLogLevel(conf.StacktraceLevel), conf.Encoding, conf.TimeFormat, conf.Skip, zapcore.NewMultiWriteSyncer(writers...)) + zap.RedirectStdLog(lg) + return &LogX{logger: lg, atomLevel: lv} +} + +// initial standard log, if you don't init, it will use default logger setting +func InitDefaultLogger(cfg *LogConfig) { + var writers []zapcore.WriteSyncer + if cfg.IsStdOut || (!cfg.IsStdOut && !cfg.IsFileOut) { + writers = append(writers, os.Stdout) + } + if cfg.IsFileOut { + writers = append(writers, NewRollingFile(cfg.FileDir, cfg.FileName, cfg.FileMaxSize, cfg.FileMaxAge)) + } + + lg, lv := newZapLogger(setLogLevel(cfg.Level), setLogLevel(cfg.StacktraceLevel), cfg.Encoding, cfg.TimeFormat, cfg.Skip, zapcore.NewMultiWriteSyncer(writers...)) + zap.RedirectStdLog(lg) + if cfg.AppName != "" { + lg = lg.With(zap.String("app", cfg.AppName)) // 加上应用名称 + } + l = &LogX{logger: lg, atomLevel: lv} +} + +// create a new logger +func NewLogger(cfg *LogConfig) *LogX { + var writers []zapcore.WriteSyncer + if cfg.IsStdOut || (!cfg.IsStdOut && !cfg.IsFileOut) { + writers = append(writers, os.Stdout) + } + if cfg.IsFileOut { + writers = append(writers, NewRollingFile(cfg.FileDir, cfg.FileName, cfg.FileMaxSize, cfg.FileMaxAge)) + } + + lg, lv := newZapLogger(setLogLevel(cfg.Level), setLogLevel(cfg.StacktraceLevel), cfg.Encoding, cfg.TimeFormat, cfg.Skip, zapcore.NewMultiWriteSyncer(writers...)) + zap.RedirectStdLog(lg) + if cfg.AppName != "" { + lg = lg.With(zap.String("app", cfg.AppName)) // 加上应用名称 + } + return &LogX{logger: lg, atomLevel: lv} +} + +// create a new zaplog logger +func newZapLogger(level, stacktrace zapcore.Level, encoding, timeType string, skip int, output zapcore.WriteSyncer) (*zap.Logger, *zap.AtomicLevel) { + encCfg := zapcore.EncoderConfig{ + TimeKey: "T", + LevelKey: "L", + NameKey: "N", + CallerKey: "C", + MessageKey: "M", + StacktraceKey: "S", + LineEnding: zapcore.DefaultLineEnding, + EncodeCaller: zapcore.ShortCallerEncoder, + EncodeDuration: zapcore.NanosDurationEncoder, + EncodeLevel: zapcore.LowercaseLevelEncoder, + } + setTimeFormat(timeType, &encCfg) // set time type + atmLvl := zap.NewAtomicLevel() // set level + atmLvl.SetLevel(level) + encoder := zapcore.NewJSONEncoder(encCfg) // 确定encoder格式 + if encoding == "console" { + encoder = zapcore.NewConsoleEncoder(encCfg) + } + return zap.New(zapcore.NewCore(encoder, output, atmLvl), zap.AddCaller(), zap.AddStacktrace(stacktrace), zap.AddCallerSkip(skip)), &atmLvl +} + +// set log level +func setLogLevel(lvl string) zapcore.Level { + switch strings.ToLower(lvl) { + case "panic": + return zapcore.PanicLevel + case "fatal": + return zapcore.FatalLevel + case "error": + return zapcore.ErrorLevel + case "warn", "warning": + return zapcore.WarnLevel + case "info": + return zapcore.InfoLevel + default: + return zapcore.DebugLevel + } +} + +// set time format +func setTimeFormat(timeType string, z *zapcore.EncoderConfig) { + switch strings.ToLower(timeType) { + case "iso": // iso8601 standard + z.EncodeTime = zapcore.ISO8601TimeEncoder + case "sec": // only for unix second, without millisecond + z.EncodeTime = func(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendInt64(t.Unix()) + } + case "second": // unix second, with millisecond + z.EncodeTime = zapcore.EpochTimeEncoder + case "milli", "millisecond": // millisecond + z.EncodeTime = zapcore.EpochMillisTimeEncoder + case "nano", "nanosecond": // nanosecond + z.EncodeTime = zapcore.EpochNanosTimeEncoder + default: // standard format + z.EncodeTime = func(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.Format("2006-01-02 15:04:05.000")) + } + } +} + +func GetLevel() string { + switch l.atomLevel.Level() { + case zapcore.PanicLevel: + return "panic" + case zapcore.FatalLevel: + return "fatal" + case zapcore.ErrorLevel: + return "error" + case zapcore.WarnLevel: + return "warn" + case zapcore.InfoLevel: + return "info" + default: + return "debug" + } +} + +func SetLevel(lvl string) { + l.atomLevel.SetLevel(setLogLevel(lvl)) +} + +// temporary add call skip +func AddCallerSkip(skip int) *LogX { + l.logger.WithOptions(zap.AddCallerSkip(skip)) + return l +} + +// permanent add call skip +func AddDepth(skip int) *LogX { + l.logger = l.logger.WithOptions(zap.AddCallerSkip(skip)) + return l +} + +// permanent add options +func AddOptions(opts ...zap.Option) *LogX { + l.logger = l.logger.WithOptions(opts...) + return l +} + +func AddField(k string, v interface{}) { + l.logger.With(zap.Any(k, v)) +} + +func AddFields(fields map[string]interface{}) *LogX { + for k, v := range fields { + l.logger.With(zap.Any(k, v)) + } + return l +} + +// Normal log +func Debug(e interface{}, args ...interface{}) error { + return l.Debug(e, args...) +} +func Info(e interface{}, args ...interface{}) error { + return l.Info(e, args...) +} +func Warn(e interface{}, args ...interface{}) error { + return l.Warn(e, args...) +} +func Error(e interface{}, args ...interface{}) error { + return l.Error(e, args...) +} +func Panic(e interface{}, args ...interface{}) error { + return l.Panic(e, args...) +} +func Fatal(e interface{}, args ...interface{}) error { + return l.Fatal(e, args...) +} + +// Format logs +func Debugf(format string, args ...interface{}) error { + return l.Debugf(format, args...) +} +func Infof(format string, args ...interface{}) error { + return l.Infof(format, args...) +} +func Warnf(format string, args ...interface{}) error { + return l.Warnf(format, args...) +} +func Errorf(format string, args ...interface{}) error { + return l.Errorf(format, args...) +} +func Panicf(format string, args ...interface{}) error { + return l.Panicf(format, args...) +} +func Fatalf(format string, args ...interface{}) error { + return l.Fatalf(format, args...) +} + +func formatFieldMap(m FieldMap) []Field { + var res []Field + for k, v := range m { + res = append(res, zap.Any(k, v)) + } + return res +} diff --git a/app/utils/logx/output.go b/app/utils/logx/output.go new file mode 100644 index 0000000..ef33f0b --- /dev/null +++ b/app/utils/logx/output.go @@ -0,0 +1,105 @@ +package logx + +import ( + "bytes" + "io" + "os" + "path/filepath" + "time" + + "gopkg.in/natefinch/lumberjack.v2" +) + +// output interface +type WriteSyncer interface { + io.Writer + Sync() error +} + +// split writer +func NewRollingFile(dir, filename string, maxSize, MaxAge int) WriteSyncer { + s, err := os.Stat(dir) + if err != nil || !s.IsDir() { + os.RemoveAll(dir) + if err := os.MkdirAll(dir, 0766); err != nil { + panic(err) + } + } + return newLumberjackWriteSyncer(&lumberjack.Logger{ + Filename: filepath.Join(dir, filename), + MaxSize: maxSize, // megabytes, MB + MaxAge: MaxAge, // days + LocalTime: true, + Compress: false, + }) +} + +type lumberjackWriteSyncer struct { + *lumberjack.Logger + buf *bytes.Buffer + logChan chan []byte + closeChan chan interface{} + maxSize int +} + +func newLumberjackWriteSyncer(l *lumberjack.Logger) *lumberjackWriteSyncer { + ws := &lumberjackWriteSyncer{ + Logger: l, + buf: bytes.NewBuffer([]byte{}), + logChan: make(chan []byte, 5000), + closeChan: make(chan interface{}), + maxSize: 1024, + } + go ws.run() + return ws +} + +func (l *lumberjackWriteSyncer) run() { + ticker := time.NewTicker(1 * time.Second) + + for { + select { + case <-ticker.C: + if l.buf.Len() > 0 { + l.sync() + } + case bs := <-l.logChan: + _, err := l.buf.Write(bs) + if err != nil { + continue + } + if l.buf.Len() > l.maxSize { + l.sync() + } + case <-l.closeChan: + l.sync() + return + } + } +} + +func (l *lumberjackWriteSyncer) Stop() { + close(l.closeChan) +} + +func (l *lumberjackWriteSyncer) Write(bs []byte) (int, error) { + b := make([]byte, len(bs)) + for i, c := range bs { + b[i] = c + } + l.logChan <- b + return 0, nil +} + +func (l *lumberjackWriteSyncer) Sync() error { + return nil +} + +func (l *lumberjackWriteSyncer) sync() error { + defer l.buf.Reset() + _, err := l.Logger.Write(l.buf.Bytes()) + if err != nil { + return err + } + return nil +} diff --git a/app/utils/logx/sugar.go b/app/utils/logx/sugar.go new file mode 100644 index 0000000..ab380fc --- /dev/null +++ b/app/utils/logx/sugar.go @@ -0,0 +1,192 @@ +package logx + +import ( + "errors" + "fmt" + "strconv" + + "go.uber.org/zap" +) + +type LogX struct { + logger *zap.Logger + atomLevel *zap.AtomicLevel +} + +type Field = zap.Field +type FieldMap map[string]interface{} + +// 判断其他类型--start +func getFields(msg string, format bool, args ...interface{}) (string, []Field) { + var str []interface{} + var fields []zap.Field + if len(args) > 0 { + for _, v := range args { + if f, ok := v.(Field); ok { + fields = append(fields, f) + } else if f, ok := v.(FieldMap); ok { + fields = append(fields, formatFieldMap(f)...) + } else { + str = append(str, AnyToString(v)) + } + } + if format { + return fmt.Sprintf(msg, str...), fields + } + str = append([]interface{}{msg}, str...) + return fmt.Sprintln(str...), fields + } + return msg, []Field{} +} + +func (l *LogX) Debug(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Debug(msg, field...) + } + return e +} +func (l *LogX) Info(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Info(msg, field...) + } + return e +} +func (l *LogX) Warn(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Warn(msg, field...) + } + return e +} +func (l *LogX) Error(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Error(msg, field...) + } + return e +} +func (l *LogX) DPanic(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.DPanic(msg, field...) + } + return e +} +func (l *LogX) Panic(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Panic(msg, field...) + } + return e +} +func (l *LogX) Fatal(s interface{}, args ...interface{}) error { + es, e := checkErr(s) + if es != "" { + msg, field := getFields(es, false, args...) + l.logger.Fatal(msg, field...) + } + return e +} + +func checkErr(s interface{}) (string, error) { + switch e := s.(type) { + case error: + return e.Error(), e + case string: + return e, errors.New(e) + case []byte: + return string(e), nil + default: + return "", nil + } +} + +func (l *LogX) LogError(err error) error { + return l.Error(err.Error()) +} + +func (l *LogX) Debugf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Debug(s, f...) + return errors.New(s) +} + +func (l *LogX) Infof(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Info(s, f...) + return errors.New(s) +} + +func (l *LogX) Warnf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Warn(s, f...) + return errors.New(s) +} + +func (l *LogX) Errorf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Error(s, f...) + return errors.New(s) +} + +func (l *LogX) DPanicf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.DPanic(s, f...) + return errors.New(s) +} + +func (l *LogX) Panicf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Panic(s, f...) + return errors.New(s) +} + +func (l *LogX) Fatalf(msg string, args ...interface{}) error { + s, f := getFields(msg, true, args...) + l.logger.Fatal(s, f...) + return errors.New(s) +} + +func AnyToString(raw interface{}) string { + switch i := raw.(type) { + case []byte: + return string(i) + case int: + return strconv.FormatInt(int64(i), 10) + case int64: + return strconv.FormatInt(i, 10) + case float32: + return strconv.FormatFloat(float64(i), 'f', 2, 64) + case float64: + return strconv.FormatFloat(i, 'f', 2, 64) + case uint: + return strconv.FormatInt(int64(i), 10) + case uint8: + return strconv.FormatInt(int64(i), 10) + case uint16: + return strconv.FormatInt(int64(i), 10) + case uint32: + return strconv.FormatInt(int64(i), 10) + case uint64: + return strconv.FormatInt(int64(i), 10) + case int8: + return strconv.FormatInt(int64(i), 10) + case int16: + return strconv.FormatInt(int64(i), 10) + case int32: + return strconv.FormatInt(int64(i), 10) + case string: + return i + case error: + return i.Error() + } + return fmt.Sprintf("%#v", raw) +} diff --git a/app/utils/map.go b/app/utils/map.go new file mode 100644 index 0000000..d9f3b7a --- /dev/null +++ b/app/utils/map.go @@ -0,0 +1,9 @@ +package utils + +// GetOneKeyOfMapString 取出Map的一个key +func GetOneKeyOfMapString(collection map[string]string) string { + for k := range collection { + return k + } + return "" +} diff --git a/app/utils/map_and_struct.go b/app/utils/map_and_struct.go new file mode 100644 index 0000000..34904ce --- /dev/null +++ b/app/utils/map_and_struct.go @@ -0,0 +1,341 @@ +package utils + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" +) + +func Map2Struct(vals map[string]interface{}, dst interface{}) (err error) { + return Map2StructByTag(vals, dst, "json") +} + +func Map2StructByTag(vals map[string]interface{}, dst interface{}, structTag string) (err error) { + defer func() { + e := recover() + if e != nil { + if v, ok := e.(error); ok { + err = fmt.Errorf("Panic: %v", v.Error()) + } else { + err = fmt.Errorf("Panic: %v", e) + } + } + }() + + pt := reflect.TypeOf(dst) + pv := reflect.ValueOf(dst) + + if pv.Kind() != reflect.Ptr || pv.Elem().Kind() != reflect.Struct { + return fmt.Errorf("not a pointer of struct") + } + + var f reflect.StructField + var ft reflect.Type + var fv reflect.Value + + for i := 0; i < pt.Elem().NumField(); i++ { + f = pt.Elem().Field(i) + fv = pv.Elem().Field(i) + ft = f.Type + + if f.Anonymous || !fv.CanSet() { + continue + } + + tag := f.Tag.Get(structTag) + + name, option := parseTag(tag) + + if name == "-" { + continue + } + + if name == "" { + name = strings.ToLower(f.Name) + } + val, ok := vals[name] + + if !ok { + if option == "required" { + return fmt.Errorf("'%v' not found", name) + } + if len(option) != 0 { + val = option // default value + } else { + //fv.Set(reflect.Zero(ft)) // TODO set zero value or just ignore it? + continue + } + } + + // convert or set value to field + vv := reflect.ValueOf(val) + vt := reflect.TypeOf(val) + + if vt.Kind() != reflect.String { + // try to assign and convert + if vt.AssignableTo(ft) { + fv.Set(vv) + continue + } + + if vt.ConvertibleTo(ft) { + fv.Set(vv.Convert(ft)) + continue + } + + return fmt.Errorf("value type not match: field=%v(%v) value=%v(%v)", f.Name, ft.Kind(), val, vt.Kind()) + } + s := strings.TrimSpace(vv.String()) + if len(s) == 0 && option == "required" { + return fmt.Errorf("value of required argument can't not be empty") + } + fk := ft.Kind() + + // convert string to value + if fk == reflect.Ptr && ft.Elem().Kind() == reflect.String { + fv.Set(reflect.ValueOf(&s)) + continue + } + if fk == reflect.Ptr || fk == reflect.Struct { + err = convertJsonValue(s, name, fv) + } else if fk == reflect.Slice { + err = convertSlice(s, f.Name, ft, fv) + } else { + err = convertValue(fk, s, f.Name, fv) + } + + if err != nil { + return err + } + continue + } + + return nil +} + +func Struct2Map(s interface{}) map[string]interface{} { + return Struct2MapByTag(s, "json") +} +func Struct2MapByTag(s interface{}, tagName string) map[string]interface{} { + t := reflect.TypeOf(s) + v := reflect.ValueOf(s) + + if v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct { + t = t.Elem() + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return nil + } + + m := make(map[string]interface{}) + + for i := 0; i < t.NumField(); i++ { + fv := v.Field(i) + ft := t.Field(i) + + if !fv.CanInterface() { + continue + } + + if ft.PkgPath != "" { // unexported + continue + } + + var name string + var option string + tag := ft.Tag.Get(tagName) + if tag != "" { + ts := strings.Split(tag, ",") + if len(ts) == 1 { + name = ts[0] + } else if len(ts) > 1 { + name = ts[0] + option = ts[1] + } + if name == "-" { + continue // skip this field + } + if name == "" { + name = strings.ToLower(ft.Name) + } + if option == "omitempty" { + if isEmpty(&fv) { + continue // skip empty field + } + } + } else { + name = strings.ToLower(ft.Name) + } + + if ft.Anonymous && fv.Kind() == reflect.Ptr && fv.IsNil() { + continue + } + if (ft.Anonymous && fv.Kind() == reflect.Struct) || + (ft.Anonymous && fv.Kind() == reflect.Ptr && fv.Elem().Kind() == reflect.Struct) { + + // embedded struct + embedded := Struct2MapByTag(fv.Interface(), tagName) + for embName, embValue := range embedded { + m[embName] = embValue + } + } else if option == "string" { + kind := fv.Kind() + if kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || kind == reflect.Int32 || kind == reflect.Int64 { + m[name] = strconv.FormatInt(fv.Int(), 10) + } else if kind == reflect.Uint || kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || kind == reflect.Uint64 { + m[name] = strconv.FormatUint(fv.Uint(), 10) + } else if kind == reflect.Float32 || kind == reflect.Float64 { + m[name] = strconv.FormatFloat(fv.Float(), 'f', 2, 64) + } else { + m[name] = fv.Interface() + } + } else { + m[name] = fv.Interface() + } + } + + return m +} + +func isEmpty(v *reflect.Value) bool { + k := v.Kind() + if k == reflect.Bool { + return v.Bool() == false + } else if reflect.Int < k && k < reflect.Int64 { + return v.Int() == 0 + } else if reflect.Uint < k && k < reflect.Uintptr { + return v.Uint() == 0 + } else if k == reflect.Float32 || k == reflect.Float64 { + return v.Float() == 0 + } else if k == reflect.Array || k == reflect.Map || k == reflect.Slice || k == reflect.String { + return v.Len() == 0 + } else if k == reflect.Interface || k == reflect.Ptr { + return v.IsNil() + } + return false +} + +func convertSlice(s string, name string, ft reflect.Type, fv reflect.Value) error { + var err error + et := ft.Elem() + + if et.Kind() == reflect.Ptr || et.Kind() == reflect.Struct { + return convertJsonValue(s, name, fv) + } + + ss := strings.Split(s, ",") + + if len(s) == 0 || len(ss) == 0 { + return nil + } + + fs := reflect.MakeSlice(ft, 0, len(ss)) + + for _, si := range ss { + ev := reflect.New(et).Elem() + + err = convertValue(et.Kind(), si, name, ev) + if err != nil { + return err + } + fs = reflect.Append(fs, ev) + } + + fv.Set(fs) + + return nil +} + +func convertJsonValue(s string, name string, fv reflect.Value) error { + var err error + d := StringToSlice(s) + + if fv.Kind() == reflect.Ptr { + if fv.IsNil() { + fv.Set(reflect.New(fv.Type().Elem())) + } + } else { + fv = fv.Addr() + } + + err = json.Unmarshal(d, fv.Interface()) + + if err != nil { + return fmt.Errorf("invalid json '%v': %v, %v", name, err.Error(), s) + } + + return nil +} + +func convertValue(kind reflect.Kind, s string, name string, fv reflect.Value) error { + if !fv.CanAddr() { + return fmt.Errorf("can not addr: %v", name) + } + + if kind == reflect.String { + fv.SetString(s) + return nil + } + + if kind == reflect.Bool { + switch s { + case "true": + fv.SetBool(true) + case "false": + fv.SetBool(false) + case "1": + fv.SetBool(true) + case "0": + fv.SetBool(false) + default: + return fmt.Errorf("invalid bool: %v value=%v", name, s) + } + return nil + } + + if reflect.Int <= kind && kind <= reflect.Int64 { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return fmt.Errorf("invalid int: %v value=%v", name, s) + } + fv.SetInt(i) + + } else if reflect.Uint <= kind && kind <= reflect.Uint64 { + i, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return fmt.Errorf("invalid int: %v value=%v", name, s) + } + fv.SetUint(i) + + } else if reflect.Float32 == kind || kind == reflect.Float64 { + i, err := strconv.ParseFloat(s, 64) + + if err != nil { + return fmt.Errorf("invalid float: %v value=%v", name, s) + } + + fv.SetFloat(i) + } else { + // not support or just ignore it? + // return fmt.Errorf("type not support: field=%v(%v) value=%v(%v)", name, ft.Kind(), val, vt.Kind()) + } + return nil +} + +func parseTag(tag string) (string, string) { + tags := strings.Split(tag, ",") + + if len(tags) <= 0 { + return "", "" + } + + if len(tags) == 1 { + return tags[0], "" + } + + return tags[0], tags[1] +} diff --git a/app/utils/md5.go b/app/utils/md5.go new file mode 100644 index 0000000..52c108d --- /dev/null +++ b/app/utils/md5.go @@ -0,0 +1,12 @@ +package utils + +import ( + "crypto/md5" + "encoding/hex" +) + +func Md5(str string) string { + h := md5.New() + h.Write([]byte(str)) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/app/utils/qrcode/decodeFile.go b/app/utils/qrcode/decodeFile.go new file mode 100644 index 0000000..f50fb28 --- /dev/null +++ b/app/utils/qrcode/decodeFile.go @@ -0,0 +1,33 @@ +package qrcode + +import ( + "image" + _ "image/jpeg" + _ "image/png" + "os" + + "github.com/makiuchi-d/gozxing" + "github.com/makiuchi-d/gozxing/qrcode" +) + +func DecodeFile(fi string) (string, error) { + file, err := os.Open(fi) + if err != nil { + return "", err + } + img, _, err := image.Decode(file) + if err != nil { + return "", err + } + // prepare BinaryBitmap + bmp, err := gozxing.NewBinaryBitmapFromImage(img) + if err != nil { + return "", err + } + // decode image + result, err := qrcode.NewQRCodeReader().Decode(bmp, nil) + if err != nil { + return "", err + } + return result.String(), nil +} diff --git a/app/utils/qrcode/getBase64.go b/app/utils/qrcode/getBase64.go new file mode 100644 index 0000000..11d149c --- /dev/null +++ b/app/utils/qrcode/getBase64.go @@ -0,0 +1,43 @@ +package qrcode + +// 生成登录二维码图片, 方便在网页上显示 + +import ( + "bytes" + "encoding/base64" + "image/jpeg" + "image/png" + + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/qr" +) + +func GetJPGBase64(content string, edges ...int) string { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + + emptyBuff := bytes.NewBuffer(nil) // 开辟一个新的空buff缓冲区 + jpeg.Encode(emptyBuff, img, nil) + dist := make([]byte, 50000) // 开辟存储空间 + base64.StdEncoding.Encode(dist, emptyBuff.Bytes()) // buff转成base64 + return "data:image/png;base64," + string(dist) // 输出图片base64(type = []byte) +} + +func GetPNGBase64(content string, edges ...int) string { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + + emptyBuff := bytes.NewBuffer(nil) // 开辟一个新的空buff缓冲区 + png.Encode(emptyBuff, img) + dist := make([]byte, 50000) // 开辟存储空间 + base64.StdEncoding.Encode(dist, emptyBuff.Bytes()) // buff转成base64 + return string(dist) // 输出图片base64(type = []byte) +} diff --git a/app/utils/qrcode/saveFile.go b/app/utils/qrcode/saveFile.go new file mode 100644 index 0000000..4854783 --- /dev/null +++ b/app/utils/qrcode/saveFile.go @@ -0,0 +1,85 @@ +package qrcode + +// 生成登录二维码图片 + +import ( + "errors" + "image" + "image/jpeg" + "image/png" + "os" + "path/filepath" + "strings" + + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/qr" +) + +func SaveJpegFile(filePath, content string, edges ...int) error { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + + return writeFile(filePath, img, "jpg") +} + +func SavePngFile(filePath, content string, edges ...int) error { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + + return writeFile(filePath, img, "png") +} + +func writeFile(filePath string, img image.Image, format string) error { + if err := createDir(filePath); err != nil { + return err + } + file, err := os.Create(filePath) + defer file.Close() + if err != nil { + return err + } + switch strings.ToLower(format) { + case "png": + err = png.Encode(file, img) + break + case "jpg": + err = jpeg.Encode(file, img, nil) + default: + return errors.New("format not accept") + } + if err != nil { + return err + } + return nil +} + +func createDir(filePath string) error { + var err error + // filePath, _ = filepath.Abs(filePath) + dirPath := filepath.Dir(filePath) + dirInfo, err := os.Stat(dirPath) + if err != nil { + if !os.IsExist(err) { + err = os.MkdirAll(dirPath, 0777) + if err != nil { + return err + } + } else { + return err + } + } else { + if dirInfo.IsDir() { + return nil + } + return errors.New("directory is a file") + } + return nil +} diff --git a/app/utils/qrcode/writeWeb.go b/app/utils/qrcode/writeWeb.go new file mode 100644 index 0000000..57e1e92 --- /dev/null +++ b/app/utils/qrcode/writeWeb.go @@ -0,0 +1,39 @@ +package qrcode + +import ( + "bytes" + "image/jpeg" + "image/png" + "net/http" + + "github.com/boombuler/barcode" + "github.com/boombuler/barcode/qr" +) + +func WritePng(w http.ResponseWriter, content string, edges ...int) error { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + buff := bytes.NewBuffer(nil) + png.Encode(buff, img) + w.Header().Set("Content-Type", "image/png") + _, err := w.Write(buff.Bytes()) + return err +} + +func WriteJpg(w http.ResponseWriter, content string, edges ...int) error { + edgeLen := 300 + if len(edges) > 0 && edges[0] > 100 && edges[0] < 2000 { + edgeLen = edges[0] + } + img, _ := qr.Encode(content, qr.L, qr.Unicode) + img, _ = barcode.Scale(img, edgeLen, edgeLen) + buff := bytes.NewBuffer(nil) + jpeg.Encode(buff, img, nil) + w.Header().Set("Content-Type", "image/jpg") + _, err := w.Write(buff.Bytes()) + return err +} diff --git a/app/utils/rand.go b/app/utils/rand.go new file mode 100644 index 0000000..0024fd0 --- /dev/null +++ b/app/utils/rand.go @@ -0,0 +1,31 @@ +package utils + +import ( + crand "crypto/rand" + "fmt" + "math/big" + "math/rand" + "time" +) + +func RandString(l int, c ...string) string { + var ( + chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + str string + num *big.Int + ) + if len(c) > 0 { + chars = c[0] + } + chrLen := int64(len(chars)) + for len(str) < l { + num, _ = crand.Int(crand.Reader, big.NewInt(chrLen)) + str += string(chars[num.Int64()]) + } + return str +} + +func RandNum() string { + seed := time.Now().UnixNano() + rand.Int63() + return fmt.Sprintf("%05v", rand.New(rand.NewSource(seed)).Int31n(1000000)) +} diff --git a/app/utils/rsa.go b/app/utils/rsa.go new file mode 100644 index 0000000..fb8274a --- /dev/null +++ b/app/utils/rsa.go @@ -0,0 +1,170 @@ +package utils + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io/ioutil" + "log" + "os" +) + +// 生成私钥文件 TODO 未指定路径 +func RsaKeyGen(bits int) error { + privateKey, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return err + } + derStream := x509.MarshalPKCS1PrivateKey(privateKey) + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: derStream, + } + priFile, err := os.Create("private.pem") + if err != nil { + return err + } + err = pem.Encode(priFile, block) + priFile.Close() + if err != nil { + return err + } + // 生成公钥文件 + publicKey := &privateKey.PublicKey + derPkix, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return err + } + block = &pem.Block{ + Type: "PUBLIC KEY", + Bytes: derPkix, + } + pubFile, err := os.Create("public.pem") + if err != nil { + return err + } + err = pem.Encode(pubFile, block) + pubFile.Close() + if err != nil { + return err + } + return nil +} + +// 生成私钥文件, 返回 privateKey , publicKey, error +func RsaKeyGenText(bits int) (string, string, error) { // bits 字节位 1024/2048 + privateKey, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return "", "", err + } + derStream := x509.MarshalPKCS1PrivateKey(privateKey) + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: derStream, + } + priBuff := bytes.NewBuffer(nil) + err = pem.Encode(priBuff, block) + if err != nil { + return "", "", err + } + // 生成公钥文件 + publicKey := &privateKey.PublicKey + derPkix, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return "", "", err + } + block = &pem.Block{ + Type: "PUBLIC KEY", + Bytes: derPkix, + } + pubBuff := bytes.NewBuffer(nil) + err = pem.Encode(pubBuff, block) + if err != nil { + return "", "", err + } + return priBuff.String(), pubBuff.String(), nil +} + +// 加密 +func RsaEncrypt(rawData, publicKey []byte) ([]byte, error) { + block, _ := pem.Decode(publicKey) + if block == nil { + return nil, errors.New("public key error") + } + pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, err + } + pub := pubInterface.(*rsa.PublicKey) + return rsa.EncryptPKCS1v15(rand.Reader, pub, rawData) +} + +// 公钥加密 +func RsaEncrypts(data, keyBytes []byte) []byte { + //解密pem格式的公钥 + block, _ := pem.Decode(keyBytes) + if block == nil { + panic(errors.New("public key error")) + } + // 解析公钥 + pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + panic(err) + } + // 类型断言 + pub := pubInterface.(*rsa.PublicKey) + //加密 + ciphertext, err := rsa.EncryptPKCS1v15(rand.Reader, pub, data) + if err != nil { + panic(err) + } + return ciphertext +} + +// 解密 +func RsaDecrypt(cipherText, privateKey []byte) ([]byte, error) { + block, _ := pem.Decode(privateKey) + if block == nil { + return nil, errors.New("private key error") + } + priv, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + return rsa.DecryptPKCS1v15(rand.Reader, priv, cipherText) +} + +// 从证书获取公钥 +func OpensslPemGetPublic(pathOrString string) (interface{}, error) { + var certPem []byte + var err error + if IsFile(pathOrString) && Exists(pathOrString) { + certPem, err = ioutil.ReadFile(pathOrString) + if err != nil { + return nil, err + } + if string(certPem) == "" { + return nil, errors.New("empty pem file") + } + } else { + if pathOrString == "" { + return nil, errors.New("empty pem string") + } + certPem = StringToSlice(pathOrString) + } + block, rest := pem.Decode(certPem) + if block == nil || block.Type != "PUBLIC KEY" { + //log.Fatal("failed to decode PEM block containing public key") + return nil, errors.New("failed to decode PEM block containing public key") + } + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Got a %T, with remaining data: %q", pub, rest) + return pub, nil +} diff --git a/app/utils/serialize.go b/app/utils/serialize.go new file mode 100644 index 0000000..1ac4d80 --- /dev/null +++ b/app/utils/serialize.go @@ -0,0 +1,23 @@ +package utils + +import ( + "encoding/json" +) + +func Serialize(data interface{}) []byte { + res, err := json.Marshal(data) + if err != nil { + return []byte{} + } + return res +} + +func Unserialize(b []byte, dst interface{}) { + if err := json.Unmarshal(b, dst); err != nil { + dst = nil + } +} + +func SerializeStr(data interface{}, arg ...interface{}) string { + return string(Serialize(data)) +} diff --git a/app/utils/shuffle.go b/app/utils/shuffle.go new file mode 100644 index 0000000..2c845a8 --- /dev/null +++ b/app/utils/shuffle.go @@ -0,0 +1,48 @@ +package utils + +import ( + "math/rand" + "time" +) + +// 打乱随机字符串 +func ShuffleString(s *string) { + if len(*s) > 1 { + b := []byte(*s) + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(b), func(x, y int) { + b[x], b[y] = b[y], b[x] + }) + *s = string(b) + } +} + +// 打乱随机slice +func ShuffleSliceBytes(b []byte) { + if len(b) > 1 { + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(b), func(x, y int) { + b[x], b[y] = b[y], b[x] + }) + } +} + +// 打乱slice int +func ShuffleSliceInt(i []int) { + if len(i) > 1 { + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(i), func(x, y int) { + i[x], i[y] = i[y], i[x] + }) + } +} + +// 打乱slice interface +func ShuffleSliceInterface(i []interface{}) { + if len(i) > 1 { + rand.Seed(time.Now().UnixNano()) + rand.Shuffle(len(i), func(x, y int) { + i[x], i[y] = i[y], i[x] + }) + } +} diff --git a/app/utils/sign_check.go b/app/utils/sign_check.go new file mode 100644 index 0000000..798f63d --- /dev/null +++ b/app/utils/sign_check.go @@ -0,0 +1,125 @@ +package utils + +import ( + "applet/app/utils/logx" + "fmt" + "github.com/forgoer/openssl" + "github.com/gin-gonic/gin" + "github.com/syyongx/php2go" + "strings" +) + +var publicKey = []byte(`-----BEGIN PUBLIC KEY----- +MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCFQD7RL2tDNuwdg0jTfV0zjAzh +WoCWfGrcNiucy2XUHZZU2oGhHv1N10qu3XayTDD4pu4sJ73biKwqR6ZN7IS4Sfon +vrzaXGvrTG4kmdo3XrbrkzmyBHDLTsJvv6pyS2HPl9QPSvKDN0iJ66+KN8QjBpw1 +FNIGe7xbDaJPY733/QIDAQAB +-----END PUBLIC KEY-----`) + +var privateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCFQD7RL2tDNuwdg0jTfV0zjAzhWoCWfGrcNiucy2XUHZZU2oGh +Hv1N10qu3XayTDD4pu4sJ73biKwqR6ZN7IS4SfonvrzaXGvrTG4kmdo3Xrbrkzmy +BHDLTsJvv6pyS2HPl9QPSvKDN0iJ66+KN8QjBpw1FNIGe7xbDaJPY733/QIDAQAB +AoGADi14wY8XDY7Bbp5yWDZFfV+QW0Xi2qAgSo/k8gjeK8R+I0cgdcEzWF3oz1Q2 +9d+PclVokAAmfj47e0AmXLImqMCSEzi1jDBUFIRoJk9WE1YstE94mrCgV0FW+N/u ++L6OgZcjmF+9dHKprnpaUGQuUV5fF8j0qp8S2Jfs3Sw+dOECQQCQnHALzFjmXXIR +Ez3VSK4ZoYgDIrrpzNst5Hh6AMDNZcG3CrCxlQrgqjgTzBSr3ZSavvkfYRj42STk +TqyX1tQFAkEA6+O6UENoUTk2lG7iO/ta7cdIULnkTGwQqvkgLIUjk6w8E3sBTIfw +rerTEmquw5F42HHE+FMrRat06ZN57lENmQJAYgUHlZevcoZIePZ35Qfcqpbo4Gc8 +Fpm6vwKr/tZf2Vlt0qo2VkhWFS6L0C92m4AX6EQmDHT+Pj7BWNdS+aCuGQJBAOkq +NKPZvWdr8jNOV3mKvxqB/U0uMigIOYGGtvLKt5vkh42J7ILFbHW8w95UbWMKjDUG +X/hF3WQEUo//Imsa2yECQHSZIpJxiTRueoDiyRt0LH+jdbYFUu/6D0UIYXhFvP/p +EZX+hfCfUnNYX59UVpRjSZ66g0CbCjuBPOhmOD+hDeQ= +-----END RSA PRIVATE KEY-----`) + +func GetApiVersion(c *gin.Context) int { + var apiVersion = c.GetHeader("apiVersion") + if StrToInt(apiVersion) == 0 { //没有版本号先不校验 + apiVersion = c.GetHeader("Apiversion") + } + if StrToInt(apiVersion) == 0 { //没有版本号先不校验 + apiVersion = c.GetHeader("api_version") + } + return StrToInt(apiVersion) +} + +//签名校验 +func SignCheck(c *gin.Context) bool { + var apiVersion = GetApiVersion(c) + if apiVersion == 0 { //没有版本号先不校验 + return true + } + //1.通过rsa 解析出 aes + var key = c.GetHeader("key") + + //拼接对应参数 + var uri = c.Request.RequestURI + var query = GetQueryParam(uri) + fmt.Println(query) + query["timestamp"] = c.GetHeader("timestamp") + query["nonce"] = c.GetHeader("nonce") + query["key"] = key + token := c.GetHeader("Authorization") + if token != "" { + // 按空格分割 + parts := strings.SplitN(token, " ", 2) + if len(parts) == 2 && parts[0] == "Bearer" { + token = parts[1] + } + } + query["token"] = token + //2.query参数按照 ASCII 码从小到大排序 + str := JoinStringsInASCII(query, "&", false, false, "") + //3.拼上密钥 + secret := "" + if InArr(c.GetHeader("platform"), []string{"android", "ios"}) { + secret = c.GetString("app_api_secret_key") + } else if c.GetHeader("platform") == "wap" { + secret = c.GetString("h5_api_secret_key") + } else { + secret = c.GetString("applet_api_secret_key") + } + str = fmt.Sprintf("%s&secret=%s", str, secret) + fmt.Println(str) + //4.md5加密 转小写 + sign := strings.ToLower(Md5(str)) + //5.判断跟前端传来的sign是否一致 + if sign != c.GetHeader("sign") { + return false + } + return true +} + +func ResultAes(c *gin.Context, raw []byte) string { + var key = c.GetHeader("key") + base, _ := php2go.Base64Decode(key) + aes, err := RsaDecrypt([]byte(base), privateKey) + if err != nil { + logx.Info(err) + return "" + } + + str, _ := openssl.AesECBEncrypt(raw, aes, openssl.PKCS7_PADDING) + value := php2go.Base64Encode(string(str)) + fmt.Println(value) + + return value +} + +func ResultAesDecrypt(c *gin.Context, raw string) string { + var key = c.GetHeader("key") + base, _ := php2go.Base64Decode(key) + aes, err := RsaDecrypt([]byte(base), privateKey) + if err != nil { + logx.Info(err) + return "" + } + fmt.Println(raw) + value1, _ := php2go.Base64Decode(raw) + if value1 == "" { + return "" + } + str1, _ := openssl.AesECBDecrypt([]byte(value1), aes, openssl.PKCS7_PADDING) + + return string(str1) +} diff --git a/app/utils/slice.go b/app/utils/slice.go new file mode 100644 index 0000000..30bd4ee --- /dev/null +++ b/app/utils/slice.go @@ -0,0 +1,26 @@ +package utils + +// ContainsString is 字符串是否包含在字符串切片里 +func ContainsString(array []string, val string) (index int) { + index = -1 + for i := 0; i < len(array); i++ { + if array[i] == val { + index = i + return + } + } + return +} + +func PaginateSliceInt64(x []int64, skip int, size int) []int64 { + if skip > len(x) { + skip = len(x) + } + + end := skip + size + if end > len(x) { + end = len(x) + } + + return x[skip:end] +} diff --git a/app/utils/slice_and_string.go b/app/utils/slice_and_string.go new file mode 100644 index 0000000..3ae6946 --- /dev/null +++ b/app/utils/slice_and_string.go @@ -0,0 +1,47 @@ +package utils + +import ( + "fmt" + "reflect" + "strings" + "unsafe" +) + +// string与slice互转,零copy省内存 + +// zero copy to change slice to string +func Slice2String(b []byte) (s string) { + pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pString := (*reflect.StringHeader)(unsafe.Pointer(&s)) + pString.Data = pBytes.Data + pString.Len = pBytes.Len + return +} + +// no copy to change string to slice +func StringToSlice(s string) (b []byte) { + pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + pString := (*reflect.StringHeader)(unsafe.Pointer(&s)) + pBytes.Data = pString.Data + pBytes.Len = pString.Len + pBytes.Cap = pString.Len + return +} + +// 任意slice合并 +func SliceJoin(sep string, elems ...interface{}) string { + l := len(elems) + if l == 0 { + return "" + } + if l == 1 { + s := fmt.Sprint(elems[0]) + sLen := len(s) - 1 + if s[0] == '[' && s[sLen] == ']' { + return strings.Replace(s[1:sLen], " ", sep, -1) + } + return s + } + sep = strings.Replace(fmt.Sprint(elems), " ", sep, -1) + return sep[1 : len(sep)-1] +} diff --git a/app/utils/string.go b/app/utils/string.go new file mode 100644 index 0000000..e7142ef --- /dev/null +++ b/app/utils/string.go @@ -0,0 +1,155 @@ +package utils + +import ( + "fmt" + "github.com/syyongx/php2go" + "reflect" + "sort" + "strings" +) + +func Implode(glue string, args ...interface{}) string { + data := make([]string, len(args)) + for i, s := range args { + data[i] = fmt.Sprint(s) + } + return strings.Join(data, glue) +} + +//字符串是否在数组里 +func InArr(target string, str_array []string) bool { + for _, element := range str_array { + if target == element { + return true + } + } + return false +} + +//把数组的值放到key里 +func ArrayColumn(array interface{}, key string) (result map[string]interface{}, err error) { + result = make(map[string]interface{}) + t := reflect.TypeOf(array) + v := reflect.ValueOf(array) + if t.Kind() != reflect.Slice { + return nil, nil + } + if v.Len() == 0 { + return nil, nil + } + for i := 0; i < v.Len(); i++ { + indexv := v.Index(i) + if indexv.Type().Kind() != reflect.Struct { + return nil, nil + } + mapKeyInterface := indexv.FieldByName(key) + if mapKeyInterface.Kind() == reflect.Invalid { + return nil, nil + } + mapKeyString, err := InterfaceToString(mapKeyInterface.Interface()) + if err != nil { + return nil, err + } + result[mapKeyString] = indexv.Interface() + } + return result, err +} + +//转string +func InterfaceToString(v interface{}) (result string, err error) { + switch reflect.TypeOf(v).Kind() { + case reflect.Int64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + result = fmt.Sprintf("%v", v) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + result = fmt.Sprintf("%v", v) + case reflect.String: + result = v.(string) + default: + err = nil + } + return result, err +} + +func HideTrueName(name string) string { + res := "**" + if name != "" { + runs := []rune(name) + leng := len(runs) + if leng <= 3 { + res = string(runs[0:1]) + res + } else if leng < 5 { + res = string(runs[0:2]) + res + } else if leng < 10 { + res = string(runs[0:2]) + "***" + string(runs[leng-2:leng]) + } else if leng < 16 { + res = string(runs[0:3]) + "****" + string(runs[leng-3:leng]) + } else { + res = string(runs[0:4]) + "*****" + string(runs[leng-4:leng]) + } + } + return res +} +func GetQueryParam(uri string) map[string]string { + //根据问号分割路由还是query参数 + uriList := strings.Split(uri, "?") + var query = make(map[string]string, 0) + //有参数才处理 + if len(uriList) == 2 { + //分割query参数 + var queryList = strings.Split(uriList[1], "&") + if len(queryList) > 0 { + //key value 分别赋值 + for _, v := range queryList { + var valueList = strings.Split(v, "=") + if len(valueList) == 2 { + value, _ := php2go.URLDecode(valueList[1]) + if value == "" { + value = valueList[1] + } + query[valueList[0]] = value + } + } + } + } + return query +} + +//JoinStringsInASCII 按照规则,参数名ASCII码从小到大排序后拼接 +//data 待拼接的数据 +//sep 连接符 +//onlyValues 是否只包含参数值,true则不包含参数名,否则参数名和参数值均有 +//includeEmpty 是否包含空值,true则包含空值,否则不包含,注意此参数不影响参数名的存在 +//exceptKeys 被排除的参数名,不参与排序及拼接 +func JoinStringsInASCII(data map[string]string, sep string, onlyValues, includeEmpty bool, exceptKeys ...string) string { + var list []string + var keyList []string + m := make(map[string]int) + if len(exceptKeys) > 0 { + for _, except := range exceptKeys { + m[except] = 1 + } + } + for k := range data { + if _, ok := m[k]; ok { + continue + } + value := data[k] + if !includeEmpty && value == "" { + continue + } + if onlyValues { + keyList = append(keyList, k) + } else { + list = append(list, fmt.Sprintf("%s=%s", k, value)) + } + } + if onlyValues { + sort.Strings(keyList) + for _, v := range keyList { + list = append(list, AnyToString(data[v])) + } + } else { + sort.Strings(list) + } + return strings.Join(list, sep) +} diff --git a/app/utils/time.go b/app/utils/time.go new file mode 100644 index 0000000..6860a57 --- /dev/null +++ b/app/utils/time.go @@ -0,0 +1,226 @@ +package utils + +import ( + "errors" + "fmt" + "strconv" + "strings" + "time" +) + +func StrToTime(s string) (int64, error) { + // delete all not int characters + if s == "" { + return time.Now().Unix(), nil + } + r := make([]rune, 14) + l := 0 + // 过滤除数字以外的字符 + for _, v := range s { + if '0' <= v && v <= '9' { + r[l] = v + l++ + if l == 14 { + break + } + } + } + for l < 14 { + r[l] = '0' // 补0 + l++ + } + t, err := time.Parse("20060102150405", string(r)) + if err != nil { + return 0, err + } + return t.Unix(), nil +} + +func TimeToStr(unixSecTime interface{}, layout ...string) string { + i := AnyToInt64(unixSecTime) + if i == 0 { + return "" + } + f := "2006-01-02 15:04:05" + if len(layout) > 0 { + f = layout[0] + } + return time.Unix(i, 0).Format(f) +} +func Time2String(date time.Time, format string) string { + if format == "" { + format = "2006-01-02 15:04:05" + } + timeS := date.Format(format) + if timeS == "0001-01-01 00:00:00" { + return "" + } + return timeS +} + +func FormatNanoUnix() string { + return strings.Replace(time.Now().Format("20060102150405.0000000"), ".", "", 1) +} + +func TimeParse(format, src string) (time.Time, error) { + return time.ParseInLocation(format, src, time.Local) +} + +func TimeParseStd(src string) time.Time { + t, _ := TimeParse("2006-01-02 15:04:05", src) + return t +} + +func TimeStdParseUnix(src string) int64 { + t, err := TimeParse("2006-01-02 15:04:05", src) + if err != nil { + return 0 + } + return t.Unix() +} + +// 获取一个当前时间 时间间隔 时间戳 +func GetTimeInterval(unit string, amount int) (startTime, endTime int64) { + t := time.Now() + nowTime := t.Unix() + tmpTime := int64(0) + switch unit { + case "years": + tmpTime = time.Date(t.Year()+amount, t.Month(), t.Day(), t.Hour(), 0, 0, 0, t.Location()).Unix() + case "months": + tmpTime = time.Date(t.Year(), t.Month()+time.Month(amount), t.Day(), t.Hour(), 0, 0, 0, t.Location()).Unix() + case "days": + tmpTime = time.Date(t.Year(), t.Month(), t.Day()+amount, t.Hour(), 0, 0, 0, t.Location()).Unix() + case "hours": + tmpTime = time.Date(t.Year(), t.Month(), t.Day(), t.Hour()+amount, 0, 0, 0, t.Location()).Unix() + } + if amount > 0 { + startTime = nowTime + endTime = tmpTime + } else { + startTime = tmpTime + endTime = nowTime + } + return +} + +// 几天前 +func TimeInterval(newTime int) string { + now := time.Now().Unix() + newTime64 := AnyToInt64(newTime) + if newTime64 >= now { + return "刚刚" + } + interval := now - newTime64 + switch { + case interval < 60: + return AnyToString(interval) + "秒前" + case interval < 60*60: + return AnyToString(interval/60) + "分前" + case interval < 60*60*24: + return AnyToString(interval/60/60) + "小时前" + case interval < 60*60*24*30: + return AnyToString(interval/60/60/24) + "天前" + case interval < 60*60*24*30*12: + return AnyToString(interval/60/60/24/30) + "月前" + default: + return AnyToString(interval/60/60/24/30/12) + "年前" + } +} + +// 时分秒字符串转时间戳,传入示例:8:40 or 8:40:10 +func HmsToUnix(str string) (int64, error) { + t := time.Now() + arr := strings.Split(str, ":") + if len(arr) < 2 { + return 0, errors.New("Time format error") + } + h, _ := strconv.Atoi(arr[0]) + m, _ := strconv.Atoi(arr[1]) + s := 0 + if len(arr) == 3 { + s, _ = strconv.Atoi(arr[3]) + } + formatted1 := fmt.Sprintf("%d%02d%02d%02d%02d%02d", t.Year(), t.Month(), t.Day(), h, m, s) + res, err := time.ParseInLocation("20060102150405", formatted1, time.Local) + if err != nil { + return 0, err + } else { + return res.Unix(), nil + } +} + +// 获取特定时间范围 +func GetTimeRange(s string) map[string]int64 { + t := time.Now() + var stime, etime time.Time + + switch s { + case "today": + stime = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month(), t.Day()+1, 0, 0, 0, 0, t.Location()) + case "yesterday": + stime = time.Date(t.Year(), t.Month(), t.Day()-1, 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) + case "within_seven_days": + // 前6天0点 + stime = time.Date(t.Year(), t.Month(), t.Day()-6, 0, 0, 0, 0, t.Location()) + // 明天 0点 + etime = time.Date(t.Year(), t.Month(), t.Day()+1, 0, 0, 0, 0, t.Location()) + case "current_month": + stime = time.Date(t.Year(), t.Month(), 0, 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month()+1, 0, 0, 0, 0, 0, t.Location()) + case "last_month": + stime = time.Date(t.Year(), t.Month()-1, 0, 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month(), 0, 0, 0, 0, 0, t.Location()) + } + + return map[string]int64{ + "start": stime.Unix(), + "end": etime.Unix(), + } +} + +// 获取特定时间范围 +func GetDateTimeRangeStr(s string) (string, string) { + t := time.Now() + var stime, etime time.Time + + switch s { + case "today": + stime = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month(), t.Day()+1, 0, 0, 0, 0, t.Location()) + case "yesterday": + stime = time.Date(t.Year(), t.Month(), t.Day()-1, 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) + case "within_seven_days": + // 前6天0点 + stime = time.Date(t.Year(), t.Month(), t.Day()-6, 0, 0, 0, 0, t.Location()) + // 明天 0点 + etime = time.Date(t.Year(), t.Month(), t.Day()+1, 0, 0, 0, 0, t.Location()) + case "current_month": + stime = time.Date(t.Year(), t.Month(), 0, 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month()+1, 0, 0, 0, 0, 0, t.Location()) + case "last_month": + stime = time.Date(t.Year(), t.Month()-1, 0, 0, 0, 0, 0, t.Location()) + etime = time.Date(t.Year(), t.Month(), 0, 0, 0, 0, 0, t.Location()) + } + + return stime.Format("2006-01-02 15:04:05"), etime.Format("2006-01-02 15:04:05") +} + +//获取传入的时间所在月份的第一天,即某月第一天的0点。如传入time.Now(), 返回当前月份的第一天0点时间。 +func GetFirstDateOfMonth(d time.Time) time.Time { + d = d.AddDate(0, 0, -d.Day()+1) + return GetZeroTime(d) +} + +//获取传入的时间所在月份的最后一天,即某月最后一天的0点。如传入time.Now(), 返回当前月份的最后一天0点时间。 +func GetLastDateOfMonth(d time.Time) time.Time { + return GetFirstDateOfMonth(d).AddDate(0, 1, -1) +} + +//获取某一天的0点时间 +func GetZeroTime(d time.Time) time.Time { + return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, d.Location()) +} diff --git a/app/utils/uuid.go b/app/utils/uuid.go new file mode 100644 index 0000000..da7018b --- /dev/null +++ b/app/utils/uuid.go @@ -0,0 +1,76 @@ +package utils + +import ( + "github.com/sony/sonyflake" + + "applet/app/utils/logx" + "fmt" + "math/rand" + "time" +) + +const ( + KC_RAND_KIND_NUM = 0 // 纯数字 + KC_RAND_KIND_LOWER = 1 // 小写字母 + KC_RAND_KIND_UPPER = 2 // 大写字母 + KC_RAND_KIND_ALL = 3 // 数字、大小写字母 +) + +func newUUID() *[16]byte { + u := &[16]byte{} + rand.Read(u[:16]) + u[8] = (u[8] | 0x80) & 0xBf + u[6] = (u[6] | 0x40) & 0x4f + return u +} + +func UUIDString() string { + u := newUUID() + return fmt.Sprintf("%x-%x-%x-%x-%x", u[:4], u[4:6], u[6:8], u[8:10], u[10:]) +} + +func UUIDHexString() string { + u := newUUID() + return fmt.Sprintf("%x%x%x%x%x", u[:4], u[4:6], u[6:8], u[8:10], u[10:]) +} +func UUIDBinString() string { + u := newUUID() + return fmt.Sprintf("%s", [16]byte(*u)) +} + +func Krand(size int, kind int) []byte { + ikind, kinds, result := kind, [][]int{[]int{10, 48}, []int{26, 97}, []int{26, 65}}, make([]byte, size) + isAll := kind > 2 || kind < 0 + rand.Seed(time.Now().UnixNano()) + for i := 0; i < size; i++ { + if isAll { // random ikind + ikind = rand.Intn(3) + } + scope, base := kinds[ikind][0], kinds[ikind][1] + result[i] = uint8(base + rand.Intn(scope)) + } + return result +} + +// OrderUUID is only num for uuid +func OrderUUID(uid int) string { + ustr := IntToStr(uid) + tstr := Int64ToStr(time.Now().Unix()) + ulen := len(ustr) + tlen := len(tstr) + rlen := 18 - ulen - tlen + krb := Krand(rlen, KC_RAND_KIND_NUM) + return ustr + tstr + string(krb) +} + +var flake *sonyflake.Sonyflake + +func GenId() int64 { + + id, err := flake.NextID() + if err != nil { + _ = logx.Errorf("flake.NextID() failed with %s\n", err) + panic(err) + } + return int64(id) +} diff --git a/app/utils/validator_err_trans.go b/app/utils/validator_err_trans.go new file mode 100644 index 0000000..29d97bf --- /dev/null +++ b/app/utils/validator_err_trans.go @@ -0,0 +1,55 @@ +package utils + +import ( + "fmt" + "github.com/gin-gonic/gin/binding" + "github.com/go-playground/locales/en" + "github.com/go-playground/locales/zh" + ut "github.com/go-playground/universal-translator" + "github.com/go-playground/validator/v10" + enTranslations "github.com/go-playground/validator/v10/translations/en" + chTranslations "github.com/go-playground/validator/v10/translations/zh" + "reflect" +) + +var ValidatorTrans ut.Translator + +// ValidatorTransInit 验证器错误信息翻译初始化 +// local 通常取决于 http 请求头的 'Accept-Language' +func ValidatorTransInit(local string) (err error) { + if v, ok := binding.Validator.Engine().(*validator.Validate); ok { + zhT := zh.New() //chinese + enT := en.New() //english + uni := ut.New(enT, zhT, enT) + + var o bool + ValidatorTrans, o = uni.GetTranslator(local) + if !o { + return fmt.Errorf("uni.GetTranslator(%s) failed", local) + } + // 注册一个方法,从自定义标签label中获取值(用在把字段名映射为中文) + v.RegisterTagNameFunc(func(field reflect.StructField) string { + label := field.Tag.Get("label") + if label == "" { + return field.Name + } + return label + }) + // 注册翻译器 + switch local { + case "en": + err = enTranslations.RegisterDefaultTranslations(v, ValidatorTrans) + case "zh": + err = chTranslations.RegisterDefaultTranslations(v, ValidatorTrans) + default: + err = enTranslations.RegisterDefaultTranslations(v, ValidatorTrans) + } + return + } + return +} + +// ValidatorTransInitZh 验证器错误信息翻译为中文初始化 +func ValidatorTransInitZh() (err error) { + return ValidatorTransInit("zh") +} diff --git a/app/utils/wx.go b/app/utils/wx.go new file mode 100644 index 0000000..6967da5 --- /dev/null +++ b/app/utils/wx.go @@ -0,0 +1,31 @@ +package utils + +import ( + "crypto/sha1" + "encoding/hex" + "sort" + "strings" +) + +// CheckSignature 微信公众号签名检查 +func CheckSignature(signature, timestamp, nonce, token string) bool { + arr := []string{timestamp, nonce, token} + // 字典序排序 + sort.Strings(arr) + + n := len(timestamp) + len(nonce) + len(token) + var b strings.Builder + b.Grow(n) + for i := 0; i < len(arr); i++ { + b.WriteString(arr[i]) + } + + return Sha1(b.String()) == signature +} + +// 进行Sha1编码 +func Sha1(str string) string { + h := sha1.New() + h.Write([]byte(str)) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/bin/.gitignore b/bin/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/bin/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/build.sh b/build.sh new file mode 100644 index 0000000..adb5edb --- /dev/null +++ b/build.sh @@ -0,0 +1,18 @@ +echo "update -> repo" +git fetch +git reset --hard origin/master +echo "update repo -> Success" + +id="git rev-parse --short HEAD" +export ZYOS_APP_COMMIT_ID=`eval $id` +echo "GET the Commit ID for git -> $ZYOS_APP_COMMIT_ID" + +echo "Start build image " + +image_name=registry-vpc.cn-shenzhen.aliyuncs.com/fnuoos-prd/zyos-mall:${ZYOS_APP_COMMIT_ID} +#final_image_name=registry.cn-shenzhen.aliyuncs.com/fnuoos-prd/zyos:${ZYOS_APP_COMMIT_ID} +docker build -t ${image_name} . + +docker push ${image_name} +echo "Push image -> $image_name Success" +export ZYOS_APP_LATEST_VERSION=${image_name} \ No newline at end of file diff --git a/cmd/task/main.go b/cmd/task/main.go new file mode 100644 index 0000000..ec3c112 --- /dev/null +++ b/cmd/task/main.go @@ -0,0 +1,43 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "syscall" + + "applet/app/cfg" + "applet/app/db" + "applet/app/task" + "applet/app/utils" + "applet/app/utils/logx" +) + +func init() { + // 加载任务配置 + cfg.InitTaskCfg() + // 日志配置 + cfg.InitLog() + // 初始化redis + cfg.InitCache() + baseDb := *cfg.DB + baseDb.Path = fmt.Sprintf(cfg.DB.Path, cfg.DB.Name) + if err := db.InitDB(&baseDb); err != nil { + panic(err) + } + utils.CurlDebug = true + //cfg.InitMemCache() +} + +func main() { + go func() { + // 初始化jobs方法列表、添加reload方法定时更新任务 + task.Init() + task.Run() + }() + // graceful shutdown + quit := make(chan os.Signal) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + _ = logx.Info("Server exiting...") +} diff --git a/cmd_db.bat b/cmd_db.bat new file mode 100644 index 0000000..7a9e8d5 --- /dev/null +++ b/cmd_db.bat @@ -0,0 +1,25 @@ +@echo off + +set Table=* +set TName="" +set one=%1 + +if "%one%" NEQ "" ( + set Table=%one% + set TName="^%one%$" +) + +set BasePath="./" +set DBUSER="root" +set DBPSW="Fnuo123com@" +set DBNAME="fnuoos_test1" +set DBHOST="119.23.182.117" +set DBPORT="3306" + +del "app\db\model\%Table%.go" + +echo start reverse table %Table% + +xorm reverse mysql "%DBUSER%:%DBPSW%@tcp(%DBHOST%:%DBPORT%)/%DBNAME%?charset=utf8" %BasePath%/etc/db_tpl %BasePath%/app/db/model/ %TName% + +echo end \ No newline at end of file diff --git a/cmd_db.sh b/cmd_db.sh new file mode 100644 index 0000000..ab769e2 --- /dev/null +++ b/cmd_db.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# 使用方法, 直接执行该脚本更新所有表, cmd_db.sh 表名, 如 ./cmd_db.sh tableName + +Table=* +TName="" +if [ "$1" ] ;then + Table=$1 + TName="^$1$" +fi + +BasePath="./" +DBUSER="root" +DBPSW="Fnuo123com@" +DBNAME="fnuoos_test1" +DBHOST="119.23.182.117" +DBPORT="3306" + +rm -rf $BasePath/app/db/model/$Table.go && \ + +xorm reverse mysql "$DBUSER:$DBPSW@tcp($DBHOST:$DBPORT)/$DBNAME?charset=utf8" $BasePath/etc/db_tpl $BasePath/app/db/model/ $TName \ No newline at end of file diff --git a/cmd_run.bat b/cmd_run.bat new file mode 100644 index 0000000..51d7b81 --- /dev/null +++ b/cmd_run.bat @@ -0,0 +1,12 @@ +@echo off + +set BasePath=%~dp0 +set APP=applet.exe +set CfgPath=%BasePath%\etc\cfg.yml + +del %BasePath%\bin\%APP% + +go build -o %BasePath%\bin\%APP% %BasePath%\cmd\main.go && %BasePath%\bin\%APP% -c=%CfgPath% + + +pause diff --git a/cmd_run.sh b/cmd_run.sh new file mode 100644 index 0000000..6758f1b --- /dev/null +++ b/cmd_run.sh @@ -0,0 +1,8 @@ +#!/bin/bash +APP=applet +BasePath=$(dirname $(readlink -f $0)) +CfgPath=$BasePath/etc/cfg.yml +cd $BasePath +rm -rf $BasePath/bin/$APP +go build -o $BasePath/bin/$APP $BasePath/main.go \ +&& $BasePath/bin/$APP -c=$CfgPath \ No newline at end of file diff --git a/cmd_task.bat b/cmd_task.bat new file mode 100644 index 0000000..f70eabc --- /dev/null +++ b/cmd_task.bat @@ -0,0 +1,13 @@ +@echo off + +set Name=task +set BasePath=%~dp0 +set APP=%Name%.exe +set CfgPath=%BasePath%etc\%Name%.yml + +del %BasePath%\bin\%APP% + +go build -o %BasePath%\bin\%APP% %BasePath%\cmd\%Name%\main.go && %BasePath%\bin\%APP% -c=%CfgPath% + + +pause diff --git a/cmd_task.sh b/cmd_task.sh new file mode 100644 index 0000000..8f7da10 --- /dev/null +++ b/cmd_task.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +APP=task +BasePath=$(dirname $(readlink -f $0)) +CfgPath=$BasePath/etc/task.yml +cd $BasePath +rm -rf $BasePath/bin/$APP +go build -o $BasePath/bin/$APP $BasePath/cmd/$APP/main.go \ +&& $BasePath/bin/$APP -c=$CfgPath \ No newline at end of file diff --git a/etc/db_tpl/config b/etc/db_tpl/config new file mode 100644 index 0000000..34c75ee --- /dev/null +++ b/etc/db_tpl/config @@ -0,0 +1,7 @@ +lang=go +genJson=1 +prefix=cos_ +ignoreColumnsJSON= +created= +updated= +deleted= \ No newline at end of file diff --git a/etc/db_tpl/struct.go.tpl b/etc/db_tpl/struct.go.tpl new file mode 100644 index 0000000..74b2896 --- /dev/null +++ b/etc/db_tpl/struct.go.tpl @@ -0,0 +1,17 @@ +package {{.Models}} + +{{$ilen := len .Imports}} +{{if gt $ilen 0}} +import ( + {{range .Imports}}"{{.}}"{{end}} +) +{{end}} + +{{range .Tables}} +type {{Mapper .Name}} struct { +{{$table := .}} +{{range .ColumnsSeq}}{{$col := $table.GetColumn .}} {{Mapper $col.Name}} {{Type $col}} {{Tag $table $col}} +{{end}} +} +{{end}} + diff --git a/etc/task.yml b/etc/task.yml new file mode 100644 index 0000000..449483c --- /dev/null +++ b/etc/task.yml @@ -0,0 +1,30 @@ +# debug release test +debug: true +prd: false +local: true +# 缓存 +redis_addr: '120.24.28.6:32572' + +# 数据库 +db: + host: '119.23.182.117:3306' + name: 'zyos_website' + user: 'root' + psw: 'Fnuo123com@' + show_log: true + max_lifetime: 30 + max_open_conns: 100 + max_idle_conns: 100 + path: 'tmp/task_sql_%v.log' + +# 日志 +log: + level: 'debug' # 普通日志级别 #debug, info, warn, fatal, panic + is_stdout: true + time_format: 'standard' # sec, second, milli, nano, standard, iso + encoding: 'console' + is_file_out: true + file_dir: './tmp/' + file_max_size: 256 + file_max_age: 1 + file_name: 'task.log' diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c86563f --- /dev/null +++ b/go.mod @@ -0,0 +1,59 @@ +module applet + +go 1.15 + +require ( + github.com/360EntSecGroup-Skylar/excelize v1.4.1 // indirect + github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5 + github.com/antchfx/htmlquery v1.3.0 // indirect + github.com/antchfx/xmlquery v1.3.16 // indirect + github.com/boombuler/barcode v1.0.1 + github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5 + github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/forgoer/openssl v0.0.0-20201023062029-c3112b0c8700 + github.com/gin-contrib/sessions v0.0.3 + github.com/gin-gonic/gin v1.6.3 + github.com/go-playground/locales v0.13.0 + github.com/go-playground/universal-translator v0.17.0 + github.com/go-playground/validator/v10 v10.4.2 + github.com/go-redis/redis v6.15.9+incompatible + github.com/go-sql-driver/mysql v1.6.0 + github.com/gobwas/glob v0.2.3 // indirect + github.com/gocolly/colly v1.2.0 + github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/snappy v0.0.3 // indirect + github.com/gomodule/redigo v2.0.0+incompatible + github.com/gorilla/sessions v1.2.1 // indirect + github.com/json-iterator/go v1.1.10 // indirect + github.com/kennygrant/sanitize v1.2.4 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/leodido/go-urn v1.2.1 // indirect + github.com/makiuchi-d/gozxing v0.0.0-20210324052758-57132e828831 + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.1 // indirect + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect + github.com/onsi/ginkgo v1.15.0 // indirect + github.com/onsi/gomega v1.10.5 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/robfig/cron/v3 v3.0.1 + github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect + github.com/shopspring/decimal v1.3.1 + github.com/smartystreets/goconvey v1.6.4 // indirect + github.com/sony/sonyflake v1.0.0 + github.com/stretchr/testify v1.7.0 // indirect + github.com/syyongx/php2go v0.9.4 + github.com/temoto/robotstxt v1.1.2 // indirect + github.com/tidwall/gjson v1.7.4 + github.com/ugorji/go v1.2.5 // indirect + go.uber.org/multierr v1.6.0 // indirect + go.uber.org/zap v1.16.0 + golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect + google.golang.org/appengine v1.6.1 // indirect + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect + gopkg.in/natefinch/lumberjack.v2 v2.0.0 + gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect + honnef.co/go/tools v0.0.1-2020.1.4 // indirect + xorm.io/builder v0.3.9 // indirect + xorm.io/xorm v1.0.7 +) diff --git a/main.go b/main.go new file mode 100644 index 0000000..7e47f6c --- /dev/null +++ b/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "applet/app/cfg" + "applet/app/db" + "applet/app/router" +) + +//系统初始化 +func init() { + cfg.InitCfg() //配置初始化 + cfg.InitLog() //日志初始化 + cfg.InitCache() //缓存初始化 + if cfg.Debug { //判断是否是debug + if err := db.InitDB(cfg.DB); err != nil { //主数据库初始化 + panic(err) + } + } + fmt.Println("init success") + +} + +func main() { + r := router.Init() //创建路由 + + srv := &http.Server{ //设置http服务参数 + Addr: cfg.SrvAddr, //指定ip和端口 + Handler: r, //指定路由 + } + + go func() { //协程启动监听http服务 + fmt.Println("Listening and serving HTTP on " + cfg.SrvAddr) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("listen: %s\n", err) + } + }() + + //退出go守护进程 + quit := make(chan os.Signal) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("Shutting down server...") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + log.Fatal("Server forced to shutdown:", err) + } + log.Println("Server exiting") + +} diff --git a/static/html/LandingPage.html b/static/html/LandingPage.html new file mode 100644 index 0000000..313f321 --- /dev/null +++ b/static/html/LandingPage.html @@ -0,0 +1,179 @@ + + + +
+ +{{.goodTitle}}
+ +
+
+
+
+