Skip to content

Commit

Permalink
feat(broadcast): added broadcast request auth
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksander-vedvik committed Jun 11, 2024
1 parent 8dba797 commit 656a7fc
Show file tree
Hide file tree
Showing 18 changed files with 203 additions and 46 deletions.
24 changes: 24 additions & 0 deletions authentication/authentication.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package authentication

import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
Expand Down Expand Up @@ -151,3 +152,26 @@ func (ec *EllipticCurve) EncodeMsg(msg any) ([]byte, error) {
}
return encodedMsg.Bytes(), nil*/
}

func encodeMsg(msg any) ([]byte, error) {
return []byte(fmt.Sprintf("%v", msg)), nil
}

func Verify(pemEncodedPub string, signature, digest []byte, msg any) (bool, error) {
encodedMsg, err := encodeMsg(msg)
if err != nil {
return false, err
}
ec := New(elliptic.P256())
h := sha256.Sum256(encodedMsg)
hash := h[:]
if !bytes.Equal(hash, digest) {
return false, fmt.Errorf("wrong digest")
}
pubKey, err := ec.DecodePublic(pemEncodedPub)
if err != nil {
return false, err
}
ok := ecdsa.VerifyASN1(pubKey, hash, signature)
return ok, nil
}
11 changes: 10 additions & 1 deletion broadcast.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,13 @@ type BroadcastMetadata struct {
OriginAddr string // address of the origin
OriginMethod string // the first method called by the origin
Method string // the current method
Digest []byte // digest of original message sent by client
Timestamp time.Time // timestamp in seconds when the broadcast request was issued by the client/server
ShardID uint16 // ID of the shard handling the broadcast request
MachineID uint16 // ID of the client/server that issued the broadcast request
SequenceNo uint32 // sequence number of the broadcast request from that particular client/server. Will roll over when reaching max.
OriginDigest []byte
OriginSignature []byte
OriginPubKey string
}

func newBroadcastMetadata(md *ordering.Metadata) BroadcastMetadata {
Expand All @@ -184,10 +186,17 @@ func newBroadcastMetadata(md *ordering.Metadata) BroadcastMetadata {
SenderAddr: md.BroadcastMsg.SenderAddr,
OriginAddr: md.BroadcastMsg.OriginAddr,
OriginMethod: md.BroadcastMsg.OriginMethod,
OriginDigest: md.BroadcastMsg.OriginDigest,
OriginSignature: md.BroadcastMsg.OriginSignature,
OriginPubKey: md.BroadcastMsg.OriginPubKey,
Method: m,
Timestamp: broadcast.Epoch().Add(time.Duration(timestamp) * time.Second),
ShardID: shardID,
MachineID: machineID,
SequenceNo: sequenceNo,
}
}

func (md BroadcastMetadata) Verify(msg protoreflect.ProtoMessage) (bool, error) {
return authentication.Verify(md.OriginPubKey, md.OriginSignature, md.OriginDigest, msg)
}
2 changes: 1 addition & 1 deletion broadcast/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ const (
ServerOriginAddr string = "server"
)

type ServerHandler func(ctx context.Context, in protoreflect.ProtoMessage, broadcastID uint64, originAddr, originMethod string, options BroadcastOptions, id uint32, addr string)
type ServerHandler func(ctx context.Context, in protoreflect.ProtoMessage, broadcastID uint64, originAddr, originMethod string, options BroadcastOptions, id uint32, addr string, originDigest, originSignature []byte, originPubKey string)
21 changes: 18 additions & 3 deletions broadcast/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ type metadata struct {
IsBroadcastClient bool
SentCancellation bool
HasReceivedClientReq bool
OriginDigest []byte
OriginPubKey string
OriginSignature []byte
}

func (p *BroadcastProcessor) run(msg *Content) {
Expand All @@ -50,6 +53,9 @@ func (p *BroadcastProcessor) run(msg *Content) {
SendFn: msg.SendFn,
Sent: false,
SentCancellation: false,
OriginDigest: msg.OriginDigest,
OriginSignature: msg.OriginSignature,
OriginPubKey: msg.OriginPubKey,
}
// methods is placed here and not in the metadata as an optimization strategy.
// Testing shows that it does not allocate memory for it on the heap.
Expand Down Expand Up @@ -102,7 +108,7 @@ func (p *BroadcastProcessor) handleCancellation(bMsg *Msg, metadata *metadata) b
p.log("broadcast: sent cancellation", nil, logging.MsgType(bMsg.MsgType.String()), logging.Stopping(false))
metadata.SentCancellation = true
go func(broadcastID uint64, cancellationMsg *cancellation) {
_ = p.router.Send(broadcastID, "", "", cancellationMsg)
_ = p.router.Send(broadcastID, "", "", nil, nil, "", cancellationMsg)
}(p.broadcastID, bMsg.Cancellation)
}
return false
Expand All @@ -114,7 +120,7 @@ func (p *BroadcastProcessor) handleBroadcast(bMsg *Msg, methods []string, metada
if !bMsg.Msg.options.AllowDuplication && alreadyBroadcasted(methods, bMsg.Method) {
return false
}
err := p.router.Send(p.broadcastID, metadata.OriginAddr, metadata.OriginMethod, bMsg.Msg)
err := p.router.Send(p.broadcastID, metadata.OriginAddr, metadata.OriginMethod, metadata.OriginDigest, metadata.OriginSignature, metadata.OriginPubKey, bMsg.Msg)
p.log("broadcast: sending broadcast", err, logging.MsgType(bMsg.MsgType.String()), logging.Method(bMsg.Method), logging.Stopping(false), logging.IsBroadcastCall(metadata.isBroadcastCall()))

p.updateOrder(bMsg.Method, bMsg.Msg.options.ProgressTo)
Expand All @@ -126,7 +132,7 @@ func (p *BroadcastProcessor) handleReply(bMsg *Msg, metadata *metadata) bool {
// BroadcastCall if origin addr is non-empty.
if metadata.isBroadcastCall() {
go func(broadcastID uint64, originAddr, originMethod string, replyMsg *reply) {
err := p.router.Send(broadcastID, originAddr, originMethod, replyMsg)
err := p.router.Send(broadcastID, originAddr, originMethod, metadata.OriginDigest, metadata.OriginSignature, metadata.OriginPubKey, replyMsg)
p.log("broadcast: sent reply to client", err, logging.Method(originMethod), logging.MsgType(bMsg.MsgType.String()), logging.Stopping(true), logging.IsBroadcastCall(metadata.isBroadcastCall()))
}(p.broadcastID, metadata.OriginAddr, metadata.OriginMethod, bMsg.Reply)
// the request is done becuase we have sent a reply to the client
Expand Down Expand Up @@ -267,6 +273,15 @@ func (m *metadata) update(new *Content) {
if m.OriginMethod == "" && new.OriginMethod != "" {
m.OriginMethod = new.OriginMethod
}
if m.OriginPubKey == "" && new.OriginPubKey != "" {
m.OriginPubKey = new.OriginPubKey
}
if m.OriginSignature == nil && new.OriginSignature != nil {
m.OriginSignature = new.OriginSignature
}
if m.OriginDigest == nil && new.OriginDigest != nil {
m.OriginDigest = new.OriginDigest
}
if m.SendFn == nil && new.SendFn != nil {
m.SendFn = new.SendFn
m.IsBroadcastClient = new.IsBroadcastClient
Expand Down
2 changes: 1 addition & 1 deletion broadcast/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type mockRouter struct {
resp protoreflect.ProtoMessage
}

func (r *mockRouter) Send(broadcastID uint64, addr, method string, req msg) error {
func (r *mockRouter) Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error {
switch val := req.(type) {
case *broadcastMsg:
r.reqType = "Broadcast"
Expand Down
18 changes: 9 additions & 9 deletions broadcast/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import (

type Client struct {
Addr string
SendMsg func(broadcastID uint64, method string, msg protoreflect.ProtoMessage, timeout time.Duration) error
SendMsg func(broadcastID uint64, method string, msg protoreflect.ProtoMessage, timeout time.Duration, originDigest, originSignature []byte, originPubKey string) error
Close func() error
}

type Router interface {
Send(broadcastID uint64, addr, method string, req msg) error
Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error
Connect(addr string)
}

Expand Down Expand Up @@ -61,15 +61,15 @@ func (r *BroadcastRouter) registerState(state *BroadcastState) {

type msg interface{}

func (r *BroadcastRouter) Send(broadcastID uint64, addr, method string, req msg) error {
func (r *BroadcastRouter) Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error {
if r.addr == "" {
panic("The listen addr on the broadcast server cannot be empty. Use the WithListenAddr() option when creating the server.")
}
switch val := req.(type) {
case *broadcastMsg:
return r.routeBroadcast(broadcastID, addr, method, val)
return r.routeBroadcast(broadcastID, addr, method, val, originDigest, originSignature, originPubKey)
case *reply:
return r.routeClientReply(broadcastID, addr, method, val)
return r.routeClientReply(broadcastID, addr, method, val, originDigest, originSignature, originPubKey)
case *cancellation:
r.canceler(broadcastID, val.srvAddrs)
return nil
Expand All @@ -83,27 +83,27 @@ func (r *BroadcastRouter) Connect(addr string) {
_, _ = r.getClient(addr)
}

func (r *BroadcastRouter) routeBroadcast(broadcastID uint64, addr, method string, msg *broadcastMsg) error {
func (r *BroadcastRouter) routeBroadcast(broadcastID uint64, addr, method string, msg *broadcastMsg, originDigest, originSignature []byte, originPubKey string) error {
if handler, ok := r.serverHandlers[msg.method]; ok {
// it runs an interceptor prior to broadcastCall, hence a different signature.
// see: (srv *broadcastServer) registerBroadcastFunc(method string).
handler(msg.ctx, msg.request, broadcastID, addr, method, msg.options, r.id, r.addr)
handler(msg.ctx, msg.request, broadcastID, addr, method, msg.options, r.id, r.addr, originDigest, originSignature, originPubKey)
return nil
}
err := errors.New("handler not found")
r.log("router (broadcast): could not find handler", err, logging.BroadcastID(broadcastID), logging.NodeAddr(addr), logging.Method(method))
return err
}

func (r *BroadcastRouter) routeClientReply(broadcastID uint64, addr, method string, resp *reply) error {
func (r *BroadcastRouter) routeClientReply(broadcastID uint64, addr, method string, resp *reply, originDigest, originSignature []byte, originPubKey string) error {
// the client has initiated a broadcast call and the reply should be sent as an RPC
if _, ok := r.clientHandlers[method]; ok && addr != "" {
client, err := r.getClient(addr)
if err != nil {
r.log("router (reply): could not get client", err, logging.BroadcastID(broadcastID), logging.NodeAddr(addr), logging.Method(method))
return err
}
err = client.SendMsg(broadcastID, method, resp.getResponse(), r.dialTimeout)
err = client.SendMsg(broadcastID, method, resp.getResponse(), r.dialTimeout, originDigest, originSignature, originPubKey)
r.log("router (reply): sending reply to client", err, logging.BroadcastID(broadcastID), logging.NodeAddr(addr), logging.Method(method))
return err
}
Expand Down
2 changes: 1 addition & 1 deletion broadcast/shard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type slowRouter struct {
resp protoreflect.ProtoMessage
}

func (r *slowRouter) Send(broadcastID uint64, addr, method string, req msg) error {
func (r *slowRouter) Send(broadcastID uint64, addr, method string, originDigest, originSignature []byte, originPubKey string, req msg) error {
time.Sleep(1 * time.Second)
switch val := req.(type) {
case *broadcastMsg:
Expand Down
3 changes: 3 additions & 0 deletions broadcast/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ type Content struct {
IsCancellation bool
OriginAddr string
OriginMethod string
OriginPubKey string
OriginSignature []byte
OriginDigest []byte
ViewNumber uint64
SenderAddr string
CurrentMethod string
Expand Down
8 changes: 7 additions & 1 deletion broadcastcall.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ type BroadcastCallData struct {
SenderAddr string
OriginAddr string
OriginMethod string
OriginPubKey string
OriginSignature []byte
OriginDigest []byte
ServerAddresses []string
SkipSelf bool
}
Expand Down Expand Up @@ -46,10 +49,13 @@ func (c RawConfiguration) BroadcastCall(ctx context.Context, d BroadcastCallData
SenderAddr: d.SenderAddr,
OriginAddr: d.OriginAddr,
OriginMethod: d.OriginMethod,
OriginPubKey: d.OriginPubKey,
OriginSignature: d.OriginSignature,
OriginDigest: d.OriginDigest,
}}
msg := &Message{Metadata: md, Message: d.Message}
o := getCallOptions(E_Broadcast, opts)
c.sign(msg)
c.sign(msg, o.signOrigin)

var replyChan chan response
if !o.noSendWaiting {
Expand Down
2 changes: 1 addition & 1 deletion channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func TestChannelUnsuccessfulConnection(t *testing.T) {
}

func TestChannelReconnection(t *testing.T) {
srvAddr := "127.0.0.1:5000"
srvAddr := "127.0.0.1:5005"
// wait to start the server
startServer, stopServer := testServerSetup(t, srvAddr, dummySrv())
node, err := NewRawNode(srvAddr)
Expand Down
76 changes: 72 additions & 4 deletions clientserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"sync"
"time"

"github.com/relab/gorums/authentication"
"github.com/relab/gorums/broadcast"
"github.com/relab/gorums/logging"
"github.com/relab/gorums/ordering"
Expand All @@ -34,6 +35,8 @@ type ClientServer struct {
grpcServer *grpc.Server
handlers map[string]requestHandler
logger *slog.Logger
auth *authentication.EllipticCurve
allowList map[string]string
ordering.UnimplementedGorumsServer
}

Expand Down Expand Up @@ -186,6 +189,10 @@ func (s *ClientServer) NodeStream(srv ordering.Gorums_NodeStreamServer) error {
if err != nil {
return err
}
err = s.verify(req)
if err != nil {
continue
}
if handler, ok := s.handlers[req.Metadata.Method]; ok {
go handler(ServerCtx{Context: ctx, once: new(sync.Once), mut: &mut}, req, nil)
mut.Lock()
Expand Down Expand Up @@ -221,6 +228,12 @@ func NewClientServer(lis net.Listener, opts ...ServerOption) *ClientServer {
}
ordering.RegisterGorumsServer(srv.grpcServer, srv)
srv.lis = lis
if serverOpts.auth != nil {
srv.auth = serverOpts.auth
}
if serverOpts.allowList != nil {
srv.allowList = serverOpts.allowList
}
return srv
}

Expand All @@ -239,6 +252,58 @@ func (srv *ClientServer) Serve(listener net.Listener) error {
return srv.grpcServer.Serve(listener)
}

func (srv *ClientServer) encodeMsg(req *Message) ([]byte, error) {
// we must not consider the signature field when validating.
// also the msgType must be set to requestType.
signature := make([]byte, len(req.Metadata.AuthMsg.Signature))
copy(signature, req.Metadata.AuthMsg.Signature)
reqType := req.msgType
req.Metadata.AuthMsg.Signature = nil
req.msgType = 0
encodedMsg, err := srv.auth.EncodeMsg(*req)
req.Metadata.AuthMsg.Signature = make([]byte, len(signature))
copy(req.Metadata.AuthMsg.Signature, signature)
req.msgType = reqType
return encodedMsg, err
}

func (srv *ClientServer) verify(req *Message) error {
if srv.auth == nil {
return nil
}
if req.Metadata.AuthMsg == nil {
return fmt.Errorf("missing authMsg")
}
if req.Metadata.AuthMsg.Signature == nil {
return fmt.Errorf("missing signature")
}
if req.Metadata.AuthMsg.PublicKey == "" {
return fmt.Errorf("missing publicKey")
}
authMsg := req.Metadata.AuthMsg
if srv.allowList != nil {
pemEncodedPub, ok := srv.allowList[authMsg.Sender]
if !ok {
return fmt.Errorf("not allowed")
}
if pemEncodedPub != authMsg.PublicKey {
return fmt.Errorf("publicKey did not match")
}
}
encodedMsg, err := srv.encodeMsg(req)
if err != nil {
return err
}
valid, err := srv.auth.VerifySignature(authMsg.PublicKey, encodedMsg, authMsg.Signature)
if err != nil {
return err
}
if !valid {
return fmt.Errorf("invalid signature")
}
return nil
}

func createClient(addr string, dialOpts []grpc.DialOption) (*broadcast.Client, error) {
// necessary to ensure correct marshalling and unmarshalling of gorums messages
// TODO: find a better solution
Expand All @@ -258,13 +323,16 @@ func createClient(addr string, dialOpts []grpc.DialOption) (*broadcast.Client, e
}
return &broadcast.Client{
Addr: node.Address(),
SendMsg: func(broadcastID uint64, method string, msg protoreflect.ProtoMessage, timeout time.Duration) error {
SendMsg: func(broadcastID uint64, method string, msg protoreflect.ProtoMessage, timeout time.Duration, originDigest, originSignature []byte, originPubKey string) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cd := CallData{
Method: method,
Message: msg,
BroadcastID: broadcastID,
Method: method,
Message: msg,
BroadcastID: broadcastID,
OriginDigest: originDigest,
OriginSignature: originSignature,
OriginPubKey: originPubKey,
}
_, err := node.RPCCall(ctx, cd)
return err
Expand Down
Loading

0 comments on commit 656a7fc

Please sign in to comment.