diff --git a/internal/runtime/map.go b/internal/runtime/map.go index 886980fd..c71df18d 100644 --- a/internal/runtime/map.go +++ b/internal/runtime/map.go @@ -442,7 +442,7 @@ bucketloop: if t.IndirectKey() { k = *((*unsafe.Pointer)(k)) } - if mapKeyEqual(t, key, k) { + if t.Key.Equal(key, k) { e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize)) if t.IndirectElem() { e = *((*unsafe.Pointer)(e)) @@ -503,7 +503,7 @@ bucketloop: if t.IndirectKey() { k = *((*unsafe.Pointer)(k)) } - if mapKeyEqual(t, key, k) { + if t.Key.Equal(key, k) { e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize)) if t.IndirectElem() { e = *((*unsafe.Pointer)(e)) @@ -547,7 +547,7 @@ bucketloop: if t.IndirectKey() { k = *((*unsafe.Pointer)(k)) } - if mapKeyEqual(t, key, k) { + if t.Key.Equal(key, k) { e := add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.KeySize)+i*uintptr(t.ValueSize)) if t.IndirectElem() { e = *((*unsafe.Pointer)(e)) @@ -635,7 +635,7 @@ bucketloop: if t.IndirectKey() { k = *((*unsafe.Pointer)(k)) } - if !mapKeyEqual(t, key, k) { + if !t.Key.Equal(key, k) { continue } // already have a mapping for key. Update it. @@ -747,7 +747,7 @@ search: if t.IndirectKey() { k2 = *((*unsafe.Pointer)(k2)) } - if !mapKeyEqual(t, key, k2) { + if !t.Key.Equal(key, k2) { continue } // Only clear key if there are pointers in it. @@ -935,7 +935,7 @@ next: // through the oldbucket, skipping any keys that will go // to the other new bucket (each oldbucket expands to two // buckets during a grow). - if t.ReflexiveKey() || mapKeyEqual(t, k, k) { + if t.ReflexiveKey() || t.Key.Equal(k, k) { // If the item in the oldbucket is not destined for // the current new bucket in the iteration, skip it. hash := t.Hasher(k, uintptr(h.hash0)) @@ -956,7 +956,7 @@ next: } } if (b.tophash[offi] != evacuatedX && b.tophash[offi] != evacuatedY) || - !(t.ReflexiveKey() || mapKeyEqual(t, k, k)) { + !(t.ReflexiveKey() || t.Key.Equal(k, k)) { // This is the golden data, we can return it. // OR // key!=key, so the entry can't be deleted or updated, so we can just return it. @@ -1214,7 +1214,7 @@ func evacuate(t *maptype, h *hmap, oldbucket uintptr) { // Compute hash to make our evacuation decision (whether we need // to send this key/elem to bucket x or bucket y). hash := t.Hasher(k2, uintptr(h.hash0)) - if h.flags&iterator != 0 && !t.ReflexiveKey() && !mapKeyEqual(t, k2, k2) { + if h.flags&iterator != 0 && !t.ReflexiveKey() && !t.Key.Equal(k2, k2) { // If key != key (NaNs), then the hash could be (and probably // will be) entirely different from the old hash. Moreover, // it isn't reproducible. Reproducibility is required in the diff --git a/internal/runtime/z_face.go b/internal/runtime/z_face.go index e4d328cd..827e388b 100644 --- a/internal/runtime/z_face.go +++ b/internal/runtime/z_face.go @@ -218,6 +218,11 @@ func Interface(pkgPath, name string, methods []Imethod) *InterfaceType { PkgPath_: pkgPath, Methods: methods, } + if len(methods) == 0 { + ret.Equal = nilinterequal + } else { + ret.Equal = interequal + } return ret } @@ -355,12 +360,6 @@ func Implements(T, V *abi.Type) bool { } func EfaceEqual(v, u eface) bool { - if v.Kind() == abi.Interface { - v = v.Elem() - } - if u.Kind() == abi.Interface { - u = u.Elem() - } if v._type == nil || u._type == nil { return v._type == u._type } @@ -370,52 +369,10 @@ func EfaceEqual(v, u eface) bool { if isDirectIface(v._type) { return v.data == u.data } - switch v.Kind() { - case abi.Bool, - abi.Int, abi.Int8, abi.Int16, abi.Int32, abi.Int64, - abi.Uint, abi.Uint8, abi.Uint16, abi.Uint32, abi.Uint64, abi.Uintptr, - abi.Float32, abi.Float64: - return *(*uintptr)(v.data) == *(*uintptr)(u.data) - case abi.Complex64: - return *(*complex64)(v.data) == *(*complex64)(u.data) - case abi.Complex128: - return *(*complex128)(v.data) == *(*complex128)(u.data) - case abi.String: - return *(*string)(v.data) == *(*string)(u.data) - case abi.Pointer, abi.UnsafePointer: - return v.data == u.data - case abi.Array: - n := v._type.Len() - tt := v._type.ArrayType() - index := func(data unsafe.Pointer, i int) eface { - offset := i * int(tt.Elem.Size_) - return eface{tt.Elem, c.Advance(data, offset)} - } - for i := 0; i < n; i++ { - if !EfaceEqual(index(v.data, i), index(u.data, i)) { - return false - } - } - return true - case abi.Struct: - st := v._type.StructType() - field := func(data unsafe.Pointer, ft *abi.StructField) eface { - ptr := c.Advance(data, int(ft.Offset)) - if isDirectIface(ft.Typ) { - ptr = *(*unsafe.Pointer)(ptr) - } - return eface{ft.Typ, ptr} - } - for _, ft := range st.Fields { - if !EfaceEqual(field(v.data, &ft), field(u.data, &ft)) { - return false - } - } - return true - case abi.Func, abi.Map, abi.Slice: - break + if equal := v._type.Equal; equal != nil { + return equal(v.data, u.data) } - panic("not comparable") + panic(errorString("comparing uncomparable type " + v._type.String()).Error()) } func (v eface) Kind() abi.Kind { diff --git a/internal/runtime/z_map.go b/internal/runtime/z_map.go index 2cb8e9ee..76fee872 100644 --- a/internal/runtime/z_map.go +++ b/internal/runtime/z_map.go @@ -82,29 +82,3 @@ func MapIterNext(it *hiter) (ok bool, k unsafe.Pointer, v unsafe.Pointer) { mapiternext(it) return } - -func mapKeyEqual(t *maptype, p, q unsafe.Pointer) bool { - if isDirectIface(t.Key) { - switch t.Key.Size_ { - case 0: - return true - case 1: - return memequal8(p, q) - case 2: - return memequal16(p, q) - case 4: - return memequal32(p, q) - case 8: - return memequal64(p, q) - } - } - switch t.Key.Kind() { - case abi.String: - return strequal(p, q) - case abi.Complex64: - return c64equal(p, q) - case abi.Complex128: - return c128equal(p, q) - } - return t.Key.Equal(p, q) -} diff --git a/internal/runtime/z_type.go b/internal/runtime/z_type.go index 3caacb2a..684362d3 100644 --- a/internal/runtime/z_type.go +++ b/internal/runtime/z_type.go @@ -30,6 +30,36 @@ var ( tyBasic [abi.UnsafePointer + 1]*Type ) +func basicEqual(kind Kind, size uintptr) func(a, b unsafe.Pointer) bool { + switch kind { + case abi.Bool, abi.Int, abi.Int8, abi.Int16, abi.Int32, abi.Int64, + abi.Uint, abi.Uint8, abi.Uint16, abi.Uint32, abi.Uint64, abi.Uintptr: + switch size { + case 1: + return memequal8 + case 2: + return memequal16 + case 4: + return memequal32 + case 8: + return memequal64 + } + case abi.Float32: + return f32equal + case abi.Float64: + return f64equal + case abi.Complex64: + return c64equal + case abi.Complex128: + return c128equal + case abi.String: + return strequal + case abi.UnsafePointer: + return ptrequal + } + panic("unreachable") +} + func Basic(kind Kind) *Type { if tyBasic[kind] == nil { name, size, align := basicTypeInfo(kind) @@ -39,10 +69,8 @@ func Basic(kind Kind) *Type { Align_: uint8(align), FieldAlign_: uint8(align), Kind_: uint8(kind), + Equal: basicEqual(kind, size), Str_: name, - Equal: func(a, b unsafe.Pointer) bool { - return uintptr(a) == uintptr(b) - }, } } return tyBasic[kind] @@ -115,15 +143,33 @@ func Struct(pkgPath string, size uintptr, fields ...abi.StructField) *Type { PkgPath_: pkgPath, Fields: fields, } + var comparable bool = true var typalign uint8 for _, f := range fields { ft := f.Typ if ft.Align_ > typalign { typalign = ft.Align_ } + comparable = comparable && (ft.Equal != nil) } ret.Align_ = typalign ret.FieldAlign_ = typalign + if comparable { + if size == 0 { + ret.Equal = memequal0 + } else { + ret.Equal = func(p, q unsafe.Pointer) bool { + for _, ft := range fields { + pi := add(p, ft.Offset) + qi := add(q, ft.Offset) + if !ft.Typ.Equal(pi, qi) { + return false + } + } + return true + } + } + } return &ret.Type } @@ -149,6 +195,7 @@ func newPointer(elem *Type) *Type { Align_: pointerAlign, FieldAlign_: pointerAlign, Kind_: uint8(abi.Pointer), + Equal: ptrequal, }, Elem: elem, } @@ -192,6 +239,22 @@ func ArrayOf(length uintptr, elem *Type) *Type { Slice: SliceOf(elem), Len: length, } + if eequal := elem.Equal; eequal != nil { + if elem.Size_ == 0 { + ret.Equal = memequal0 + } else { + ret.Equal = func(p, q unsafe.Pointer) bool { + for i := uintptr(0); i < length; i++ { + pi := add(p, i*elem.Size_) + qi := add(q, i*elem.Size_) + if !eequal(pi, qi) { + return false + } + } + return true + } + } + } return &ret.Type }