用Go导入大型CSV到PostgreSQL

April 12, 2022

用Go导入大型CSV到PostgreSQL 最近我想试试 PostgreSQL,素闻美名,一直没有尝试过。从网上下载了一个超大的CSV,解压后达18G,一般的文件编辑器 直接打不开,简单的方案是直接用 PostgreSQL 提供的 \copy 命令,或者 COPY 语句,但是这个文件无法使用,因为 其中有几行都是坏数据。

如果是MySQL的话,可以使用 LOAD FILE IGNORE... 来忽略错误,但是PostgreSQL没有这个选项,所以我只能选择用Go自己 来导入。

吐个槽,MySQL用IGNORE之后,连数据错误也会忽略,导致我导入数据之后,才发现 int 不够表示CSV里的数据字段,导入的 很多数据直接变成了 2 ** 31 -1 也就是 2147483647 了,白等了一个小时。

对于大型文件,如果没有足够的内存,也确实是很难处理,我们采取的基本策略就是分块处理,为了提高吞吐量,我们要做 批量提交和并发处理,为了处理异常数据,我们要能dump出有问题的那一块数据,以便处理之后我们再次导入, 由于每一个块是相对较小的,dump出来之后,我们是可以直接用文本编辑器处理问题行的,此外由于涉及到 string和bytes 的转换,我们需要避免频繁申请内存,可以使用上 黑科技:

// b2s converts byte slice to a string without memory allocation. // See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ . // // Note it may break if string and/or slice header will change // in the future go versions. func b2s(b []byte) string { / #nosec G103 / return (string)(unsafe.Pointer(&b)) }

// s2b converts string to a byte slice without memory allocation. // // Note it may break if string and/or slice header will change // in the future go versions. func s2b(s string) (b []byte) { / #nosec G103 / bh := (reflect.SliceHeader)(unsafe.Pointer(&b)) / #nosec G103 / sh := (reflect.StringHeader)(unsafe.Pointer(&s)) bh.Data = sh.Data bh.Cap = sh.Len bh.Len = sh.Len return b } 以及 sync.Pool,最开始我使用的是 bufio.Scanner,但是没想到超过了它的限制,报了 bufio.Scanner: token too long, 翻了一下代码如下:

var ( ErrTooLong = errors.New("bufio.Scanner: token too long") )

//...

if len(s.buf) >= s.maxTokenSize || len(s.buf) > maxInt/2 { s.setErr(ErrTooLong) return false }

//...

const ( // MaxScanTokenSize is the maximum size used to buffer a token // unless the user provides an explicit buffer with Scanner.Buffer. // The actual maximum token size may be smaller as the buffer // may need to include, for instance, a newline. MaxScanTokenSize = 64 * 1024 ) 可以自己把buffer调大,不过我选择自己逐行读取并且处理,毕竟咱也不知道这么大的数据里,最长的那行到底有多长。

最后代码如下:

package main

import ( "bufio" "io" "io/ioutil" "log" "os" "strings" "sync"

"github.com/jmoiron/sqlx"
"github.com/lib/pq"

)

const size uint64 = 10000

var ( tokens = make(chan bool, 50) stringSlicePool = sync.Pool{ New: func() interface{} { cache := make([]string, size) return cache[:0] }, } )

// 用wrapper避免参数是 []string 时,是值拷贝的问题 type wrapper struct { lines []string }

func dumpData(w *wrapper) { file, err := ioutil.TempFile("./dumps/", "damage") if err != nil { log.Printf("failed to open temp file") return } defer file.Close()

for _, line := range w.lines {
    file.WriteString(line)
}

}

func newWrapper() *wrapper { lines := stringSlicePool.Get().([]string) return &wrapper{lines: lines} }

func deleteWrapper(w *wrapper) { w.lines = w.lines[:0] stringSlicePool.Put(w.lines) }

func writeData(wg sync.WaitGroup, db sqlx.DB, w *wrapper) { wg.Add(1)

token := <-tokens // 并发控制

tx := db.MustBegin()
stmt, err := tx.Prepare(pq.CopyIn("表名", "字段1", "字段2" /*字段3...*/))
if err != nil {
    log.Printf("failed to prepare: %s", err)
    goto done
}

if len(w.lines) == 0 {
    goto done
}

for _, line := range w.lines {
    data := strings.Split(line, "\t") // 此处是一个频繁内存申请的点
    if len(data) < 2 {
        log.Printf("ignore %s", line)
        continue
    }

    stmt.Exec(data[len(data)-2], data[len(data)-1][:len(data[1])-1])
}
stmt.Close()
if err = tx.Commit(); err != nil {
    log.Printf("failed to commit: %s", err)
    dumpData(w)
    goto done
}

log.Printf("saving %d lines", len(w.lines))

done: tokens <- token deleteWrapper(w) wg.Done() }

func main() { var wg sync.WaitGroup

db, err := sqlx.Connect("postgres", "user=postgres dbname=数据库名 sslmode=disable password=密码")
if err != nil {
    log.Fatalln(err)
}
log.Printf("%v, %s", db, err)

file, err := os.Open("./to_import.csv")
if err != nil {
    log.Fatal(err)
}
defer file.Close()

for i := 0; i < cap(tokens); i++ {
    tokens <- true
}

reader := bufio.NewReader(file)
cache := newWrapper()

// optionally, resize scanner's capacity for lines over 64K, see next example
var i uint64 = 0
reader.ReadString('\n')
for {
    line, err := reader.ReadString('\n')
    if err == io.EOF {
        break
    }
    cache.lines = append(cache.lines, line)
    i += 1

    if i%size == 0 {
        oldCache := cache
        go writeData(&wg, db, oldCache)
        cache = newWrapper()
    }
}

log.Printf("wait wg...")
wg.Wait()
log.Printf("done...")

}