diff --git a/_example/mod_regexp/Makefile b/_example/mod_regexp/Makefile index 97b1e0f3..1ef69a6f 100644 --- a/_example/mod_regexp/Makefile +++ b/_example/mod_regexp/Makefile @@ -1,22 +1,27 @@ ifeq ($(OS),Windows_NT) EXE=extension.exe -EXT=sqlite3_mod_regexp.dll +LIB_EXT=dll RM=cmd /c del LDFLAG= else EXE=extension -EXT=sqlite3_mod_regexp.so -RM=rm +ifeq ($(shell uname -s),Darwin) +LIB_EXT=dylib +else +LIB_EXT=so +endif +RM=rm -f LDFLAG=-fPIC endif +LIB=sqlite3_mod_regexp.$(LIB_EXT) -all : $(EXE) $(EXT) +all : $(EXE) $(LIB) $(EXE) : extension.go go build $< -$(EXT) : sqlite3_mod_regexp.c +$(LIB) : sqlite3_mod_regexp.c gcc $(LDFLAG) -shared -o $@ $< -lsqlite3 -lpcre clean : - @-$(RM) $(EXE) $(EXT) + @-$(RM) $(EXE) $(LIB) diff --git a/_example/mod_vtable/Makefile b/_example/mod_vtable/Makefile index cdd4853d..f65a0042 100644 --- a/_example/mod_vtable/Makefile +++ b/_example/mod_vtable/Makefile @@ -1,24 +1,29 @@ ifeq ($(OS),Windows_NT) EXE=extension.exe -EXT=sqlite3_mod_vtable.dll +LIB_EXT=dll RM=cmd /c del LIBCURL=-lcurldll LDFLAG= else EXE=extension -EXT=sqlite3_mod_vtable.so -RM=rm +ifeq ($(shell uname -s),Darwin) +LIB_EXT=dylib +else +LIB_EXT=so +endif +RM=rm -f LDFLAG=-fPIC LIBCURL=-lcurl endif +LIB=sqlite3_mod_vtable.$(LIB_EXT) -all : $(EXE) $(EXT) +all : $(EXE) $(LIB) $(EXE) : extension.go go build $< -$(EXT) : sqlite3_mod_vtable.cc +$(LIB) : sqlite3_mod_vtable.cc g++ $(LDFLAG) -shared -o $@ $< -lsqlite3 $(LIBCURL) clean : - @-$(RM) $(EXE) $(EXT) + @-$(RM) $(EXE) $(LIB) diff --git a/_example/mod_vtable/sqlite3_mod_vtable.cc b/_example/mod_vtable/sqlite3_mod_vtable.cc index 5bd4e66f..4caf4842 100644 --- a/_example/mod_vtable/sqlite3_mod_vtable.cc +++ b/_example/mod_vtable/sqlite3_mod_vtable.cc @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include "picojson.h" diff --git a/sqlite3_load_extension.go b/sqlite3_load_extension.go index 23c5d31c..e6c50f28 100644 --- a/sqlite3_load_extension.go +++ b/sqlite3_load_extension.go @@ -28,12 +28,9 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error { } for _, extension := range extensions { - cext := C.CString(extension) - defer C.free(unsafe.Pointer(cext)) - rv = C.sqlite3_load_extension(c.db, cext, nil, nil) - if rv != C.SQLITE_OK { + if err := c.loadExtension(extension, nil); err != nil { C.sqlite3_enable_load_extension(c.db, 0) - return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) + return err } } @@ -41,6 +38,7 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error { if rv != C.SQLITE_OK { return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) } + return nil } @@ -51,19 +49,35 @@ func (c *SQLiteConn) LoadExtension(lib string, entry string) error { return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) } - clib := C.CString(lib) - defer C.free(unsafe.Pointer(clib)) - centry := C.CString(entry) - defer C.free(unsafe.Pointer(centry)) + if err := c.loadExtension(lib, &entry); err != nil { + C.sqlite3_enable_load_extension(c.db, 0) + return err + } - rv = C.sqlite3_load_extension(c.db, clib, centry, nil) + rv = C.sqlite3_enable_load_extension(c.db, 0) if rv != C.SQLITE_OK { return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) } - rv = C.sqlite3_enable_load_extension(c.db, 0) + return nil +} + +func (c *SQLiteConn) loadExtension(lib string, entry *string) error { + clib := C.CString(lib) + defer C.free(unsafe.Pointer(clib)) + + var centry *C.char + if entry != nil { + centry := C.CString(*entry) + defer C.free(unsafe.Pointer(centry)) + } + + var errMsg *C.char + defer C.sqlite3_free(unsafe.Pointer(errMsg)) + + rv := C.sqlite3_load_extension(c.db, clib, centry, &errMsg) if rv != C.SQLITE_OK { - return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) + return errors.New(C.GoString(errMsg)) } return nil diff --git a/sqlite3_load_extension_test.go b/sqlite3_load_extension_test.go new file mode 100644 index 00000000..97b11233 --- /dev/null +++ b/sqlite3_load_extension_test.go @@ -0,0 +1,63 @@ +// Copyright (C) 2019 Yasuhiro Matsumoto . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// +build !sqlite_omit_load_extension + +package sqlite3 + +import ( + "database/sql" + "testing" +) + +func TestExtensionsError(t *testing.T) { + sql.Register("sqlite3_TestExtensionsError", + &SQLiteDriver{ + Extensions: []string{ + "foobar", + }, + }, + ) + + db, err := sql.Open("sqlite3_TestExtensionsError", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Ping() + if err == nil { + t.Fatal("expected error loading non-existent extension") + } + + if err.Error() == "not an error" { + t.Fatal("expected error from sqlite3_enable_load_extension to be returned") + } +} + +func TestLoadExtensionError(t *testing.T) { + sql.Register("sqlite3_TestLoadExtensionError", + &SQLiteDriver{ + ConnectHook: func(c *SQLiteConn) error { + return c.LoadExtension("foobar", "") + }, + }, + ) + + db, err := sql.Open("sqlite3_TestLoadExtensionError", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Ping() + if err == nil { + t.Fatal("expected error loading non-existent extension") + } + + if err.Error() == "not an error" { + t.Fatal("expected error from sqlite3_enable_load_extension to be returned") + } +}