diff --git a/src/pkg/netchan/common.go b/src/pkg/netchan/common.go index 3f99868490b..87981ca8603 100644 --- a/src/pkg/netchan/common.go +++ b/src/pkg/netchan/common.go @@ -37,6 +37,7 @@ const ( payError // error structure follows payData // user payload follows payAck // acknowledgement; no payload + payClosed // channel is now closed ) // A header is sent as a prefix to every transmission. It will be followed by diff --git a/src/pkg/netchan/export.go b/src/pkg/netchan/export.go index a58797e630c..d7dceead990 100644 --- a/src/pkg/netchan/export.go +++ b/src/pkg/netchan/export.go @@ -31,6 +31,12 @@ import ( // Export +// expLog is a logging convenience function. The first argument must be a string. +func expLog(args ...interface{}) { + args[0] = "netchan export: " + args[0].(string) + log.Stderr(args) +} + // An Exporter allows a set of channels to be published on a single // network port. A single machine may have multiple Exporters // but they must use different ports. @@ -60,7 +66,7 @@ func newClient(exp *Exporter, conn net.Conn) *expClient { func (client *expClient) sendError(hdr *header, err string) { error := &error{err} - log.Stderr("export:", error.error) + expLog("sending error to client", error.error) client.encode(hdr, payError, error) // ignore any encode error, hope client gets it client.mu.Lock() client.errored = true @@ -96,13 +102,13 @@ func (client *expClient) run() { for { *hdr = header{} if err := client.decode(hdrValue); err != nil { - log.Stderr("error decoding client header:", err) + expLog("error decoding client header:", err) break } switch hdr.payloadType { case payRequest: if err := client.decode(reqValue); err != nil { - log.Stderr("error decoding client request:", err) + expLog("error decoding client request:", err) break } switch req.dir { @@ -114,12 +120,14 @@ func (client *expClient) run() { // The actual sends will have payload type payData. // TODO: manage the count? default: - error.error = "export request: can't handle channel direction" - log.Stderr(error.error, req.dir) + error.error = "request: can't handle channel direction" + expLog(error.error, req.dir) client.encode(hdr, payError, error) } case payData: client.serveSend(*hdr) + case payClosed: + client.serveClosed(*hdr) case payAck: client.mu.Lock() if client.ackNum != hdr.seqNum-1 { @@ -127,12 +135,14 @@ func (client *expClient) run() { // in a single instance of locking client.mu, the messages are guaranteed // to be sent in order. Therefore receipt of acknowledgement N means // all messages <=N have been seen by the recipient. We check anyway. - log.Stderr("netchan export: sequence out of order:", client.ackNum, hdr.seqNum) + expLog("sequence out of order:", client.ackNum, hdr.seqNum) } if client.ackNum < hdr.seqNum { // If there has been an error, don't back up the count. client.ackNum = hdr.seqNum } client.mu.Unlock() + default: + log.Exit("netchan export: unknown payload type", hdr.payloadType) } } client.exp.delClient(client) @@ -148,7 +158,9 @@ func (client *expClient) serveRecv(hdr header, count int64) { for { val := ech.ch.Recv() if ech.ch.Closed() { - client.sendError(&hdr, os.EOF.String()) + if err := client.encode(&hdr, payClosed, nil); err != nil { + expLog("error encoding server closed message:", err) + } break } // We hold the lock during transmission to guarantee messages are @@ -161,7 +173,7 @@ func (client *expClient) serveRecv(hdr header, count int64) { err := client.encode(&hdr, payData, val.Interface()) client.mu.Unlock() if err != nil { - log.Stderr("error encoding client response:", err) + expLog("error encoding client response:", err) client.sendError(&hdr, err.String()) break } @@ -184,11 +196,20 @@ func (client *expClient) serveSend(hdr header) { // Create a new value for each received item. val := reflect.MakeZero(ech.ch.Type().(*reflect.ChanType).Elem()) if err := client.decode(val); err != nil { - log.Stderr("exporter value decode:", err) + expLog("value decode:", err) return } ech.ch.Send(val) - // TODO count +} + +// Report that client has closed the channel that is sending to us. +// The header is passed by value to avoid issues of overwriting. +func (client *expClient) serveClosed(hdr header) { + ech := client.getChan(&hdr, Recv) + if ech == nil { + return + } + ech.ch.Close() } func (client *expClient) unackedCount() int64 { @@ -217,7 +238,7 @@ func (exp *Exporter) listen() { for { conn, err := exp.listener.Accept() if err != nil { - log.Stderr("exporter.listen:", err) + expLog("listen:", err) break } client := exp.addClient(conn) diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go index 028a25f7f80..e6bf4cbb328 100644 --- a/src/pkg/netchan/import.go +++ b/src/pkg/netchan/import.go @@ -14,6 +14,12 @@ import ( // Import +// impLog is a logging convenience function. The first argument must be a string. +func impLog(args ...interface{}) { + args[0] = "netchan import: " + args[0].(string) + log.Stderr(args) +} + // An Importer allows a set of channels to be imported from a single // remote machine/network port. A machine may have multiple // importers, even from the same machine/network port. @@ -66,7 +72,7 @@ func (imp *Importer) run() { for { *hdr = header{} if e := imp.decode(hdrValue); e != nil { - log.Stderr("importer header:", e) + impLog("header:", e) imp.shutdown() return } @@ -75,27 +81,30 @@ func (imp *Importer) run() { // done lower in loop case payError: if e := imp.decode(errValue); e != nil { - log.Stderr("importer error:", e) + impLog("error:", e) return } if err.error != "" { - log.Stderr("importer response error:", err.error) + impLog("response error:", err.error) imp.shutdown() return } + case payClosed: + ich := imp.getChan(hdr.name) + if ich != nil { + ich.ch.Close() + } + continue default: - log.Stderr("unexpected payload type:", hdr.payloadType) + impLog("unexpected payload type:", hdr.payloadType) return } - imp.chanLock.Lock() - ich, ok := imp.chans[hdr.name] - imp.chanLock.Unlock() - if !ok { - log.Stderr("unknown name in request:", hdr.name) - return + ich := imp.getChan(hdr.name) + if ich == nil { + continue } if ich.dir != Recv { - log.Stderr("cannot happen: receive from non-Recv channel") + impLog("cannot happen: receive from non-Recv channel") return } // Acknowledge receipt @@ -105,13 +114,24 @@ func (imp *Importer) run() { // Create a new value for each received item. value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem()) if e := imp.decode(value); e != nil { - log.Stderr("importer value decode:", e) + impLog("importer value decode:", e) return } ich.ch.Send(value) } } +func (imp *Importer) getChan(name string) *chanDir { + imp.chanLock.Lock() + ich := imp.chans[name] + imp.chanLock.Unlock() + if ich == nil { + impLog("unknown name in netchan request:", name) + return nil + } + return ich +} + // Import imports a channel of the given type and specified direction. // It is equivalent to ImportNValues with a count of -1, meaning unbounded. func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error { @@ -145,18 +165,24 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) } imp.chans[name] = &chanDir{ch, dir} // Tell the other side about this channel. - hdr := &header{name: name, payloadType: payRequest} + hdr := &header{name: name} req := &request{count: int64(n), dir: dir} - if err := imp.encode(hdr, payRequest, req); err != nil { - log.Stderr("importer request encode:", err) + if err = imp.encode(hdr, payRequest, req); err != nil { + impLog("request encode:", err) return err } if dir == Send { go func() { for i := 0; n == -1 || i < n; i++ { val := ch.Recv() - if err := imp.encode(hdr, payData, val.Interface()); err != nil { - log.Stderr("error encoding client response:", err) + if ch.Closed() { + if err = imp.encode(hdr, payClosed, nil); err != nil { + impLog("error encoding client closed message:", err) + } + return + } + if err = imp.encode(hdr, payData, val.Interface()); err != nil { + impLog("error encoding client send:", err) return } } diff --git a/src/pkg/netchan/netchan_test.go b/src/pkg/netchan/netchan_test.go index 1626c367d3b..42cb3d1ec1e 100644 --- a/src/pkg/netchan/netchan_test.go +++ b/src/pkg/netchan/netchan_test.go @@ -9,6 +9,8 @@ import "testing" const count = 10 // number of items in most tests const closeCount = 5 // number of items when sender closes early +const base = 23 + func exportSend(exp *Exporter, n int, t *testing.T) { ch := make(chan int) err := exp.Export("exportedSend", ch, Send) @@ -17,7 +19,7 @@ func exportSend(exp *Exporter, n int, t *testing.T) { } go func() { for i := 0; i < n; i++ { - ch <- 23+i + ch <- base+i } close(ch) }() @@ -31,12 +33,32 @@ func exportReceive(exp *Exporter, t *testing.T) { } for i := 0; i < count; i++ { v := <-ch - if v != 45+i { - t.Errorf("export Receive: bad value: expected 4%d; got %d", 45+i, v) + if closed(ch) { + if i != closeCount { + t.Errorf("exportReceive expected close at %d; got one at %d\n", closeCount, i) + } + break + } + if v != base+i { + t.Errorf("export Receive: bad value: expected %d+%d=%d; got %d", base, i, base+i, v) } } } +func importSend(imp *Importer, n int, t *testing.T) { + ch := make(chan int) + err := imp.ImportNValues("exportedRecv", ch, Send, count) + if err != nil { + t.Fatal("importSend:", err) + } + go func() { + for i := 0; i < n; i++ { + ch <- base+i + } + close(ch) + }() +} + func importReceive(imp *Importer, t *testing.T, done chan bool) { ch := make(chan int) err := imp.ImportNValues("exportedSend", ch, Recv, count) @@ -47,12 +69,12 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) { v := <-ch if closed(ch) { if i != closeCount { - t.Errorf("expected close at %d; got one at %d\n", closeCount, i) + t.Errorf("importReceive expected close at %d; got one at %d\n", closeCount, i) } break } if v != 23+i { - t.Errorf("importReceive: bad value: expected %d; got %+d", 23+i, v) + t.Errorf("importReceive: bad value: expected %%d+%d=%d; got %+d", base, i, base+i, v) } } if done != nil { @@ -60,17 +82,6 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) { } } -func importSend(imp *Importer, t *testing.T) { - ch := make(chan int) - err := imp.ImportNValues("exportedRecv", ch, Send, count) - if err != nil { - t.Fatal("importSend:", err) - } - for i := 0; i < count; i++ { - ch <- 45+i - } -} - func TestExportSendImportReceive(t *testing.T) { exp, err := NewExporter("tcp", "127.0.0.1:0") if err != nil { @@ -93,7 +104,7 @@ func TestExportReceiveImportSend(t *testing.T) { if err != nil { t.Fatal("new importer:", err) } - go importSend(imp, t) + importSend(imp, count, t) exportReceive(exp, t) } @@ -110,6 +121,19 @@ func TestClosingExportSendImportReceive(t *testing.T) { importReceive(imp, t, nil) } +func TestClosingImportSendExportReceive(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + importSend(imp, closeCount, t) + exportReceive(exp, t) +} + // Not a great test but it does at least invoke Drain. func TestExportDrain(t *testing.T) { exp, err := NewExporter("tcp", "127.0.0.1:0")