diff --git a/.ko.yaml b/.ko.yaml index 6ca5987e..4bd31c2d 100644 --- a/.ko.yaml +++ b/.ko.yaml @@ -9,4 +9,4 @@ # github.com/carbynestack/ephemeral/cmd/ephemeral: ghcr.io/carbynestack/ephemeral-spdz-base-image:cleared-20210827 defaultBaseImage: ghcr.io/carbynestack/ubuntu:20.04-20210827-nonroot baseImageOverrides: - github.com/carbynestack/ephemeral/cmd/ephemeral: ghcr.io/carbynestack/spdz:20210827 + github.com/carbynestack/ephemeral/cmd/ephemeral: ghcr.io/carbynestack/spdz:642d11f diff --git a/cmd/discovery/main_test.go b/cmd/discovery/main_test.go index 598ffdbf..2148bba8 100644 --- a/cmd/discovery/main_test.go +++ b/cmd/discovery/main_test.go @@ -7,6 +7,7 @@ package main import ( + "context" "errors" "fmt" "io/ioutil" @@ -57,7 +58,7 @@ var _ = Describe("Main", func() { }) Context("all required parameters are specified", func() { AfterEach(func() { - _, _, err := cmder.CallCMD([]string{fmt.Sprintf("rm %s", path)}, "./") + _, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", path)}, "./") Expect(err).NotTo(HaveOccurred()) }) Context("parameters are plausible", func() { @@ -100,7 +101,7 @@ var _ = Describe("Main", func() { Context("one of the required parameters is missing", func() { Context("when no frontendURL is defined", func() { AfterEach(func() { - _, _, err := cmder.CallCMD([]string{fmt.Sprintf("rm %s", path)}, "./") + _, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", path)}, "./") Expect(err).NotTo(HaveOccurred()) }) It("returns an error", func() { diff --git a/cmd/ephemeral/main_test.go b/cmd/ephemeral/main_test.go index 6030723b..402473b1 100644 --- a/cmd/ephemeral/main_test.go +++ b/cmd/ephemeral/main_test.go @@ -7,6 +7,7 @@ package main_test import ( + "context" "fmt" "io/ioutil" "math/rand" @@ -43,7 +44,7 @@ var _ = Describe("Main", func() { path = fmt.Sprintf("/tmp/test-%d", random) }) AfterEach(func() { - _, _, err := cmder.CallCMD([]string{fmt.Sprintf("rm %s", path)}, "./") + _, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", path)}, "./") Expect(err).NotTo(HaveOccurred()) }) Context("when it succeeds", func() { diff --git a/pkg/ephemeral/fake_spdz_test.go b/pkg/ephemeral/fake_spdz_test.go index 83144640..678075c9 100644 --- a/pkg/ephemeral/fake_spdz_test.go +++ b/pkg/ephemeral/fake_spdz_test.go @@ -7,6 +7,7 @@ package ephemeral import ( + "context" "errors" "github.com/carbynestack/ephemeral/pkg/discovery/fsm" pb "github.com/carbynestack/ephemeral/pkg/discovery/transport/proto" @@ -93,14 +94,14 @@ func (f *FakePlayer) PublishEvent(name, topic string, event *pb.Event) { type FakeExecutor struct { } -func (f *FakeExecutor) CallCMD(cmd []string, dir string) ([]byte, []byte, error) { +func (f *FakeExecutor) CallCMD(ctx context.Context, cmd []string, dir string) ([]byte, []byte, error) { return []byte{}, []byte{}, nil } type BrokenFakeExecutor struct { } -func (f *BrokenFakeExecutor) CallCMD(cmd []string, dir string) ([]byte, []byte, error) { +func (f *BrokenFakeExecutor) CallCMD(ctx context.Context, cmd []string, dir string) ([]byte, []byte, error) { return []byte{}, []byte{}, errors.New("some error") } diff --git a/pkg/ephemeral/io/carrier.go b/pkg/ephemeral/io/carrier.go index 7c8ef18d..e7a880d5 100644 --- a/pkg/ephemeral/io/carrier.go +++ b/pkg/ephemeral/io/carrier.go @@ -8,8 +8,11 @@ package io import ( "context" + "encoding/binary" "errors" + "fmt" "github.com/carbynestack/ephemeral/pkg/amphora" + "io" "io/ioutil" "net" ) @@ -21,7 +24,7 @@ type Result struct { // AbstractCarrier is the carriers interface. type AbstractCarrier interface { - Connect(context.Context, string, string) error + Connect(context.Context, int32, string, string) error Close() error Send([]amphora.SecretShare) error Read(ResponseConverter, bool) (*Result, error) @@ -42,16 +45,54 @@ type Config struct { } // Connect establishes a TCP connection to a socket on a given host and port. -func (c *Carrier) Connect(ctx context.Context, host, port string) error { +func (c *Carrier) Connect(ctx context.Context, playerID int32, host string, port string) error { conn, err := c.Dialer(ctx, host, port) + c.Conn = conn if err != nil { return err } - c.Conn = conn + _, err = conn.Write(c.buildHeader(playerID)) + if err != nil { + return err + } + if playerID == 0 { + err = c.readPrime() + if err != nil { + return err + } + } c.connected = true return nil } +// readPrime reads the file header from the MP-SPDZ connection +// In MP-SPDZ connection, this will only be used when player0 connects as client to MP-SPDZ +// +// For the header composition, check: +// https://github.com/data61/MP-SPDZ/issues/418#issuecomment-975424591 +// +// It is made up as follows: +// - Careful: The other header parts are not part of this communication, they are only used when reading tuple files +// - length of the prime as 4-byte number little-endian (e.g. 16), +// - prime in big-endian (e.g. 170141183460469231731687303715885907969) +func (c Carrier) readPrime() error { + const size = 4 + readBytes := make([]byte, size) + _, err := io.LimitReader(c.Conn, size).Read(readBytes) + if err != nil { + return err + } + + sizeOfHeader := binary.LittleEndian.Uint32(readBytes) + readBytes = make([]byte, sizeOfHeader) + _, err = io.LimitReader(c.Conn, int64(sizeOfHeader)).Read(readBytes) + if err != nil { + return err + } + //ToDo, compare read PRIME with prime number from config? + return nil +} + // Close closes the underlying TCP connection. func (c *Carrier) Close() error { if c.connected { @@ -78,6 +119,17 @@ func (c *Carrier) Send(secret []amphora.SecretShare) error { return nil } +// Returns a new Slice with the header appended +// The header consists of the clientId as string: +// - 1 Long (4 Byte) that contains the length of the string in bytes +// - Then come X Bytes for the String +func (c *Carrier) buildHeader(playerID int32) []byte { + playerIDString := []byte(fmt.Sprintf("%d", playerID)) + lengthOfString := make([]byte, 4) + binary.LittleEndian.PutUint32(lengthOfString, uint32(len(playerIDString))) + return append(lengthOfString, playerIDString...) +} + // Read reads the response from the TCP connection and unmarshals it. func (c *Carrier) Read(conv ResponseConverter, bulkObjects bool) (*Result, error) { resp := []byte{} diff --git a/pkg/ephemeral/io/carrier_test.go b/pkg/ephemeral/io/carrier_test.go index 5ed31b52..00cca319 100644 --- a/pkg/ephemeral/io/carrier_test.go +++ b/pkg/ephemeral/io/carrier_test.go @@ -9,17 +9,18 @@ package io_test import ( "context" "fmt" - "net" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/carbynestack/ephemeral/pkg/amphora" . "github.com/carbynestack/ephemeral/pkg/ephemeral/io" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "net" + "sync" ) var _ = Describe("Carrier", func() { var ctx = context.TODO() + var playerID = int32(1) // PlayerID 1, since PlayerID==0 contains another check when connecting + It("connects to a socket", func() { var connected bool conn := FakeNetConnection{} @@ -30,7 +31,7 @@ var _ = Describe("Carrier", func() { carrier := Carrier{ Dialer: fakeDialer, } - err := carrier.Connect(context.TODO(), "", "") + err := carrier.Connect(context.TODO(), playerID, "", "") Expect(connected).To(BeTrue()) Expect(err).NotTo(HaveOccurred()) }) @@ -42,7 +43,7 @@ var _ = Describe("Carrier", func() { carrier := Carrier{ Dialer: fakeDialer, } - err := carrier.Connect(context.TODO(), "", "") + err := carrier.Connect(context.TODO(), playerID, "", "") Expect(err).NotTo(HaveOccurred()) err = carrier.Close() Expect(err).NotTo(HaveOccurred()) @@ -50,16 +51,18 @@ var _ = Describe("Carrier", func() { }) var ( - secret []amphora.SecretShare - output []byte - client, server net.Conn - dialer func(ctx context.Context, addr, port string) (net.Conn, error) + secret []amphora.SecretShare + output []byte + connectionOutput []byte //Will contain (length 4 byte, playerID 1 byte) + client, server net.Conn + dialer func(ctx context.Context, addr, port string) (net.Conn, error) ) BeforeEach(func() { secret = []amphora.SecretShare{ amphora.SecretShare{}, } output = make([]byte, 1) + connectionOutput = make([]byte, 5) client, server = net.Pipe() dialer = func(ctx context.Context, addr, port string) (net.Conn, error) { return client, nil @@ -75,12 +78,14 @@ var _ = Describe("Carrier", func() { Dialer: dialer, Packer: packer, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(ctx, playerID, "", "") go server.Read(output) err := carrier.Send(secret) carrier.Close() Expect(err).NotTo(HaveOccurred()) Expect(output[0]).To(Equal(byte(1))) + Expect(connectionOutput).To(Equal([]byte{1, 0, 0, 0, fmt.Sprintf("%d", playerID)[0]})) }) It("returns an error when it fails to marshal the object", func() { packer := &FakeBrokenPacker{} @@ -88,7 +93,8 @@ var _ = Describe("Carrier", func() { Dialer: dialer, Packer: packer, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(ctx, playerID, "", "") go server.Read(output) err := carrier.Send(secret) carrier.Close() @@ -103,7 +109,8 @@ var _ = Describe("Carrier", func() { Dialer: dialer, Packer: packer, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(ctx, playerID, "", "") // Closing the connection to trigger a failure due to writing into the closed socket. server.Close() err := carrier.Send(secret) @@ -123,7 +130,8 @@ var _ = Describe("Carrier", func() { Dialer: dialer, Packer: &packer, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(ctx, playerID, "", "") go func() { server.Write(serverResponse) server.Close() @@ -143,7 +151,8 @@ var _ = Describe("Carrier", func() { Dialer: dialer, Packer: &packer, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(ctx, playerID, "", "") server.Close() anyConverter := &PlaintextConverter{} _, err := carrier.Read(anyConverter, false) @@ -156,7 +165,8 @@ var _ = Describe("Carrier", func() { Dialer: dialer, Packer: packer, } - carrier.Connect(ctx, "", "") + go server.Read(connectionOutput) + carrier.Connect(ctx, playerID, "", "") go func() { server.Write(serverResponse) server.Close() @@ -166,4 +176,41 @@ var _ = Describe("Carrier", func() { Expect(err).To(HaveOccurred()) }) }) + + Context("when connecting as Player0", func() { + playerID := int32(0) + It("will receive and handle the server's fileHeader", func() { + // Arrange + // ToDo: Better Response for real-life scenario? + serverResponse := []byte{1, 0, 0, 0, 1} // 4 byte length + header, in this case "1". In real case Descriptor + Prime + packer := &FakeBrokenPacker{} + carrier := Carrier{ + Dialer: dialer, + Packer: packer, + } + waitGroup := sync.WaitGroup{} + waitGroup.Add(1) + go server.Read(connectionOutput) + + // Act + var errConnecting error + go func() { + errConnecting = carrier.Connect(ctx, playerID, "", "") + waitGroup.Done() + }() + + numberOfBytesWritten, errWrite := server.Write(serverResponse) + errClose := server.Close() + + // Make sure we wait until the Connect and Write are done + waitGroup.Wait() + + // Assert + Expect(connectionOutput).To(Equal([]byte{1, 0, 0, 0, fmt.Sprintf("%d", playerID)[0]})) + Expect(errConnecting).NotTo(HaveOccurred()) + Expect(errWrite).NotTo(HaveOccurred()) + Expect(numberOfBytesWritten).To(Equal(len(serverResponse))) + Expect(errClose).NotTo(HaveOccurred()) + }) + }) }) diff --git a/pkg/ephemeral/io/feeder.go b/pkg/ephemeral/io/feeder.go index d9082416..45c90f18 100644 --- a/pkg/ephemeral/io/feeder.go +++ b/pkg/ephemeral/io/feeder.go @@ -118,7 +118,7 @@ func (f *AmphoraFeeder) feedAndRead(params []string, port string, ctx *CtxConfig default: return nil, fmt.Errorf("no output config is given, either %s, %s or %s must be defined", PlainText, SecretShare, AmphoraSecret) } - err := f.carrier.Connect(ctx.Context, "localhost", port) + err := f.carrier.Connect(ctx.Context, ctx.Spdz.PlayerID, "localhost", port) defer f.carrier.Close() if err != nil { return nil, err diff --git a/pkg/ephemeral/io/feeder_test.go b/pkg/ephemeral/io/feeder_test.go index e40e7865..77d703c3 100644 --- a/pkg/ephemeral/io/feeder_test.go +++ b/pkg/ephemeral/io/feeder_test.go @@ -44,6 +44,7 @@ var _ = Describe("Feeder", func() { conf = &CtxConfig{ Act: act, Context: context.TODO(), + Spdz: &SPDZEngineTypedConfig{PlayerCount: 2}, } }) @@ -211,7 +212,7 @@ type FakeCarrier struct { isBulk bool } -func (f *FakeCarrier) Connect(context.Context, string, string) error { +func (f *FakeCarrier) Connect(context.Context, int32, string, string) error { return nil } @@ -232,7 +233,7 @@ type BrokenConnectFakeCarrier struct { isBulk bool } -func (f *BrokenConnectFakeCarrier) Connect(context.Context, string, string) error { +func (f *BrokenConnectFakeCarrier) Connect(context.Context, int32, string, string) error { return errors.New("carrier connect error") } @@ -253,7 +254,7 @@ type BrokenSendFakeCarrier struct { isBulk bool } -func (f *BrokenSendFakeCarrier) Connect(context.Context, string, string) error { +func (f *BrokenSendFakeCarrier) Connect(context.Context, int32, string, string) error { return nil } diff --git a/pkg/ephemeral/player.go b/pkg/ephemeral/player.go index 07cde86a..d638b9e0 100644 --- a/pkg/ephemeral/player.go +++ b/pkg/ephemeral/player.go @@ -223,5 +223,5 @@ func (c *Callbacker) sendEvent(name, topic string, e interface{}) { }, } c.pb.PublishWithBody(name, topic, event, c.playerParams.GameID) - c.logger.Debugf("Sending event %v to topic %s\n", event.Name, topic) + c.logger.Debugw("Sending event", "event", event, "topic", topic) } diff --git a/pkg/ephemeral/server.go b/pkg/ephemeral/server.go index 3ca8896e..5ab7ee9c 100644 --- a/pkg/ephemeral/server.go +++ b/pkg/ephemeral/server.go @@ -340,7 +340,7 @@ func (s *Server) getPodName() (string, error) { // TODO: this is brittle, read the pod name from more reliable place. // use something like os.Getenv("HOST_NAME")? cmder := s.executor - name, _, err := cmder.CallCMD([]string{"hostname"}, "/") + name, _, err := cmder.CallCMD(context.TODO(), []string{"hostname"}, "/") if err != nil { return "", err } diff --git a/pkg/ephemeral/spdz.go b/pkg/ephemeral/spdz.go index 2ae7f2d9..d511b60f 100644 --- a/pkg/ephemeral/spdz.go +++ b/pkg/ephemeral/spdz.go @@ -7,17 +7,17 @@ package ephemeral import ( + "context" + "errors" + "fmt" d "github.com/carbynestack/ephemeral/pkg/discovery" pb "github.com/carbynestack/ephemeral/pkg/discovery/transport/proto" . "github.com/carbynestack/ephemeral/pkg/ephemeral/io" "github.com/carbynestack/ephemeral/pkg/ephemeral/network" . "github.com/carbynestack/ephemeral/pkg/types" . "github.com/carbynestack/ephemeral/pkg/utils" - "sort" - - "errors" - "fmt" "io/ioutil" + "sort" "strconv" "time" @@ -209,8 +209,8 @@ func (s *SPDZEngine) Compile(ctx *CtxConfig) error { } var stdoutSlice []byte var stderrSlice []byte - command := fmt.Sprintf("./compile.py %s", appName) - stdoutSlice, stderrSlice, err = s.cmder.CallCMD([]string{command}, s.baseDir) + command := fmt.Sprintf("./compile.py -M %s", appName) + stdoutSlice, stderrSlice, err = s.cmder.CallCMD(context.TODO(), []string{command}, s.baseDir) stdOut := string(stdoutSlice) stdErr := string(stderrSlice) s.logger.Debugw("Compiled Successfully", "Command", command, "StdOut", stdOut, "StdErr", stdErr) @@ -228,13 +228,14 @@ func (s *SPDZEngine) getFeedPort() string { func (s *SPDZEngine) startMPC(ctx *CtxConfig) { command := []string{fmt.Sprintf("./Player-Online.x %s %s -N %s --ip-file-name %s", fmt.Sprint(s.config.PlayerID), appName, fmt.Sprint(ctx.Spdz.PlayerCount), ipFile)} s.logger.Infow("Starting Player-Online.x", GameID, ctx.Act.GameID, "command", command) - stdout, stderr, err := s.cmder.CallCMD(command, s.baseDir) + stdout, stderr, err := s.cmder.CallCMD(ctx.Context, command, s.baseDir) if err != nil { + s.logger.Errorw("Error while executing the user code", GameID, ctx.Act.GameID, "StdErr", string(stderr), "StdOut", string(stdout), "error", err) err := fmt.Errorf("error while executing the user code: %v", err) ctx.ErrCh <- err - s.logger.Errorw(err.Error(), GameID, ctx.Act.GameID) + } else { + s.logger.Debugw("Computation finished", GameID, ctx.Act.GameID, "StdErr", string(stderr), "StdOut", string(stdout)) } - s.logger.Debugw("Computation finished", GameID, ctx.Act.GameID, "StdErr", string(stderr), "StdOut", string(stdout), "error", err) } func (s *SPDZEngine) writeIPFile(path string, addr string, parties int32) error { diff --git a/pkg/ephemeral/spdz_test.go b/pkg/ephemeral/spdz_test.go index 373104d9..0277d1b5 100644 --- a/pkg/ephemeral/spdz_test.go +++ b/pkg/ephemeral/spdz_test.go @@ -47,7 +47,7 @@ var _ = Describe("Spdz", func() { fileName = fmt.Sprintf("/tmp/program-%d.mpc", random) }) AfterEach(func() { - cmder.CallCMD([]string{fmt.Sprintf("rm %s", fileName)}, "./") + cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", fileName)}, "./") }) Context("writing succeeds", func() { It("writes the source code on the disk and runs the compiler", func() { @@ -63,7 +63,7 @@ var _ = Describe("Spdz", func() { } err := s.Compile(conf) Expect(err).NotTo(HaveOccurred()) - out, _, err := cmder.CallCMD([]string{fmt.Sprintf("cat %s", s.sourceCodePath)}, "./") + out, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("cat %s", s.sourceCodePath)}, "./") Expect(err).NotTo(HaveOccurred()) Expect(string(out)).To(Equal("a")) }) diff --git a/pkg/utils/os.go b/pkg/utils/os.go index f44f4d58..e44e87cc 100644 --- a/pkg/utils/os.go +++ b/pkg/utils/os.go @@ -8,6 +8,7 @@ package utils import ( "bytes" + "context" "errors" "io/ioutil" "os" @@ -18,7 +19,7 @@ import ( // Executor is an interface for calling a command and process its output. type Executor interface { // CallCMD executes the command and returns the output's STDOUT, STDERR streams as well as any errors - CallCMD(cmd []string, dir string) ([]byte, []byte, error) + CallCMD(ctx context.Context, cmd []string, dir string) ([]byte, []byte, error) } var ( @@ -45,7 +46,7 @@ type Commander struct { // Run is a facade command that runs a single command from the current directory. func (c *Commander) Run(cmd string) ([]byte, []byte, error) { - return c.CallCMD([]string{cmd}, "./") + return c.CallCMD(context.TODO(), []string{cmd}, "./") } // CallCMD calls a specified command in sh and returns its stdout and stderr as a byte slice and potentially an error. @@ -53,22 +54,19 @@ func (c *Commander) Run(cmd string) ([]byte, []byte, error) { // ``` // If the command fails to run or doesn't complete successfully, the error is of type *ExitError. Other error types may be returned for I/O problems. // ``` -func (c *Commander) CallCMD(cmd []string, dir string) ([]byte, []byte, error) { +func (c *Commander) CallCMD(ctx context.Context, cmd []string, dir string) ([]byte, []byte, error) { baseCmd := c.Options baseCmd = append(baseCmd, cmd...) - command := exec.Command(c.Command, baseCmd...) - + command := exec.CommandContext(ctx, c.Command, baseCmd...) stderrBuffer := bytes.NewBuffer([]byte{}) stdoutBuffer := bytes.NewBuffer([]byte{}) command.Stderr = stderrBuffer command.Stdout = stdoutBuffer - command.Dir = dir err := command.Start() if err != nil { return nil, nil, err } - // Check if the command finished successfully. err = command.Wait() if err != nil { switch err.(type) { diff --git a/pkg/utils/os_test.go b/pkg/utils/os_test.go index c292d0df..e137dbe7 100644 --- a/pkg/utils/os_test.go +++ b/pkg/utils/os_test.go @@ -7,6 +7,7 @@ package utils_test import ( + "context" "fmt" "io/ioutil" "math/rand" @@ -82,7 +83,7 @@ var _ = Describe("OS utils", func() { } }) AfterEach(func() { - cmder.CallCMD([]string{fmt.Sprintf("rm %s", fileName)}, "./") + cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", fileName)}, "./") }) It("reads file content", func() { data := []byte(`a`)