diff --git a/client/sync.go b/client/sync.go index 233313d4..b3f71310 100644 --- a/client/sync.go +++ b/client/sync.go @@ -232,6 +232,31 @@ func SyncStateHas(roomID string, check func(gjson.Result) bool) SyncCheckOpt { } } +// Check that the `state_after` section for `roomID` has an event which passes the check function. +// +// Note that the `state_after` section of a sync response will not contain the entire +// state of the room for incremental or `lazy_load_members` syncs. +func SyncStateAfterHas(roomID string, check func(gjson.Result) bool) SyncCheckOpt { + return func(clientUserID string, topLevelSyncJSON gjson.Result) error { + // Check the stable field + errStable := checkArrayElements( + topLevelSyncJSON, "rooms.join."+GjsonEscape(roomID)+".state_after.events", check, + ) + // Check the unstable field + // + // FIXME: Some implementations haven't stabilized yet (Synapse) so we'll keep this + // here until then. + errUnstable := checkArrayElements( + topLevelSyncJSON, "rooms.join."+GjsonEscape(roomID)+"."+GjsonEscape("org.matrix.msc4222.state_after")+".events", check, + ) + // Valid to find it in either place + if errStable == nil || errUnstable == nil { + return nil + } + return fmt.Errorf("SyncStateAfterHas(%s): Tried to check in the stable field: %s - and unstable field: %s", roomID, errStable, errUnstable) + } +} + func SyncEphemeralHas(roomID string, check func(gjson.Result) bool) SyncCheckOpt { return func(clientUserID string, topLevelSyncJSON gjson.Result) error { err := checkArrayElements( diff --git a/tests/msc4140/delayed_event_test.go b/tests/msc4140/delayed_event_test.go index 03557d88..6c93ffb5 100644 --- a/tests/msc4140/delayed_event_test.go +++ b/tests/msc4140/delayed_event_test.go @@ -30,16 +30,6 @@ const ( DelayedEventActionSend = "send" ) -// A filter for `/sync` that excludes timeline events. -// -// This is useful if you want to see `state` in the `/sync` response without the pesky -// de-duplication with `timeline` that traditional `/sync` does. -const NoTimelineSyncFilter = `{ - "room": { - "timeline": { "limit": 0 } - } -}` - // TODO: Test pagination of `GET /_matrix/client/v1/delayed_events` once // it is implemented in a homeserver. @@ -195,7 +185,7 @@ func TestDelayedEvents(t *testing.T) { // Check for the state change from the delayed state event (using `MustSyncUntil` to // account for any processing or worker replication delays) - user.MustSyncUntil(t, client.SyncReq{Filter: NoTimelineSyncFilter}, client.SyncStateHas(roomID, func(ev gjson.Result) bool { + user.MustSyncUntil(t, client.SyncReq{UseStateAfter: true}, client.SyncStateAfterHas(roomID, func(ev gjson.Result) bool { return ev.Get("type").Str == eventType && ev.Get("state_key").Str == stateKey })) // Make sure the state looks as expected after @@ -347,7 +337,7 @@ func TestDelayedEvents(t *testing.T) { // Check for the state change from the delayed state event (using `MustSyncUntil` to // account for any processing or worker replication delays) - user.MustSyncUntil(t, client.SyncReq{Filter: NoTimelineSyncFilter}, client.SyncStateHas(roomID, func(ev gjson.Result) bool { + user.MustSyncUntil(t, client.SyncReq{UseStateAfter: true}, client.SyncStateAfterHas(roomID, func(ev gjson.Result) bool { return ev.Get("type").Str == eventType && ev.Get("state_key").Str == stateKey })) // Make sure the state looks as expected after @@ -417,7 +407,7 @@ func TestDelayedEvents(t *testing.T) { // Check for the state change from the delayed state event (using `MustSyncUntil` to // account for any processing or worker replication delays) - user.MustSyncUntil(t, client.SyncReq{Filter: NoTimelineSyncFilter}, client.SyncStateHas(roomID, func(ev gjson.Result) bool { + user.MustSyncUntil(t, client.SyncReq{UseStateAfter: true}, client.SyncStateAfterHas(roomID, func(ev gjson.Result) bool { return ev.Get("type").Str == eventType && ev.Get("state_key").Str == stateKey })) // Make sure the state looks as expected after @@ -512,7 +502,7 @@ func TestDelayedEvents(t *testing.T) { // Sanity check that the room state was updated correctly with the delayed events // that were sent. (using `MustSyncUntil` to account for any processing or worker // replication delays) - user.MustSyncUntil(t, client.SyncReq{Filter: NoTimelineSyncFilter}, client.SyncStateHas(roomID, func(ev gjson.Result) bool { + user.MustSyncUntil(t, client.SyncReq{UseStateAfter: true}, client.SyncStateAfterHas(roomID, func(ev gjson.Result) bool { return ev.Get("type").Str == eventType && ev.Get("state_key").Str == stateKey1 })) @@ -528,7 +518,7 @@ func TestDelayedEvents(t *testing.T) { // FIXME: Ideally, we'd check specifically for the last one that was sent but it // will be a bit of a juggle and fiddly to get this right so for now we just check // one. - user.MustSyncUntil(t, client.SyncReq{Filter: NoTimelineSyncFilter}, client.SyncStateHas(roomID, func(ev gjson.Result) bool { + user.MustSyncUntil(t, client.SyncReq{UseStateAfter: true}, client.SyncStateAfterHas(roomID, func(ev gjson.Result) bool { return ev.Get("type").Str == eventType && ev.Get("state_key").Str == stateKey2 })) })