diff --git a/src/pkg/netchan/common.go b/src/pkg/netchan/common.go index bde3087a5a0..56c0b25199f 100644 --- a/src/pkg/netchan/common.go +++ b/src/pkg/netchan/common.go @@ -38,21 +38,24 @@ const ( payData // user payload follows payAck // acknowledgement; no payload payClosed // channel is now closed + payAckSend // payload has been delivered. ) // A header is sent as a prefix to every transmission. It will be followed by // a request structure, an error structure, or an arbitrary user payload structure. type header struct { - Name string + Id int PayloadType int SeqNum int64 } // Sent with a header once per channel from importer to exporter to report // that it wants to bind to a channel with the specified direction for count -// messages. If count is -1, it means unlimited. +// messages, with space for size buffered values. If count is -1, it means unlimited. type request struct { + Name string Count int64 + Size int Dir Dir } @@ -78,7 +81,7 @@ type chanDir struct { // clients of an exporter and draining outstanding messages. type clientSet struct { mu sync.Mutex // protects access to channel and client maps - chans map[string]*chanDir + names map[string]*chanDir clients map[unackedCounter]bool } @@ -132,7 +135,7 @@ func (cs *clientSet) drain(timeout int64) os.Error { pending := false cs.mu.Lock() // Any messages waiting for a client? - for _, chDir := range cs.chans { + for _, chDir := range cs.names { if chDir.ch.Len() > 0 { pending = true } @@ -189,3 +192,134 @@ func (cs *clientSet) sync(timeout int64) os.Error { } return nil } + +// A netChan represents a channel imported or exported +// on a single connection. Flow is controlled by the receiving +// side by sending payAckSend messages when values +// are delivered into the local channel. +type netChan struct { + *chanDir + name string + id int + size int // buffer size of channel. + + // sender-specific state + ackCh chan bool // buffered with space for all the acks we need + space int // available space. + + // receiver-specific state + sendCh chan reflect.Value // buffered channel of values received from other end. + ed *encDec // so that we can send acks. + count int64 // number of values still to receive. +} + +// Create a new netChan with the given name (only used for +// messages), id, direction, buffer size, and count. +// The connection to the other side is represented by ed. +func newNetChan(name string, id int, ch *chanDir, ed *encDec, size int, count int64) *netChan { + c := &netChan{chanDir: ch, name: name, id: id, size: size, ed: ed, count: count} + if c.dir == Send { + c.ackCh = make(chan bool, size) + c.space = size + } + return c +} + +// Close the channel. +func (nch *netChan) close() { + if nch.dir == Recv { + if nch.sendCh != nil { + // If the sender goroutine is active, close the channel to it. + // It will close nch.ch when it can. + close(nch.sendCh) + } else { + nch.ch.Close() + } + } else { + nch.ch.Close() + close(nch.ackCh) + } +} + +// Send message from remote side to local receiver. +func (nch *netChan) send(val reflect.Value) { + if nch.dir != Recv { + panic("send on wrong direction of channel") + } + if nch.sendCh == nil { + // If possible, do local send directly and ack immediately. + if nch.ch.TrySend(val) { + nch.sendAck() + return + } + // Start sender goroutine to manage delayed delivery of values. + nch.sendCh = make(chan reflect.Value, nch.size) + go nch.sender() + } + if ok := nch.sendCh <- val; !ok { + // TODO: should this be more resilient? + panic("netchan: remote sender sent more values than allowed") + } +} + +// sendAck sends an acknowledgment that a message has left +// the channel's buffer. If the messages remaining to be sent +// will fit in the channel's buffer, then we don't +// need to send an ack. +func (nch *netChan) sendAck() { + if nch.count < 0 || nch.count > int64(nch.size) { + nch.ed.encode(&header{Id: nch.id}, payAckSend, nil) + } + if nch.count > 0 { + nch.count-- + } +} + +// The sender process forwards items from the sending queue +// to the destination channel, acknowledging each item. +func (nch *netChan) sender() { + if nch.dir != Recv { + panic("sender on wrong direction of channel") + } + // When Exporter.Hangup is called, the underlying channel is closed, + // and so we may get a "too many operations on closed channel" error + // if there are outstanding messages in sendCh. + // Make sure that this doesn't panic the whole program. + defer func() { + if r := recover(); r != nil { + // TODO check that r is "too many operations", otherwise re-panic. + } + }() + for v := range nch.sendCh { + nch.ch.Send(v) + nch.sendAck() + } + nch.ch.Close() +} + +// Receive value from local side for sending to remote side. +func (nch *netChan) recv() (val reflect.Value, closed bool) { + if nch.dir != Send { + panic("recv on wrong direction of channel") + } + + if nch.space == 0 { + // Wait for buffer space. + <-nch.ackCh + nch.space++ + } + nch.space-- + return nch.ch.Recv(), nch.ch.Closed() +} + +// acked is called when the remote side indicates that +// a value has been delivered. +func (nch *netChan) acked() { + if nch.dir != Send { + panic("recv on wrong direction of channel") + } + if ok := nch.ackCh <- true; !ok { + panic("netchan: remote receiver sent too many acks") + // TODO: should this be more resilient? + } +} diff --git a/src/pkg/netchan/export.go b/src/pkg/netchan/export.go index 9ad388c1828..0f72ca7a940 100644 --- a/src/pkg/netchan/export.go +++ b/src/pkg/netchan/export.go @@ -26,6 +26,7 @@ import ( "net" "os" "reflect" + "strconv" "sync" ) @@ -48,11 +49,12 @@ type Exporter struct { type expClient struct { *encDec exp *Exporter - mu sync.Mutex // protects remaining fields - errored bool // client has been sent an error - seqNum int64 // sequences messages sent to client; has value of highest sent - ackNum int64 // highest sequence number acknowledged - seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu + chans map[int]*netChan // channels in use by client + mu sync.Mutex // protects remaining fields + errored bool // client has been sent an error + seqNum int64 // sequences messages sent to client; has value of highest sent + ackNum int64 // highest sequence number acknowledged + seqLock sync.Mutex // guarantees messages are in sequence, only locked under mu } func newClient(exp *Exporter, conn net.Conn) *expClient { @@ -61,8 +63,8 @@ func newClient(exp *Exporter, conn net.Conn) *expClient { client.encDec = newEncDec(conn) client.seqNum = 0 client.ackNum = 0 + client.chans = make(map[int]*netChan) return client - } func (client *expClient) sendError(hdr *header, err string) { @@ -74,20 +76,33 @@ func (client *expClient) sendError(hdr *header, err string) { client.mu.Unlock() } -func (client *expClient) getChan(hdr *header, dir Dir) *chanDir { +func (client *expClient) newChan(hdr *header, dir Dir, name string, size int, count int64) *netChan { exp := client.exp exp.mu.Lock() - ech, ok := exp.chans[hdr.Name] + ech, ok := exp.names[name] exp.mu.Unlock() if !ok { - client.sendError(hdr, "no such channel: "+hdr.Name) + client.sendError(hdr, "no such channel: "+name) return nil } if ech.dir != dir { - client.sendError(hdr, "wrong direction for channel: "+hdr.Name) + client.sendError(hdr, "wrong direction for channel: "+name) return nil } - return ech + nch := newNetChan(name, hdr.Id, ech, client.encDec, size, count) + client.chans[hdr.Id] = nch + return nch +} + +func (client *expClient) getChan(hdr *header, dir Dir) *netChan { + nch := client.chans[hdr.Id] + if nch == nil { + return nil + } + if nch.dir != dir { + client.sendError(hdr, "wrong direction for channel: "+nch.name) + } + return nch } // The function run manages sends and receives for a single client. For each @@ -113,12 +128,18 @@ func (client *expClient) run() { expLog("error decoding client request:", err) break } + if req.Size < 1 { + panic("netchan: remote requested " + strconv.Itoa(req.Size) + " values") + } switch req.Dir { case Recv: - go client.serveRecv(*hdr, req.Count) + // look up channel before calling serveRecv to + // avoid a lock around client.chans. + if nch := client.newChan(hdr, Send, req.Name, req.Size, req.Count); nch != nil { + go client.serveRecv(nch, *hdr, req.Count) + } case Send: - // Request to send is clear as a matter of protocol - // but not actually used by the implementation. + client.newChan(hdr, Recv, req.Name, req.Size, req.Count) // The actual sends will have payload type payData. // TODO: manage the count? default: @@ -143,6 +164,10 @@ func (client *expClient) run() { client.ackNum = hdr.SeqNum } client.mu.Unlock() + case payAckSend: + if nch := client.getChan(hdr, Send); nch != nil { + nch.acked() + } default: log.Exit("netchan export: unknown payload type", hdr.PayloadType) } @@ -152,14 +177,10 @@ func (client *expClient) run() { // Send all the data on a single channel to a client asking for a Recv. // The header is passed by value to avoid issues of overwriting. -func (client *expClient) serveRecv(hdr header, count int64) { - ech := client.getChan(&hdr, Send) - if ech == nil { - return - } +func (client *expClient) serveRecv(nch *netChan, hdr header, count int64) { for { - val := ech.ch.Recv() - if ech.ch.Closed() { + val, closed := nch.recv() + if closed { if err := client.encode(&hdr, payClosed, nil); err != nil { expLog("error encoding server closed message:", err) } @@ -167,7 +188,7 @@ func (client *expClient) serveRecv(hdr header, count int64) { } // We hold the lock during transmission to guarantee messages are // sent in sequence number order. Also, we increment first so the - // value of client.seqNum is the value of the highest used sequence + // value of client.SeqNum is the value of the highest used sequence // number, not one beyond. client.mu.Lock() client.seqNum++ @@ -193,27 +214,27 @@ func (client *expClient) serveRecv(hdr header, count int64) { // Receive and deliver locally one item from a client asking for a Send // The header is passed by value to avoid issues of overwriting. func (client *expClient) serveSend(hdr header) { - ech := client.getChan(&hdr, Recv) - if ech == nil { + nch := client.getChan(&hdr, Recv) + if nch == nil { return } // Create a new value for each received item. - val := reflect.MakeZero(ech.ch.Type().(*reflect.ChanType).Elem()) + val := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem()) if err := client.decode(val); err != nil { - expLog("value decode:", err) + expLog("value decode:", err, "; type ", nch.ch.Type()) return } - ech.ch.Send(val) + nch.send(val) } // Report that client has closed the channel that is sending to us. // The header is passed by value to avoid issues of overwriting. func (client *expClient) serveClosed(hdr header) { - ech := client.getChan(&hdr, Recv) - if ech == nil { + nch := client.getChan(&hdr, Recv) + if nch == nil { return } - ech.ch.Close() + nch.close() } func (client *expClient) unackedCount() int64 { @@ -260,7 +281,7 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) { e := &Exporter{ listener: listener, clientSet: &clientSet{ - chans: make(map[string]*chanDir), + names: make(map[string]*chanDir), clients: make(map[unackedCounter]bool), }, } @@ -343,11 +364,11 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { } exp.mu.Lock() defer exp.mu.Unlock() - _, present := exp.chans[name] + _, present := exp.names[name] if present { return os.ErrorString("channel name already being exported:" + name) } - exp.chans[name] = &chanDir{ch, dir} + exp.names[name] = &chanDir{ch, dir} return nil } @@ -355,10 +376,11 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { // the channel. Messages in flight for the channel may be dropped. func (exp *Exporter) Hangup(name string) os.Error { exp.mu.Lock() - chDir, ok := exp.chans[name] + chDir, ok := exp.names[name] if ok { - exp.chans[name] = nil, false + exp.names[name] = nil, false } + // TODO drop all instances of channel from client sets exp.mu.Unlock() if !ok { return os.ErrorString("netchan export: hangup: no such channel: " + name) diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go index baae367a0c9..22b0f69ba38 100644 --- a/src/pkg/netchan/import.go +++ b/src/pkg/netchan/import.go @@ -27,8 +27,10 @@ type Importer struct { *encDec conn net.Conn chanLock sync.Mutex // protects access to channel map - chans map[string]*chanDir + names map[string]*netChan + chans map[int]*netChan errors chan os.Error + maxId int } // NewImporter creates a new Importer object to import channels @@ -43,7 +45,8 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) { imp := new(Importer) imp.encDec = newEncDec(conn) imp.conn = conn - imp.chans = make(map[string]*chanDir) + imp.chans = make(map[int]*netChan) + imp.names = make(map[string]*netChan) imp.errors = make(chan os.Error, 10) go imp.run() return imp, nil @@ -54,7 +57,7 @@ func (imp *Importer) shutdown() { imp.chanLock.Lock() for _, ich := range imp.chans { if ich.dir == Recv { - ich.ch.Close() + ich.close() } } imp.chanLock.Unlock() @@ -95,43 +98,53 @@ func (imp *Importer) run() { continue // errors are not acknowledged. } case payClosed: - ich := imp.getChan(hdr.Name) - if ich != nil { - ich.ch.Close() + nch := imp.getChan(hdr.Id, false) + if nch != nil { + nch.close() } continue // closes are not acknowledged. + case payAckSend: + // we can receive spurious acks if the channel is + // hung up, so we ask getChan to ignore any errors. + nch := imp.getChan(hdr.Id, true) + if nch != nil { + nch.acked() + } + continue default: impLog("unexpected payload type:", hdr.PayloadType) return } - ich := imp.getChan(hdr.Name) - if ich == nil { + nch := imp.getChan(hdr.Id, false) + if nch == nil { continue } - if ich.dir != Recv { + if nch.dir != Recv { impLog("cannot happen: receive from non-Recv channel") return } // Acknowledge receipt - ackHdr.Name = hdr.Name + ackHdr.Id = hdr.Id ackHdr.SeqNum = hdr.SeqNum imp.encode(ackHdr, payAck, nil) // Create a new value for each received item. - value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem()) + value := reflect.MakeZero(nch.ch.Type().(*reflect.ChanType).Elem()) if e := imp.decode(value); e != nil { impLog("importer value decode:", e) return } - ich.ch.Send(value) + nch.send(value) } } -func (imp *Importer) getChan(name string) *chanDir { +func (imp *Importer) getChan(id int, errOk bool) *netChan { imp.chanLock.Lock() - ich := imp.chans[name] + ich := imp.chans[id] imp.chanLock.Unlock() if ich == nil { - impLog("unknown name in netchan request:", name) + if !errOk { + impLog("unknown id in netchan request: ", id) + } return nil } return ich @@ -145,17 +158,18 @@ func (imp *Importer) Errors() chan os.Error { return imp.errors } -// Import imports a channel of the given type and specified direction. +// Import imports a channel of the given type, size and specified direction. // It is equivalent to ImportNValues with a count of -1, meaning unbounded. -func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error { - return imp.ImportNValues(name, chT, dir, -1) +func (imp *Importer) Import(name string, chT interface{}, dir Dir, size int) os.Error { + return imp.ImportNValues(name, chT, dir, size, -1) } -// ImportNValues imports a channel of the given type and specified direction -// and then receives or transmits up to n values on that channel. A value of -// n==-1 implies an unbounded number of values. The channel to be bound to -// the remote site's channel is provided in the call and may be of arbitrary -// channel type. +// ImportNValues imports a channel of the given type and specified +// direction and then receives or transmits up to n values on that +// channel. A value of n==-1 implies an unbounded number of values. The +// channel will have buffer space for size values, or 1 value if size < 1. +// The channel to be bound to the remote site's channel is provided +// in the call and may be of arbitrary channel type. // Despite the literal signature, the effective signature is // ImportNValues(name string, chT chan T, dir Dir, n int) os.Error // Example usage: @@ -165,21 +179,28 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error { // err = imp.ImportNValues("name", ch, Recv, 1) // if err != nil { log.Exit(err) } // fmt.Printf("%+v\n", <-ch) -func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) os.Error { +func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size, n int) os.Error { ch, err := checkChan(chT, dir) if err != nil { return err } imp.chanLock.Lock() defer imp.chanLock.Unlock() - _, present := imp.chans[name] + _, present := imp.names[name] if present { return os.ErrorString("channel name already being imported:" + name) } - imp.chans[name] = &chanDir{ch, dir} + if size < 1 { + size = 1 + } + id := imp.maxId + imp.maxId++ + nch := newNetChan(name, id, &chanDir{ch, dir}, imp.encDec, size, int64(n)) + imp.names[name] = nch + imp.chans[id] = nch // Tell the other side about this channel. - hdr := &header{Name: name} - req := &request{Count: int64(n), Dir: dir} + hdr := &header{Id: id} + req := &request{Name: name, Count: int64(n), Dir: dir, Size: size} if err = imp.encode(hdr, payRequest, req); err != nil { impLog("request encode:", err) return err @@ -187,8 +208,8 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) if dir == Send { go func() { for i := 0; n == -1 || i < n; i++ { - val := ch.Recv() - if ch.Closed() { + val, closed := nch.recv() + if closed { if err = imp.encode(hdr, payClosed, nil); err != nil { impLog("error encoding client closed message:", err) } @@ -208,14 +229,15 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) // the channel. Messages in flight for the channel may be dropped. func (imp *Importer) Hangup(name string) os.Error { imp.chanLock.Lock() - chDir, ok := imp.chans[name] + nc, ok := imp.names[name] if ok { - imp.chans[name] = nil, false + imp.names[name] = nil, false + imp.chans[nc.id] = nil, false } imp.chanLock.Unlock() if !ok { return os.ErrorString("netchan import: hangup: no such channel: " + name) } - chDir.ch.Close() + nc.close() return nil } diff --git a/src/pkg/netchan/netchan_test.go b/src/pkg/netchan/netchan_test.go index 766c4c47401..2134297c40b 100644 --- a/src/pkg/netchan/netchan_test.go +++ b/src/pkg/netchan/netchan_test.go @@ -15,7 +15,7 @@ const closeCount = 5 // number of items when sender closes early const base = 23 -func exportSend(exp *Exporter, n int, t *testing.T) { +func exportSend(exp *Exporter, n int, t *testing.T, done chan bool) { ch := make(chan int) err := exp.Export("exportedSend", ch, Send) if err != nil { @@ -26,6 +26,9 @@ func exportSend(exp *Exporter, n int, t *testing.T) { ch <- base+i } close(ch) + if done != nil { + done <- true + } }() } @@ -50,9 +53,9 @@ func exportReceive(exp *Exporter, t *testing.T, expDone chan bool) { } } -func importSend(imp *Importer, n int, t *testing.T) { +func importSend(imp *Importer, n int, t *testing.T, done chan bool) { ch := make(chan int) - err := imp.ImportNValues("exportedRecv", ch, Send, count) + err := imp.ImportNValues("exportedRecv", ch, Send, 3, -1) if err != nil { t.Fatal("importSend:", err) } @@ -61,12 +64,15 @@ func importSend(imp *Importer, n int, t *testing.T) { ch <- base+i } close(ch) + if done != nil { + done <- true + } }() } func importReceive(imp *Importer, t *testing.T, done chan bool) { ch := make(chan int) - err := imp.ImportNValues("exportedSend", ch, Recv, count) + err := imp.ImportNValues("exportedSend", ch, Recv, 3, count) if err != nil { t.Fatal("importReceive:", err) } @@ -78,7 +84,7 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) { } break } - if v != 23+i { + if v != base+i { t.Errorf("importReceive: bad value: expected %d+%d=%d; got %+d", base, i, base+i, v) } } @@ -96,7 +102,7 @@ func TestExportSendImportReceive(t *testing.T) { if err != nil { t.Fatal("new importer:", err) } - exportSend(exp, count, t) + exportSend(exp, count, t, nil) importReceive(imp, t, nil) } @@ -116,7 +122,7 @@ func TestExportReceiveImportSend(t *testing.T) { done <- true }() <-expDone - importSend(imp, count, t) + importSend(imp, count, t, nil) <-done } @@ -129,7 +135,7 @@ func TestClosingExportSendImportReceive(t *testing.T) { if err != nil { t.Fatal("new importer:", err) } - exportSend(exp, closeCount, t) + exportSend(exp, closeCount, t, nil) importReceive(imp, t, nil) } @@ -149,7 +155,7 @@ func TestClosingImportSendExportReceive(t *testing.T) { done <- true }() <-expDone - importSend(imp, closeCount, t) + importSend(imp, closeCount, t, nil) <-done } @@ -172,7 +178,7 @@ func TestErrorForIllegalChannel(t *testing.T) { close(ch) // Now try to import a different channel. ch = make(chan int) - err = imp.Import("notAChannel", ch, Recv) + err = imp.Import("notAChannel", ch, Recv, 1) if err != nil { t.Fatal("import:", err) } @@ -204,7 +210,7 @@ func TestExportDrain(t *testing.T) { } done := make(chan bool) go func() { - exportSend(exp, closeCount, t) + exportSend(exp, closeCount, t, nil) done <- true }() <-done @@ -224,7 +230,7 @@ func TestExportSync(t *testing.T) { t.Fatal("new importer:", err) } done := make(chan bool) - exportSend(exp, closeCount, t) + exportSend(exp, closeCount, t, nil) go importReceive(imp, t, done) exp.Sync(0) <-done @@ -248,7 +254,7 @@ func TestExportHangup(t *testing.T) { } // Prepare to receive two values. We'll actually deliver only one. ich := make(chan int) - err = imp.ImportNValues("exportedSend", ich, Recv, 2) + err = imp.ImportNValues("exportedSend", ich, Recv, 1, 2) if err != nil { t.Fatal("import exportedSend:", err) } @@ -285,7 +291,7 @@ func TestImportHangup(t *testing.T) { } // Prepare to Send two values. We'll actually deliver only one. ich := make(chan int) - err = imp.ImportNValues("exportedRecv", ich, Send, 2) + err = imp.ImportNValues("exportedRecv", ich, Send, 1, 2) if err != nil { t.Fatal("import exportedRecv:", err) } @@ -304,10 +310,70 @@ func TestImportHangup(t *testing.T) { } } +// loop back exportedRecv to exportedSend, +// but receive a value from ctlch before starting the loop. +func exportLoopback(exp *Exporter, t *testing.T) { + inch := make(chan int) + if err := exp.Export("exportedRecv", inch, Recv); err != nil { + t.Fatal("exportRecv") + } + + outch := make(chan int) + if err := exp.Export("exportedSend", outch, Send); err != nil { + t.Fatal("exportSend") + } + + ctlch := make(chan int) + if err := exp.Export("exportedCtl", ctlch, Recv); err != nil { + t.Fatal("exportRecv") + } + + go func() { + <-ctlch + for i := 0; i < count; i++ { + x := <-inch + if x != base+i { + t.Errorf("exportLoopback expected %d; got %d", i, x) + } + outch <- x + } + }() +} + +// This test checks that channel operations can proceed +// even when other concurrent operations are blocked. +func TestIndependentSends(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + + exportLoopback(exp, t) + + importSend(imp, count, t, nil) + done := make(chan bool) + go importReceive(imp, t, done) + + // wait for export side to try to deliver some values. + time.Sleep(0.25e9) + + ctlch := make(chan int) + if err := imp.ImportNValues("exportedCtl", ctlch, Send, 1, 1); err != nil { + t.Fatal("importSend:", err) + } + ctlch <- 0 + + <-done +} + // This test cross-connects a pair of exporter/importer pairs. type value struct { - i int - source string + I int + Source string } func TestCrossConnect(t *testing.T) { @@ -353,13 +419,13 @@ func crossExport(e1, e2 *Exporter, t *testing.T) { // Import side of cross-traffic. func crossImport(i1, i2 *Importer, t *testing.T) { s := make(chan value) - err := i2.Import("exportedReceive", s, Send) + err := i2.Import("exportedReceive", s, Send, 2) if err != nil { t.Fatal("import of exportedReceive:", err) } r := make(chan value) - err = i1.Import("exportedSend", r, Recv) + err = i1.Import("exportedSend", r, Recv, 2) if err != nil { t.Fatal("import of exported Send:", err) } @@ -374,10 +440,76 @@ func crossLoop(name string, s, r chan value, t *testing.T) { case s <- value{si, name}: si++ case v := <-r: - if v.i != ri { + if v.I != ri { t.Errorf("loop: bad value: expected %d, hello; got %+v", ri, v) } ri++ } } } + +const flowCount = 100 + +// test flow control from exporter to importer. +func TestExportFlowControl(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + + sendDone := make(chan bool, 1) + exportSend(exp, flowCount, t, sendDone) + + ch := make(chan int) + err = imp.ImportNValues("exportedSend", ch, Recv, 20, -1) + if err != nil { + t.Fatal("importReceive:", err) + } + + testFlow(sendDone, ch, flowCount, t) +} + +// test flow control from importer to exporter. +func TestImportFlowControl(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + + ch := make(chan int) + err = exp.Export("exportedRecv", ch, Recv) + if err != nil { + t.Fatal("importReceive:", err) + } + + sendDone := make(chan bool, 1) + importSend(imp, flowCount, t, sendDone) + testFlow(sendDone, ch, flowCount, t) +} + +func testFlow(sendDone chan bool, ch <-chan int, N int, t *testing.T) { + go func() { + time.Sleep(1e9) + sendDone <- false + }() + + if <-sendDone { + t.Fatal("send did not block") + } + n := 0 + for i := range ch { + t.Log("after blocking, got value ", i) + n++ + } + if n != N { + t.Fatalf("expected %d values; got %d", N, n) + } +}