golang封装mysql操作

Golang piniu 500浏览 0评论

database/sql 是 Go 操作数据库的标准库之一,它提供了一系列接口方法,用于访问数据库(mysql,sqllite,oralce,postgresql)。

一、安装go操作mysql的驱动

go get "github.com/go-sql-driver/mysql"

二、go封装的Mysql操作

package main
 
import(
	"fmt"
	"database/sql"
    _"github.com/go-sql-driver/mysql"
    "strconv"
    // "time"
    "strings"
    "encoding/json"
)
 
var(
    key string
    value string
    conditions string
    str string  
)

//数据库的相关配置
var (
    db_userName  string = "test"
    db_password  string = "test123"
    db_ipAddrees string = "127.0.0.1"
    db_port      int    = 3306
    db_name    string = "mydb"
    db_charset   string = "utf8"
)
 
type Model struct{
    link  *sql.DB  //存储连接对象
    tableName string //存储表名
    field string //存储字段
    allFields []string //存储当前表所有字段
    where string //存储where条件
    order string //存储order条件
    limit string //存储limit条件
    page int //当前页码
    limitCount int //每页数据条数
}
 
//构造方法
func NewModel(table string) Model{
	var this Model
	this.field = "*"
    this.limitCount = 10
    this.page = 1
    this.limit = "limit 0, 10"
	//1.存储操作的表名
	this.tableName = table
    //2.初始化连接数据库
    this.getConnect();
    //3.获得当前表的所有字段
    // this.getFields();  该方法仅在Add 和 Update 方法中使用,可以单独写在对应的方法中
	return this
}
 
/**
 * 初始化连接数据库操作
 */
func (this *Model) getConnect(){

    //1.连接数据库
	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s", db_userName, db_password, db_ipAddrees, db_port, db_name, db_charset)
    db,err := sql.Open("mysql", dsn);

    //2.判断连接
    if err != nil{
    	fmt.Printf("connect mysql fail ! [%s]",err)
    }

    this.link = db
}
 
/**
 * 获取当前表的所有字段
 */
func (this *Model) getFields(){
 
    //查看表结构
    sql := "DESC "+this.tableName
    //执行并发送SQL
    result, err := this.link.Query(sql)
 
    if err != nil{
    	fmt.Printf("sql fail ! [%s]",err)
    }
 
    this.allFields = make([]string,0)
 
	for result.Next() {
		var field string
	  	var Type interface{}
	  	var Null string
	  	var Key string
	  	var Default interface{}
	  	var Extra string
		err :=result.Scan(&field,&Type,&Null,&Key,&Default,&Extra)
     	if err != nil{
      		fmt.Printf("scan fail ! [%s]",err)
     	}

     	this.allFields = append(this.allFields,field)
	}
 
}
 
/**
 * 执行并发送SQL(查询)
 * @param string $sql  要查询的SQL语句
 * @return array 返回查询出来的二维数组
 */
func (this *Model) query(sql string) interface{}{
 
    rows2, err := this.link.Query(sql)
    
	if err != nil {
		return returnRes(0,``,err)
	}
 
	//返回所有列
	cols, err := rows2.Columns()
 
	if err != nil {
        return returnRes(0,``,err)
	}
 
	//这里表示一行所有列的值,用[]byte表示 
	vals := make([][]byte, len(cols))
 
	//这里表示一行填充数据
	scans := make([]interface{}, len(cols))
 
	//这里scans引用vals,把数据填充到[]byte里
	for k, _ := range vals {
		scans[k] = &vals[k]
	}
 
	// i := 0
	//result := make(map[int]map[string]string)

    var result []map[string]string    //定义一个切片
 
	for rows2.Next() {
		//填充数据
		rows2.Scan(scans...) //将列值的地址传入

		//每行数据
		row := make(map[string]string)
		//把vals中的数据复制到row中
		for k, v := range vals {
			key := cols[k]
			//这里把[]byte数据转成string
			row[key] = string(v)
		}
		//放入结果集
		//result[i] = row
		// i++
        result = append(result, row)

	}

    //关闭结果集(释放连接)
    rows2.Close()

    return returnRes(1,result,"success")
 
}

 
/**
 * 设置要查询的字段信息
 * @param string $field  要查询的字段
 * @return object 返回自己,保证连贯操作
 */
func (this *Model) Field(field string) *Model{
    this.field = field
    return this
}
 
/**
 * Order排序条件
 * @param string  $order  以此为基准进行排序
 * @return $this  返回自己,保证连贯操作
 */
func (this *Model) Order(order string) *Model{
    this.order = `order by ` + order
    return this
}
 

/**
 * Limit查询数据条数
 * @param string  $limit  以此为基准进行排序
 * @return 返回 limitFilter()方法的返回值
 */
func (this *Model) Limit(limit int) *Model{

    this.limitCount = limit
    return this.limitFilter()

}

/**
 * Page 查询指定页码的数据
 * @param string  $page  以此为基准进行排序
 * @return 返回 limitFilter()方法的返回值
 */
func (this *Model) Page(page int) *Model{

    this.page = page
    return this.limitFilter()

}

/**
 * limit条件
 * @param string $limit 输入的limit条件
 * @return $this 返回自己,保证连贯操作
 */
func (this *Model) limitFilter() *Model{

    offset := (this.page - 1) * this.limitCount
    this.limit = "limit "+ strconv.Itoa(offset) +"," + strconv.Itoa(this.limitCount)
    return this
}
 
/**
 * where条件
 * @param string $where 输入的where条件
 * @return $this 返回自己,保证连贯操作
 */
func (this *Model) Where(where string) *Model{
   this.where = `where ` + where
   return this
}

 
/**
 * 执行并发送SQL语句(增删改)
 * @param string $sql 要执行的SQL语句
 * @return bool|int|string 添加成功则返回上一次操作id,删除修改操作则返回true,失败则返回false
 */
func (this *Model) exec(sql string) interface{}{
 
	res, err := this.link.Exec(sql)
 
    if err != nil {
        return returnRes(0,``,err)
    }
 
    result, err := res.LastInsertId()
    if err != nil {
        return returnRes(0,``,err)
    }
    return returnRes(1,result,"success")
}
 
 
//是否存在数组内
func in_array(need interface{}, needArr []string) bool {
    for _,v := range needArr{
        if need == v{
            return true
        }
    }
    return false
}
 
//返回json
func returnRes(errCode int,res interface{},msg interface{}) string{
    result := make(map[string]interface{})
    result["errCode"] = errCode
    result["result"] = res
    result["msg"]    = msg
    data,_:= json.Marshal(result)
    return string(data)
}


/**
 * 添加操作
 * @param array  $data 要添加的数组
 * @return bool|int|string 添加成功则返回上一次操作的id,失败则返回false
 */
 func (this *Model) Add(data map[string]interface{}) interface{}{
 
    //过滤非法字段
    for k,v:=range data{
    	if res:=in_array(k,this.allFields); res !=true {
    		delete(data,k)
    	}else{
    		key += `,` + k
    		value += `,` + `'`+v.(string)+`'`
    	}
    }
 
    //将map中取出的键转为字符串拼接
    key =strings.TrimLeft(key, ",")
    //将map中的值转化为字符串拼接
    value =strings.TrimLeft(value, ",")
    //准备SQL语句
    sql :=`insert into ` + this.tableName +` (`+ key + `) values (` + value+ `)`
    // //执行并发送SQL
    result :=this.exec(sql)
 
    return result
    
}

/**
 * 修改操作
 * @param  array $data  要修改的数组
 * @return bool 修改成功返回true,失败返回false
 */
 func (this *Model) Update(data map[string]interface{}) interface{}{
 
    //过滤非法字段
    for k,v:=range data{
        if res:=in_array(k,this.allFields); res !=true {
            delete(data,k)
        }else{
            str += k +` = '`+v.(string)+`',`
        }
    }
 
    //去掉最右侧的逗号
    str =strings.TrimRight(str, ",")
 
    //判断是否有条件
    if this.where == ""{
        fmt.Println("没有条件")
    }
 
    sql := `update ` + this.tableName + ` set ` + str  +` `+ this.where
 
    result :=this.exec(sql)
    return result
}
 
/**
 * 删除操作
 * @param string $id 要删除的id
 * @return bool  删除成功则返回true,失败则返回false
 */
func (this *Model) Delete(user_id int) interface{}{
 
    //判断id是否存在
    if this.where == "" {
        conditions = `where user_id = ` + strconv.Itoa(user_id)
    }else{
        conditions = this.where + ` and user_id = ` + strconv.Itoa(user_id)
    }
 
    sql := `delete from ` + this.tableName + ` ` + conditions
 
    //执行并发送
    result :=this.exec(sql)
 
    return result
}


/**
 * 查询多条数据
 */
 func (this *Model) Get() interface{}{
    sql := `select `+this.field+` from `+this.tableName+` `+this.where+` `+this.order+` `+this.limit
    //执行并发送SQL
    result :=this.query(sql)
    return result
}
 
/**
 * 查询一条数据
 * @param string $id 要查询的id
 * @return array  返回一条数据
 */
func (this *Model) Find() interface{}{
    //判断id是否存在
    sql := `select `+this.field+` from `+this.tableName + ` ` + this.where +` limit 1`
    //执行并发送sql
    result :=this.query(sql)
    return result
}
 
/**
 * 统计总条数
 * @return int 返回总数
 */
func (this *Model) Count() interface{}{
    //准备SQL语句
    sql := `select count(*) as total from ` + this.tableName + ` ` + this.where + ` limit 1`
    result :=this.query(sql)
    return returnRes(1,result,"success")
}


 

func main() {
 
	M := NewModel("userinfo")
	//查询链式操作
	//res := M.Field("user_name, email, add_time").Order("id desc").Where("user_name= '张三'").Page(1).Limit(10).Get()

	//添加操作
	// data := make(map[string]interface{})
	// data["ddd"] = "118284901@qq.com"
	// data["daaa"] = "118284901@qq.com"
	// data["email"] = "118284901@qq.com"
	// data["user_name"] = "张三"
	// data["add_time"] = time.Unix(time.Now().Unix(),0).Format("2006-01-02 15:04:05")
	// res:=M.add(data)
 
    //删除操作
    // res :=M.delete(18)
 
    //更新操作
    // data := make(map[string]interface{})
    // data["email"] = "118284901@qq.com"
    // data["user_name"] = "打啊"
    // data["add_time"] = time.Unix(time.Now().Unix(),0).Format("2006-01-02 15:04:05")
    // res:=M.Where("user_id = 1").update(data)
    fmt.Println(res)
 
}

三、query方法说明:

query() 方法返回的是一个 sql.Rows 类型的结果集,也可以用来查询多个字段的数据,不过需要定义多个字段的变量进行接收,迭代后者的 Next() 方法,然后使用 Scan() 方法给对应类型变量赋值,以便取出结果,最后再把结果集关闭(释放连接)


发表我的评论
取消评论
表情

Hi,您需要填写昵称和邮箱!

  • * 昵称:
  • * 邮箱: