diff --git a/sync3/connmap_test.go b/sync3/connmap_test.go index da6ff364..8fbcf8cd 100644 --- a/sync3/connmap_test.go +++ b/sync3/connmap_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/matrix-org/complement/must" "github.com/matrix-org/sliding-sync/sync3/caches" ) @@ -17,6 +16,15 @@ const ( bob = "@bob:localhost" ) +// mustEqual ensures that got==want else logs an error. +// The 'msg' is displayed with the error to provide extra context. +func mustEqual[V comparable](t *testing.T, got, want V, msg string) { + t.Helper() + if got != want { + t.Errorf("Equal %s: got '%v' want '%v'", msg, got, want) + } +} + func TestConnMap(t *testing.T) { cm := NewConnMap(false, time.Minute) cid := ConnID{UserID: alice, DeviceID: "A", CID: "room-list"} @@ -24,13 +32,13 @@ func TestConnMap(t *testing.T) { conn := cm.CreateConn(cid, cancel, func() ConnHandler { return &mockConnHandler{} }) - must.Equal(t, conn.ConnID, cid, "cid mismatch") + mustEqual(t, conn.ConnID, cid, "cid mismatch") // lookups work - must.Equal(t, cm.Conn(cid), conn, "*Conn wasn't the same when fetched via Conn(ConnID)") + mustEqual(t, cm.Conn(cid), conn, "*Conn wasn't the same when fetched via Conn(ConnID)") conns := cm.Conns(cid.UserID, cid.DeviceID) - must.Equal(t, len(conns), 1, "Conns length mismatch") - must.Equal(t, conns[0], conn, "*Conn wasn't the same when fetched via Conns()[0]") + mustEqual(t, len(conns), 1, "Conns length mismatch") + mustEqual(t, conns[0], conn, "*Conn wasn't the same when fetched via Conns()[0]") } func TestConnMap_CloseConnsForDevice(t *testing.T) { @@ -85,7 +93,7 @@ func TestConnMap_CloseConnsForUser(t *testing.T) { num := cm.CloseConnsForUsers([]string{alice}) time.Sleep(100 * time.Millisecond) // some stuff happens asyncly in goroutines - must.Equal(t, num, 6, "unexpected number of closed conns") + mustEqual(t, num, 6, "unexpected number of closed conns") // Destroy should have been called for all alice connections assertDestroyedConns(t, cidToConn, func(cid ConnID) bool { @@ -183,7 +191,7 @@ func TestConnMap_TTLExpiryStaggeredDevices(t *testing.T) { } sort.Strings(gotIDs) wantIDs := []string{"encryption", "notifications"} - must.Equal(t, len(conns), 2, "unexpected number of Conns for device") + mustEqual(t, len(conns), 2, "unexpected number of Conns for device") if !reflect.DeepEqual(gotIDs, wantIDs) { t.Fatalf("unexpected active conns: got %v want %v", gotIDs, wantIDs) } @@ -193,9 +201,9 @@ func assertDestroyedConns(t *testing.T, cidToConn map[ConnID]*Conn, isDestroyedF t.Helper() for cid, conn := range cidToConn { if isDestroyedFn(cid) { - must.Equal(t, conn.handler.(*mockConnHandler).isDestroyed, true, fmt.Sprintf("conn %+v was not destroyed", cid)) + mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, true, fmt.Sprintf("conn %+v was not destroyed", cid)) } else { - must.Equal(t, conn.handler.(*mockConnHandler).isDestroyed, false, fmt.Sprintf("conn %+v was destroyed", cid)) + mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, false, fmt.Sprintf("conn %+v was destroyed", cid)) } } }