Extend rax usage by allowing any long long value (#12837)

The raxFind implementation uses a special pointer value (the address of
a static string) as the "not found" value. It works as long as actual
pointers were used. However we've seen usages where long long,
non-pointer values have been used. It creates a risk that one of the
long long value precisely is the address of the special "not found"
value. This commit changes raxFind to return 1 or 0 to indicate
elementhood, and take in a new void **value to optionally return the
associated value.

By extension, this also allow the RedisModule_DictSet/Replace operations
to also safely insert integers instead of just pointers.
This commit is contained in:
Guillaume Koenig 2023-12-14 17:50:18 -05:00 committed by GitHub
parent e95a5d4831
commit 967fb3c6e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 97 additions and 73 deletions

View File

@ -437,7 +437,7 @@ aclSelector *ACLUserGetRootSelector(user *u) {
*
* If the user with such name already exists NULL is returned. */
user *ACLCreateUser(const char *name, size_t namelen) {
if (raxFind(Users,(unsigned char*)name,namelen) != raxNotFound) return NULL;
if (raxFind(Users,(unsigned char*)name,namelen,NULL)) return NULL;
user *u = zmalloc(sizeof(*u));
u->name = sdsnewlen(name,namelen);
u->flags = USER_FLAG_DISABLED;
@ -1553,8 +1553,8 @@ unsigned long ACLGetCommandID(sds cmdname) {
sds lowername = sdsdup(cmdname);
sdstolower(lowername);
if (commandId == NULL) commandId = raxNew();
void *id = raxFind(commandId,(unsigned char*)lowername,sdslen(lowername));
if (id != raxNotFound) {
void *id;
if (raxFind(commandId,(unsigned char*)lowername,sdslen(lowername),&id)) {
sdsfree(lowername);
return (unsigned long)id;
}
@ -1585,8 +1585,8 @@ void ACLClearCommandID(void) {
/* Return an username by its name, or NULL if the user does not exist. */
user *ACLGetUserByName(const char *name, size_t namelen) {
void *myuser = raxFind(Users,(unsigned char*)name,namelen);
if (myuser == raxNotFound) return NULL;
void *myuser = NULL;
raxFind(Users,(unsigned char*)name,namelen,&myuser);
return myuser;
}

View File

@ -9130,7 +9130,7 @@ RedisModuleTimerID RM_CreateTimer(RedisModuleCtx *ctx, mstime_t period, RedisMod
while(1) {
key = htonu64(expiretime);
if (raxFind(Timers, (unsigned char*)&key,sizeof(key)) == raxNotFound) {
if (!raxFind(Timers, (unsigned char*)&key,sizeof(key),NULL)) {
raxInsert(Timers,(unsigned char*)&key,sizeof(key),timer,NULL);
break;
} else {
@ -9169,8 +9169,11 @@ RedisModuleTimerID RM_CreateTimer(RedisModuleCtx *ctx, mstime_t period, RedisMod
* If not NULL, the data pointer is set to the value of the data argument when
* the timer was created. */
int RM_StopTimer(RedisModuleCtx *ctx, RedisModuleTimerID id, void **data) {
RedisModuleTimer *timer = raxFind(Timers,(unsigned char*)&id,sizeof(id));
if (timer == raxNotFound || timer->module != ctx->module)
void *result;
if (!raxFind(Timers,(unsigned char*)&id,sizeof(id),&result))
return REDISMODULE_ERR;
RedisModuleTimer *timer = result;
if (timer->module != ctx->module)
return REDISMODULE_ERR;
if (data) *data = timer->data;
raxRemove(Timers,(unsigned char*)&id,sizeof(id),NULL);
@ -9185,8 +9188,11 @@ int RM_StopTimer(RedisModuleCtx *ctx, RedisModuleTimerID id, void **data) {
* REDISMODULE_OK is returned. The arguments remaining or data can be NULL if
* the caller does not need certain information. */
int RM_GetTimerInfo(RedisModuleCtx *ctx, RedisModuleTimerID id, uint64_t *remaining, void **data) {
RedisModuleTimer *timer = raxFind(Timers,(unsigned char*)&id,sizeof(id));
if (timer == raxNotFound || timer->module != ctx->module)
void *result;
if (!raxFind(Timers,(unsigned char*)&id,sizeof(id),&result))
return REDISMODULE_ERR;
RedisModuleTimer *timer = result;
if (timer->module != ctx->module)
return REDISMODULE_ERR;
if (remaining) {
int64_t rem = ntohu64(id)-ustime();
@ -9954,9 +9960,10 @@ int RM_DictReplace(RedisModuleDict *d, RedisModuleString *key, void *ptr) {
* be set by reference to 1 if the key does not exist, or to 0 if the key
* exists. */
void *RM_DictGetC(RedisModuleDict *d, void *key, size_t keylen, int *nokey) {
void *res = raxFind(d->rax,key,keylen);
if (nokey) *nokey = (res == raxNotFound);
return (res == raxNotFound) ? NULL : res;
void *res = NULL;
int found = raxFind(d->rax,key,keylen,&res);
if (nokey) *nokey = !found;
return res;
}
/* Like RedisModule_DictGetC() but takes the key as a RedisModuleString. */
@ -10378,8 +10385,10 @@ void RM_FreeServerInfo(RedisModuleCtx *ctx, RedisModuleServerInfoData *data) {
* mechanism to release the returned string. Return value will be NULL if the
* field was not found. */
RedisModuleString *RM_ServerInfoGetField(RedisModuleCtx *ctx, RedisModuleServerInfoData *data, const char* field) {
sds val = raxFind(data->rax, (unsigned char *)field, strlen(field));
if (val == raxNotFound) return NULL;
void *result;
if (!raxFind(data->rax, (unsigned char *)field, strlen(field), &result))
return NULL;
sds val = result;
RedisModuleString *o = createStringObject(val,sdslen(val));
if (ctx != NULL) autoMemoryAdd(ctx,REDISMODULE_AM_STRING,o);
return o;
@ -10387,9 +10396,9 @@ RedisModuleString *RM_ServerInfoGetField(RedisModuleCtx *ctx, RedisModuleServerI
/* Similar to RM_ServerInfoGetField, but returns a char* which should not be freed but the caller. */
const char *RM_ServerInfoGetFieldC(RedisModuleServerInfoData *data, const char* field) {
sds val = raxFind(data->rax, (unsigned char *)field, strlen(field));
if (val == raxNotFound) return NULL;
return val;
void *result = NULL;
raxFind(data->rax, (unsigned char *)field, strlen(field), &result);
return result;
}
/* Get the value of a field from data collected with RM_GetServerInfo(). If the
@ -10397,11 +10406,12 @@ const char *RM_ServerInfoGetFieldC(RedisModuleServerInfoData *data, const char*
* 0, and the optional out_err argument will be set to REDISMODULE_ERR. */
long long RM_ServerInfoGetFieldSigned(RedisModuleServerInfoData *data, const char* field, int *out_err) {
long long ll;
sds val = raxFind(data->rax, (unsigned char *)field, strlen(field));
if (val == raxNotFound) {
void *result;
if (!raxFind(data->rax, (unsigned char *)field, strlen(field), &result)) {
if (out_err) *out_err = REDISMODULE_ERR;
return 0;
}
sds val = result;
if (!string2ll(val,sdslen(val),&ll)) {
if (out_err) *out_err = REDISMODULE_ERR;
return 0;
@ -10415,11 +10425,12 @@ long long RM_ServerInfoGetFieldSigned(RedisModuleServerInfoData *data, const cha
* 0, and the optional out_err argument will be set to REDISMODULE_ERR. */
unsigned long long RM_ServerInfoGetFieldUnsigned(RedisModuleServerInfoData *data, const char* field, int *out_err) {
unsigned long long ll;
sds val = raxFind(data->rax, (unsigned char *)field, strlen(field));
if (val == raxNotFound) {
void *result;
if (!raxFind(data->rax, (unsigned char *)field, strlen(field), &result)) {
if (out_err) *out_err = REDISMODULE_ERR;
return 0;
}
sds val = result;
if (!string2ull(val,&ll)) {
if (out_err) *out_err = REDISMODULE_ERR;
return 0;
@ -10433,11 +10444,12 @@ unsigned long long RM_ServerInfoGetFieldUnsigned(RedisModuleServerInfoData *data
* optional out_err argument will be set to REDISMODULE_ERR. */
double RM_ServerInfoGetFieldDouble(RedisModuleServerInfoData *data, const char* field, int *out_err) {
double dbl;
sds val = raxFind(data->rax, (unsigned char *)field, strlen(field));
if (val == raxNotFound) {
void *result;
if (!raxFind(data->rax, (unsigned char *)field, strlen(field), &result)) {
if (out_err) *out_err = REDISMODULE_ERR;
return 0;
}
sds val = result;
if (!string2d(val,sdslen(val),&dbl)) {
if (out_err) *out_err = REDISMODULE_ERR;
return 0;

View File

@ -1812,8 +1812,9 @@ int freeClientsInAsyncFreeQueue(void) {
* are not registered clients. */
client *lookupClientByID(uint64_t id) {
id = htonu64(id);
client *c = raxFind(server.clients_index,(unsigned char*)&id,sizeof(id));
return (c == raxNotFound) ? NULL : c;
void *c = NULL;
raxFind(server.clients_index,(unsigned char*)&id,sizeof(id),&c);
return c;
}
/* This function should be called from _writeToClient when the reply list is not empty,

View File

@ -44,11 +44,6 @@
#include RAX_MALLOC_INCLUDE
/* This is a special pointer that is guaranteed to never have the same value
* of a radix tree node. It's used in order to report "not found" error without
* requiring the function to have multiple return values. */
void *raxNotFound = (void*)"rax-not-found-pointer";
/* -------------------------------- Debugging ------------------------------ */
void raxDebugShowNode(const char *msg, raxNode *n);
@ -912,18 +907,19 @@ int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old)
return raxGenericInsert(rax,s,len,data,old,0);
}
/* Find a key in the rax, returns raxNotFound special void pointer value
* if the item was not found, otherwise the value associated with the
* item is returned. */
void *raxFind(rax *rax, unsigned char *s, size_t len) {
/* Find a key in the rax: return 1 if the item is found, 0 otherwise.
* If there is an item and 'value' is passed in a non-NULL pointer,
* the value associated with the item is set at that address. */
int raxFind(rax *rax, unsigned char *s, size_t len, void **value) {
raxNode *h;
debugf("### Lookup: %.*s\n", (int)len, s);
int splitpos = 0;
size_t i = raxLowWalk(rax,s,len,&h,NULL,&splitpos,NULL);
if (i != len || (h->iscompr && splitpos != 0) || !h->iskey)
return raxNotFound;
return raxGetData(h);
return 0;
if (value != NULL) *value = raxGetData(h);
return 1;
}
/* Return the memory address where the 'parent' node stores the specified

View File

@ -185,15 +185,12 @@ typedef struct raxIterator {
raxNodeCallback node_cb; /* Optional node callback. Normally set to NULL. */
} raxIterator;
/* A special pointer returned for not found items. */
extern void *raxNotFound;
/* Exported API. */
rax *raxNew(void);
int raxInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old);
int raxTryInsert(rax *rax, unsigned char *s, size_t len, void *data, void **old);
int raxRemove(rax *rax, unsigned char *s, size_t len, void **old);
void *raxFind(rax *rax, unsigned char *s, size_t len);
int raxFind(rax *rax, unsigned char *s, size_t len, void **value);
void raxFree(rax *rax);
void raxFreeWithCallback(rax *rax, void (*free_callback)(void*));
void raxStart(raxIterator *it, rax *rt);

View File

@ -2751,13 +2751,14 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key, int dbid, int *error) {
decrRefCount(o);
return NULL;
}
streamNACK *nack = raxFind(cgroup->pel,rawid,sizeof(rawid));
if (nack == raxNotFound) {
void *result;
if (!raxFind(cgroup->pel,rawid,sizeof(rawid),&result)) {
rdbReportCorruptRDB("Consumer entry not found in "
"group global PEL");
decrRefCount(o);
return NULL;
}
streamNACK *nack = result;
/* Set the NACK consumer, that was left to NULL when
* loading the global PEL. Then set the same shared

View File

@ -4279,13 +4279,15 @@ int processCommand(client *c) {
/* ====================== Error lookup and execution ===================== */
void incrementErrorCount(const char *fullerr, size_t namelen) {
struct redisError *error = raxFind(server.errors,(unsigned char*)fullerr,namelen);
if (error == raxNotFound) {
error = zmalloc(sizeof(*error));
error->count = 0;
void *result;
if (!raxFind(server.errors,(unsigned char*)fullerr,namelen,&result)) {
struct redisError *error = zmalloc(sizeof(*error));
error->count = 1;
raxInsert(server.errors,(unsigned char*)fullerr,namelen,error,NULL);
} else {
struct redisError *error = result;
error->count++;
}
error->count++;
}
/*================================== Shutdown =============================== */

View File

@ -242,10 +242,12 @@ robj *streamDup(robj *o) {
raxStart(&ri_cpel, consumer->pel);
raxSeek(&ri_cpel, "^", NULL, 0);
while (raxNext(&ri_cpel)) {
streamNACK *new_nack = raxFind(new_cg->pel,ri_cpel.key,sizeof(streamID));
void *result;
int found = raxFind(new_cg->pel,ri_cpel.key,sizeof(streamID),&result);
serverAssert(new_nack != raxNotFound);
serverAssert(found);
streamNACK *new_nack = result;
new_nack->consumer = new_consumer;
raxInsert(new_consumer->pel,ri_cpel.key,sizeof(streamID),new_nack,NULL);
}
@ -1760,8 +1762,10 @@ size_t streamReplyWithRange(client *c, stream *s, streamID *start, streamID *end
* or update it if the consumer is the same as before. */
if (group_inserted == 0) {
streamFreeNACK(nack);
nack = raxFind(group->pel,buf,sizeof(buf));
serverAssert(nack != raxNotFound);
void *result;
int found = raxFind(group->pel,buf,sizeof(buf),&result);
serverAssert(found);
nack = result;
raxRemove(nack->consumer->pel,buf,sizeof(buf),NULL);
/* Update the consumer and NACK metadata. */
nack->consumer = consumer;
@ -2473,7 +2477,7 @@ void streamFreeConsumer(streamConsumer *sc) {
* consumer group is returned. */
streamCG *streamCreateCG(stream *s, char *name, size_t namelen, streamID *id, long long entries_read) {
if (s->cgroups == NULL) s->cgroups = raxNew();
if (raxFind(s->cgroups,(unsigned char*)name,namelen) != raxNotFound)
if (raxFind(s->cgroups,(unsigned char*)name,namelen,NULL))
return NULL;
streamCG *cg = zmalloc(sizeof(*cg));
@ -2496,9 +2500,9 @@ void streamFreeCG(streamCG *cg) {
* pointer, otherwise if there is no such group, NULL is returned. */
streamCG *streamLookupCG(stream *s, sds groupname) {
if (s->cgroups == NULL) return NULL;
streamCG *cg = raxFind(s->cgroups,(unsigned char*)groupname,
sdslen(groupname));
return (cg == raxNotFound) ? NULL : cg;
void *cg = NULL;
raxFind(s->cgroups,(unsigned char*)groupname,sdslen(groupname),&cg);
return cg;
}
/* Create a consumer with the specified name in the group 'cg' and return.
@ -2528,9 +2532,8 @@ streamConsumer *streamCreateConsumer(streamCG *cg, sds name, robj *key, int dbid
/* Lookup the consumer with the specified name in the group 'cg'. */
streamConsumer *streamLookupConsumer(streamCG *cg, sds name) {
if (cg == NULL) return NULL;
streamConsumer *consumer = raxFind(cg->consumers,(unsigned char*)name,
sdslen(name));
if (consumer == raxNotFound) return NULL;
void *consumer = NULL;
raxFind(cg->consumers,(unsigned char*)name,sdslen(name),&consumer);
return consumer;
}
@ -2844,8 +2847,9 @@ void xackCommand(client *c) {
/* Lookup the ID in the group PEL: it will have a reference to the
* NACK structure that will have a reference to the consumer, so that
* we are able to remove the entry from both PELs. */
streamNACK *nack = raxFind(group->pel,buf,sizeof(buf));
if (nack != raxNotFound) {
void *result;
if (raxFind(group->pel,buf,sizeof(buf),&result)) {
streamNACK *nack = result;
raxRemove(group->pel,buf,sizeof(buf),NULL);
raxRemove(nack->consumer->pel,buf,sizeof(buf),NULL);
streamFreeNACK(nack);
@ -3224,12 +3228,14 @@ void xclaimCommand(client *c) {
streamEncodeID(buf,&id);
/* Lookup the ID in the group PEL. */
streamNACK *nack = raxFind(group->pel,buf,sizeof(buf));
void *result = NULL;
raxFind(group->pel,buf,sizeof(buf),&result);
streamNACK *nack = result;
/* Item must exist for us to transfer it to another consumer. */
if (!streamEntryExists(o->ptr,&id)) {
/* Clear this entry from the PEL, it no longer exists */
if (nack != raxNotFound) {
if (nack != NULL) {
/* Propagate this change (we are going to delete the NACK). */
streamPropagateXCLAIM(c,c->argv[1],group,c->argv[2],c->argv[j],nack);
propagate_last_id = 0; /* Will be propagated by XCLAIM itself. */
@ -3247,13 +3253,13 @@ void xclaimCommand(client *c) {
* entry in the PEL from scratch, so that XCLAIM can also
* be used to create entries in the PEL. Useful for AOF
* and replication of consumer groups. */
if (force && nack == raxNotFound) {
if (force && nack == NULL) {
/* Create the NACK. */
nack = streamCreateNACK(NULL);
raxInsert(group->pel,buf,sizeof(buf),nack,NULL);
}
if (nack != raxNotFound) {
if (nack != NULL) {
/* We need to check if the minimum idle time requested
* by the caller is satisfied by this entry.
*

View File

@ -72,8 +72,10 @@ void disableTracking(client *c) {
raxStart(&ri,c->client_tracking_prefixes);
raxSeek(&ri,"^",NULL,0);
while(raxNext(&ri)) {
bcastState *bs = raxFind(PrefixTable,ri.key,ri.key_len);
serverAssert(bs != raxNotFound);
void *result;
int found = raxFind(PrefixTable,ri.key,ri.key_len,&result);
serverAssert(found);
bcastState *bs = result;
raxRemove(bs->clients,(unsigned char*)&c,sizeof(c),NULL);
/* Was it the last client? Remove the prefix from the
* table. */
@ -153,14 +155,17 @@ int checkPrefixCollisionsOrReply(client *c, robj **prefixes, size_t numprefix) {
/* Set the client 'c' to track the prefix 'prefix'. If the client 'c' is
* already registered for the specified prefix, no operation is performed. */
void enableBcastTrackingForPrefix(client *c, char *prefix, size_t plen) {
bcastState *bs = raxFind(PrefixTable,(unsigned char*)prefix,plen);
void *result;
bcastState *bs;
/* If this is the first client subscribing to such prefix, create
* the prefix in the table. */
if (bs == raxNotFound) {
if (!raxFind(PrefixTable,(unsigned char*)prefix,plen,&result)) {
bs = zmalloc(sizeof(*bs));
bs->keys = raxNew();
bs->clients = raxNew();
raxInsert(PrefixTable,(unsigned char*)prefix,plen,bs,NULL);
} else {
bs = result;
}
if (raxTryInsert(bs->clients,(unsigned char*)&c,sizeof(c),NULL,NULL)) {
if (c->client_tracking_prefixes == NULL)
@ -240,12 +245,15 @@ void trackingRememberKeys(client *tracking, client *executing) {
for(int j = 0; j < numkeys; j++) {
int idx = keys[j].pos;
sds sdskey = executing->argv[idx]->ptr;
rax *ids = raxFind(TrackingTable,(unsigned char*)sdskey,sdslen(sdskey));
if (ids == raxNotFound) {
void *result;
rax *ids;
if (!raxFind(TrackingTable,(unsigned char*)sdskey,sdslen(sdskey),&result)) {
ids = raxNew();
int inserted = raxTryInsert(TrackingTable,(unsigned char*)sdskey,
sdslen(sdskey),ids, NULL);
serverAssert(inserted == 1);
} else {
ids = result;
}
if (raxTryInsert(ids,(unsigned char*)&tracking->id,sizeof(tracking->id),NULL,NULL))
TrackingTableTotalItems++;
@ -372,8 +380,9 @@ void trackingInvalidateKey(client *c, robj *keyobj, int bcast) {
if (bcast && raxSize(PrefixTable) > 0)
trackingRememberKeyToBroadcast(c,(char *)key,keylen);
rax *ids = raxFind(TrackingTable,key,keylen);
if (ids == raxNotFound) return;
void *result;
if (!raxFind(TrackingTable,key,keylen,&result)) return;
rax *ids = result;
raxIterator ri;
raxStart(&ri,ids);