diff --git a/src/connection.c b/src/connection.c index 3a17d983d..11fc4ba28 100644 --- a/src/connection.c +++ b/src/connection.c @@ -178,6 +178,21 @@ static int connSocketWrite(connection *conn, const void *data, size_t data_len) return ret; } +static int connSocketWritev(connection *conn, const struct iovec *iov, int iovcnt) { + int ret = writev(conn->fd, iov, iovcnt); + if (ret < 0 && errno != EAGAIN) { + conn->last_errno = errno; + + /* Don't overwrite the state of a connection that is not already + * connected, not to mess with handler callbacks. + */ + if (errno != EINTR && conn->state == CONN_STATE_CONNECTED) + conn->state = CONN_STATE_ERROR; + } + + return ret; +} + static int connSocketRead(connection *conn, void *buf, size_t buf_len) { int ret = read(conn->fd, buf, buf_len); if (!ret) { @@ -349,6 +364,7 @@ ConnectionType CT_Socket = { .ae_handler = connSocketEventHandler, .close = connSocketClose, .write = connSocketWrite, + .writev = connSocketWritev, .read = connSocketRead, .accept = connSocketAccept, .connect = connSocketConnect, diff --git a/src/connection.h b/src/connection.h index 07c1d4dd8..dad2e2fd6 100644 --- a/src/connection.h +++ b/src/connection.h @@ -32,6 +32,7 @@ #define __REDIS_CONNECTION_H #include +#include #define CONN_INFO_LEN 32 @@ -59,6 +60,7 @@ typedef struct ConnectionType { void (*ae_handler)(struct aeEventLoop *el, int fd, void *clientData, int mask); int (*connect)(struct connection *conn, const char *addr, int port, const char *source_addr, ConnectionCallbackFunc connect_handler); int (*write)(struct connection *conn, const void *data, size_t data_len); + int (*writev)(struct connection *conn, const struct iovec *iov, int iovcnt); int (*read)(struct connection *conn, void *buf, size_t buf_len); void (*close)(struct connection *conn); int (*accept)(struct connection *conn, ConnectionCallbackFunc accept_handler); @@ -142,6 +144,18 @@ static inline int connWrite(connection *conn, const void *data, size_t data_len) return conn->type->write(conn, data, data_len); } +/* Gather output data from the iovcnt buffers specified by the members of the iov + * array: iov[0], iov[1], ..., iov[iovcnt-1] and write to connection, behaves the same as writev(3). + * + * Like writev(3), a short write is possible. A -1 return indicates an error. + * + * The caller should NOT rely on errno. Testing for an EAGAIN-like condition, use + * connGetState() to see if the connection state is still CONN_STATE_CONNECTED. + */ +static inline int connWritev(connection *conn, const struct iovec *iov, int iovcnt) { + return conn->type->writev(conn, iov, iovcnt); +} + /* Read from the connection, behaves the same as read(2). * * Like read(2), a short read is possible. A return value of 0 will indicate the diff --git a/src/networking.c b/src/networking.c index 2a8bcc0b9..27bfbad8b 100644 --- a/src/networking.c +++ b/src/networking.c @@ -1697,10 +1697,82 @@ client *lookupClientByID(uint64_t id) { return (c == raxNotFound) ? NULL : c; } +/* This function should be called from _writeToClient when the reply list is not empty, + * it gathers the scattered buffers from reply list and sends them away with connWritev. + * If we write successfully, it returns C_OK, otherwise, C_ERR is returned, + * and 'nwritten' is an output parameter, it means how many bytes server write + * to client. */ +static int _writevToClient(client *c, ssize_t *nwritten) { + struct iovec iov[IOV_MAX]; + int iovcnt = 0; + size_t iov_bytes_len = 0; + /* If the static reply buffer is not empty, + * add it to the iov array for writev() as well. */ + if (c->bufpos > 0) { + iov[iovcnt].iov_base = c->buf + c->sentlen; + iov[iovcnt].iov_len = c->bufpos - c->sentlen; + iov_bytes_len += iov[iovcnt++].iov_len; + } + /* The first node of reply list might be incomplete from the last call, + * thus it needs to be calibrated to get the actual data address and length. */ + size_t offset = c->bufpos > 0 ? 0 : c->sentlen; + listIter iter; + listNode *next; + clientReplyBlock *o; + listRewind(c->reply, &iter); + while ((next = listNext(&iter)) && iovcnt < IOV_MAX && iov_bytes_len < NET_MAX_WRITES_PER_EVENT) { + o = listNodeValue(next); + if (o->used == 0) { /* empty node, just release it and skip. */ + c->reply_bytes -= o->size; + listDelNode(c->reply, next); + offset = 0; + continue; + } + + iov[iovcnt].iov_base = o->buf + offset; + iov[iovcnt].iov_len = o->used - offset; + iov_bytes_len += iov[iovcnt++].iov_len; + offset = 0; + } + if (iovcnt == 0) return C_OK; + *nwritten = connWritev(c->conn, iov, iovcnt); + if (*nwritten <= 0) return C_ERR; + + /* Locate the new node which has leftover data and + * release all nodes in front of it. */ + ssize_t remaining = *nwritten; + if (c->bufpos > 0) { /* deal with static reply buffer first. */ + int buf_len = c->bufpos - c->sentlen; + c->sentlen += remaining; + /* If the buffer was sent, set bufpos to zero to continue with + * the remainder of the reply. */ + if (remaining >= buf_len) { + c->bufpos = 0; + c->sentlen = 0; + } + remaining -= buf_len; + } + listRewind(c->reply, &iter); + while (remaining > 0) { + next = listNext(&iter); + o = listNodeValue(next); + if (remaining < (ssize_t)(o->used - c->sentlen)) { + c->sentlen += remaining; + break; + } + remaining -= (ssize_t)(o->used - c->sentlen); + c->reply_bytes -= o->size; + listDelNode(c->reply, next); + c->sentlen = 0; + } + + return C_OK; +} + /* This function does actual writing output buffers to different types of * clients, it is called by writeToClient. - * If we write successfully, it return C_OK, otherwise, C_ERR is returned, - * And 'nwritten' is a output parameter, it means how many bytes server write + * If we write successfully, it returns C_OK, otherwise, C_ERR is returned, + * and 'nwritten' is an output parameter, it means how many bytes server write * to client. */ int _writeToClient(client *c, ssize_t *nwritten) { *nwritten = 0; @@ -1729,8 +1801,18 @@ int _writeToClient(client *c, ssize_t *nwritten) { return C_OK; } - if (c->bufpos > 0) { - *nwritten = connWrite(c->conn,c->buf+c->sentlen,c->bufpos-c->sentlen); + /* When the reply list is not empty, it's better to use writev to save us some + * system calls and TCP packets. */ + if (listLength(c->reply) > 0) { + int ret = _writevToClient(c, nwritten); + if (ret != C_OK) return ret; + + /* If there are no longer objects in the list, we expect + * the count of reply bytes to be exactly zero. */ + if (listLength(c->reply) == 0) + serverAssert(c->reply_bytes == 0); + } else if (c->bufpos > 0) { + *nwritten = connWrite(c->conn, c->buf + c->sentlen, c->bufpos - c->sentlen); if (*nwritten <= 0) return C_ERR; c->sentlen += *nwritten; @@ -1740,31 +1822,8 @@ int _writeToClient(client *c, ssize_t *nwritten) { c->bufpos = 0; c->sentlen = 0; } - } else { - clientReplyBlock *o = listNodeValue(listFirst(c->reply)); - size_t objlen = o->used; + } - if (objlen == 0) { - c->reply_bytes -= o->size; - listDelNode(c->reply,listFirst(c->reply)); - return C_OK; - } - - *nwritten = connWrite(c->conn, o->buf + c->sentlen, objlen - c->sentlen); - if (*nwritten <= 0) return C_ERR; - c->sentlen += *nwritten; - - /* If we fully sent the object on head go to the next one */ - if (c->sentlen == objlen) { - c->reply_bytes -= o->size; - listDelNode(c->reply,listFirst(c->reply)); - c->sentlen = 0; - /* If there are no longer objects in the list, we expect - * the count of reply bytes to be exactly zero. */ - if (listLength(c->reply) == 0) - serverAssert(c->reply_bytes == 0); - } - } return C_OK; } diff --git a/src/tls.c b/src/tls.c index c6449b2a7..66c485ac2 100644 --- a/src/tls.c +++ b/src/tls.c @@ -42,6 +42,7 @@ #if OPENSSL_VERSION_NUMBER >= 0x30000000L #include #endif +#include #define REDIS_TLS_PROTO_TLSv1 (1<<0) #define REDIS_TLS_PROTO_TLSv1_1 (1<<1) @@ -819,6 +820,43 @@ static int connTLSWrite(connection *conn_, const void *data, size_t data_len) { return ret; } +static int connTLSWritev(connection *conn_, const struct iovec *iov, int iovcnt) { + if (iovcnt == 1) return connTLSWrite(conn_, iov[0].iov_base, iov[0].iov_len); + + /* Accumulate the amount of bytes of each buffer and check if it exceeds NET_MAX_WRITES_PER_EVENT. */ + size_t iov_bytes_len = 0; + for (int i = 0; i < iovcnt; i++) { + iov_bytes_len += iov[i].iov_len; + if (iov_bytes_len > NET_MAX_WRITES_PER_EVENT) break; + } + + /* The amount of all buffers is greater than NET_MAX_WRITES_PER_EVENT, + * which is not worth doing so much memory copying to reduce system calls, + * therefore, invoke connTLSWrite() multiple times to avoid memory copies. */ + if (iov_bytes_len > NET_MAX_WRITES_PER_EVENT) { + size_t tot_sent = 0; + for (int i = 0; i < iovcnt; i++) { + size_t sent = connTLSWrite(conn_, iov[i].iov_base, iov[i].iov_len); + if (sent <= 0) return tot_sent > 0 ? tot_sent : sent; + tot_sent += sent; + if (sent != iov[i].iov_len) break; + } + return tot_sent; + } + + /* The amount of all buffers is less than NET_MAX_WRITES_PER_EVENT, + * which is worth doing more memory copies in exchange for fewer system calls, + * so concatenate these scattered buffers into a contiguous piece of memory + * and send it away by one call to connTLSWrite(). */ + char buf[iov_bytes_len]; + size_t offset = 0; + for (int i = 0; i < iovcnt; i++) { + memcpy(buf + offset, iov[i].iov_base, iov[i].iov_len); + offset += iov[i].iov_len; + } + return connTLSWrite(conn_, buf, iov_bytes_len); +} + static int connTLSRead(connection *conn_, void *buf, size_t buf_len) { tls_connection *conn = (tls_connection *) conn_; int ret; @@ -982,6 +1020,7 @@ ConnectionType CT_TLS = { .blocking_connect = connTLSBlockingConnect, .read = connTLSRead, .write = connTLSWrite, + .writev = connTLSWritev, .close = connTLSClose, .set_write_handler = connTLSSetWriteHandler, .set_read_handler = connTLSSetReadHandler,