Fix heap-use-after-free when pubsubshard_channels became NULL (#13038)

After fix for #13033, address sanitizer reports this heap-use-after-free
error. When the pubsubshard_channels dict becomes empty, we will delete
the dict, and the dictReleaseIterator will call dictResetIterator, it
will use the dict so we will trigger the error.

This PR introduced a new struct kvstoreDictIterator to wrap
dictIterator.
Replace the original dict iterator with the new kvstore dict iterator.

---------

Co-authored-by: Oran Agra <oran@redislabs.com>
Co-authored-by: guybe7 <guy.benoish@redislabs.com>
This commit is contained in:
Binbin 2024-02-07 20:53:50 +08:00 committed by GitHub
parent 886b117031
commit 81666a6510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 64 additions and 31 deletions

View File

@ -917,16 +917,16 @@ void clusterCommand(client *c) {
unsigned int keys_in_slot = countKeysInSlot(slot);
unsigned int numkeys = maxkeys > keys_in_slot ? keys_in_slot : maxkeys;
addReplyArrayLen(c,numkeys);
dictIterator *iter = NULL;
kvstoreDictIterator *kvs_di = NULL;
dictEntry *de = NULL;
iter = kvstoreDictGetIterator(server.db->keys, slot);
kvs_di = kvstoreGetDictIterator(server.db->keys, slot);
for (unsigned int i = 0; i < numkeys; i++) {
de = dictNext(iter);
de = kvstoreDictIteratorNext(kvs_di);
serverAssert(de != NULL);
sds sdskey = dictGetKey(de);
addReplyBulkCBuffer(c, sdskey, sdslen(sdskey));
}
dictReleaseIterator(iter);
kvstoreReleaseDictIterator(kvs_di);
} else if ((!strcasecmp(c->argv[1]->ptr,"slaves") ||
!strcasecmp(c->argv[1]->ptr,"replicas")) && c->argc == 3) {
/* CLUSTER SLAVES <NODE ID> */

View File

@ -5742,10 +5742,10 @@ unsigned int delKeysInSlot(unsigned int hashslot) {
unsigned int j = 0;
dictIterator *iter = NULL;
kvstoreDictIterator *kvs_di = NULL;
dictEntry *de = NULL;
iter = kvstoreDictGetSafeIterator(server.db->keys, hashslot);
while((de = dictNext(iter)) != NULL) {
kvs_di = kvstoreGetDictSafeIterator(server.db->keys, hashslot);
while((de = kvstoreDictIteratorNext(kvs_di)) != NULL) {
enterExecutionUnit(1, 0);
sds sdskey = dictGetKey(de);
robj *key = createStringObject(sdskey, sdslen(sdskey));
@ -5762,7 +5762,7 @@ unsigned int delKeysInSlot(unsigned int hashslot) {
j++;
server.dirty++;
}
dictReleaseIterator(iter);
kvstoreReleaseDictIterator(kvs_di);
return j;
}

View File

@ -800,7 +800,7 @@ void keysCommand(client *c) {
if (server.cluster_enabled && !allkeys) {
pslot = patternHashSlot(pattern, plen);
}
dictIterator *di = NULL;
kvstoreDictIterator *kvs_di = NULL;
kvstoreIterator *kvs_it = NULL;
if (pslot != -1) {
if (!kvstoreDictSize(c->db->keys, pslot)) {
@ -808,12 +808,12 @@ void keysCommand(client *c) {
setDeferredArrayLen(c,replylen,0);
return;
}
di = kvstoreDictGetSafeIterator(c->db->keys, pslot);
kvs_di = kvstoreGetDictSafeIterator(c->db->keys, pslot);
} else {
kvs_it = kvstoreIteratorInit(c->db->keys);
}
robj keyobj;
while ((de = di ? dictNext(di) : kvstoreIteratorNext(kvs_it)) != NULL) {
while ((de = kvs_di ? kvstoreDictIteratorNext(kvs_di) : kvstoreIteratorNext(kvs_it)) != NULL) {
sds key = dictGetKey(de);
if (allkeys || stringmatchlen(pattern,plen,key,sdslen(key),0)) {
@ -826,8 +826,8 @@ void keysCommand(client *c) {
if (c->flags & CLIENT_CLOSE_ASAP)
break;
}
if (di)
dictReleaseIterator(di);
if (kvs_di)
kvstoreReleaseDictIterator(kvs_di);
if (kvs_it)
kvstoreIteratorRelease(kvs_it);
setDeferredArrayLen(c,replylen,numkeys);

View File

@ -71,6 +71,13 @@ struct _kvstoreIterator {
dictIterator di;
};
/* Structure for kvstore dict iterator that allows iterating the corresponding dict. */
struct _kvstoreDictIterator {
kvstore *kvs;
long long didx;
dictIterator di;
};
/* Dict metadata for database, used for record the position in rehashing list. */
typedef struct {
listNode *rehashing_node; /* list node in rehashing list */
@ -80,6 +87,7 @@ typedef struct {
/*** Helpers **********************/
/**********************************/
/* Get the dictionary pointer based on dict-index. */
static dict *kvstoreGetDict(kvstore *kvs, int didx) {
return kvs->dicts[didx];
}
@ -529,7 +537,7 @@ kvstoreIterator *kvstoreIteratorInit(kvstore *kvs) {
return kvs_it;
}
/* Free the dbit returned by dbIteratorInit. */
/* Free the kvs_it returned by kvstoreIteratorInit. */
void kvstoreIteratorRelease(kvstoreIterator *kvs_it) {
dictIterator *iter = &kvs_it->di;
dictResetIterator(iter);
@ -621,16 +629,41 @@ unsigned long kvstoreDictSize(kvstore *kvs, int didx)
return dictSize(d);
}
dictIterator *kvstoreDictGetIterator(kvstore *kvs, int didx)
kvstoreDictIterator *kvstoreGetDictIterator(kvstore *kvs, int didx)
{
dict *d = kvstoreGetDict(kvs, didx);
return dictGetIterator(d);
kvstoreDictIterator *kvs_di = zmalloc(sizeof(*kvs_di));
kvs_di->kvs = kvs;
kvs_di->didx = didx;
dictInitIterator(&kvs_di->di, kvstoreGetDict(kvs, didx));
return kvs_di;
}
dictIterator *kvstoreDictGetSafeIterator(kvstore *kvs, int didx)
kvstoreDictIterator *kvstoreGetDictSafeIterator(kvstore *kvs, int didx)
{
dict *d = kvstoreGetDict(kvs, didx);
return dictGetSafeIterator(d);
kvstoreDictIterator *kvs_di = zmalloc(sizeof(*kvs_di));
kvs_di->kvs = kvs;
kvs_di->didx = didx;
dictInitSafeIterator(&kvs_di->di, kvstoreGetDict(kvs, didx));
return kvs_di;
}
/* Free the kvs_di returned by kvstoreGetDictIterator and kvstoreGetDictSafeIterator. */
void kvstoreReleaseDictIterator(kvstoreDictIterator *kvs_di)
{
/* The dict may be deleted during the iteration process, so here need to check for NULL. */
if (kvstoreGetDict(kvs_di->kvs, kvs_di->didx)) dictResetIterator(&kvs_di->di);
zfree(kvs_di);
}
/* Get the next element of the dict through kvstoreDictIterator and dictNext. */
dictEntry *kvstoreDictIteratorNext(kvstoreDictIterator *kvs_di)
{
/* The dict may be deleted during the iteration process, so here need to check for NULL. */
dict *d = kvstoreGetDict(kvs_di->kvs, kvs_di->didx);
if (!d) return NULL;
return dictNext(&kvs_di->di);
}
dictEntry *kvstoreDictGetRandomKey(kvstore *kvs, int didx)

View File

@ -6,6 +6,7 @@
typedef struct _kvstore kvstore;
typedef struct _kvstoreIterator kvstoreIterator;
typedef struct _kvstoreDictIterator kvstoreDictIterator;
typedef int (kvstoreScanShouldSkipDict)(dict *d);
typedef int (kvstoreExpandShouldSkipDictIndex)(int didx);
@ -45,8 +46,10 @@ uint64_t kvstoreIncrementallyRehash(kvstore *kvs, uint64_t threshold_ms);
/* Specific dict access by dict-index */
unsigned long kvstoreDictSize(kvstore *kvs, int didx);
dictIterator *kvstoreDictGetIterator(kvstore *kvs, int didx);
dictIterator *kvstoreDictGetSafeIterator(kvstore *kvs, int didx);
kvstoreDictIterator *kvstoreGetDictIterator(kvstore *kvs, int didx);
kvstoreDictIterator *kvstoreGetDictSafeIterator(kvstore *kvs, int didx);
void kvstoreReleaseDictIterator(kvstoreDictIterator *kvs_id);
dictEntry *kvstoreDictIteratorNext(kvstoreDictIterator *kvs_di);
dictEntry *kvstoreDictGetRandomKey(kvstore *kvs, int didx);
dictEntry *kvstoreDictGetFairRandomKey(kvstore *kvs, int didx);
dictEntry *kvstoreDictFindEntryByPtrAndHash(kvstore *kvs, int didx, const void *oldptr, uint64_t hash);

View File

@ -329,9 +329,9 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) {
if (!kvstoreDictSize(server.pubsubshard_channels, slot))
return;
dictIterator *di = kvstoreDictGetSafeIterator(server.pubsubshard_channels, slot);
kvstoreDictIterator *kvs_di = kvstoreGetDictSafeIterator(server.pubsubshard_channels, slot);
dictEntry *de;
while ((de = dictNext(di)) != NULL) {
while ((de = kvstoreDictIteratorNext(kvs_di)) != NULL) {
robj *channel = dictGetKey(de);
dict *clients = dictGetVal(de);
/* For each client subscribed to the channel, unsubscribe it. */
@ -350,11 +350,8 @@ void pubsubShardUnsubscribeAllChannelsInSlot(unsigned int slot) {
}
dictReleaseIterator(iter);
kvstoreDictDelete(server.pubsubshard_channels, slot, channel);
/* After the dict becomes empty, the dict will be deleted.
* We break out without calling dictNext. */
if (!kvstoreDictSize(server.pubsubshard_channels, slot)) break;
}
dictReleaseIterator(di);
kvstoreReleaseDictIterator(kvs_di);
}
/* Subscribe a client to a pattern. Returns 1 if the operation succeeded, or 0 if the client was already subscribed to that pattern. */
@ -697,9 +694,9 @@ void channelList(client *c, sds pat, kvstore *pubsub_channels) {
for (unsigned int i = 0; i < slot_cnt; i++) {
if (!kvstoreDictSize(pubsub_channels, i))
continue;
dictIterator *di = kvstoreDictGetIterator(pubsub_channels, i);
kvstoreDictIterator *kvs_di = kvstoreGetDictIterator(pubsub_channels, i);
dictEntry *de;
while((de = dictNext(di)) != NULL) {
while((de = kvstoreDictIteratorNext(kvs_di)) != NULL) {
robj *cobj = dictGetKey(de);
sds channel = cobj->ptr;
@ -710,7 +707,7 @@ void channelList(client *c, sds pat, kvstore *pubsub_channels) {
mblen++;
}
}
dictReleaseIterator(di);
kvstoreReleaseDictIterator(kvs_di);
}
setDeferredArrayLen(c,replylen,mblen);
}