Optimize performance when many clients [p|s]unsubscribe simultaneously (#12838)

I'm testing the performance of Pub/Sub command recently. I find if many
clients unsubscribe or are killed simultaneously, Redis needs a long
time to deal with it.

In my experiment, I set 5000 clients and each client subscribes 100
channels. Then I call `client kill type pubsub` to simulate the
situation where clients unsubscribe all channels at the same time and
calculate the execution time. The result shows that it takes about 23s.
I use the _perf_ and find that `listSearchKey` in
`pubsubUnsubscribeChannel` costs more than 90% cpu time. I think we can
optimize this situation.

In this PR, I replace list with dict to track the clients subscribing
the channel more efficiently. It changes O(N) to O(1) in the search
phase. Then I repeat the experiment as above. The results are as
follows.

|              | Execution Time(s) |used_memory(MB) |
| :---------------- | :------: | :----: |
| unstable(1bd0b54)        |   23.734   | 65.41 |
| optimize-pubsub           |   0.288   | 67.66 |

Thanks for #11595 , I use a no-value dict and the results shows that the
performance improves significantly but the memory usage only increases
slightly.

Notice:

- This PR will cause the performance degradation about 20% in
`[p|s]subscribe` command but won't freeze Redis.
This commit is contained in:
Yanqi Lv 2024-01-08 16:32:31 +08:00 committed by GitHub
parent 4730563e93
commit c452e414a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 45 deletions

View File

@ -280,7 +280,7 @@ void unmarkClientAsPubSub(client *c) {
int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) {
dict **d_ptr;
dictEntry *de;
list *clients = NULL;
dict *clients = NULL;
int retval = 0;
unsigned int slot = 0;
@ -294,13 +294,13 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) {
}
d_ptr = type.serverPubSubChannels(slot);
if (*d_ptr == NULL) {
*d_ptr = dictCreate(&keylistDictType);
*d_ptr = dictCreate(&objToDictDictType);
de = NULL;
} else {
de = dictFind(*d_ptr, channel);
}
if (de == NULL) {
clients = listCreate();
clients = dictCreate(&clientDictType);
dictAdd(*d_ptr, channel, clients);
incrRefCount(channel);
if (type.shard) {
@ -309,7 +309,7 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) {
} else {
clients = dictGetVal(de);
}
listAddNodeTail(clients,c);
serverAssert(dictAdd(clients, c, NULL) != DICT_ERR);
}
/* Notify the client */
addReplyPubsubSubscribed(c,channel,type);
@ -321,8 +321,7 @@ int pubsubSubscribeChannel(client *c, robj *channel, pubsubtype type) {
int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype type) {
dict *d;
dictEntry *de;
list *clients;
listNode *ln;
dict *clients;
int retval = 0;
int slot = 0;
@ -340,11 +339,9 @@ int pubsubUnsubscribeChannel(client *c, robj *channel, int notify, pubsubtype ty
de = dictFind(d, channel);
serverAssertWithInfo(c,NULL,de != NULL);
clients = dictGetVal(de);
ln = listSearchKey(clients,c);
serverAssertWithInfo(c,NULL,ln != NULL);
listDelNode(clients,ln);
if (listLength(clients) == 0) {
/* Free the list and associated hash entry at all if this was
serverAssertWithInfo(c, NULL, dictDelete(clients, c) == DICT_OK);
if (dictSize(clients) == 0) {
/* Free the dict and associated hash entry at all if this was
* the latest client, so that it will be possible to abuse
* Redis PUBSUB creating millions of channels. */
dictDelete(d, channel);
@ -376,11 +373,13 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) {
dictEntry *de;
while ((de = dictNext(di)) != NULL) {
robj *channel = dictGetKey(de);
list *clients = dictGetVal(de);
dict *clients = dictGetVal(de);
if (dictSize(clients) == 0) goto cleanup;
/* For each client subscribed to the channel, unsubscribe it. */
listNode *ln;
while ((ln = listFirst(clients)) != NULL) {
client *c = listNodeValue(ln);
dictIterator *iter = dictGetSafeIterator(clients);
dictEntry *entry;
while ((entry = dictNext(iter)) != NULL) {
client *c = dictGetKey(entry);
int retval = dictDelete(c->pubsubshard_channels, channel);
serverAssertWithInfo(c,channel,retval == DICT_OK);
addReplyPubsubUnsubscribed(c, channel, pubSubShardType);
@ -389,8 +388,9 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) {
if (clientTotalPubSubSubscriptionCount(c) == 0) {
unmarkClientAsPubSub(c);
}
listDelNode(clients, ln);
}
dictReleaseIterator(iter);
cleanup:
server.shard_channel_count--;
dictDelete(d, channel);
}
@ -402,7 +402,7 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) {
/* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */
int pubsubSubscribePattern(client *c, robj *pattern) {
dictEntry *de;
list *clients;
dict *clients;
int retval = 0;
if (dictAdd(c->pubsub_patterns, pattern, NULL) == DICT_OK) {
@ -411,13 +411,13 @@ int pubsubSubscribePattern(client *c, robj *pattern) {
/* Add the client to the pattern -> list of clients hash table */
de = dictFind(server.pubsub_patterns,pattern);
if (de == NULL) {
clients = listCreate();
clients = dictCreate(&clientDictType);
dictAdd(server.pubsub_patterns,pattern,clients);
incrRefCount(pattern);
} else {
clients = dictGetVal(de);
}
listAddNodeTail(clients,c);
serverAssert(dictAdd(clients, c, NULL) != DICT_ERR);
}
/* Notify the client */
addReplyPubsubPatSubscribed(c,pattern);
@ -428,8 +428,7 @@ int pubsubSubscribePattern(client *c, robj *pattern) {
* 0 if the client was not subscribed to the specified channel. */
int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) {
dictEntry *de;
list *clients;
listNode *ln;
dict *clients;
int retval = 0;
incrRefCount(pattern); /* Protect the object. May be the same we remove */
@ -439,11 +438,9 @@ int pubsubUnsubscribePattern(client *c, robj *pattern, int notify) {
de = dictFind(server.pubsub_patterns,pattern);
serverAssertWithInfo(c,NULL,de != NULL);
clients = dictGetVal(de);
ln = listSearchKey(clients,c);
serverAssertWithInfo(c,NULL,ln != NULL);
listDelNode(clients,ln);
if (listLength(clients) == 0) {
/* Free the list and associated hash entry at all if this was
serverAssertWithInfo(c, NULL, dictDelete(clients, c) == DICT_OK);
if (dictSize(clients) == 0) {
/* Free the dict and associated hash entry at all if this was
* the latest client. */
dictDelete(server.pubsub_patterns,pattern);
}
@ -521,8 +518,6 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type)
dict *d;
dictEntry *de;
dictIterator *di;
listNode *ln;
listIter li;
unsigned int slot = 0;
/* Send to clients listening for that channel */
@ -532,17 +527,16 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type)
d = *type.serverPubSubChannels(slot);
de = d ? dictFind(d, channel) : NULL;
if (de) {
list *list = dictGetVal(de);
listNode *ln;
listIter li;
listRewind(list,&li);
while ((ln = listNext(&li)) != NULL) {
client *c = ln->value;
dict *clients = dictGetVal(de);
dictEntry *entry;
dictIterator *iter = dictGetSafeIterator(clients);
while ((entry = dictNext(iter)) != NULL) {
client *c = dictGetKey(entry);
addReplyPubsubMessage(c,channel,message,*type.messageBulk);
updateClientMemUsageAndBucket(c);
receivers++;
}
dictReleaseIterator(iter);
}
if (type.shard) {
@ -556,19 +550,21 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type)
channel = getDecodedObject(channel);
while((de = dictNext(di)) != NULL) {
robj *pattern = dictGetKey(de);
list *clients = dictGetVal(de);
dict *clients = dictGetVal(de);
if (!stringmatchlen((char*)pattern->ptr,
sdslen(pattern->ptr),
(char*)channel->ptr,
sdslen(channel->ptr),0)) continue;
listRewind(clients,&li);
while ((ln = listNext(&li)) != NULL) {
client *c = listNodeValue(ln);
dictEntry *entry;
dictIterator *iter = dictGetSafeIterator(clients);
while ((entry = dictNext(iter)) != NULL) {
client *c = dictGetKey(entry);
addReplyPubsubPatMessage(c,pattern,channel,message);
updateClientMemUsageAndBucket(c);
receivers++;
}
dictReleaseIterator(iter);
}
decrRefCount(channel);
dictReleaseIterator(di);
@ -706,10 +702,10 @@ NULL
addReplyArrayLen(c,(c->argc-2)*2);
for (j = 2; j < c->argc; j++) {
list *l = dictFetchValue(server.pubsub_channels,c->argv[j]);
dict *d = dictFetchValue(server.pubsub_channels, c->argv[j]);
addReplyBulk(c,c->argv[j]);
addReplyLongLong(c,l ? listLength(l) : 0);
addReplyLongLong(c, d ? dictSize(d) : 0);
}
} else if (!strcasecmp(c->argv[1]->ptr,"numpat") && c->argc == 2) {
/* PUBSUB NUMPAT */
@ -727,10 +723,10 @@ NULL
for (j = 2; j < c->argc; j++) {
unsigned int slot = calculateKeySlot(c->argv[j]->ptr);
dict *d = server.pubsubshard_channels[slot];
list *l = d ? dictFetchValue(d, c->argv[j]) : NULL;
dict *clients = d ? dictFetchValue(d, c->argv[j]) : NULL;
addReplyBulk(c,c->argv[j]);
addReplyLongLong(c,l ? listLength(l) : 0);
addReplyLongLong(c, d ? dictSize(clients) : 0);
}
} else {
addReplySubcommandSyntaxError(c);

View File

@ -282,6 +282,12 @@ void dictListDestructor(dict *d, void *val)
listRelease((list*)val);
}
void dictDictDestructor(dict *d, void *val)
{
UNUSED(d);
dictRelease((dict*)val);
}
int dictSdsKeyCompare(dict *d, const void *key1,
const void *key2)
{
@ -351,6 +357,17 @@ uint64_t dictCStrCaseHash(const void *key) {
return dictGenCaseHashFunction((unsigned char*)key, strlen((char*)key));
}
/* Dict hash function for client */
uint64_t dictClientHash(const void *key) {
return ((client *)key)->id;
}
/* Dict compare function for client */
int dictClientKeyCompare(dict *d, const void *key1, const void *key2) {
UNUSED(d);
return ((client *)key1)->id == ((client *)key2)->id;
}
/* Dict compare function for null terminated string */
int dictCStrKeyCompare(dict *d, const void *key1, const void *key2) {
int l1,l2;
@ -596,6 +613,18 @@ dictType keylistDictType = {
NULL /* allow to expand */
};
/* KeyDict hash table type has unencoded redis objects as keys and
* dicts as values. It's used for PUBSUB command to track clients subscribing the channels. */
dictType objToDictDictType = {
dictObjHash, /* hash function */
NULL, /* key dup */
NULL, /* val dup */
dictObjKeyCompare, /* key compare */
dictObjectDestructor, /* key destructor */
dictDictDestructor, /* val destructor */
NULL /* allow to expand */
};
/* Modules system dictionary type. Keys are module name,
* values are pointer to RedisModule struct. */
dictType modulesDictType = {
@ -655,6 +684,15 @@ dictType sdsHashDictType = {
NULL /* allow to expand */
};
/* Client Set dictionary type. Keys are client, values are not used. */
dictType clientDictType = {
dictClientHash, /* hash function */
NULL, /* key dup */
NULL, /* val dup */
dictClientKeyCompare, /* key compare */
.no_value = 1 /* no values in this dict */
};
int htNeedsResize(dict *dict) {
long long size, used;
@ -2745,8 +2783,8 @@ void initServer(void) {
}
server.rehashing = listCreate();
evictionPoolAlloc(); /* Initialize the LRU keys pool. */
server.pubsub_channels = dictCreate(&keylistDictType);
server.pubsub_patterns = dictCreate(&keylistDictType);
server.pubsub_channels = dictCreate(&objToDictDictType);
server.pubsub_patterns = dictCreate(&objToDictDictType);
server.pubsubshard_channels = zcalloc(sizeof(dict *) * slot_count);
server.shard_channel_count = 0;
server.pubsub_clients = 0;

View File

@ -2499,6 +2499,8 @@ extern dictType hashDictType;
extern dictType stringSetDictType;
extern dictType externalStringType;
extern dictType sdsHashDictType;
extern dictType clientDictType;
extern dictType objToDictDictType;
extern dictType dbExpiresDictType;
extern dictType modulesDictType;
extern dictType sdsReplyDictType;