diff --git a/src/mime/type.go b/src/mime/type.go index c605a94787..89254f5001 100644 --- a/src/mime/type.go +++ b/src/mime/type.go @@ -25,7 +25,8 @@ var ( ".svg": "image/svg+xml", ".xml": "text/xml; charset=utf-8", } - mimeTypes = clone(mimeTypesLower) + mimeTypes = clone(mimeTypesLower) + extensions = invert(mimeTypesLower) ) func clone(m map[string]string) map[string]string { @@ -39,6 +40,18 @@ func clone(m map[string]string) map[string]string { return m2 } +func invert(m map[string]string) map[string][]string { + m2 := make(map[string][]string, len(m)) + for k, v := range m { + justType, _, err := ParseMediaType(v) + if err != nil { + panic(err) + } + m2[justType] = append(m2[justType], k) + } + return m2 +} + var once sync.Once // guards initMime // TypeByExtension returns the MIME type associated with the file extension ext. @@ -92,6 +105,26 @@ func TypeByExtension(ext string) string { return mimeTypesLower[string(lower)] } +// ExtensionsByType returns the extensions known to be associated with the MIME +// type typ. The returned extensions will each begin with a leading dot, as in +// ".html". When typ has no associated extensions, ExtensionsByType returns an +// nil slice. +func ExtensionsByType(typ string) ([]string, error) { + justType, _, err := ParseMediaType(typ) + if err != nil { + return nil, err + } + + once.Do(initMime) + mimeLock.RLock() + defer mimeLock.RUnlock() + s, ok := extensions[justType] + if !ok { + return nil, nil + } + return append([]string{}, s...), nil +} + // AddExtensionType sets the MIME type associated with // the extension ext to typ. The extension should begin with // a leading dot, as in ".html". @@ -104,7 +137,7 @@ func AddExtensionType(ext, typ string) error { } func setExtensionType(extension, mimeType string) error { - _, param, err := ParseMediaType(mimeType) + justType, param, err := ParseMediaType(mimeType) if err != nil { return err } @@ -115,8 +148,14 @@ func setExtensionType(extension, mimeType string) error { extLower := strings.ToLower(extension) mimeLock.Lock() + defer mimeLock.Unlock() mimeTypes[extension] = mimeType mimeTypesLower[extLower] = mimeType - mimeLock.Unlock() + for _, v := range extensions[justType] { + if v == extLower { + return nil + } + } + extensions[justType] = append(extensions[justType], extLower) return nil } diff --git a/src/mime/type_test.go b/src/mime/type_test.go index d2d254ae9a..dabb585e21 100644 --- a/src/mime/type_test.go +++ b/src/mime/type_test.go @@ -5,6 +5,9 @@ package mime import ( + "reflect" + "sort" + "strings" "testing" ) @@ -43,6 +46,59 @@ func TestTypeByExtensionCase(t *testing.T) { } } +func TestExtensionsByType(t *testing.T) { + for want, typ := range typeTests { + val, err := ExtensionsByType(typ) + if err != nil { + t.Errorf("error %s for ExtensionsByType(%q)", err, typ) + continue + } + if len(val) != 1 { + t.Errorf("ExtensionsByType(%q) = %v; expected exactly 1 entry", typ, val) + continue + } + // We always expect lower case, test data includes upper-case. + want = strings.ToLower(want) + if val[0] != want { + t.Errorf("ExtensionsByType(%q) = %q, want %q", typ, val[0], want) + } + } +} + +func TestExtensionsByTypeMultiple(t *testing.T) { + const typ = "text/html" + exts, err := ExtensionsByType(typ) + if err != nil { + t.Fatalf("ExtensionsByType(%q) error: %v", typ, err) + } + sort.Strings(exts) + if want := []string{".htm", ".html"}; !reflect.DeepEqual(exts, want) { + t.Errorf("ExtensionsByType(%q) = %v; want %v", typ, exts, want) + } +} + +func TestExtensionsByTypeNoDuplicates(t *testing.T) { + const ( + typ = "text/html" + ext = ".html" + ) + AddExtensionType(ext, typ) + AddExtensionType(ext, typ) + exts, err := ExtensionsByType(typ) + if err != nil { + t.Fatalf("ExtensionsByType(%q) error: %v", typ, err) + } + count := 0 + for _, v := range exts { + if v == ext { + count++ + } + } + if count != 1 { + t.Errorf("ExtensionsByType(%q) = %v; want %v once", typ, exts, ext) + } +} + func TestLookupMallocs(t *testing.T) { n := testing.AllocsPerRun(10000, func() { TypeByExtension(".html")