golang封装mysql操作

Golang piniu 719浏览 0评论

database/sql 是 Go 操作数据库的标准库之一,它提供了一系列接口方法,用于访问数据库(mysql,sqllite,oralce,postgresql)。

一、安装go操作mysql的驱动

go get "github.com/go-sql-driver/mysql"

二、go封装的Mysql操作

package main

import (
	"database/sql"
	"encoding/json"
	"fmt"
	_ "github.com/go-sql-driver/mysql"
	"strconv"
	// "time"
	"strings"
)

var (
	key        string
	value      string
	conditions string
	str        string
)

// 数据库的相关配置
var (
	db_userName  string = "root"
	db_password  string = "root"
	db_ipAddrees string = "127.0.0.1"
	db_port      int    = 3306
	db_name      string = "test"
	db_charset   string = "utf8"
)

type Model struct {
	link       *sql.DB  //存储连接对象
	tableName  string   //存储表名
	field      string   //存储字段
	allFields  []string //存储当前表所有字段
	where      string   //存储where条件
	order      string   //存储order条件
	limit      string   //存储limit条件
	page       int      //当前页码
	limitCount int      //每页数据条数
}

// 构造方法
func NewModel(table string) Model {
	var this Model
	this.field = "*"
	this.limitCount = 10
	this.page = 1
	this.limit = "limit 0, 10"
	//1.存储操作的表名
	this.tableName = table
	//2.初始化连接数据库
	this.getConnect()
	//3.获得当前表的所有字段
	// this.getFields()  该方法仅在Add 和 Update 方法中使用,可以单独写在对应的方法中
	return this
}

/**
 * 初始化连接数据库操作
 */
func (this *Model) getConnect() {

	//1.连接数据库
	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s", db_userName, db_password, db_ipAddrees, db_port, db_name, db_charset)
	db, err := sql.Open("mysql", dsn)

	//2.判断连接
	if err != nil {
		fmt.Printf("connect mysql fail ! [%s]", err)
	}

	this.link = db
}

/**
 * 获取当前表的所有字段
 */
func (this *Model) getFields() {

	//查看表结构
	sql := "DESC " + this.tableName
	//执行并发送SQL
	result, err := this.link.Query(sql)

	if err != nil {
		fmt.Printf("sql fail ! [%s]", err)
	}

	this.allFields = make([]string, 0)

	for result.Next() {
		var field string
		var Type interface{}
		var Null string
		var Key string
		var Default interface{}
		var Extra string
		err := result.Scan(&field, &Type, &Null, &Key, &Default, &Extra)
		if err != nil {
			fmt.Printf("scan fail ! [%s]", err)
		}

		this.allFields = append(this.allFields, field)
	}

}

/**
 * 执行并发送SQL(查询)
 * @param string $sql  要查询的SQL语句
 * @return array 返回查询出来的二维数组
 */
func (this *Model) query(sql string) interface{} {

	var result []map[string]interface{} //定义一个切片

	rows2, err := this.link.Query(sql)

	if err != nil {
		// return returnRes(0, ``, err)
		return result
	}

	//返回所有列
	cols, err := rows2.Columns()

	if err != nil {
		// return returnRes(0, ``, err)
		return result
	}

	//这里表示一行所有列的值,用[]byte表示
	vals := make([][]byte, len(cols))

	//这里表示一行填充数据
	scans := make([]interface{}, len(cols))

	//这里scans引用vals,把数据填充到[]byte里
	for k, _ := range vals {
		scans[k] = &vals[k]
	}

	// i := 0
	//result := make(map[int]map[string]string)

	for rows2.Next() {
		//填充数据
		rows2.Scan(scans...) //将列值的地址传入

		//每行数据
		row := make(map[string]interface{})
		//把vals中的数据复制到row中
		for k, v := range vals {
			key := cols[k]
			//这里把[]byte数据转成string
			row[key] = string(v)
		}
		//放入结果集
		//result[i] = row
		// i++
		result = append(result, row)

	}

	//关闭结果集(释放连接)
	rows2.Close()

	//return returnRes(1, result, "success")
	return result

}

/**
 * 设置要查询的字段信息
 * @param string $field  要查询的字段
 * @return object 返回自己,保证连贯操作
 */
func (this *Model) Field(field string) *Model {
	this.field = field
	return this
}

/**
 * Order排序条件
 * @param string  $order  以此为基准进行排序
 * @return $this  返回自己,保证连贯操作
 */
func (this *Model) Order(order string) *Model {
	this.order = `order by ` + order
	return this
}

/**
 * Limit查询数据条数
 * @param string  $limit  以此为基准进行排序
 * @return 返回 limitFilter()方法的返回值
 */
func (this *Model) Limit(limit int) *Model {

	this.limitCount = limit
	return this.limitFilter()

}

/**
 * Page 查询指定页码的数据
 * @param string  $page  以此为基准进行排序
 * @return 返回 limitFilter()方法的返回值
 */
func (this *Model) Page(page int) *Model {

	this.page = page
	return this.limitFilter()

}

/**
 * limit条件
 * @param string $limit 输入的limit条件
 * @return $this 返回自己,保证连贯操作
 */
func (this *Model) limitFilter() *Model {

	offset := (this.page - 1) * this.limitCount
	this.limit = "limit " + strconv.Itoa(offset) + "," + strconv.Itoa(this.limitCount)
	return this
}

/**
 * where条件
 * @param string $where 输入的where条件
 * @return $this 返回自己,保证连贯操作
 */
func (this *Model) Where(where string) *Model {
	this.where = `where ` + where
	return this
}

/**
 * 执行并发送SQL语句(增删改)
 * @param string $sql 要执行的SQL语句
 * @return bool|int|string 添加成功则返回上一次操作id,删除修改操作则返回true,失败则返回false
 */
func (this *Model) exec(sql string) interface{} {

	res, err := this.link.Exec(sql)

	ret := make(map[string]interface{})
	ret["InsertId"] = 0
	ret["Affected"] = 0

	if err != nil {
		//return returnRes(0, ``, err)
		return ret
	}

	result, err := res.LastInsertId()  //insert 
	if err != nil {
		return ret
	}

	result2, err := res.RowsAffected() 
	if err != nil {
		return ret
	}

	ret["InsertId"] = result
	ret["Affected"] = result2

	//return returnRes(1, result, "success")
	return ret

}

// 是否存在数组内
func in_array(need interface{}, needArr []string) bool {
	for _, v := range needArr {
		if need == v {
			return true
		}
	}
	return false
}

// 返回json
func returnRes(errCode int, res interface{}, msg interface{}) string {
	result := make(map[string]interface{})
	result["errCode"] = errCode
	result["result"] = res
	result["msg"] = msg
	data, _ := json.Marshal(result)
	return string(data)
}

/**
 * 添加操作
 * @param array  $data 要添加的数组
 * @return bool|int|string 添加成功则返回上一次操作的id,失败则返回false
 */
func (this *Model) Add(data map[string]interface{}) interface{} {

	this.getFields()  //获取下表的所有字段

	//过滤非法字段
	for k, v := range data {
		if res := in_array(k, this.allFields); res != true {
			delete(data, k)
		} else {
			key += `,` + k
			value += `,` + `'` + v.(string) + `'`
		}
	}

	//将map中取出的键转为字符串拼接
	key = strings.TrimLeft(key, ",")
	//将map中的值转化为字符串拼接
	value = strings.TrimLeft(value, ",")
	//准备SQL语句
	sql := `insert into ` + this.tableName + ` (` + key + `) values (` + value + `)`

	// //执行并发送SQL
	result := this.exec(sql)

	return result

}

/**
 * 修改操作
 * @param  array $data  要修改的数组
 * @return bool 修改成功返回true,失败返回false
 */
func (this *Model) Update(data map[string]interface{}) interface{} {

	this.getFields()  //获取下表的所有字段

	//过滤非法字段
	for k, v := range data {
		if res := in_array(k, this.allFields); res != true {
			delete(data, k)
		} else {
			str += k + ` = '` + v.(string) + `',`
		}
	}

	//去掉最右侧的逗号
	str = strings.TrimRight(str, ",")

	//判断是否有条件
	if this.where == "" { //避免全部更新
		fmt.Println("没有条件")
		return 0
	}

	sql := `update ` + this.tableName + ` set ` + str + ` ` + this.where

	result := this.exec(sql)
	return result
}

/**
 * 删除操作
 * @param string $id 要删除的id
 * @return bool  删除成功则返回true,失败则返回false
 */
func (this *Model) Delete() interface{} {

	//判断是否有条件
	if this.where == "" { //避免全部删除
		fmt.Println("没有条件")
		return 0
	}

	sql := `delete from ` + this.tableName + ` ` + this.where

	//执行并发送
	result := this.exec(sql)

	return result
}

/**
 * 查询多条数据
 */
func (this *Model) Get() []map[string]interface{} {
	sql := `select ` + this.field + ` from ` + this.tableName + ` ` + this.where + ` ` + this.order + ` ` + this.limit
	//执行并发送SQL
	result := this.query(sql).([]map[string]interface{})  //查询结果,转类型,赋值

	return result
}

/**
 * 查询一条数据
 * @param string $id 要查询的id
 * @return array  返回一条数据
 */
func (this *Model) Find() map[string]interface{} {
	//判断id是否存在
	sql := `select ` + this.field + ` from ` + this.tableName + ` ` + this.where + ` limit 1`
	//执行并发送sql
	result := this.query(sql).([]map[string]interface{})

	if result != nil {
		return result[0]
	}

	return make(map[string]interface{})

}

/**
 * 统计总条数
 * @return int 返回总数
 */
func (this *Model) Count() int {
	//准备SQL语句
	sql := `select count(*) as total from ` + this.tableName + ` ` + this.where + ` limit 1`
	result := this.query(sql).([]map[string]interface{})

	// return returnRes(1, result, "success")
	if result != nil {
		count , err :=  strconv.Atoi(result[0]["total"].(string))
		if err != nil {
			return 0
		}
		return count
	}

	return 0
}


func main() {

	//M := NewModel("fl_content")

	//多行查询链式操作
	//res := M.Field("id, title, created_at").Where("status=3").Order("id desc").Page(2).Limit(5).Get()
	//fmt.Println(res)

	//获取单行
	//res2 := M.Field("id, title, created_at").Where("status=3").Find()
	//fmt.Println(res2)

	//查询总条数
	//res3 := M.Where("status=3").Count()
	//fmt.Println("count = ", res3)

	M2 := library2.NewModel("fl_exam")

	insertData := make(map[string]interface{})
	insertData["name"] = "来自go insert 的标题"
	insertData["content"] = "这是内容"
	insertData["answer"] = "这是答案444"
	insertData["subject_id"] = "1"  //需要字符串
	insertData["year"] = time.Unix(time.Now().Unix(),0).Format("2006")
	insertData["created_at"] = strconv.FormatInt(time.Now().Unix(), 10)   //需要将int64 转为 string

	//添加数据
	//fmt.Println(insertData)
	// ret := M2.Add(insertData)
	//fmt.Println("insert ret = ", ret)

	//更新数据
	// ret2 := M2.Where("id = 18").Update(insertData)
	// fmt.Println("update ret = ", ret2)

	//删除数据
	// ret := M2.Where("id = 3").Delete()
	// fmt.Println("delete ret = ", ret)

}

三、Query方法说明:

query() 方法返回的是一个 sql.Rows 类型的结果集,也可以用来查询多个字段的数据,不过需要定义多个字段的变量进行接收,迭代后者的 Next() 方法,然后使用 Scan() 方法给对应类型变量赋值,以便取出结果,最后再把结果集关闭(释放连接)

四、查询说明

sql.DB支持4种查询:

  • db.Query()
  • db.QueryRow()
  • db.Prepare(sql) stmt.Query(args)
  • db.Exec()

说明

  • db.Query() 返回多行数据,需要依次遍历,并且需要自己关闭查询结果集
  • db.QueryRow() 是专门查询一行数据的一个语法糖,返回ErrNoRow或者一行数据,不需要自己关闭结果集
  • db.Prepare() 是预先将一个数据库连接(con)和一个条sql语句绑定并返回stmt结构体代表这个绑定后的连接,然后运行stmt.Query()或者stmt.QueryRow();stmt是并发安全的。之所以这样设计,是因为每次直接调用db.Prepare都会自动选择一个可用的con,每次选择的可能不是同一个con
  • db.Exec() 适用于执行insert、update、delete等不需要返回结果集的操作

发表我的评论
取消评论
表情

Hi,您需要填写昵称和邮箱!

  • * 昵称:
  • * 邮箱: