From 1a96391971572b5d7b210559b99f6e2330c11af9 Mon Sep 17 00:00:00 2001 From: Roger Peppe Date: Wed, 16 Feb 2011 08:14:41 -0800 Subject: [PATCH] netchan: allow use of arbitrary connections. R=r, r2, rsc CC=golang-dev https://golang.org/cl/4119055 --- src/pkg/netchan/common.go | 15 ++-- src/pkg/netchan/export.go | 50 ++++++----- src/pkg/netchan/import.go | 39 ++++----- src/pkg/netchan/netchan_test.go | 146 ++++++-------------------------- 4 files changed, 87 insertions(+), 163 deletions(-) diff --git a/src/pkg/netchan/common.go b/src/pkg/netchan/common.go index 6c085484e5e..dd06050ee50 100644 --- a/src/pkg/netchan/common.go +++ b/src/pkg/netchan/common.go @@ -6,7 +6,7 @@ package netchan import ( "gob" - "net" + "io" "os" "reflect" "sync" @@ -93,7 +93,7 @@ type encDec struct { enc *gob.Encoder } -func newEncDec(conn net.Conn) *encDec { +func newEncDec(conn io.ReadWriter) *encDec { return &encDec{ dec: gob.NewDecoder(conn), enc: gob.NewEncoder(conn), @@ -199,9 +199,10 @@ func (cs *clientSet) sync(timeout int64) os.Error { // are delivered into the local channel. type netChan struct { *chanDir - name string - id int - size int // buffer size of channel. + name string + id int + size int // buffer size of channel. + closed bool // sender-specific state ackCh chan bool // buffered with space for all the acks we need @@ -227,6 +228,9 @@ func newNetChan(name string, id int, ch *chanDir, ed *encDec, size int, count in // Close the channel. func (nch *netChan) close() { + if nch.closed { + return + } if nch.dir == Recv { if nch.sendCh != nil { // If the sender goroutine is active, close the channel to it. @@ -239,6 +243,7 @@ func (nch *netChan) close() { nch.ch.Close() close(nch.ackCh) } + nch.closed = true } // Send message from remote side to local receiver. diff --git a/src/pkg/netchan/export.go b/src/pkg/netchan/export.go index 675e252d5c3..55eba0e2e0f 100644 --- a/src/pkg/netchan/export.go +++ b/src/pkg/netchan/export.go @@ -23,6 +23,7 @@ package netchan import ( "log" + "io" "net" "os" "reflect" @@ -43,7 +44,6 @@ func expLog(args ...interface{}) { // but they must use different ports. type Exporter struct { *clientSet - listener net.Listener } type expClient struct { @@ -57,7 +57,7 @@ type expClient struct { seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu } -func newClient(exp *Exporter, conn net.Conn) *expClient { +func newClient(exp *Exporter, conn io.ReadWriter) *expClient { client := new(expClient) client.exp = exp client.encDec = newEncDec(conn) @@ -260,39 +260,50 @@ func (client *expClient) ack() int64 { return n } -// Wait for incoming connections, start a new runner for each -func (exp *Exporter) listen() { +// Serve waits for incoming connections on the listener +// and serves the Exporter's channels on each. +// It blocks until the listener is closed. +func (exp *Exporter) Serve(listener net.Listener) { for { - conn, err := exp.listener.Accept() + conn, err := listener.Accept() if err != nil { expLog("listen:", err) break } - client := exp.addClient(conn) - go client.run() + go exp.ServeConn(conn) } } -// NewExporter creates a new Exporter to export channels -// on the network and local address defined as in net.Listen. -func NewExporter(network, localaddr string) (*Exporter, os.Error) { - listener, err := net.Listen(network, localaddr) - if err != nil { - return nil, err - } +// ServeConn exports the Exporter's channels on conn. +// It blocks until the connection is terminated. +func (exp *Exporter) ServeConn(conn io.ReadWriter) { + exp.addClient(conn).run() +} + +// NewExporter creates a new Exporter that exports a set of channels. +func NewExporter() *Exporter { e := &Exporter{ - listener: listener, clientSet: &clientSet{ names: make(map[string]*chanDir), clients: make(map[unackedCounter]bool), }, } - go e.listen() - return e, nil + return e +} + +// ListenAndServe exports the exporter's channels through the +// given network and local address defined as in net.Listen. +func (exp *Exporter) ListenAndServe(network, localaddr string) os.Error { + listener, err := net.Listen(network, localaddr) + if err != nil { + return err + } + go exp.Serve(listener) + return nil } // addClient creates a new expClient and records its existence -func (exp *Exporter) addClient(conn net.Conn) *expClient { +func (exp *Exporter) addClient(conn io.ReadWriter) *expClient { client := newClient(exp, conn) exp.mu.Lock() exp.clients[client] = true @@ -329,9 +340,6 @@ func (exp *Exporter) Sync(timeout int64) os.Error { return exp.clientSet.sync(timeout) } -// Addr returns the Exporter's local network address. -func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() } - func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) { chanType, ok := reflect.Typeof(chT).(*reflect.ChanType) if !ok { diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go index d220d9a662a..30edcd8123b 100644 --- a/src/pkg/netchan/import.go +++ b/src/pkg/netchan/import.go @@ -5,6 +5,7 @@ package netchan import ( + "io" "log" "net" "os" @@ -25,7 +26,6 @@ func impLog(args ...interface{}) { // importers, even from the same machine/network port. type Importer struct { *encDec - conn net.Conn chanLock sync.Mutex // protects access to channel map names map[string]*netChan chans map[int]*netChan @@ -33,23 +33,26 @@ type Importer struct { maxId int } -// NewImporter creates a new Importer object to import channels -// from an Exporter at the network and remote address as defined in net.Dial. -// The Exporter must be available and serving when the Importer is -// created. -func NewImporter(network, remoteaddr string) (*Importer, os.Error) { - conn, err := net.Dial(network, "", remoteaddr) - if err != nil { - return nil, err - } +// NewImporter creates a new Importer object to import a set of channels +// from the given connection. The Exporter must be available and serving when +// the Importer is created. +func NewImporter(conn io.ReadWriter) *Importer { imp := new(Importer) imp.encDec = newEncDec(conn) - imp.conn = conn imp.chans = make(map[int]*netChan) imp.names = make(map[string]*netChan) imp.errors = make(chan os.Error, 10) go imp.run() - return imp, nil + return imp +} + +// Import imports a set of channels from the given network and address. +func Import(network, remoteaddr string) (*Importer, os.Error) { + conn, err := net.Dial(network, "", remoteaddr) + if err != nil { + return nil, err + } + return NewImporter(conn), nil } // shutdown closes all channels for which we are receiving data from the remote side. @@ -231,15 +234,13 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size, // the channel. Messages in flight for the channel may be dropped. func (imp *Importer) Hangup(name string) os.Error { imp.chanLock.Lock() - nc, ok := imp.names[name] - if ok { - imp.names[name] = nil, false - imp.chans[nc.id] = nil, false - } - imp.chanLock.Unlock() - if !ok { + defer imp.chanLock.Unlock() + nc := imp.names[name] + if nc == nil { return os.ErrorString("netchan import: hangup: no such channel: " + name) } + imp.names[name] = nil, false + imp.chans[nc.id] = nil, false nc.close() return nil } diff --git a/src/pkg/netchan/netchan_test.go b/src/pkg/netchan/netchan_test.go index 4076aefebfc..1c84a9d14de 100644 --- a/src/pkg/netchan/netchan_test.go +++ b/src/pkg/netchan/netchan_test.go @@ -5,6 +5,7 @@ package netchan import ( + "net" "strings" "testing" "time" @@ -94,27 +95,13 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) { } func TestExportSendImportReceive(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) exportSend(exp, count, t, nil) importReceive(imp, t, nil) } func TestExportReceiveImportSend(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) expDone := make(chan bool) done := make(chan bool) go func() { @@ -127,27 +114,13 @@ func TestExportReceiveImportSend(t *testing.T) { } func TestClosingExportSendImportReceive(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) exportSend(exp, closeCount, t, nil) importReceive(imp, t, nil) } func TestClosingImportSendExportReceive(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) expDone := make(chan bool) done := make(chan bool) go func() { @@ -160,17 +133,10 @@ func TestClosingImportSendExportReceive(t *testing.T) { } func TestErrorForIllegalChannel(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) // Now export a channel. ch := make(chan int, 1) - err = exp.Export("aChannel", ch, Send) + err := exp.Export("aChannel", ch, Send) if err != nil { t.Fatal("export:", err) } @@ -200,14 +166,7 @@ func TestErrorForIllegalChannel(t *testing.T) { // Not a great test but it does at least invoke Drain. func TestExportDrain(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) done := make(chan bool) go func() { exportSend(exp, closeCount, t, nil) @@ -221,14 +180,7 @@ func TestExportDrain(t *testing.T) { // Not a great test but it does at least invoke Sync. func TestExportSync(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) done := make(chan bool) exportSend(exp, closeCount, t, nil) go importReceive(imp, t, done) @@ -239,16 +191,9 @@ func TestExportSync(t *testing.T) { // Test hanging up the send side of an export. // TODO: test hanging up the receive side of an export. func TestExportHangup(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) ech := make(chan int) - err = exp.Export("exportedSend", ech, Send) + err := exp.Export("exportedSend", ech, Send) if err != nil { t.Fatal("export:", err) } @@ -276,16 +221,9 @@ func TestExportHangup(t *testing.T) { // Test hanging up the send side of an import. // TODO: test hanging up the receive side of an import. func TestImportHangup(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) ech := make(chan int) - err = exp.Export("exportedRecv", ech, Recv) + err := exp.Export("exportedRecv", ech, Recv) if err != nil { t.Fatal("export:", err) } @@ -343,14 +281,7 @@ func exportLoopback(exp *Exporter, t *testing.T) { // This test checks that channel operations can proceed // even when other concurrent operations are blocked. func TestIndependentSends(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) exportLoopback(exp, t) @@ -377,23 +308,8 @@ type value struct { } func TestCrossConnect(t *testing.T) { - e1, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - i1, err := NewImporter("tcp", e1.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } - - e2, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - i2, err := NewImporter("tcp", e2.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + e1, i1 := pair(t) + e2, i2 := pair(t) crossExport(e1, e2, t) crossImport(i1, i2, t) @@ -452,20 +368,13 @@ const flowCount = 100 // test flow control from exporter to importer. func TestExportFlowControl(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) sendDone := make(chan bool, 1) exportSend(exp, flowCount, t, sendDone) ch := make(chan int) - err = imp.ImportNValues("exportedSend", ch, Recv, 20, -1) + err := imp.ImportNValues("exportedSend", ch, Recv, 20, -1) if err != nil { t.Fatal("importReceive:", err) } @@ -475,17 +384,10 @@ func TestExportFlowControl(t *testing.T) { // test flow control from importer to exporter. func TestImportFlowControl(t *testing.T) { - exp, err := NewExporter("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal("new exporter:", err) - } - imp, err := NewImporter("tcp", exp.Addr().String()) - if err != nil { - t.Fatal("new importer:", err) - } + exp, imp := pair(t) ch := make(chan int) - err = exp.Export("exportedRecv", ch, Recv) + err := exp.Export("exportedRecv", ch, Recv) if err != nil { t.Fatal("importReceive:", err) } @@ -513,3 +415,11 @@ func testFlow(sendDone chan bool, ch <-chan int, N int, t *testing.T) { t.Fatalf("expected %d values; got %d", N, n) } } + +func pair(t *testing.T) (*Exporter, *Importer) { + c0, c1 := net.Pipe() + exp := NewExporter() + go exp.ServeConn(c0) + imp := NewImporter(c1) + return exp, imp +}