redis/src/multi.c

483 lines
17 KiB
C

/*
* Copyright (c) 2009-Present, Redis Ltd.
* All rights reserved.
*
* Licensed under your choice of the Redis Source Available License 2.0
* (RSALv2) or the Server Side Public License v1 (SSPLv1).
*/
#include "server.h"
/* ================================ MULTI/EXEC ============================== */
/* Client state initialization for MULTI/EXEC */
void initClientMultiState(client *c) {
c->mstate.commands = NULL;
c->mstate.count = 0;
c->mstate.cmd_flags = 0;
c->mstate.cmd_inv_flags = 0;
c->mstate.argv_len_sums = 0;
c->mstate.alloc_count = 0;
}
/* Release all the resources associated with MULTI/EXEC state */
void freeClientMultiState(client *c) {
int j;
for (j = 0; j < c->mstate.count; j++) {
int i;
multiCmd *mc = c->mstate.commands+j;
for (i = 0; i < mc->argc; i++)
decrRefCount(mc->argv[i]);
zfree(mc->argv);
}
zfree(c->mstate.commands);
}
/* Add a new command into the MULTI commands queue */
void queueMultiCommand(client *c, uint64_t cmd_flags) {
multiCmd *mc;
/* No sense to waste memory if the transaction is already aborted.
* this is useful in case client sends these in a pipeline, or doesn't
* bother to read previous responses and didn't notice the multi was already
* aborted. */
if (c->flags & (CLIENT_DIRTY_CAS|CLIENT_DIRTY_EXEC))
return;
if (c->mstate.count == 0) {
/* If a client is using multi/exec, assuming it is used to execute at least
* two commands. Hence, creating by default size of 2. */
c->mstate.commands = zmalloc(sizeof(multiCmd)*2);
c->mstate.alloc_count = 2;
}
if (c->mstate.count == c->mstate.alloc_count) {
c->mstate.alloc_count = c->mstate.alloc_count < INT_MAX/2 ? c->mstate.alloc_count*2 : INT_MAX;
c->mstate.commands = zrealloc(c->mstate.commands, sizeof(multiCmd)*(c->mstate.alloc_count));
}
mc = c->mstate.commands+c->mstate.count;
mc->cmd = c->cmd;
mc->argc = c->argc;
mc->argv = c->argv;
mc->argv_len = c->argv_len;
c->mstate.count++;
c->mstate.cmd_flags |= cmd_flags;
c->mstate.cmd_inv_flags |= ~cmd_flags;
c->mstate.argv_len_sums += c->argv_len_sum + sizeof(robj*)*c->argc;
/* Reset the client's args since we copied them into the mstate and shouldn't
* reference them from c anymore. */
c->argv = NULL;
c->argc = 0;
c->argv_len_sum = 0;
c->argv_len = 0;
}
void discardTransaction(client *c) {
freeClientMultiState(c);
initClientMultiState(c);
c->flags &= ~(CLIENT_MULTI|CLIENT_DIRTY_CAS|CLIENT_DIRTY_EXEC);
unwatchAllKeys(c);
}
/* Flag the transaction as DIRTY_EXEC so that EXEC will fail.
* Should be called every time there is an error while queueing a command. */
void flagTransaction(client *c) {
if (c->flags & CLIENT_MULTI)
c->flags |= CLIENT_DIRTY_EXEC;
}
void multiCommand(client *c) {
if (c->flags & CLIENT_MULTI) {
addReplyError(c,"MULTI calls can not be nested");
return;
}
c->flags |= CLIENT_MULTI;
addReply(c,shared.ok);
}
void discardCommand(client *c) {
if (!(c->flags & CLIENT_MULTI)) {
addReplyError(c,"DISCARD without MULTI");
return;
}
discardTransaction(c);
addReply(c,shared.ok);
}
/* Aborts a transaction, with a specific error message.
* The transaction is always aborted with -EXECABORT so that the client knows
* the server exited the multi state, but the actual reason for the abort is
* included too.
* Note: 'error' may or may not end with \r\n. see addReplyErrorFormat. */
void execCommandAbort(client *c, sds error) {
discardTransaction(c);
if (error[0] == '-') error++;
addReplyErrorFormat(c, "-EXECABORT Transaction discarded because of: %s", error);
/* Send EXEC to clients waiting data from MONITOR. We did send a MULTI
* already, and didn't send any of the queued commands, now we'll just send
* EXEC so it is clear that the transaction is over. */
replicationFeedMonitors(c,server.monitors,c->db->id,c->argv,c->argc);
}
void execCommand(client *c) {
int j;
robj **orig_argv;
int orig_argc, orig_argv_len;
struct redisCommand *orig_cmd;
if (!(c->flags & CLIENT_MULTI)) {
addReplyError(c,"EXEC without MULTI");
return;
}
/* EXEC with expired watched key is disallowed*/
if (isWatchedKeyExpired(c)) {
c->flags |= (CLIENT_DIRTY_CAS);
}
/* Check if we need to abort the EXEC because:
* 1) Some WATCHed key was touched.
* 2) There was a previous error while queueing commands.
* A failed EXEC in the first case returns a multi bulk nil object
* (technically it is not an error but a special behavior), while
* in the second an EXECABORT error is returned. */
if (c->flags & (CLIENT_DIRTY_CAS | CLIENT_DIRTY_EXEC)) {
if (c->flags & CLIENT_DIRTY_EXEC) {
addReplyErrorObject(c, shared.execaborterr);
} else {
addReply(c, shared.nullarray[c->resp]);
}
discardTransaction(c);
return;
}
uint64_t old_flags = c->flags;
/* we do not want to allow blocking commands inside multi */
c->flags |= CLIENT_DENY_BLOCKING;
/* Exec all the queued commands */
unwatchAllKeys(c); /* Unwatch ASAP otherwise we'll waste CPU cycles */
server.in_exec = 1;
orig_argv = c->argv;
orig_argv_len = c->argv_len;
orig_argc = c->argc;
orig_cmd = c->cmd;
addReplyArrayLen(c,c->mstate.count);
for (j = 0; j < c->mstate.count; j++) {
c->argc = c->mstate.commands[j].argc;
c->argv = c->mstate.commands[j].argv;
c->argv_len = c->mstate.commands[j].argv_len;
c->cmd = c->realcmd = c->mstate.commands[j].cmd;
/* ACL permissions are also checked at the time of execution in case
* they were changed after the commands were queued. */
int acl_errpos;
int acl_retval = ACLCheckAllPerm(c,&acl_errpos);
if (acl_retval != ACL_OK) {
char *reason;
switch (acl_retval) {
case ACL_DENIED_CMD:
reason = "no permission to execute the command or subcommand";
break;
case ACL_DENIED_KEY:
reason = "no permission to touch the specified keys";
break;
case ACL_DENIED_CHANNEL:
reason = "no permission to access one of the channels used "
"as arguments";
break;
default:
reason = "no permission";
break;
}
addACLLogEntry(c,acl_retval,ACL_LOG_CTX_MULTI,acl_errpos,NULL,NULL);
addReplyErrorFormat(c,
"-NOPERM ACLs rules changed between the moment the "
"transaction was accumulated and the EXEC call. "
"This command is no longer allowed for the "
"following reason: %s", reason);
} else {
if (c->id == CLIENT_ID_AOF)
call(c,CMD_CALL_NONE);
else
call(c,CMD_CALL_FULL);
serverAssert((c->flags & CLIENT_BLOCKED) == 0);
}
/* Commands may alter argc/argv, restore mstate. */
c->mstate.commands[j].argc = c->argc;
c->mstate.commands[j].argv = c->argv;
c->mstate.commands[j].argv_len = c->argv_len;
c->mstate.commands[j].cmd = c->cmd;
}
// restore old DENY_BLOCKING value
if (!(old_flags & CLIENT_DENY_BLOCKING))
c->flags &= ~CLIENT_DENY_BLOCKING;
c->argv = orig_argv;
c->argv_len = orig_argv_len;
c->argc = orig_argc;
c->cmd = c->realcmd = orig_cmd;
discardTransaction(c);
server.in_exec = 0;
}
/* ===================== WATCH (CAS alike for MULTI/EXEC) ===================
*
* The implementation uses a per-DB hash table mapping keys to list of clients
* WATCHing those keys, so that given a key that is going to be modified
* we can mark all the associated clients as dirty.
*
* Also every client contains a list of WATCHed keys so that's possible to
* un-watch such keys when the client is freed or when UNWATCH is called. */
/* The watchedKey struct is included in two lists: the client->watched_keys list,
* and db->watched_keys dict (each value in that dict is a list of watchedKey structs).
* The list in the client struct is a plain list, where each node's value is a pointer to a watchedKey.
* The list in the db db->watched_keys is different, the listnode member that's embedded in this struct
* is the node in the dict. And the value inside that listnode is a pointer to the that list, and we can use
* struct member offset math to get from the listnode to the watchedKey struct.
* This is done to avoid the need for listSearchKey and dictFind when we remove from the list. */
typedef struct watchedKey {
listNode node;
robj *key;
redisDb *db;
client *client;
unsigned expired:1; /* Flag that we're watching an already expired key. */
} watchedKey;
/* Attach a watchedKey to the list of clients watching that key. */
static inline void watchedKeyLinkToClients(list *clients, watchedKey *wk) {
wk->node.value = clients; /* Point the value back to the list */
listLinkNodeTail(clients, &wk->node); /* Link the embedded node */
}
/* Get the list of clients watching that key. */
static inline list *watchedKeyGetClients(watchedKey *wk) {
return listNodeValue(&wk->node); /* embedded node->value points back to the list */
}
/* Get the node with wk->client in the list of clients watching that key. Actually it
* is just the embedded node. */
static inline listNode *watchedKeyGetClientNode(watchedKey *wk) {
return &wk->node;
}
/* Watch for the specified key */
void watchForKey(client *c, robj *key) {
list *clients = NULL;
listIter li;
listNode *ln;
watchedKey *wk;
if (listLength(c->watched_keys) == 0) server.watching_clients++;
/* Check if we are already watching for this key */
listRewind(c->watched_keys,&li);
while((ln = listNext(&li))) {
wk = listNodeValue(ln);
if (wk->db == c->db && equalStringObjects(key,wk->key))
return; /* Key already watched */
}
/* This key is not already watched in this DB. Let's add it */
clients = dictFetchValue(c->db->watched_keys,key);
if (!clients) {
clients = listCreate();
dictAdd(c->db->watched_keys,key,clients);
incrRefCount(key);
}
/* Add the new key to the list of keys watched by this client */
wk = zmalloc(sizeof(*wk));
wk->key = key;
wk->client = c;
wk->db = c->db;
wk->expired = keyIsExpired(c->db, key);
incrRefCount(key);
listAddNodeTail(c->watched_keys, wk);
watchedKeyLinkToClients(clients, wk);
}
/* Unwatch all the keys watched by this client. To clean the EXEC dirty
* flag is up to the caller. */
void unwatchAllKeys(client *c) {
listIter li;
listNode *ln;
if (listLength(c->watched_keys) == 0) return;
listRewind(c->watched_keys,&li);
while((ln = listNext(&li))) {
list *clients;
watchedKey *wk;
/* Remove the client's wk from the list of clients watching the key. */
wk = listNodeValue(ln);
clients = watchedKeyGetClients(wk);
serverAssertWithInfo(c,NULL,clients != NULL);
listUnlinkNode(clients, watchedKeyGetClientNode(wk));
/* Kill the entry at all if this was the only client */
if (listLength(clients) == 0)
dictDelete(wk->db->watched_keys, wk->key);
/* Remove this watched key from the client->watched list */
listDelNode(c->watched_keys,ln);
decrRefCount(wk->key);
zfree(wk);
}
server.watching_clients--;
}
/* Iterates over the watched_keys list and looks for an expired key. Keys which
* were expired already when WATCH was called are ignored. */
int isWatchedKeyExpired(client *c) {
listIter li;
listNode *ln;
watchedKey *wk;
if (listLength(c->watched_keys) == 0) return 0;
listRewind(c->watched_keys,&li);
while ((ln = listNext(&li))) {
wk = listNodeValue(ln);
if (wk->expired) continue; /* was expired when WATCH was called */
if (keyIsExpired(wk->db, wk->key)) return 1;
}
return 0;
}
/* "Touch" a key, so that if this key is being WATCHed by some client the
* next EXEC will fail. */
void touchWatchedKey(redisDb *db, robj *key) {
list *clients;
listIter li;
listNode *ln;
if (dictSize(db->watched_keys) == 0) return;
clients = dictFetchValue(db->watched_keys, key);
if (!clients) return;
/* Mark all the clients watching this key as CLIENT_DIRTY_CAS */
/* Check if we are already watching for this key */
listRewind(clients,&li);
while((ln = listNext(&li))) {
watchedKey *wk = redis_member2struct(watchedKey, node, ln);
client *c = wk->client;
if (wk->expired) {
/* The key was already expired when WATCH was called. */
if (db == wk->db &&
equalStringObjects(key, wk->key) &&
dbFind(db, key->ptr) == NULL)
{
/* Already expired key is deleted, so logically no change. Clear
* the flag. Deleted keys are not flagged as expired. */
wk->expired = 0;
goto skip_client;
}
break;
}
c->flags |= CLIENT_DIRTY_CAS;
/* As the client is marked as dirty, there is no point in getting here
* again in case that key (or others) are modified again (or keep the
* memory overhead till EXEC). */
unwatchAllKeys(c);
skip_client:
continue;
}
}
/* Set CLIENT_DIRTY_CAS to all clients of DB when DB is dirty.
* It may happen in the following situations:
* FLUSHDB, FLUSHALL, SWAPDB, end of successful diskless replication.
*
* replaced_with: for SWAPDB, the WATCH should be invalidated if
* the key exists in either of them, and skipped only if it
* doesn't exist in both. */
void touchAllWatchedKeysInDb(redisDb *emptied, redisDb *replaced_with) {
listIter li;
listNode *ln;
dictEntry *de;
if (dictSize(emptied->watched_keys) == 0) return;
dictIterator *di = dictGetSafeIterator(emptied->watched_keys);
while((de = dictNext(di)) != NULL) {
robj *key = dictGetKey(de);
int exists_in_emptied = dbFind(emptied, key->ptr) != NULL;
if (exists_in_emptied ||
(replaced_with && dbFind(replaced_with, key->ptr) != NULL))
{
list *clients = dictGetVal(de);
if (!clients) continue;
listRewind(clients,&li);
while((ln = listNext(&li))) {
watchedKey *wk = redis_member2struct(watchedKey, node, ln);
if (wk->expired) {
if (!replaced_with || !dbFind(replaced_with, key->ptr)) {
/* Expired key now deleted. No logical change. Clear the
* flag. Deleted keys are not flagged as expired. */
wk->expired = 0;
continue;
} else if (keyIsExpired(replaced_with, key)) {
/* Expired key remains expired. */
continue;
}
} else if (!exists_in_emptied && keyIsExpired(replaced_with, key)) {
/* Non-existing key is replaced with an expired key. */
wk->expired = 1;
continue;
}
client *c = wk->client;
c->flags |= CLIENT_DIRTY_CAS;
/* Note - we could potentially call unwatchAllKeys for this specific client in order to reduce
* the total number of iterations. BUT this could also free the current next entry pointer
* held by the iterator and can lead to use-after-free. */
}
}
}
dictReleaseIterator(di);
}
void watchCommand(client *c) {
int j;
if (c->flags & CLIENT_MULTI) {
addReplyError(c,"WATCH inside MULTI is not allowed");
return;
}
/* No point in watching if the client is already dirty. */
if (c->flags & CLIENT_DIRTY_CAS) {
addReply(c,shared.ok);
return;
}
for (j = 1; j < c->argc; j++)
watchForKey(c,c->argv[j]);
addReply(c,shared.ok);
}
void unwatchCommand(client *c) {
unwatchAllKeys(c);
c->flags &= (~CLIENT_DIRTY_CAS);
addReply(c,shared.ok);
}
size_t multiStateMemOverhead(client *c) {
size_t mem = c->mstate.argv_len_sums;
/* Add watched keys overhead, Note: this doesn't take into account the watched keys themselves, because they aren't managed per-client. */
mem += listLength(c->watched_keys) * (sizeof(listNode) + sizeof(watchedKey));
/* Reserved memory for queued multi commands. */
mem += c->mstate.alloc_count * sizeof(multiCmd);
return mem;
}