diff --git a/src/encoding/xml/xml.go b/src/encoding/xml/xml.go
index 0fe323f7c86..adbc996b283 100644
--- a/src/encoding/xml/xml.go
+++ b/src/encoding/xml/xml.go
@@ -1169,15 +1169,28 @@ func (d *Decoder) nsname() (name Name, ok bool) {
if !ok {
return
}
- if strings.Count(s, ":") > 1 {
- return name, false
- } else if space, local, ok := strings.Cut(s, ":"); !ok || space == "" || local == "" {
- name.Local = s
- } else {
- name.Space = space
- name.Local = local
+ // XML does not allow a document to end with a name, so there must
+ // be another byte.
+ b, ok := d.mustgetc()
+ if !ok {
+ return
}
- return name, true
+ if b != ':' {
+ d.ungetc(b)
+ name.Local = s
+ return
+ }
+ n, ok := d.name()
+ if ok {
+ // give a better error message than would otherwise be possible
+ if d.nextByte == ':' {
+ d.err = d.syntaxError("colon after prefixed XML name " + string(s) + ":" + string(n))
+ return name, false
+ }
+ name.Space = s
+ name.Local = n
+ }
+ return
}
// Get name: /first(first|second)*/
@@ -1229,7 +1242,7 @@ func isNameByte(c byte) bool {
return 'A' <= c && c <= 'Z' ||
'a' <= c && c <= 'z' ||
'0' <= c && c <= '9' ||
- c == '_' || c == ':' || c == '.' || c == '-'
+ c == '_' || c == '.' || c == '-'
}
func isName(s []byte) bool {
@@ -1287,7 +1300,6 @@ func isNameString(s string) bool {
var first = &unicode.RangeTable{
R16: []unicode.Range16{
- {0x003A, 0x003A, 1},
{0x0041, 0x005A, 1},
{0x005F, 0x005F, 1},
{0x0061, 0x007A, 1},
diff --git a/src/encoding/xml/xml_test.go b/src/encoding/xml/xml_test.go
index b2a06a76397..7e446607b36 100644
--- a/src/encoding/xml/xml_test.go
+++ b/src/encoding/xml/xml_test.go
@@ -31,6 +31,89 @@ func (t *toks) Token() (Token, error) {
return tok, nil
}
+func TestDecodeBadName(t *testing.T) {
+ tests := []struct {
+ name string
+ invalid string
+ message string
+ }{
+ {
+ name: "Number after colon",
+ invalid: ``,
+ message: "invalid XML name: 1",
+ },
+ {
+ name: "Two colons at end",
+ invalid: ``,
+ message: "expected element name after <",
+ },
+ {
+ name: "Two colons together in middle",
+ invalid: "",
+ message: "expected element name after <",
+ },
+ {
+ name: "Colon at end",
+ invalid: "",
+ message: "expected element name after <",
+ },
+ {
+ name: "Colon at start",
+ invalid: "<:a/>",
+ message: "expected element name after <",
+ },
+ {
+ name: "Number after colon in attribute",
+ invalid: ``,
+ message: "invalid XML name: 1",
+ },
+ {
+ name: "Two colons separate",
+ invalid: ``,
+ message: "colon after prefixed XML name a:b",
+ },
+ {
+ name: "Two colons at end",
+ invalid: ``,
+ message: "expected attribute name in element",
+ },
+ {
+ name: "Two colons together in middle",
+ invalid: ``,
+ message: "expected attribute name in element",
+ },
+ {
+ name: "Colon at end",
+ invalid: ``,
+ message: "expected attribute name in element",
+ },
+ {
+ name: "Colon at start",
+ invalid: ``,
+ message: "expected attribute name in element",
+ },
+ }
+ for i, j := range tests {
+ t.Run(j.name, func(t *testing.T) {
+ d := NewDecoder(strings.NewReader(j.invalid))
+ tok, err := d.RawToken()
+ if tok != nil {
+ t.Fatalf("%d: d.Decode: expected nil token, got %#v", i, tok)
+ }
+ if err == nil {
+ t.Fatalf("%d: d.Decode: expected non-nil error, got nil", i)
+ }
+ syntaxError, ok := err.(*SyntaxError)
+ if !ok {
+ t.Fatalf("%d: d.Decode: expected syntax error", i)
+ }
+ if syntaxError.Msg != j.message {
+ t.Errorf("%d: bad message: expected %q, got %q", i, j.message, syntaxError.Msg)
+ }
+ })
+ }
+}
+
func TestDecodeEOF(t *testing.T) {
start := StartElement{Name: Name{Local: "test"}}
tests := []struct {
@@ -1130,12 +1213,12 @@ func TestIssue20396(t *testing.T) {
wantErr error
}{
{``, // Issue 20396
- UnmarshalError("XML syntax error on line 1: expected element name after <")},
+ UnmarshalError("XML syntax error on line 1: colon after prefixed XML name a:te")},
{``, attrError},
{``, attrError},
{``, nil},
{`1`,
- UnmarshalError("XML syntax error on line 1: expected element name after <")},
+ UnmarshalError("XML syntax error on line 1: colon after prefixed XML name a:te")},
{`1`, attrError},
{`1`, attrError},
{`1`, nil},
@@ -1324,7 +1407,6 @@ func testRoundTrip(t *testing.T, input string) {
func TestRoundTrip(t *testing.T) {
tests := map[string]string{
- "trailing colon": ``,
"comments in directives": `--x --> > --x ]>`,
}
for name, input := range tests {