Skip to content

Commit

Permalink
fix CBOR protocol WS issue
Browse files Browse the repository at this point in the history
  • Loading branch information
remade committed Sep 12, 2024
1 parent 95ef10b commit 4884a2e
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 133 deletions.
23 changes: 12 additions & 11 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ func New(connectionURL string) (*DB, error) {
return nil, err
}

baseURL := fmt.Sprintf("%s://%s", u.Scheme, u.Host)
scheme := u.Scheme

newParams := connection.NewConnectionParams{
Expand All @@ -37,23 +36,25 @@ func New(connectionURL string) (*DB, error) {
return nil, fmt.Errorf("invalid connection url")
}

err := conn.Connect()
err = conn.Connect()
if err != nil {
return nil, err
}

// Only Websocket exposes live fields, try to connect to ws
liveconn := connection.NewWebSocket(newParams)
//liveScheme := "ws"
liveScheme := "ws"
if scheme == "wss" || scheme == "https" {
liveScheme = "wss"
}
//liveconn, err = liveconn.Connect(fmt.Sprintf("%s://%s", liveScheme, u.Host))
newLiveConnParams := newParams
newLiveConnParams.BaseURL = fmt.Sprintf("%s://%s", liveScheme, u.Host)
liveconn := connection.NewWebSocket(newParams)
err = liveconn.Connect()
if err != nil {
return nil, err
}

return &DB{conn: connect, liveHandler: liveconn}, nil
return &DB{conn: conn, liveHandler: liveconn}, nil
}

// --------------------------------------------------
Expand Down Expand Up @@ -94,12 +95,12 @@ func (db *DB) Authenticate(token string) (interface{}, error) {
return db.send("authenticate", token)
}

func (db *DB) Let(key string, val interface{}) (interface{}, error) {
return db.send("let", key, val)
func (db *DB) Let(key string, val interface{}) error {
return db.conn.Let(key, val)
}

func (db *DB) Unset(key string, val interface{}) (interface{}, error) {
return db.send("unset", key, val)
func (db *DB) Unset(key string) error {
return db.conn.Unset(key)
}

// Query is a convenient method for sending a query to the database.
Expand Down Expand Up @@ -150,7 +151,7 @@ func (db *DB) Live(table string, diff bool) (string, error) {
}

func (db *DB) Kill(liveQueryID string) (interface{}, error) {
return db.send("kill", liveQueryID)
return db.liveHandler.Kill(liveQueryID)
}

// LiveNotifications returns a channel for live query.
Expand Down
2 changes: 1 addition & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (s *SurrealDBTestSuite) createTestDB() *surrealdb.DB {
// openConnection opens a new connection to the database
func (s *SurrealDBTestSuite) openConnection(url string, impl connection.Connection) *surrealdb.DB {
require.NotNil(s.T(), impl)
db, err := surrealdb.New(url, "")
db, err := surrealdb.New(url)
s.Require().NoError(err)
return db
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ package connection

import (
"github.com/surrealdb/surrealdb.go/internal/codec"
"github.com/surrealdb/surrealdb.go/pkg/model"
)

type Connection interface {
Connect() error
Close() error
Send(method string, params []interface{}) (interface{}, error)
Use(namespace string, database string) error
SignIn(auth model.Auth) (string, error)
Let(key string, value interface{}) error
Unset(key string) error
}

type LiveHandler interface {
LiveNotifications(id string) (chan Notification, error)
Kill(id string) (interface{}, error)
}

type BaseConnection struct {
Expand Down
78 changes: 34 additions & 44 deletions pkg/connection/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@ import (
"bytes"
"fmt"
"github.com/surrealdb/surrealdb.go/internal/rand"
"github.com/surrealdb/surrealdb.go/pkg/model"
"io"
"log"
"net/http"
"sync"
"time"
)

type Http struct {
BaseConnection

httpClient *http.Client

namespace string
database string
token string
variables sync.Map
}

func NewHttp(p NewConnectionParams) *Http {
con := Http{
BaseConnection: BaseConnection{
marshaler: p.Marshaler,
unmarshaler: p.Unmarshaler,
baseURL: p.BaseURL,
},
}

Expand Down Expand Up @@ -82,10 +80,6 @@ func (h *Http) Send(method string, params []interface{}) (interface{}, error) {
return nil, fmt.Errorf("connection host not set")
}

if h.namespace == "" || h.database == "" {
return nil, fmt.Errorf("namespace or database or both are not set")
}

rpcReq := &RPCRequest{
ID: rand.String(RequestIDLength),
Method: method,
Expand All @@ -101,16 +95,20 @@ func (h *Http) Send(method string, params []interface{}) (interface{}, error) {
req.Header.Set("Accept", "application/cbor")
req.Header.Set("Content-Type", "application/cbor")

if h.namespace != "" {
req.Header.Set("Surreal-NS", h.namespace)
if namespace, ok := h.variables.Load("namespace"); ok {
req.Header.Set("Surreal-NS", namespace.(string))
} else {
return nil, fmt.Errorf("namespace or database or both are not set")
}

if h.database != "" {
req.Header.Set("Surreal-DB", h.database)
if database, ok := h.variables.Load("database"); ok {
req.Header.Set("Surreal-DB", database.(string))
} else {
return nil, fmt.Errorf("namespace or database or both are not set")
}

if h.token != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", h.token))
if token, ok := h.variables.Load("token"); ok {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
}

resp, err := h.MakeRequest(req)
Expand All @@ -121,6 +119,19 @@ func (h *Http) Send(method string, params []interface{}) (interface{}, error) {
var rpcResponse RPCResponse
err = h.unmarshaler.Unmarshal(resp, &rpcResponse)

// Manage auth tokens
switch method {
case "signin", "signup":
h.variables.Store("token", rpcResponse.Result)
break
case "authenticate":
h.variables.Store("token", params[0])
break
case "invalidate":
h.variables.Delete("token")
break
}

return rpcResponse.Result, nil
}

Expand All @@ -139,39 +150,18 @@ func (h *Http) MakeRequest(req *http.Request) ([]byte, error) {
}

func (h *Http) Use(namespace string, database string) error {
h.namespace = namespace
h.database = database
h.variables.Store("namespace", namespace)
h.variables.Store("database", database)

return nil
}

func (h *Http) SignIn(auth model.Auth) (string, error) {
resp, err := h.Send("signin", []interface{}{auth})
if err != nil {
return "", err
}

h.token = resp.(string)

return resp.(string), nil
}

func (h *Http) signup() {

}

func (h *Http) let() {

}

func (h *Http) unset() {

}

func (h *Http) authenticate() {

func (h *Http) Let(key string, value interface{}) error {
h.variables.Store(key, value)
return nil
}

func (h *Http) invalidate() {

func (h *Http) Unset(key string) error {
h.variables.Delete(key)
return nil
}
53 changes: 32 additions & 21 deletions pkg/connection/http_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package connection

import (
"bytes"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/surrealdb/surrealdb.go/pkg/model"
"io/ioutil"
"net/http"
"testing"
"time"
Expand All @@ -24,26 +26,35 @@ func NewTestClient(fn RoundTripFunc) *http.Client {
}
}

func TestEngine_MakeRequest(t *testing.T) {
httpClient := NewTestClient(func(req *http.Request) *http.Response {
assert.Equal(t, req.URL.String(), "http://test.surreal/rpc")

return &http.Response{
StatusCode: 400,
// Send response to be tested
Body: ioutil.NopCloser(bytes.NewBufferString(`OK`)),
// Must be set to non-nil value or it panics
Header: make(http.Header),
}
})

p := NewConnectionParams{
BaseURL: "http://test.surreal",
Marshaler: model.CborMarshaler{},
Unmarshaler: model.CborUnmashaler{},
}
httpEngine := NewHttp(p)
httpEngine.SetHttpClient(httpClient)

req, _ := http.NewRequest(http.MethodGet, "http://test.surreal/rpc", nil)
resp, err := httpEngine.MakeRequest(req)
assert.Error(t, err, "should return error for status code 400")

fmt.Println(resp)
}

func TestEngine_HttpMakeRequest(t *testing.T) {
//httpClient := NewTestClient(func(req *http.Request) *http.Response {
// assert.Equal(t, req.URL.String(), "http://test.surreal/rpc")
//
// return &http.Response{
// StatusCode: 400,
// // Send response to be tested
// Body: ioutil.NopCloser(bytes.NewBufferString(`OK`)),
// // Must be set to non-nil value or it panics
// Header: make(http.Header),
// }
//})
//
//httpEngine := (NewHttp(p)).(*Http)
//httpEngine.SetHttpClient(httpClient)
//
//resp, err := httpEngine.MakeRequest(http.MethodGet, "http://test.surreal/rpc", nil)
//assert.Error(t, err, "should return error for status code 400")
//
//fmt.Println(resp)

p := NewConnectionParams{
BaseURL: "http://localhost:8000",
Expand All @@ -54,10 +65,10 @@ func TestEngine_HttpMakeRequest(t *testing.T) {
err := con.Use("test", "test")
assert.Nil(t, err, "no error returned when setting namespace and database")

con, err = con.Connect("http://127.0.0.1:8000")
err = con.Connect() //implement a is ready
assert.Nil(t, err, "no error returned when initializing engine connection")

token, err := con.SignIn(model.Auth{Username: "pass", Password: "pass"})
token, err := con.Send("signin", []interface{}{model.Auth{Username: "pass", Password: "pass"}})
assert.Nil(t, err, "no error returned when signing in")
fmt.Println(token)

Expand Down
2 changes: 1 addition & 1 deletion pkg/connection/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type RPCNotification struct {
type RPCFunction string

var (
FUse RPCFunction = "use"
Use RPCFunction = "use"
Info RPCFunction = "info"
SignUp RPCFunction = "signup"
SignIn RPCFunction = "signin"
Expand Down
Loading

0 comments on commit 4884a2e

Please sign in to comment.