diff --git a/src/pkg/xgb/auth.go b/src/pkg/xgb/auth.go new file mode 100644 index 00000000000..4d920c231aa --- /dev/null +++ b/src/pkg/xgb/auth.go @@ -0,0 +1,110 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xgb + +import ( + "bufio" + "io" + "os" +) + +func getU16BE(r io.Reader, b []byte) (uint16, os.Error) { + _, err := io.ReadFull(r, b[0:2]) + if err != nil { + return 0, err + } + return uint16(b[0])<<8 + uint16(b[1]), nil +} + +func getBytes(r io.Reader, b []byte) ([]byte, os.Error) { + n, err := getU16BE(r, b) + if err != nil { + return nil, err + } + if int(n) > len(b) { + return nil, os.NewError("bytes too long for buffer") + } + _, err = io.ReadFull(r, b[0:n]) + if err != nil { + return nil, err + } + return b[0:n], nil +} + +func getString(r io.Reader, b []byte) (string, os.Error) { + b, err := getBytes(r, b) + if err != nil { + return "", err + } + return string(b), nil +} + +// readAuthority reads the X authority file for the DISPLAY. +// If hostname == "" or hostname == "localhost", +// readAuthority uses the system's hostname (as returned by os.Hostname) instead. +func readAuthority(hostname, display string) (name string, data []byte, err os.Error) { + // b is a scratch buffer to use and should be at least 256 bytes long + // (i.e. it should be able to hold a hostname). + var b [256]byte + + // As per /usr/include/X11/Xauth.h. + const familyLocal = 256 + + if len(hostname) == 0 || hostname == "localhost" { + hostname, err = os.Hostname() + if err != nil { + return "", nil, err + } + } + + fname := os.Getenv("XAUTHORITY") + if len(fname) == 0 { + home := os.Getenv("HOME") + if len(home) == 0 { + err = os.NewError("Xauthority not found: $XAUTHORITY, $HOME not set") + return "", nil, err + } + fname = home + "/.Xauthority" + } + + r, err := os.Open(fname, os.O_RDONLY, 0444) + if err != nil { + return "", nil, err + } + defer r.Close() + + br := bufio.NewReader(r) + for { + family, err := getU16BE(br, b[0:2]) + if err != nil { + return "", nil, err + } + + addr, err := getString(br, b[0:]) + if err != nil { + return "", nil, err + } + + disp, err := getString(br, b[0:]) + if err != nil { + return "", nil, err + } + + name0, err := getString(br, b[0:]) + if err != nil { + return "", nil, err + } + + data0, err := getBytes(br, b[0:]) + if err != nil { + return "", nil, err + } + + if family == familyLocal && addr == hostname && disp == display { + return name0, data0, nil + } + } + panic("unreachable") +} diff --git a/src/pkg/xgb/xgb.go b/src/pkg/xgb/xgb.go index 1cde2b5c93e..3f6f0b0a68f 100644 --- a/src/pkg/xgb/xgb.go +++ b/src/pkg/xgb/xgb.go @@ -18,12 +18,14 @@ import ( // A Conn represents a connection to an X server. // Only one goroutine should use a Conn's methods at a time. type Conn struct { + host string conn net.Conn nextId Id nextCookie Cookie replies map[Cookie][]byte events queue err os.Error + display string defaultScreen int scratch [32]byte Setup SetupInfo @@ -89,7 +91,7 @@ func get32(buf []byte) uint32 { v := uint32(buf[0]) v |= uint32(buf[1]) << 8 v |= uint32(buf[2]) << 16 - v |= uint32(buf[3]) << 32 + v |= uint32(buf[3]) << 24 return v } @@ -284,49 +286,39 @@ func (c *Conn) PollForEvent() (Event, os.Error) { } // Dial connects to the X server given in the 'display' string. -// The display string is typically taken from os.Getenv("DISPLAY"). +// If 'display' is empty it will be taken from os.Getenv("DISPLAY"). +// +// Examples: +// Dial(":1") // connect to net.Dial("unix", "", "/tmp/.X11-unix/X1") +// Dial("/tmp/launch-123/:0") // connect to net.Dial("unix", "", "/tmp/launch-123/:0") +// Dial("hostname:2.1") // connect to net.Dial("tcp", "", "hostname:6002") +// Dial("tcp/hostname:1.0") // connect to net.Dial("tcp", "", "hostname:6001") func Dial(display string) (*Conn, os.Error) { - var err os.Error - - c := new(Conn) - - if display[0] == '/' { - c.conn, err = net.Dial("unix", "", display) - if err != nil { - fmt.Printf("cannot connect: %v\n", err) - return nil, err - } - } else { - parts := strings.Split(display, ":", 2) - host := parts[0] - port := 0 - if len(parts) > 1 { - parts = strings.Split(parts[1], ".", 2) - port, _ = strconv.Atoi(parts[0]) - if len(parts) > 1 { - c.defaultScreen, _ = strconv.Atoi(parts[1]) - } - } - display = fmt.Sprintf("%s:%d", host, port+6000) - c.conn, err = net.Dial("tcp", "", display) - if err != nil { - fmt.Printf("cannot connect: %v\n", err) - return nil, err - } + c, err := connect(display) + if err != nil { + return nil, err } - // TODO: get these from .Xauthority - var authName, authData []byte + // Get authentication data + authName, authData, err := readAuthority(c.host, c.display) + if err != nil { + return nil, err + } + + // Assume that the authentication protocol is "MIT-MAGIC-COOKIE-1". + if authName != "MIT-MAGIC-COOKIE-1" || len(authData) != 16 { + return nil, os.NewError("unsupported auth protocol " + authName) + } buf := make([]byte, 12+pad(len(authName))+pad(len(authData))) - buf[0] = 'l' + buf[0] = 0x6c buf[1] = 0 put16(buf[2:], 11) put16(buf[4:], 0) put16(buf[6:], uint16(len(authName))) put16(buf[8:], uint16(len(authData))) put16(buf[10:], 0) - copy(buf[12:], authName) + copy(buf[12:], strings.Bytes(authName)) copy(buf[12+pad(len(authName)):], authData) if _, err = c.conn.Write(buf); err != nil { return nil, err @@ -396,3 +388,77 @@ func getClientMessageData(b []byte, v *ClientMessageData) int { } return 20 } + +func connect(display string) (*Conn, os.Error) { + if len(display) == 0 { + display = os.Getenv("DISPLAY") + } + + display0 := display + if len(display) == 0 { + return nil, os.NewError("empty display string") + } + + colonIdx := strings.LastIndex(display, ":") + if colonIdx < 0 { + return nil, os.NewError("bad display string: " + display0) + } + + var protocol, socket string + c := new(Conn) + + if display[0] == '/' { + socket = display[0:colonIdx] + } else { + slashIdx := strings.LastIndex(display, "/") + if slashIdx >= 0 { + protocol = display[0:slashIdx] + c.host = display[slashIdx+1 : colonIdx] + } else { + c.host = display[0:colonIdx] + } + } + + display = display[colonIdx+1 : len(display)] + if len(display) == 0 { + return nil, os.NewError("bad display string: " + display0) + } + + var scr string + dotIdx := strings.LastIndex(display, ".") + if dotIdx < 0 { + c.display = display[0:] + } else { + c.display = display[0:dotIdx] + scr = display[dotIdx+1:] + } + + dispnum, err := strconv.Atoui(c.display) + if err != nil { + return nil, os.NewError("bad display string: " + display0) + } + + if len(scr) != 0 { + c.defaultScreen, err = strconv.Atoi(scr) + if err != nil { + return nil, os.NewError("bad display string: " + display0) + } + } + + // Connect to server + if len(socket) != 0 { + c.conn, err = net.Dial("unix", "", socket+":"+c.display) + } else if len(c.host) != 0 { + if protocol == "" { + protocol = "tcp" + } + c.conn, err = net.Dial(protocol, "", c.host+":"+strconv.Uitoa(6000+dispnum)) + } else { + c.conn, err = net.Dial("unix", "", "/tmp/.X11-unix/X"+c.display) + } + + if err != nil { + return nil, os.NewError("cannot connect to " + display0 + ": " + err.String()) + } + return c, nil +}