diff --git a/.golangci.yml b/.golangci.yml index 95db2d0..5a5d214 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -22,6 +22,7 @@ linters-settings: - ifElseChain - octalLiteral - whyNoLint + - dupSubExpr # https://github.com/go-critic/go-critic/issues/897#issuecomment-568896534 gocyclo: min-complexity: 15 goimports: diff --git a/README.md b/README.md index b3624a6..c915324 100644 --- a/README.md +++ b/README.md @@ -216,6 +216,20 @@ or for a secure connection db, err := surrealdb.New("https://localhost:8000") ``` +### Using SurrealKV and Memory +SurrealKV and Memory also do not support live notifications at this time. This would be updated in the next +release. + +For Surreal KV +```go +db, err := surrealdb.New("surrealkv://path/to/dbfile.kv") +``` + +For Memory +```go +db, err := surrealdb.New("mem://") +db, err := surrealdb.New("memory://") +``` ## Data Models This package facilitates communication between client and the backend service using the Concise diff --git a/db.go b/db.go index 8ce17ab..ec62e1e 100644 --- a/db.go +++ b/db.go @@ -49,6 +49,8 @@ func New(connectionURL string) (*DB, error) { con = connection.NewHTTPConnection(newParams) } else if scheme == "ws" || scheme == "wss" { con = connection.NewWebSocketConnection(newParams) + } else if scheme == "memory" || scheme == "mem" || scheme == "surrealkv" { + con = connection.NewEmbeddedConnection(newParams) } else { return nil, fmt.Errorf("invalid connection url") } diff --git a/db_test.go b/db_test.go index 41d1473..5635c42 100644 --- a/db_test.go +++ b/db_test.go @@ -7,8 +7,9 @@ import ( "testing" "time" - "github.com/stretchr/testify/suite" "github.com/surrealdb/surrealdb.go" + + "github.com/stretchr/testify/suite" "github.com/surrealdb/surrealdb.go/pkg/connection" "github.com/surrealdb/surrealdb.go/pkg/models" ) diff --git a/libsrc/libsurrealdb_c.dylib b/libsrc/libsurrealdb_c.dylib new file mode 100755 index 0000000..73885b1 Binary files /dev/null and b/libsrc/libsurrealdb_c.dylib differ diff --git a/libsrc/surrealdb.h b/libsrc/surrealdb.h new file mode 100644 index 0000000..f27490e --- /dev/null +++ b/libsrc/surrealdb.h @@ -0,0 +1,440 @@ +#include +#include +#include +#include + +#define sr_SR_NONE 0 + +#define sr_SR_CLOSED -1 + +#define sr_SR_ERROR -2 + +#define sr_SR_FATAL -3 + +typedef enum sr_action { + SR_ACTION_CREATE, + SR_ACTION_UPDATE, + SR_ACTION_DELETE, +} sr_action; + +typedef struct sr_opaque_object_internal_t sr_opaque_object_internal_t; + +typedef struct sr_RpcStream sr_RpcStream; + +/** + * may be sent across threads, but must not be aliased + */ +typedef struct sr_stream_t sr_stream_t; + +/** + * The object representing a Surreal connection + * + * It is safe to be referenced from multiple threads + * If any operation, on any thread returns SR_FATAL then the connection is poisoned and must not be used again. + * (use will cause the program to abort) + * + * should be freed with sr_surreal_disconnect + */ +typedef struct sr_surreal_t sr_surreal_t; + +/** + * The object representing a Surreal connection + * + * It is safe to be referenced from multiple threads + * If any operation, on any thread returns SR_FATAL then the connection is poisoned and must not be used again. + * (use will cause the program to abort) + * + * should be freed with sr_surreal_disconnect + */ +typedef struct sr_surreal_rpc_t sr_surreal_rpc_t; + +typedef char *sr_string_t; + +typedef struct sr_object_t { + struct sr_opaque_object_internal_t *_0; +} sr_object_t; + +typedef enum sr_number_t_Tag { + SR_NUMBER_INT, + SR_NUMBER_FLOAT, +} sr_number_t_Tag; + +typedef struct sr_number_t { + sr_number_t_Tag tag; + union { + struct { + int64_t sr_number_int; + }; + struct { + double sr_number_float; + }; + }; +} sr_number_t; + +typedef struct sr_duration_t { + uint64_t secs; + uint32_t nanos; +} sr_duration_t; + +typedef struct sr_uuid_t { + uint8_t _0[16]; +} sr_uuid_t; + +typedef struct sr_bytes_t { + uint8_t *arr; + int len; +} sr_bytes_t; + +typedef enum sr_id_t_Tag { + SR_ID_NUMBER, + SR_ID_STRING, + SR_ID_ARRAY, + SR_ID_OBJECT, +} sr_id_t_Tag; + +typedef struct sr_id_t { + sr_id_t_Tag tag; + union { + struct { + int64_t sr_id_number; + }; + struct { + sr_string_t sr_id_string; + }; + struct { + struct sr_array_t *sr_id_array; + }; + struct { + struct sr_object_t sr_id_object; + }; + }; +} sr_id_t; + +typedef struct sr_thing_t { + sr_string_t table; + struct sr_id_t id; +} sr_thing_t; + +typedef enum sr_value_t_Tag { + SR_VALUE_NONE, + SR_VALUE_NULL, + SR_VALUE_BOOL, + SR_VALUE_NUMBER, + SR_VALUE_STRAND, + SR_VALUE_DURATION, + SR_VALUE_DATETIME, + SR_VALUE_UUID, + SR_VALUE_ARRAY, + SR_VALUE_OBJECT, + SR_VALUE_BYTES, + SR_VALUE_THING, +} sr_value_t_Tag; + +typedef struct sr_value_t { + sr_value_t_Tag tag; + union { + struct { + bool sr_value_bool; + }; + struct { + struct sr_number_t sr_value_number; + }; + struct { + sr_string_t sr_value_strand; + }; + struct { + struct sr_duration_t sr_value_duration; + }; + struct { + sr_string_t sr_value_datetime; + }; + struct { + struct sr_uuid_t sr_value_uuid; + }; + struct { + struct sr_array_t *sr_value_array; + }; + struct { + struct sr_object_t sr_value_object; + }; + struct { + struct sr_bytes_t sr_value_bytes; + }; + struct { + struct sr_thing_t sr_value_thing; + }; + }; +} sr_value_t; + +typedef struct sr_array_t { + struct sr_value_t *arr; + int len; +} sr_array_t; + +/** + * when code = 0 there is no error + */ +typedef struct sr_SurrealError { + int code; + sr_string_t msg; +} sr_SurrealError; + +typedef struct sr_arr_res_t { + struct sr_array_t ok; + struct sr_SurrealError err; +} sr_arr_res_t; + +typedef struct sr_option_t { + bool strict; + uint8_t query_timeout; + uint8_t transaction_timeout; +} sr_option_t; + +typedef struct sr_notification_t { + struct sr_uuid_t query_id; + enum sr_action action; + struct sr_value_t data; +} sr_notification_t; + +/** + * connects to a local, remote, or embedded database + * + * if any function returns SR_FATAL, this must not be used (except to drop) (TODO: check this is safe) doing so will cause the program to abort + * + * # Examples + * + * ```c + * sr_string_t err; + * sr_surreal_t *db; + * + * // connect to in-memory instance + * if (sr_connect(&err, &db, "mem://") < 0) { + * printf("error connecting to db: %s\n", err); + * return 1; + * } + * + * // connect to surrealkv file + * if (sr_connect(&err, &db, "surrealkv://test.skv") < 0) { + * printf("error connecting to db: %s\n", err); + * return 1; + * } + * + * // connect to surrealdb server + * if (sr_connect(&err, &db, "wss://localhost:8000") < 0) { + * printf("error connecting to db: %s\n", err); + * return 1; + * } + * + * sr_surreal_disconnect(db); + * ``` + */ +int sr_connect(sr_string_t *err_ptr, + struct sr_surreal_t **surreal_ptr, + const char *endpoint); + +/** + * disconnect a database connection + * note: the Surreal object must not be used after this function has been called + * any object allocations will still be valid, and should be freed, using the appropriate function + * TODO: check if Stream can be freed after disconnection because of rt + * + * # Examples + * + * ```c + * sr_surreal_t *db; + * // connect + * disconnect(db); + * ``` + */ +void sr_surreal_disconnect(struct sr_surreal_t *db); + +/** + * create a record + * + */ +int sr_create(const struct sr_surreal_t *db, + sr_string_t *err_ptr, + struct sr_object_t **res_ptr, + const char *resource, + const struct sr_object_t *content); + +/** + * make a live selection + * if successful sets *stream_ptr to be an exclusive reference to an opaque Stream object + * which can be moved accross threads but not aliased + * + * # Examples + * + * sr_stream_t *stream; + * if (sr_select_live(db, &err, &stream, "foo") < 0) + * { + * printf("%s", err); + * return 1; + * } + * + * sr_notification_t not ; + * if (sr_stream_next(stream, ¬ ) > 0) + * { + * sr_print_notification(¬ ); + * } + * sr_stream_kill(stream); + */ +int sr_select_live(const struct sr_surreal_t *db, + sr_string_t *err_ptr, + struct sr_stream_t **stream_ptr, + const char *resource); + +int sr_query(const struct sr_surreal_t *db, + sr_string_t *err_ptr, + struct sr_arr_res_t **res_ptr, + const char *query, + const struct sr_object_t *vars); + +/** + * select a resource + * + * can be used to select everything from a table or a single record + * writes values to *res_ptr, and returns number of values + * result values are allocated by Surreal and must be freed with sr_free_arr + * + * # Examples + * + * ```c + * sr_surreal_t *db; + * sr_string_t err; + * sr_value_t *foos; + * int len = sr_select(db, &err, &foos, "foo"); + * if (len < 0) { + * printf("%s", err); + * return 1; + * } + * ``` + * for (int i = 0; i < len; i++) + * { + * sr_value_print(&foos[i]); + * } + * sr_free_arr(foos, len); + */ +int sr_select(const struct sr_surreal_t *db, + sr_string_t *err_ptr, + struct sr_value_t **res_ptr, + const char *resource); + +/** + * select database + * NOTE: namespace must be selected first with sr_use_ns + * + * # Examples + * ```c + * sr_surreal_t *db; + * sr_string_t err; + * if (sr_use_db(db, &err, "test") < 0) + * { + * printf("%s", err); + * return 1; + * } + * ``` + */ +int sr_use_db(const struct sr_surreal_t *db, sr_string_t *err_ptr, const char *query); + +/** + * select namespace + * NOTE: database must be selected before use with sr_use_db + * + * # Examples + * ```c + * sr_surreal_t *db; + * sr_string_t err; + * if (sr_use_ns(db, &err, "test") < 0) + * { + * printf("%s", err); + * return 1; + * } + * ``` + */ +int sr_use_ns(const struct sr_surreal_t *db, sr_string_t *err_ptr, const char *query); + +/** + * returns the db version + * NOTE: version is allocated in Surreal and must be freed with sr_free_string + * # Examples + * ```c + * sr_surreal_t *db; + * sr_string_t err; + * sr_string_t ver; + * + * if (sr_version(db, &err, &ver) < 0) + * { + * printf("%s", err); + * return 1; + * } + * printf("%s", ver); + * sr_free_string(ver); + * ``` + */ +int sr_version(const struct sr_surreal_t *db, sr_string_t *err_ptr, sr_string_t *res_ptr); + +int sr_surreal_rpc_new(sr_string_t *err_ptr, + struct sr_surreal_rpc_t **surreal_ptr, + const char *endpoint, + struct sr_option_t options); + +/** + * execute rpc + * + * free result with sr_free_byte_arr + */ +int sr_surreal_rpc_execute(const struct sr_surreal_rpc_t *self, + sr_string_t *err_ptr, + uint8_t **res_ptr, + const uint8_t *ptr, + int len); + +int sr_surreal_rpc_notifications(const struct sr_surreal_rpc_t *self, + sr_string_t *err_ptr, + struct sr_RpcStream **stream_ptr); + +void sr_surreal_rpc_free(struct sr_surreal_rpc_t *ctx); + +void sr_free_arr(struct sr_value_t *ptr, int len); + +void sr_free_bytes(struct sr_bytes_t bytes); + +void sr_free_byte_arr(uint8_t *ptr, int len); + +void sr_print_notification(const struct sr_notification_t *notification); + +const struct sr_value_t *sr_object_get(const struct sr_object_t *obj, const char *key); + +struct sr_object_t sr_object_new(void); + +void sr_object_insert(struct sr_object_t *obj, const char *key, const struct sr_value_t *value); + +void sr_object_insert_str(struct sr_object_t *obj, const char *key, const char *value); + +void sr_object_insert_int(struct sr_object_t *obj, const char *key, int value); + +void sr_object_insert_float(struct sr_object_t *obj, const char *key, float value); + +void sr_object_insert_double(struct sr_object_t *obj, const char *key, double value); + +void sr_free_object(struct sr_object_t obj); + +void sr_free_arr_res(struct sr_arr_res_t res); + +void sr_free_arr_res_arr(struct sr_arr_res_t *ptr, int len); + +/** + * blocks until next item is recieved on stream + * will return 1 and write notification to notification_ptr is recieved + * will return SR_NONE if the stream is closed + */ +int sr_stream_next(struct sr_stream_t *self, struct sr_notification_t *notification_ptr); + +void sr_stream_kill(struct sr_stream_t *stream); + +void sr_free_string(sr_string_t string); + +void sr_value_print(const struct sr_value_t *val); + +bool sr_value_eq(const struct sr_value_t *lhs, const struct sr_value_t *rhs); diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index 1a65928..46387e9 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -91,6 +91,14 @@ func (bc *BaseConnection) createNotificationChannel(liveQueryID string) (chan No return ch, nil } +func (bc *BaseConnection) getNotificationChannel(id string) (chan Notification, bool) { + bc.notificationChannelsLock.RLock() + defer bc.notificationChannelsLock.RUnlock() + ch, ok := bc.notificationChannels[id] + + return ch, ok +} + func (bc *BaseConnection) removeResponseChannel(id string) { bc.responseChannelsLock.Lock() defer bc.responseChannelsLock.Unlock() @@ -117,14 +125,6 @@ func (bc *BaseConnection) getErrorChannel(id string) (chan error, bool) { return ch, ok } -func (bc *BaseConnection) getLiveChannel(id string) (chan Notification, bool) { - bc.notificationChannelsLock.RLock() - defer bc.notificationChannelsLock.RUnlock() - ch, ok := bc.notificationChannels[id] - - return ch, ok -} - func (bc *BaseConnection) preConnectionChecks() error { if bc.baseURL == "" { return constants.ErrNoBaseURL diff --git a/pkg/connection/connection_test.go b/pkg/connection/connection_test.go index c03a263..4862258 100644 --- a/pkg/connection/connection_test.go +++ b/pkg/connection/connection_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/suite" + "github.com/surrealdb/surrealdb.go/pkg/constants" "github.com/surrealdb/surrealdb.go/pkg/logger" "github.com/surrealdb/surrealdb.go/pkg/models" diff --git a/pkg/connection/embedded.go b/pkg/connection/embedded.go new file mode 100644 index 0000000..052a6c5 --- /dev/null +++ b/pkg/connection/embedded.go @@ -0,0 +1,143 @@ +package connection + +/* +#cgo LDFLAGS: -L./../../libsrc -lsurrealdb_c +#include +#include "./../../libsrc/surrealdb.h" +*/ +import "C" + +import ( + "fmt" + "net/url" + "sync" + "unsafe" + + "github.com/fxamacker/cbor/v2" + "github.com/surrealdb/surrealdb.go/internal/codec" + "github.com/surrealdb/surrealdb.go/internal/rand" + "github.com/surrealdb/surrealdb.go/pkg/constants" +) + +type EmbeddedConnection struct { + BaseConnection + + variables sync.Map + + surrealRPC *C.sr_surreal_rpc_t + surrealStream *C.sr_RpcStream + + closeChan chan int + closeErr error +} + +func (h *EmbeddedConnection) GetUnmarshaler() codec.Unmarshaler { + return h.unmarshaler +} + +func NewEmbeddedConnection(p NewConnectionParams) *EmbeddedConnection { + con := EmbeddedConnection{ + BaseConnection: BaseConnection{ + baseURL: p.BaseURL, + + marshaler: p.Marshaler, + unmarshaler: p.Unmarshaler, + + responseChannels: make(map[string]chan []byte), + notificationChannels: make(map[string]chan Notification), + }, + + closeChan: make(chan int), + } + + return &con +} + +func (h *EmbeddedConnection) Connect() error { + err := h.preConnectionChecks() + if err != nil { + return err + } + + var cErr C.sr_string_t + defer C.sr_free_string(cErr) + + cEndpoint := C.CString(h.baseURL) + u, err := url.ParseRequestURI(h.baseURL) + if err != nil { + return err + } + if u.Scheme == "mem" || u.Scheme == "memory" { + cEndpoint = C.CString("memory") + } + defer C.free(unsafe.Pointer(cEndpoint)) + + var surrealOptions C.sr_option_t + var surrealRPC *C.sr_surreal_rpc_t + if ret := C.sr_surreal_rpc_new(&cErr, &surrealRPC, cEndpoint, surrealOptions); ret < 0 { + return fmt.Errorf("error initiating rpc. %v. return %v", C.GoString(cErr), ret) + } + h.surrealRPC = surrealRPC + + var cStream *C.sr_RpcStream + if ret := C.sr_surreal_rpc_notifications(h.surrealRPC, &cErr, &cStream); ret < 0 { + return fmt.Errorf("error initiating rpc. %v. return %v", C.GoString(cErr), ret) + } + h.surrealStream = cStream + + return nil +} + +func (h *EmbeddedConnection) Close() error { + C.sr_surreal_rpc_free(h.surrealRPC) + + h.surrealRPC = nil + return nil +} + +func (h *EmbeddedConnection) Send(res interface{}, method string, params ...interface{}) error { + request := &RPCRequest{ + ID: rand.String(constants.RequestIDLength), + Method: method, + Params: params, + } + reqBody, err := h.marshaler.Marshal(request) + if err != nil { + return err + } + + var cErr C.sr_string_t + defer C.sr_free_string(cErr) + + inputPtr := (*C.uint8_t)(unsafe.Pointer(&reqBody[0])) + inputLen := C.int(len(reqBody)) + + var cRes *C.uint8_t + defer C.free(unsafe.Pointer(cRes)) + + resSize := C.sr_surreal_rpc_execute(h.surrealRPC, &cErr, &cRes, inputPtr, inputLen) + if resSize < 0 { + return fmt.Errorf("%v", C.GoString(cErr)) + } + + if res == nil { + return nil + } + + resultBytes := cbor.RawMessage(C.GoBytes(unsafe.Pointer(cRes), resSize)) + + rpcRes, _ := h.marshaler.Marshal(RPCResponse[cbor.RawMessage]{ID: request.ID, Result: &resultBytes}) + return h.unmarshaler.Unmarshal(rpcRes, res) +} + +func (h *EmbeddedConnection) Use(namespace, database string) error { + return h.Send(nil, "use", namespace, database) +} + +func (h *EmbeddedConnection) Let(key string, value interface{}) error { + return h.Send(nil, "let", key, value) +} + +func (h *EmbeddedConnection) Unset(key string) error { + return h.Send(nil, "unset", key) +} diff --git a/pkg/connection/embedded_test.go b/pkg/connection/embedded_test.go new file mode 100644 index 0000000..d0e8423 --- /dev/null +++ b/pkg/connection/embedded_test.go @@ -0,0 +1,77 @@ +package connection + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/surrealdb/surrealdb.go/pkg/models" +) + +type EmbeddedConnectionTestSuite struct { + suite.Suite + con *EmbeddedConnection + name string +} + +func TestEmbeddedConnectionTestSuite(t *testing.T) { + s := new(EmbeddedConnectionTestSuite) + s.name = "Test_Embedded_Connection" + suite.Run(t, s) +} + +// SetupSuite is called before the s starts running +func (s *EmbeddedConnectionTestSuite) SetupSuite() { + con := NewEmbeddedConnection(NewConnectionParams{ + BaseURL: "memory", + Marshaler: models.CborMarshaler{}, + Unmarshaler: models.CborUnmarshaler{}, + }) + + err := con.Connect() + s.Require().NoError(err, "no error during connection") + + s.con = con +} + +// TearDownTest is called after each test +func (s *EmbeddedConnectionTestSuite) TearDownTest() { + +} + +// TearDownSuite is called after the s has finished running +func (s *EmbeddedConnectionTestSuite) TearDownSuite() { + err := s.con.Close() + s.Require().NoError(err) +} + +func (s *EmbeddedConnectionTestSuite) TestSendRequest() { + err := s.con.Use("test", "test") + s.Require().NoError(err) + + var versionRes RPCResponse[string] + err = s.con.Send(&versionRes, "version") + s.Require().NoError(err) +} + +func (s *EmbeddedConnectionTestSuite) TestLiveAndNotification() { + err := s.con.Use("test", "test") + s.Require().NoError(err) + + var liveRes RPCResponse[models.UUID] + err = s.con.Send(&liveRes, "live", "users", false) + s.Require().NoError(err, "should not return error on live request") + + liveID := liveRes.Result.String() + defer func() { + err = s.con.Send(nil, "kill", liveID) + s.Require().NoError(err) + }() + + notifications, err := s.con.LiveNotifications(liveID) + s.Require().NoError(err) + + fmt.Println(notifications) + + // Notification reader not ready on C lib +} diff --git a/pkg/connection/http.go b/pkg/connection/http.go index b94b6a7..b56637a 100644 --- a/pkg/connection/http.go +++ b/pkg/connection/http.go @@ -87,7 +87,6 @@ func (h *HTTPConnection) Send(dest any, method string, params ...interface{}) er Params: params, } reqBody, err := h.marshaler.Marshal(request) - if err != nil { return err } @@ -137,6 +136,7 @@ func (h *HTTPConnection) Send(dest any, method string, params ...interface{}) er func (h *HTTPConnection) MakeRequest(req *http.Request) ([]byte, error) { resp, err := h.httpClient.Do(req) + if err != nil { return nil, fmt.Errorf("error making HTTP request: %w", err) } diff --git a/pkg/connection/http_test.go b/pkg/connection/http_test.go index 2981ac7..88c4fbe 100644 --- a/pkg/connection/http_test.go +++ b/pkg/connection/http_test.go @@ -3,12 +3,14 @@ package connection import ( "bytes" "context" + "encoding/base64" "io" "net/http" "testing" "github.com/stretchr/testify/suite" + "github.com/surrealdb/surrealdb.go/pkg/models" ) diff --git a/pkg/connection/ws.go b/pkg/connection/ws.go index df4b474..91b4e81 100644 --- a/pkg/connection/ws.go +++ b/pkg/connection/ws.go @@ -122,12 +122,7 @@ func (ws *WebSocketConnection) Close() error { } func (ws *WebSocketConnection) Use(namespace, database string) error { - err := ws.Send(nil, "use", namespace, database) - if err != nil { - return err - } - - return nil + return ws.Send(nil, "use", namespace, database) } func (ws *WebSocketConnection) Let(key string, value interface{}) error { @@ -285,7 +280,7 @@ func (ws *WebSocketConnection) handleResponse(res []byte) { channelID := notificationRes.Result.ID - LiveNotificationChan, ok := ws.getLiveChannel(channelID.String()) + LiveNotificationChan, ok := ws.getNotificationChannel(channelID.String()) if !ok { err := fmt.Errorf("unavailable ResponseChannel %+v", channelID.String()) ws.logger.Error(err.Error(), "result", fmt.Sprint(rpcRes.Result))