|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445 |
- package rabbit
-
- import (
- "encoding/json"
- "errors"
- "fmt"
- "github.com/streadway/amqp"
- "log"
- "os"
- "sync"
- "sync/atomic"
- "time"
- )
-
- const (
- defaultLogPrefix = "[rabbit-pool]"
- )
-
- var (
- ErrInvalidConfig = errors.New("invalid pool config\n")
- ErrFailedConnection = errors.New("failed to establish connection\n")
- ErrConnectionMaximum = errors.New("the number of connections exceeds the maximum\n")
- ErrChannelMaximum = errors.New("the number of channels exceeds the maximum\n")
- ErrGetChannelTimeOut = errors.New("get channel timeout\n")
- )
-
- type LoggerInter interface {
- Print(v ...interface{})
- }
-
- type Config struct {
- Host string // MQ的地址
- MinConn int // 最少建立的连接数
- MaxConn int // 最大建立的连接数
- MaxChannelPerConn int // 每个连接最多建立的信道数量
- MaxLifetime time.Duration
- }
-
- // 连接池
- type Pool struct {
- mu *sync.Mutex
- conf *Config
- logger LoggerInter
- connectionNum int32
- connections map[int64]*Connection
- connectionSerialNumber int64
- idleChannels chan *Channel
- }
-
- func NewPool(conf *Config, logger ...LoggerInter) (*Pool, error) {
- if conf.MaxConn <= 0 || conf.MinConn > conf.MaxConn {
- return nil, ErrInvalidConfig
- }
- p := &Pool{
- mu: new(sync.Mutex),
- connections: make(map[int64]*Connection),
- idleChannels: make(chan *Channel, conf.MaxConn*conf.MaxChannelPerConn),
- }
-
- if conf.MaxLifetime == 0 {
- conf.MaxLifetime = time.Duration(3600)
- }
-
- if len(logger) > 0 {
- p.SetLogger(logger[0])
- } else {
- p.SetLogger(log.New(os.Stdout, defaultLogPrefix, log.LstdFlags))
- }
- p.conf = conf
-
- var conn *Connection
- var err error
- // 建立最少连接数
- for i := 0; i < conf.MinConn; i++ {
- conn, err = p.NewConnection()
- if err != nil {
- p.GetLogger().Print(ErrFailedConnection.Error())
- return nil, ErrFailedConnection
- }
- p.connections[conn.connIdentity] = conn
- }
- return p, nil
- }
-
- func (p *Pool) SetConfig(conf *Config) *Pool {
- p.conf = conf
- return p
- }
-
- func (p *Pool) GetConfig() *Config {
- return p.conf
- }
-
- func (p *Pool) SetLogger(logger LoggerInter) *Pool {
- p.logger = logger
- return p
- }
-
- func (p *Pool) GetLogger() LoggerInter {
- return p.logger
- }
-
- func (p *Pool) NewConnection() (*Connection, error) {
- // 判断连接是否达到最大值
- if atomic.AddInt32(&p.connectionNum, 1) > int32(p.conf.MaxConn) {
- atomic.AddInt32(&p.connectionNum, -1)
- return nil, ErrConnectionMaximum
- }
- conn, err := amqp.Dial(p.conf.Host)
- if err != nil {
- atomic.AddInt32(&p.connectionNum, -1)
- return nil, err
- }
-
- return &Connection{
- mu: new(sync.Mutex),
- conn: conn,
- pool: p,
- channelNum: 0,
- expireTime: time.Duration(time.Now().Unix()) + p.conf.MaxLifetime,
- connIdentity: atomic.AddInt64(&p.connectionSerialNumber, 1),
- }, nil
- }
-
- func (p *Pool) CloseConnection(c *Connection) error {
- p.mu.Lock()
- defer p.mu.Unlock()
- atomic.AddInt32(&p.connectionNum, -1)
- delete(p.connections, c.connIdentity)
- return c.conn.Close()
- }
-
- func (p *Pool) GetChannel() (*Channel, error) {
- ch, _ := p.getOrCreate()
- if ch != nil {
- return ch, nil
- }
-
- C := time.After(time.Second * 10)
- for {
- ch, _ := p.getOrCreate()
- if ch != nil {
- return ch, nil
- }
- select {
- case <-C:
- p.GetLogger().Print(ErrGetChannelTimeOut.Error())
- return nil, ErrGetChannelTimeOut
- default:
- }
- }
- }
-
- func (p *Pool) getOrCreate() (*Channel, error) {
- // 池中是否有空闲channel
- var (
- ch *Channel
- err error
- )
- select {
- case ch = <-p.idleChannels:
- return ch, nil
- default:
- }
-
- p.mu.Lock()
- defer p.mu.Unlock()
- // 池中已有连接是否可以建立新的channel
- for _, conn := range p.connections {
- if conn.CheckExpire() {
- continue
- }
- ch, err = conn.NewChannel()
- if ch != nil {
- return ch, nil
- }
- }
- // 新建连接获取新的channel
- var conn *Connection
- conn, err = p.NewConnection()
- if err != nil {
- return nil, err
- }
- p.connections[conn.connIdentity] = conn
- ch, err = conn.NewChannel()
- if err != nil {
- return nil, err
- }
- return ch, nil
- }
-
- func (p *Pool) ReleaseChannel(ch *Channel) error {
- p.idleChannels <- ch
- return nil
- }
-
- type Connection struct {
- mu *sync.Mutex
- conn *amqp.Connection
- pool *Pool
- expireTime time.Duration
- isExpire bool
- connIdentity int64 // 连接标记
- channelNum int32 // 该连接的信道数量
- channelSerialNumber int64 // 第几个channel
- }
-
- func (c *Connection) NewChannel() (*Channel, error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if atomic.AddInt32(&c.channelNum, 1) > int32(c.pool.conf.MaxChannelPerConn) {
- atomic.AddInt32(&c.channelNum, -1)
- return nil, ErrChannelMaximum
- }
- ch, err := c.conn.Channel()
- if err != nil {
- atomic.AddInt32(&c.channelNum, -1)
- return nil, err
- }
- return &Channel{
- Channel: ch,
- conn: c,
- chanIdentity: atomic.AddInt64(&c.channelSerialNumber, 1),
- }, nil
- }
-
- func (c *Connection) ReleaseChannel(ch *Channel) error {
- if c.CheckExpire() {
- return c.CloseChannel(ch)
- }
- return c.pool.ReleaseChannel(ch)
- }
-
- func (c *Connection) CloseChannel(ch *Channel) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- atomic.AddInt32(&c.channelNum, -1)
- var err = ch.Channel.Close()
- if atomic.LoadInt32(&c.channelNum) <= 0 && c.CheckExpire() {
- return c.pool.CloseConnection(c)
- }
- return err
- }
-
- // 检查是否过期
- func (c *Connection) CheckExpire() bool {
- if c.isExpire {
- return true
- }
- if time.Duration(time.Now().Unix()) > c.expireTime {
- c.isExpire = true
- }
- return c.isExpire
- }
-
- /************************************************************************************************************/
-
- type Channel struct {
- *amqp.Channel
- conn *Connection
- chanIdentity int64 // 该连接的第几个channel
- Name string
- exchange string
- }
-
- func (ch *Channel) Release() error {
- return ch.conn.ReleaseChannel(ch)
- }
-
- func (ch *Channel) Close() error {
- return ch.conn.CloseChannel(ch)
- }
-
- // QueueDeclare 声明交换机
- func (ch *Channel) QueueDeclare(queue string) {
- _, e := ch.Channel.QueueDeclare(queue, false, true, false, false, nil)
- failOnError(e, "声明交换机!")
- }
-
- // QueueDelete 删除交换机
- func (ch *Channel) QueueDelete(queue string) {
- _, e := ch.Channel.QueueDelete(queue, false, true, false)
- failOnError(e, "删除队列失败!")
- }
-
- //初始化队列
- func (ch *Channel) NewQueue(name string, args ...bool) {
- var durable, autoDelete, exclusive, noWait = true, false, false, false
- if len(args) > 0 {
- durable = args[0]
- }
- if len(args) > 1 {
- autoDelete = args[1]
- }
- if len(args) > 2 {
- exclusive = args[2]
- }
- if len(args) > 3 {
- noWait = args[3]
- }
- q, e := ch.Channel.QueueDeclare(
- name, //队列名
- durable, //是否开启持久化
- autoDelete, //不使用时删除
- exclusive, //排他
- noWait, //不等待
- nil, //参数
- )
- failOnError(e, "初始化队列失败!")
- ch.Name = q.Name
- }
-
- // NewExchange 初始化交换机
- //s:rabbitmq服务器的链接,name:交换机名字,typename:交换机类型
- func (ch *Channel) NewExchange(name string, typename string, args ...bool) {
- var durable, autoDelete, internal, noWait = true, false, false, false
- if len(args) > 0 {
- durable = args[0]
- }
- if len(args) > 1 {
- autoDelete = args[1]
- }
- if len(args) > 2 {
- internal = args[2]
- }
- if len(args) > 3 {
- noWait = args[3]
- }
- e := ch.ExchangeDeclare(
- name, // name
- typename, // type
- durable, // durable
- autoDelete, // auto-deleted
- internal, // 是否只在rabbitmq server内部使用
- noWait, // no-wait
- nil, // arguments
- )
- failOnError(e, "初始化交换机失败!")
- }
-
- // ExchangeDelete 删除交换机
- func (ch *Channel) ExchangeDelete(exchange string) {
- e := ch.Channel.ExchangeDelete(exchange, false, true)
- failOnError(e, "绑定队列失败!")
- }
-
- // Bind 绑定消息队列到哪个exchange
- func (ch *Channel) Bind(queueName string, exchange string, key string) {
- e := ch.Channel.QueueBind(
- queueName,
- key,
- exchange,
- false,
- nil,
- )
- failOnError(e, "绑定队列失败!")
- ch.exchange = exchange
- }
-
- // Qos 配置队列参数
- func (ch *Channel) Qos(prefetchCount int) {
- e := ch.Channel.Qos(prefetchCount, 0, false)
- failOnError(e, "无法设置QoS")
- }
-
- //Send 向消息队列发送消息
- //可以往某个消息队列发送消息
- func (ch *Channel) Send(queue string, body interface{}) {
- str, e := json.Marshal(body)
- failOnError(e, "消息序列化失败!")
- e = ch.Channel.Publish(
- "", //交换机
- queue, //路由键
- false, //必填
- false, //立即
- amqp.Publishing{
- ReplyTo: ch.Name,
- Body: []byte(str),
- })
- msg := "向队列:" + ch.Name + "发送消息失败!"
- failOnError(e, msg)
- }
-
- // Publish 向exchange发送消息
- // 可以往某个exchange发送消息
- func (ch *Channel) Publish(exchange string, body interface{}, key string) {
- str, e := json.Marshal(body)
- failOnError(e, "消息序列化失败!")
- e = ch.Channel.Publish(
- exchange,
- key,
- false,
- false,
- amqp.Publishing{
- ContentType: "text/plain",
- DeliveryMode: amqp.Transient,
- Priority: 0,
- Body: []byte(str)},
- )
- failOnError(e, "向路由发送消息失败!")
- }
-
- // PublishV2 向exchange发送消息
- // 可以往某个exchange发送消息
- func (ch *Channel) PublishV2(exchange string, body interface{}, key string) error {
- str, e := json.Marshal(body)
- if e != nil {
- return e
- }
- e = ch.Channel.Publish(
- exchange,
- key,
- false,
- false,
- amqp.Publishing{
- ContentType: "text/plain",
- DeliveryMode: amqp.Transient,
- Priority: 0,
- Body: []byte(str)},
- )
- return e
- }
-
- // 接收某个消息队列的消息
- func (ch *Channel) Consume(name string, autoAck bool) <-chan amqp.Delivery {
- c, e := ch.Channel.Consume(
- name, // 指定从哪个队列中接收消息
- "", // 用来区分多个消费者
- autoAck, // 是否自动应答
- false, // 是否独有
- false, // 如果设置为true,表示不能将同一个connection中发送的消息传递给这个connection中的消费者
- false, // 列是否阻塞
- nil,
- )
- failOnError(e, "接收消息失败!")
- return c
- }
-
- //错误处理函数
- func failOnError(err error, msg string) {
- if err != nil {
- log.Fatalf("%s: %s", msg, err)
- panic(fmt.Sprintf("%s:%s", msg, err))
- }
- }
|