From 1d8534aa4dff5d140a7d2b6a94ef88d838311494 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Tue, 6 Sep 2022 10:02:38 -0400 Subject: [PATCH] Add Serialize and Deserialize support --- sqlite3.go | 44 ++++++++++++++++++++ sqlite3_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) diff --git a/sqlite3.go b/sqlite3.go index 5ac95709..cef691e8 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -905,6 +905,50 @@ func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) { return &SQLiteTx{c}, nil } +// Serialize returns a byte slice that is a serialization of the database. +// If the database fails to serialize, a nil slice will be returned. +// +// See https://www.sqlite.org/c3ref/serialize.html +func (c *SQLiteConn) Serialize(schema string) []byte { + if schema == "" { + schema = "main" + } + var zSchema *C.char + zSchema = C.CString(schema) + defer C.free(unsafe.Pointer(zSchema)) + + var sz C.sqlite3_int64 + ptr := C.sqlite3_serialize(c.db, zSchema, &sz, 0) + if ptr == nil { + return nil + } + defer C.sqlite3_free(unsafe.Pointer(ptr)) + return C.GoBytes(unsafe.Pointer(ptr), C.int(sz)) +} + +// Deserialize causes the connection to disconnect from the current database +// and then re-open as an in-memory database based on the contents of the +// byte slice. If deserelization fails, error will contain the return code +// of the underlying SQLite API call. +// +// See https://www.sqlite.org/c3ref/deserialize.html +func (c *SQLiteConn) Deserialize(b []byte, schema string) error { + if schema == "" { + schema = "main" + } + var zSchema *C.char + zSchema = C.CString(schema) + defer C.free(unsafe.Pointer(zSchema)) + + rc := C.sqlite3_deserialize(c.db, zSchema, + (*C.uint8_t)(unsafe.Pointer(&b[0])), + C.sqlite3_int64(len(b)), C.sqlite3_int64(len(b)), 0) + if rc != 0 { + return fmt.Errorf("deserialize failed with return %v", rc) + } + return nil +} + // Open database and return a new connection. // // A pragma can take either zero or one argument. diff --git a/sqlite3_test.go b/sqlite3_test.go index 878ec495..e5660c58 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -888,6 +888,111 @@ func TestTransaction(t *testing.T) { } } +func TestSerialize(t *testing.T) { + d := SQLiteDriver{} + + srcConn, err := d.Open(":memory:") + if err != nil { + t.Fatal("failed to get database connection:", err) + } + defer srcConn.Close() + sqlite3conn := srcConn.(*SQLiteConn) + + _, err = sqlite3conn.Exec(`CREATE TABLE foo (name string)`, nil) + if err != nil { + t.Fatal("failed to create table:", err) + } + _, err = sqlite3conn.Exec(`INSERT INTO foo(name) VALUES("alice")`, nil) + if err != nil { + t.Fatal("failed to insert record:", err) + } + + // Serialize the database to a file + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + if err := ioutil.WriteFile(tempFilename, sqlite3conn.Serialize(""), 0644); err != nil { + t.Fatalf("failed to write serialized database to disk") + } + + // Open the SQLite3 file, and test that contents are as expected. + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("failed to open database:", err) + } + defer db.Close() + + rows, err := db.Query(`SELECT * FROM foo`) + if err != nil { + t.Fatal("failed to query database:", err) + } + defer rows.Close() + + rows.Next() + + var name string + rows.Scan(&name) + if exp, got := name, "alice"; exp != got { + t.Errorf("Expected %s for fetched result, but got %s:", exp, got) + } +} + +func TestDeserialize(t *testing.T) { + var sqlite3conn *SQLiteConn + d := SQLiteDriver{} + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + // Create source database on disk. + conn, err := d.Open(tempFilename) + if err != nil { + t.Fatal("failed to open on-disk database:", err) + } + defer conn.Close() + sqlite3conn = conn.(*SQLiteConn) + _, err = sqlite3conn.Exec(`CREATE TABLE foo (name string)`, nil) + if err != nil { + t.Fatal("failed to create table:", err) + } + _, err = sqlite3conn.Exec(`INSERT INTO foo(name) VALUES("alice")`, nil) + if err != nil { + t.Fatal("failed to insert record:", err) + } + conn.Close() + + // Read database file bytes from disk. + b, err := ioutil.ReadFile(tempFilename) + if err != nil { + t.Fatal("failed to read database file on disk", err) + } + + // Deserialize file contents into memory. + conn, err = d.Open(":memory:") + if err != nil { + t.Fatal("failed to open in-memory database:", err) + } + sqlite3conn = conn.(*SQLiteConn) + defer conn.Close() + if err := sqlite3conn.Deserialize(b, ""); err != nil { + t.Fatal("failed to deserialize database", err) + } + + // Check database contents are as expected. + rows, err := sqlite3conn.Query(`SELECT * FROM foo`, nil) + if err != nil { + t.Fatal("failed to query database:", err) + } + if len(rows.Columns()) != 1 { + t.Fatal("incorrect number of columns returned:", len(rows.Columns())) + } + values := make([]driver.Value, 1) + rows.Next(values) + if v, ok := values[0].(string); !ok { + t.Fatalf("wrong type for value: %T", v) + } else if v != "alice" { + t.Fatal("wrong value returned", v) + } +} + func TestWAL(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename)