|
- package rabbit
-
- import (
- "encoding/json"
- "errors"
- "fmt"
- "github.com/streadway/amqp"
- "log"
- "os"
- "sync"
- "sync/atomic"
- "time"
- )
-
- const (
- defaultLogPrefix = "[rabbit-pool]"
- )
-
- var (
- ErrInvalidConfig = errors.New("invalid pool config\n")
- ErrFailedConnection = errors.New("failed to establish connection\n")
- ErrConnectionMaximum = errors.New("the number of connections exceeds the maximum\n")
- ErrChannelMaximum = errors.New("the number of channels exceeds the maximum\n")
- ErrGetChannelTimeOut = errors.New("get channel timeout\n")
- )
-
- type LoggerInter interface {
- Print(v ...interface{})
- }
-
- type Config struct {
- Host string // MQ的地址
- MinConn int // 最少建立的连接数
- MaxConn int // 最大建立的连接数
- MaxChannelPerConn int // 每个连接最多建立的信道数量
- MaxLifetime time.Duration
- }
-
- // 连接池
- type Pool struct {
- mu *sync.Mutex
- conf *Config
- logger LoggerInter
- connectionNum int32
- connections map[int64]*Connection
- connectionSerialNumber int64
- idleChannels chan *Channel
- }
-
- func NewPool(conf *Config, logger ...LoggerInter) (*Pool, error) {
- if conf.MaxConn <= 0 || conf.MinConn > conf.MaxConn {
- return nil, ErrInvalidConfig
- }
- p := &Pool{
- mu: new(sync.Mutex),
- connections: make(map[int64]*Connection),
- idleChannels: make(chan *Channel, conf.MaxConn*conf.MaxChannelPerConn),
- }
-
- if conf.MaxLifetime == 0 {
- conf.MaxLifetime = time.Duration(3600)
- }
-
- if len(logger) > 0 {
- p.SetLogger(logger[0])
- } else {
- p.SetLogger(log.New(os.Stdout, defaultLogPrefix, log.LstdFlags))
- }
- p.conf = conf
-
- var conn *Connection
- var err error
- // 建立最少连接数
- for i := 0; i < conf.MinConn; i++ {
- conn, err = p.NewConnection()
- if err != nil {
- p.GetLogger().Print(ErrFailedConnection.Error())
- return nil, ErrFailedConnection
- }
- p.connections[conn.connIdentity] = conn
- }
- return p, nil
- }
-
- func (p *Pool) SetConfig(conf *Config) *Pool {
- p.conf = conf
- return p
- }
-
- func (p *Pool) GetConfig() *Config {
- return p.conf
- }
-
- func (p *Pool) SetLogger(logger LoggerInter) *Pool {
- p.logger = logger
- return p
- }
-
- func (p *Pool) GetLogger() LoggerInter {
- return p.logger
- }
-
- func (p *Pool) NewConnection() (*Connection, error) {
- // 判断连接是否达到最大值
- if atomic.AddInt32(&p.connectionNum, 1) > int32(p.conf.MaxConn) {
- atomic.AddInt32(&p.connectionNum, -1)
- return nil, ErrConnectionMaximum
- }
- conn, err := amqp.Dial(p.conf.Host)
- if err != nil {
- atomic.AddInt32(&p.connectionNum, -1)
- return nil, err
- }
-
- return &Connection{
- mu: new(sync.Mutex),
- conn: conn,
- pool: p,
- channelNum: 0,
- expireTime: time.Duration(time.Now().Unix()) + p.conf.MaxLifetime,
- connIdentity: atomic.AddInt64(&p.connectionSerialNumber, 1),
- }, nil
- }
-
- func (p *Pool) CloseConnection(c *Connection) error {
- p.mu.Lock()
- defer p.mu.Unlock()
- atomic.AddInt32(&p.connectionNum, -1)
- delete(p.connections, c.connIdentity)
- return c.conn.Close()
- }
-
- func (p *Pool) GetChannel() (*Channel, error) {
- ch, _ := p.getOrCreate()
- if ch != nil {
- return ch, nil
- }
-
- C := time.After(time.Second * 10)
- for {
- ch, _ := p.getOrCreate()
- if ch != nil {
- return ch, nil
- }
- select {
- case <-C:
- p.GetLogger().Print(ErrGetChannelTimeOut.Error())
- return nil, ErrGetChannelTimeOut
- default:
- }
- }
- }
-
- func (p *Pool) getOrCreate() (*Channel, error) {
- // 池中是否有空闲channel
- var (
- ch *Channel
- err error
- )
- select {
- case ch = <-p.idleChannels:
- return ch, nil
- default:
- }
-
- p.mu.Lock()
- defer p.mu.Unlock()
- // 池中已有连接是否可以建立新的channel
- for _, conn := range p.connections {
- if conn.CheckExpire() {
- continue
- }
- ch, err = conn.NewChannel()
- if ch != nil {
- return ch, nil
- }
- }
- // 新建连接获取新的channel
- var conn *Connection
- conn, err = p.NewConnection()
- if err != nil {
- return nil, err
- }
- p.connections[conn.connIdentity] = conn
- ch, err = conn.NewChannel()
- if err != nil {
- return nil, err
- }
- return ch, nil
- }
-
- func (p *Pool) ReleaseChannel(ch *Channel) error {
- p.idleChannels <- ch
- return nil
- }
-
- type Connection struct {
- mu *sync.Mutex
- conn *amqp.Connection
- pool *Pool
- expireTime time.Duration
- isExpire bool
- connIdentity int64 // 连接标记
- channelNum int32 // 该连接的信道数量
- channelSerialNumber int64 // 第几个channel
- }
-
- func (c *Connection) NewChannel() (*Channel, error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if atomic.AddInt32(&c.channelNum, 1) > int32(c.pool.conf.MaxChannelPerConn) {
- atomic.AddInt32(&c.channelNum, -1)
- return nil, ErrChannelMaximum
- }
- ch, err := c.conn.Channel()
- if err != nil {
- atomic.AddInt32(&c.channelNum, -1)
- return nil, err
- }
- return &Channel{
- Channel: ch,
- conn: c,
- chanIdentity: atomic.AddInt64(&c.channelSerialNumber, 1),
- }, nil
- }
-
- func (c *Connection) ReleaseChannel(ch *Channel) error {
- if c.CheckExpire() {
- return c.CloseChannel(ch)
- }
- return c.pool.ReleaseChannel(ch)
- }
-
- func (c *Connection) CloseChannel(ch *Channel) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- atomic.AddInt32(&c.channelNum, -1)
- var err = ch.Channel.Close()
- if atomic.LoadInt32(&c.channelNum) <= 0 && c.CheckExpire() {
- return c.pool.CloseConnection(c)
- }
- return err
- }
-
- // 检查是否过期
- func (c *Connection) CheckExpire() bool {
- if c.isExpire {
- return true
- }
- if time.Duration(time.Now().Unix()) > c.expireTime {
- c.isExpire = true
- }
- return c.isExpire
- }
-
- /************************************************************************************************************/
-
- type Channel struct {
- *amqp.Channel
- conn *Connection
- chanIdentity int64 // 该连接的第几个channel
- Name string
- exchange string
- }
-
- func (ch *Channel) Release() error {
- return ch.conn.ReleaseChannel(ch)
- }
-
- func (ch *Channel) Close() error {
- return ch.conn.CloseChannel(ch)
- }
-
- // QueueDeclare 声明交换机
- func (ch *Channel) QueueDeclare(queue string) {
- _, e := ch.Channel.QueueDeclare(queue, false, true, false, false, nil)
- failOnError(e, "声明交换机!")
- }
-
- // QueueDelete 删除交换机
- func (ch *Channel) QueueDelete(queue string) {
- _, e := ch.Channel.QueueDelete(queue, false, true, false)
- failOnError(e, "删除队列失败!")
- }
-
- //初始化队列
- func (ch *Channel) NewQueue(name string, args ...bool) {
- var durable, autoDelete, exclusive, noWait = true, false, false, false
- if len(args) > 0 {
- durable = args[0]
- }
- if len(args) > 1 {
- autoDelete = args[1]
- }
- if len(args) > 2 {
- exclusive = args[2]
- }
- if len(args) > 3 {
- noWait = args[3]
- }
- q, e := ch.Channel.QueueDeclare(
- name, //队列名
- durable, //是否开启持久化
- autoDelete, //不使用时删除
- exclusive, //排他
- noWait, //不等待
- nil, //参数
- )
- failOnError(e, "初始化队列失败!")
- ch.Name = q.Name
- }
-
- // NewExchange 初始化交换机
- //s:rabbitmq服务器的链接,name:交换机名字,typename:交换机类型
- func (ch *Channel) NewExchange(name string, typename string, args ...bool) {
- var durable, autoDelete, internal, noWait = true, false, false, false
- if len(args) > 0 {
- durable = args[0]
- }
- if len(args) > 1 {
- autoDelete = args[1]
- }
- if len(args) > 2 {
- internal = args[2]
- }
- if len(args) > 3 {
- noWait = args[3]
- }
- e := ch.ExchangeDeclare(
- name, // name
- typename, // type
- durable, // durable
- autoDelete, // auto-deleted
- internal, // 是否只在rabbitmq server内部使用
- noWait, // no-wait
- nil, // arguments
- )
- failOnError(e, "初始化交换机失败!")
- }
-
- // ExchangeDelete 删除交换机
- func (ch *Channel) ExchangeDelete(exchange string) {
- e := ch.Channel.ExchangeDelete(exchange, false, true)
- failOnError(e, "绑定队列失败!")
- }
-
- // Bind 绑定消息队列到哪个exchange
- func (ch *Channel) Bind(queueName string, exchange string, key string) {
- e := ch.Channel.QueueBind(
- queueName,
- key,
- exchange,
- false,
- nil,
- )
- failOnError(e, "绑定队列失败!")
- ch.exchange = exchange
- }
-
- // Qos 配置队列参数
- func (ch *Channel) Qos(prefetchCount int) {
- e := ch.Channel.Qos(prefetchCount, 0, false)
- failOnError(e, "无法设置QoS")
- }
-
- //Send 向消息队列发送消息
- //可以往某个消息队列发送消息
- func (ch *Channel) Send(queue string, body interface{}) {
- str, e := json.Marshal(body)
- failOnError(e, "消息序列化失败!")
- e = ch.Channel.Publish(
- "", //交换机
- queue, //路由键
- false, //必填
- false, //立即
- amqp.Publishing{
- ReplyTo: ch.Name,
- Body: []byte(str),
- })
- msg := "向队列:" + ch.Name + "发送消息失败!"
- failOnError(e, msg)
- }
-
- // Publish 向exchange发送消息
- // 可以往某个exchange发送消息
- func (ch *Channel) Publish(exchange string, body interface{}, key string) {
- str, e := json.Marshal(body)
- failOnError(e, "消息序列化失败!")
- e = ch.Channel.Publish(
- exchange,
- key,
- false,
- false,
- amqp.Publishing{
- ContentType: "text/plain",
- DeliveryMode: amqp.Transient,
- Priority: 0,
- Body: []byte(str)},
- )
- failOnError(e, "向路由发送消息失败!")
- }
-
- // 接收某个消息队列的消息
- func (ch *Channel) Consume(name string) <-chan amqp.Delivery {
- c, e := ch.Channel.Consume(
- name, // 指定从哪个队列中接收消息
- "", // 用来区分多个消费者
- false, // 是否自动应答
- false, // 是否独有
- false, // 如果设置为true,表示不能将同一个connection中发送的消息传递给这个connection中的消费者
- false, // 列是否阻塞
- nil,
- )
- failOnError(e, "接收消息失败!")
- return c
- }
-
- //错误处理函数
- func failOnError(err error, msg string) {
- if err != nil {
- log.Fatalf("%s: %s", msg, err)
- panic(fmt.Sprintf("%s:%s", msg, err))
- }
- }
|