119 řádky
2.7 KiB

  1. package session
  2. import (
  3. "database/sql"
  4. )
  5. const beginStatus = 1
  6. // SessionFactory 会话工厂
  7. type SessionFactory struct {
  8. *sql.DB
  9. }
  10. // Session 会话
  11. type Session struct {
  12. DB *sql.DB // 原生db
  13. tx *sql.Tx // 原生事务
  14. commitSign int8 // 提交标记,控制是否提交事务
  15. rollbackSign bool // 回滚标记,控制是否回滚事务
  16. }
  17. // NewSessionFactory 创建一个会话工厂
  18. func NewSessionFactory(driverName, dataSourseName string) (*SessionFactory, error) {
  19. db, err := sql.Open(driverName, dataSourseName)
  20. if err != nil {
  21. panic(err)
  22. }
  23. factory := new(SessionFactory)
  24. factory.DB = db
  25. return factory, nil
  26. }
  27. // GetSession 获取一个Session
  28. func (sf *SessionFactory) GetSession() *Session {
  29. session := new(Session)
  30. session.DB = sf.DB
  31. return session
  32. }
  33. // Begin 开启事务
  34. func (s *Session) Begin() error {
  35. s.rollbackSign = true
  36. if s.tx == nil {
  37. tx, err := s.DB.Begin()
  38. if err != nil {
  39. return err
  40. }
  41. s.tx = tx
  42. s.commitSign = beginStatus
  43. return nil
  44. }
  45. s.commitSign++
  46. return nil
  47. }
  48. // Rollback 回滚事务
  49. func (s *Session) Rollback() error {
  50. if s.tx != nil && s.rollbackSign {
  51. err := s.tx.Rollback()
  52. if err != nil {
  53. return err
  54. }
  55. s.tx = nil
  56. return nil
  57. }
  58. return nil
  59. }
  60. // Commit 提交事务
  61. func (s *Session) Commit() error {
  62. s.rollbackSign = false
  63. if s.tx != nil {
  64. if s.commitSign == beginStatus {
  65. err := s.tx.Commit()
  66. if err != nil {
  67. return err
  68. }
  69. s.tx = nil
  70. return nil
  71. } else {
  72. s.commitSign--
  73. }
  74. return nil
  75. }
  76. return nil
  77. }
  78. // Exec 执行sql语句,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
  79. func (s *Session) Exec(query string, args ...interface{}) (sql.Result, error) {
  80. if s.tx != nil {
  81. return s.tx.Exec(query, args...)
  82. }
  83. return s.DB.Exec(query, args...)
  84. }
  85. // QueryRow 如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
  86. func (s *Session) QueryRow(query string, args ...interface{}) *sql.Row {
  87. if s.tx != nil {
  88. return s.tx.QueryRow(query, args...)
  89. }
  90. return s.DB.QueryRow(query, args...)
  91. }
  92. // Query 查询数据,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
  93. func (s *Session) Query(query string, args ...interface{}) (*sql.Rows, error) {
  94. if s.tx != nil {
  95. return s.tx.Query(query, args...)
  96. }
  97. return s.DB.Query(query, args...)
  98. }
  99. // Prepare 预执行,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
  100. func (s *Session) Prepare(query string) (*sql.Stmt, error) {
  101. if s.tx != nil {
  102. return s.tx.Prepare(query)
  103. }
  104. return s.DB.Prepare(query)
  105. }