diff --git a/runtest-moduleapi b/runtest-moduleapi index 81af306e5..c924a1f98 100755 --- a/runtest-moduleapi +++ b/runtest-moduleapi @@ -49,4 +49,5 @@ $TCLSH tests/test_helper.tcl \ --single unit/moduleapi/cmdintrospection \ --single unit/moduleapi/eventloop \ --single unit/moduleapi/timer \ +--single unit/moduleapi/publish \ "${@}" diff --git a/src/cluster.c b/src/cluster.c index a757172e7..adad07e19 100644 --- a/src/cluster.c +++ b/src/cluster.c @@ -2506,11 +2506,7 @@ int clusterProcessPacket(clusterLink *link) { message = createStringObject( (char*)hdr->data.publish.msg.bulk_data+channel_len, message_len); - if (type == CLUSTERMSG_TYPE_PUBLISHSHARD) { - pubsubPublishMessageShard(channel, message); - } else { - pubsubPublishMessage(channel,message); - } + pubsubPublishMessage(channel, message, type == CLUSTERMSG_TYPE_PUBLISHSHARD); decrRefCount(channel); decrRefCount(message); } @@ -3199,20 +3195,19 @@ int clusterSendModuleMessageToTarget(const char *target, uint64_t module_id, uin /* ----------------------------------------------------------------------------- * CLUSTER Pub/Sub support * - * For now we do very little, just propagating PUBLISH messages across the whole + * If `sharded` is 0: + * For now we do very little, just propagating [S]PUBLISH messages across the whole * cluster. In the future we'll try to get smarter and avoiding propagating those * messages to hosts without receives for a given channel. - * -------------------------------------------------------------------------- */ -void clusterPropagatePublish(robj *channel, robj *message) { - clusterSendPublish(NULL, channel, message, CLUSTERMSG_TYPE_PUBLISH); -} - -/* ----------------------------------------------------------------------------- - * CLUSTER Pub/Sub shard support - * + * Otherwise: * Publish this message across the slot (primary/replica). * -------------------------------------------------------------------------- */ -void clusterPropagatePublishShard(robj *channel, robj *message) { +void clusterPropagatePublish(robj *channel, robj *message, int sharded) { + if (!sharded) { + clusterSendPublish(NULL, channel, message, CLUSTERMSG_TYPE_PUBLISH); + return; + } + list *nodes_for_slot = clusterGetNodesServingMySlots(server.cluster->myself); if (listLength(nodes_for_slot) != 0) { listIter li; diff --git a/src/cluster.h b/src/cluster.h index 90b775ca2..1349a7a92 100644 --- a/src/cluster.h +++ b/src/cluster.h @@ -384,8 +384,7 @@ void migrateCloseTimedoutSockets(void); int verifyClusterConfigWithData(void); unsigned long getClusterConnectionsCount(void); int clusterSendModuleMessageToTarget(const char *target, uint64_t module_id, uint8_t type, const char *payload, uint32_t len); -void clusterPropagatePublish(robj *channel, robj *message); -void clusterPropagatePublishShard(robj *channel, robj *message); +void clusterPropagatePublish(robj *channel, robj *message, int sharded); unsigned int keyHashSlot(char *key, int keylen); void slotToKeyAddEntry(dictEntry *entry, redisDb *db); void slotToKeyDelEntry(dictEntry *entry, redisDb *db); diff --git a/src/module.c b/src/module.c index 163104d8a..04a5be721 100644 --- a/src/module.c +++ b/src/module.c @@ -3381,10 +3381,13 @@ int RM_GetClientInfoById(void *ci, uint64_t id) { /* Publish a message to subscribers (see PUBLISH command). */ int RM_PublishMessage(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) { UNUSED(ctx); - int receivers = pubsubPublishMessage(channel, message); - if (server.cluster_enabled) - clusterPropagatePublish(channel, message); - return receivers; + return pubsubPublishMessageAndPropagateToCluster(channel, message, 0); +} + +/* Publish a message to shard-subscribers (see SPUBLISH command). */ +int RM_PublishMessageShard(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) { + UNUSED(ctx); + return pubsubPublishMessageAndPropagateToCluster(channel, message, 1); } /* Return the currently selected DB. */ @@ -12545,6 +12548,7 @@ void moduleRegisterCoreAPI(void) { REGISTER_API(ServerInfoGetFieldDouble); REGISTER_API(GetClientInfoById); REGISTER_API(PublishMessage); + REGISTER_API(PublishMessageShard); REGISTER_API(SubscribeToServerEvent); REGISTER_API(SetLRU); REGISTER_API(GetLRU); diff --git a/src/notify.c b/src/notify.c index 633e35bdc..2881a48db 100644 --- a/src/notify.c +++ b/src/notify.c @@ -126,7 +126,7 @@ void notifyKeyspaceEvent(int type, char *event, robj *key, int dbid) { chan = sdscatlen(chan, "__:", 3); chan = sdscatsds(chan, key->ptr); chanobj = createObject(OBJ_STRING, chan); - pubsubPublishMessage(chanobj, eventobj); + pubsubPublishMessage(chanobj, eventobj, 0); decrRefCount(chanobj); } @@ -138,7 +138,7 @@ void notifyKeyspaceEvent(int type, char *event, robj *key, int dbid) { chan = sdscatlen(chan, "__:", 3); chan = sdscatsds(chan, eventobj->ptr); chanobj = createObject(OBJ_STRING, chan); - pubsubPublishMessage(chanobj, key); + pubsubPublishMessage(chanobj, key, 0); decrRefCount(chanobj); } decrRefCount(eventobj); diff --git a/src/pubsub.c b/src/pubsub.c index e805b16ef..a630afc8f 100644 --- a/src/pubsub.c +++ b/src/pubsub.c @@ -499,16 +499,10 @@ int pubsubPublishMessageInternal(robj *channel, robj *message, pubsubtype type) } /* Publish a message to all the subscribers. */ -int pubsubPublishMessage(robj *channel, robj *message) { - return pubsubPublishMessageInternal(channel,message,pubSubType); +int pubsubPublishMessage(robj *channel, robj *message, int sharded) { + return pubsubPublishMessageInternal(channel, message, sharded? pubSubShardType : pubSubType); } -/* Publish a shard message to all the subscribers. */ -int pubsubPublishMessageShard(robj *channel, robj *message) { - return pubsubPublishMessageInternal(channel, message, pubSubShardType); -} - - /*----------------------------------------------------------------------------- * Pubsub commands implementation *----------------------------------------------------------------------------*/ @@ -578,6 +572,15 @@ void punsubscribeCommand(client *c) { if (clientTotalPubSubSubscriptionCount(c) == 0) c->flags &= ~CLIENT_PUBSUB; } +/* This function wraps pubsubPublishMessage and also propagates the message to cluster. + * Used by the commands PUBLISH/SPUBLISH and their respective module APIs.*/ +int pubsubPublishMessageAndPropagateToCluster(robj *channel, robj *message, int sharded) { + int receivers = pubsubPublishMessage(channel, message, sharded); + if (server.cluster_enabled) + clusterPropagatePublish(channel, message, sharded); + return receivers; +} + /* PUBLISH */ void publishCommand(client *c) { if (server.sentinel_mode) { @@ -585,10 +588,8 @@ void publishCommand(client *c) { return; } - int receivers = pubsubPublishMessage(c->argv[1],c->argv[2]); - if (server.cluster_enabled) - clusterPropagatePublish(c->argv[1],c->argv[2]); - else + int receivers = pubsubPublishMessageAndPropagateToCluster(c->argv[1],c->argv[2],0); + if (!server.cluster_enabled) forceCommandPropagation(c,PROPAGATE_REPL); addReplyLongLong(c,receivers); } @@ -677,12 +678,9 @@ void channelList(client *c, sds pat, dict *pubsub_channels) { /* SPUBLISH */ void spublishCommand(client *c) { - int receivers = pubsubPublishMessageInternal(c->argv[1], c->argv[2], pubSubShardType); - if (server.cluster_enabled) { - clusterPropagatePublishShard(c->argv[1], c->argv[2]); - } else { + int receivers = pubsubPublishMessageAndPropagateToCluster(c->argv[1],c->argv[2],1); + if (!server.cluster_enabled) forceCommandPropagation(c,PROPAGATE_REPL); - } addReplyLongLong(c,receivers); } diff --git a/src/redismodule.h b/src/redismodule.h index 846967a62..e72f5f7bc 100644 --- a/src/redismodule.h +++ b/src/redismodule.h @@ -999,6 +999,7 @@ REDISMODULE_API unsigned long long (*RedisModule_GetClientId)(RedisModuleCtx *ct REDISMODULE_API RedisModuleString * (*RedisModule_GetClientUserNameById)(RedisModuleCtx *ctx, uint64_t id) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_GetClientInfoById)(void *ci, uint64_t id) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_PublishMessage)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_PublishMessageShard)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_GetContextFlags)(RedisModuleCtx *ctx) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_AvoidReplicaTraffic)() REDISMODULE_ATTR; REDISMODULE_API void * (*RedisModule_PoolAlloc)(RedisModuleCtx *ctx, size_t bytes) REDISMODULE_ATTR; @@ -1423,6 +1424,7 @@ static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int REDISMODULE_GET_API(ServerInfoGetFieldDouble); REDISMODULE_GET_API(GetClientInfoById); REDISMODULE_GET_API(PublishMessage); + REDISMODULE_GET_API(PublishMessageShard); REDISMODULE_GET_API(SubscribeToServerEvent); REDISMODULE_GET_API(SetLRU); REDISMODULE_GET_API(GetLRU); diff --git a/src/sentinel.c b/src/sentinel.c index 3ad8f902b..9ea78aae5 100644 --- a/src/sentinel.c +++ b/src/sentinel.c @@ -705,7 +705,7 @@ void sentinelEvent(int level, char *type, sentinelRedisInstance *ri, if (level != LL_DEBUG) { channel = createStringObject(type,strlen(type)); payload = createStringObject(msg,strlen(msg)); - pubsubPublishMessage(channel,payload); + pubsubPublishMessage(channel,payload,0); decrRefCount(channel); decrRefCount(payload); } diff --git a/src/server.h b/src/server.h index 65727cd40..b57f39d38 100644 --- a/src/server.h +++ b/src/server.h @@ -2965,8 +2965,8 @@ int pubsubUnsubscribeAllChannels(client *c, int notify); int pubsubUnsubscribeShardAllChannels(client *c, int notify); void pubsubUnsubscribeShardChannels(robj **channels, unsigned int count); int pubsubUnsubscribeAllPatterns(client *c, int notify); -int pubsubPublishMessage(robj *channel, robj *message); -int pubsubPublishMessageShard(robj *channel, robj *message); +int pubsubPublishMessage(robj *channel, robj *message, int sharded); +int pubsubPublishMessageAndPropagateToCluster(robj *channel, robj *message, int sharded); void addReplyPubsubMessage(client *c, robj *channel, robj *msg); int serverPubsubSubscriptionCount(); int serverPubsubShardSubscriptionCount(); diff --git a/tests/modules/Makefile b/tests/modules/Makefile index 16b5570aa..ac4c3e27b 100644 --- a/tests/modules/Makefile +++ b/tests/modules/Makefile @@ -57,7 +57,8 @@ TEST_MODULES = \ cmdintrospection.so \ eventloop.so \ moduleconfigs.so \ - moduleconfigstwo.so + moduleconfigstwo.so \ + publish.so .PHONY: all diff --git a/tests/modules/publish.c b/tests/modules/publish.c new file mode 100644 index 000000000..eee96d689 --- /dev/null +++ b/tests/modules/publish.c @@ -0,0 +1,42 @@ +#include "redismodule.h" +#include +#include +#include + +#define UNUSED(V) ((void) V) + +int cmd_publish_classic(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) +{ + if (argc != 3) + return RedisModule_WrongArity(ctx); + + int receivers = RedisModule_PublishMessage(ctx, argv[1], argv[2]); + RedisModule_ReplyWithLongLong(ctx, receivers); + return REDISMODULE_OK; +} + +int cmd_publish_shard(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) +{ + if (argc != 3) + return RedisModule_WrongArity(ctx); + + int receivers = RedisModule_PublishMessageShard(ctx, argv[1], argv[2]); + RedisModule_ReplyWithLongLong(ctx, receivers); + return REDISMODULE_OK; +} + +int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + UNUSED(argv); + UNUSED(argc); + + if (RedisModule_Init(ctx,"publish",1,REDISMODULE_APIVER_1)== REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx,"publish.classic",cmd_publish_classic,"",0,0,0) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + if (RedisModule_CreateCommand(ctx,"publish.shard",cmd_publish_shard,"",0,0,0) == REDISMODULE_ERR) + return REDISMODULE_ERR; + + return REDISMODULE_OK; +} diff --git a/tests/support/util.tcl b/tests/support/util.tcl index 4ad96ab10..a7972a854 100644 --- a/tests/support/util.tcl +++ b/tests/support/util.tcl @@ -823,11 +823,21 @@ proc subscribe {client channels} { consume_subscribe_messages $client subscribe $channels } +proc ssubscribe {client channels} { + $client ssubscribe {*}$channels + consume_subscribe_messages $client ssubscribe $channels +} + proc unsubscribe {client {channels {}}} { $client unsubscribe {*}$channels consume_subscribe_messages $client unsubscribe $channels } +proc sunsubscribe {client {channels {}}} { + $client sunsubscribe {*}$channels + consume_subscribe_messages $client sunsubscribe $channels +} + proc psubscribe {client channels} { $client psubscribe {*}$channels consume_subscribe_messages $client psubscribe $channels diff --git a/tests/unit/moduleapi/publish.tcl b/tests/unit/moduleapi/publish.tcl new file mode 100644 index 000000000..ab3611093 --- /dev/null +++ b/tests/unit/moduleapi/publish.tcl @@ -0,0 +1,17 @@ +set testmodule [file normalize tests/modules/publish.so] + +start_server {tags {"modules"}} { + r module load $testmodule + + test {PUBLISH and SPUBLISH via a module} { + set rd1 [redis_deferring_client] + set rd2 [redis_deferring_client] + + assert_equal {1} [ssubscribe $rd1 {chan1}] + assert_equal {1} [subscribe $rd2 {chan1}] + assert_equal 1 [r publish.shard chan1 hello] + assert_equal 1 [r publish.classic chan1 world] + assert_equal {message chan1 hello} [$rd1 read] + assert_equal {message chan1 world} [$rd2 read] + } +} diff --git a/tests/unit/pubsubshard.tcl b/tests/unit/pubsubshard.tcl index 5c3564afe..d0023a841 100644 --- a/tests/unit/pubsubshard.tcl +++ b/tests/unit/pubsubshard.tcl @@ -1,80 +1,4 @@ start_server {tags {"pubsubshard external:skip"}} { - proc __consume_ssubscribe_messages {client type channels} { - set numsub -1 - set counts {} - - for {set i [llength $channels]} {$i > 0} {incr i -1} { - set msg [$client read] - assert_equal $type [lindex $msg 0] - - # when receiving subscribe messages the channels names - # are ordered. when receiving unsubscribe messages - # they are unordered - set idx [lsearch -exact $channels [lindex $msg 1]] - if {[string match "sunsubscribe" $type]} { - assert {$idx >= 0} - } else { - assert {$idx == 0} - } - set channels [lreplace $channels $idx $idx] - - # aggregate the subscription count to return to the caller - lappend counts [lindex $msg 2] - } - - # we should have received messages for channels - assert {[llength $channels] == 0} - return $counts - } - - proc __consume_subscribe_messages {client type channels} { - set numsub -1 - set counts {} - - for {set i [llength $channels]} {$i > 0} {incr i -1} { - set msg [$client read] - assert_equal $type [lindex $msg 0] - - # when receiving subscribe messages the channels names - # are ordered. when receiving unsubscribe messages - # they are unordered - set idx [lsearch -exact $channels [lindex $msg 1]] - if {[string match "unsubscribe" $type]} { - assert {$idx >= 0} - } else { - assert {$idx == 0} - } - set channels [lreplace $channels $idx $idx] - - # aggregate the subscription count to return to the caller - lappend counts [lindex $msg 2] - } - - # we should have received messages for channels - assert {[llength $channels] == 0} - return $counts - } - - proc ssubscribe {client channels} { - $client ssubscribe {*}$channels - __consume_ssubscribe_messages $client ssubscribe $channels - } - - proc subscribe {client channels} { - $client subscribe {*}$channels - __consume_subscribe_messages $client subscribe $channels - } - - proc sunsubscribe {client {channels {}}} { - $client sunsubscribe {*}$channels - __consume_subscribe_messages $client sunsubscribe $channels - } - - proc unsubscribe {client {channels {}}} { - $client unsubscribe {*}$channels - __consume_subscribe_messages $client unsubscribe $channels - } - test "SPUBLISH/SSUBSCRIBE basics" { set rd1 [redis_deferring_client]