[download] Fix corrupted download after mirror fail-over bug

When we had already read bytes from one mirror, but needed to fail-over to another mirror, we would re-read the same bytes again corrupting our download.

Fixes #2412

This commit resumes the download where the last mirror left off.
This commit is contained in:
Torsten Grote 2022-09-02 14:39:28 -03:00
parent f6d1637d92
commit c8514adb94
No known key found for this signature in database
GPG Key ID: 3E5F77D92CF891FF
3 changed files with 147 additions and 23 deletions

View File

@ -49,8 +49,9 @@ public open class HttpManager @JvmOverloads constructor(
private val httpClientEngineFactory: HttpClientEngineFactory<*> = getHttpClientEngineFactory(),
) {
private companion object {
internal companion object {
val log = KotlinLogging.logger {}
const val READ_BUFFER = 8 * 1024
}
private var httpClient = getNewHttpClient(proxyConfig)
@ -122,18 +123,23 @@ public open class HttpManager @JvmOverloads constructor(
request: DownloadRequest,
skipFirstBytes: Long? = null,
receiver: BytesReceiver,
): Unit = mirrorChooser.mirrorRequest(request) { mirror, url ->
getHttpStatement(request, mirror, url, skipFirstBytes).execute { response ->
val contentLength = response.contentLength()
if (skipFirstBytes != null && response.status != PartialContent) {
throw NoResumeException()
}
val channel: ByteReadChannel = response.body()
val limit = 8L * 1024L
while (!channel.isClosedForRead) {
val packet = channel.readRemaining(limit)
while (!packet.isEmpty) {
receiver.receive(packet.readBytes(), contentLength)
) {
// remember what we've read already, so we can pass it to the next mirror if needed
var skipBytes = skipFirstBytes ?: 0L
mirrorChooser.mirrorRequest(request) { mirror, url ->
getHttpStatement(request, mirror, url, skipBytes).execute { response ->
val contentLength = response.contentLength()
if (skipBytes > 0L && response.status != PartialContent) {
throw NoResumeException()
}
val channel: ByteReadChannel = response.body()
while (!channel.isClosedForRead) {
val packet = channel.readRemaining(READ_BUFFER.toLong())
while (!packet.isEmpty) {
val readBytes = packet.readBytes()
skipBytes += readBytes.size
receiver.receive(readBytes, contentLength)
}
}
}
}
@ -143,7 +149,7 @@ public open class HttpManager @JvmOverloads constructor(
request: DownloadRequest,
mirror: Mirror,
url: Url,
skipFirstBytes: Long? = null,
skipFirstBytes: Long,
): HttpStatement {
resetProxyIfNeeded(request.proxy, mirror)
log.info { "GET $url" }
@ -154,7 +160,7 @@ public open class HttpManager @JvmOverloads constructor(
// increase connect timeout if using Tor mirror
if (mirror.isOnion()) timeout { connectTimeoutMillis = 20_000 }
// add range header if set
if (skipFirstBytes != null) header(Range, "bytes=$skipFirstBytes-")
if (skipFirstBytes > 0) header(Range, "bytes=$skipFirstBytes-")
}
}
@ -167,7 +173,7 @@ public open class HttpManager @JvmOverloads constructor(
): ByteReadChannel {
// TODO check if closed
return mirrorChooser.mirrorRequest(request) { mirror, url ->
getHttpStatement(request, mirror, url, skipFirstBytes).body()
getHttpStatement(request, mirror, url, skipFirstBytes ?: 0L).body()
}
}

View File

@ -4,7 +4,16 @@ import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.engine.HttpClientEngineFactory
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.MockEngineConfig
import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.LookAheadSession
import io.ktor.utils.io.LookAheadSuspendSession
import io.ktor.utils.io.ReadSession
import io.ktor.utils.io.SuspendableReadSession
import io.ktor.utils.io.bits.Memory
import io.ktor.utils.io.core.ByteReadPacket
import io.ktor.utils.io.core.internal.ChunkBuffer
import kotlinx.coroutines.runBlocking
import java.nio.ByteBuffer
import kotlin.random.Random
fun getRandomString(length: Int = Random.nextInt(4, 16)): String {
@ -38,3 +47,59 @@ internal fun getIndexFile(
override fun serialize(): String = error("Not yet implemented")
}
}
@Suppress("OVERRIDE_DEPRECATION", "OverridingDeprecatedMember", "DEPRECATION")
internal abstract class TestByteReadChannel : ByteReadChannel {
override val closedCause: Throwable? get() = error("Not yet implemented")
override val isClosedForRead: Boolean get() = error("Not yet implemented")
override val isClosedForWrite: Boolean get() = error("Not yet implemented")
override val totalBytesRead: Long get() = error("Not yet implemented")
override suspend fun awaitContent() = error("Not yet implemented")
override fun cancel(cause: Throwable?): Boolean = error("Not yet implemented")
override suspend fun discard(max: Long): Long = error("Not yet implemented")
override fun <R> lookAhead(visitor: LookAheadSession.() -> R): R = error("Not yet implemented")
override suspend fun <R> lookAheadSuspend(visitor: suspend LookAheadSuspendSession.() -> R): R =
error("Not yet implemented")
override suspend fun peekTo(
destination: Memory,
destinationOffset: Long,
offset: Long,
min: Long,
max: Long,
): Long = error("Not yet implemented")
override suspend fun read(min: Int, consumer: (ByteBuffer) -> Unit) =
error("Not yet implemented")
override suspend fun readAvailable(dst: ByteBuffer): Int = error("Not yet implemented")
override suspend fun readAvailable(dst: ByteArray, offset: Int, length: Int): Int =
error("Not yet implemented")
override fun readAvailable(min: Int, block: (ByteBuffer) -> Unit): Int =
error("Not yet implemented")
override suspend fun readBoolean(): Boolean = error("Not yet implemented")
override suspend fun readByte(): Byte = error("Not yet implemented")
override suspend fun readDouble(): Double = error("Not yet implemented")
override suspend fun readFloat(): Float = error("Not yet implemented")
override suspend fun readFully(dst: ChunkBuffer, n: Int) = error("Not yet implemented")
override suspend fun readFully(dst: ByteBuffer): Int = error("Not yet implemented")
override suspend fun readFully(dst: ByteArray, offset: Int, length: Int) =
error("Not yet implemented")
override suspend fun readInt(): Int = error("Not yet implemented")
override suspend fun readLong(): Long = error("Not yet implemented")
override suspend fun readPacket(size: Int): ByteReadPacket = error("Not yet implemented")
override suspend fun readRemaining(limit: Long): ByteReadPacket = error("Not yet implemented")
override fun readSession(consumer: ReadSession.() -> Unit) = error("Not yet implemented")
override suspend fun readShort(): Short = error("Not yet implemented")
override suspend fun readSuspendableSession(
consumer: suspend SuspendableReadSession.() -> Unit,
) = error("Not yet implemented")
override suspend fun readUTF8Line(limit: Int): String? = error("Not yet implemented")
override suspend fun <A : Appendable> readUTF8LineTo(out: A, limit: Int): Boolean =
error("Not yet implemented")
}

View File

@ -3,15 +3,18 @@ package org.fdroid.download
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.engine.HttpClientEngineFactory
import io.ktor.client.engine.ProxyBuilder
import io.ktor.client.engine.config
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.MockEngineConfig
import io.ktor.client.engine.mock.respond
import io.ktor.client.engine.mock.respondError
import io.ktor.client.engine.mock.respondOk
import io.ktor.client.engine.mock.respondRedirect
import io.ktor.client.network.sockets.SocketTimeoutException
import io.ktor.client.plugins.ClientRequestException
import io.ktor.client.plugins.RedirectResponseException
import io.ktor.client.plugins.ServerResponseException
import io.ktor.client.request.HttpRequestData
import io.ktor.http.HttpHeaders.Authorization
import io.ktor.http.HttpHeaders.ETag
import io.ktor.http.HttpHeaders.Range
@ -24,6 +27,10 @@ import io.ktor.http.HttpStatusCode.Companion.PartialContent
import io.ktor.http.HttpStatusCode.Companion.TemporaryRedirect
import io.ktor.http.Url
import io.ktor.http.headersOf
import io.ktor.utils.io.core.internal.ChunkBuffer
import io.ktor.utils.io.core.writeFully
import org.fdroid.TestByteReadChannel
import org.fdroid.download.HttpManager.Companion.READ_BUFFER
import org.fdroid.get
import org.fdroid.getRandomString
import org.fdroid.runSuspend
@ -122,13 +129,7 @@ internal class HttpManagerTest {
var requestNum = 1
val mockEngine = MockEngine { request ->
assertNotNull(request.headers[Range])
val (fromStr, endStr) = request.headers[Range]!!
.replace("bytes=", "")
.split('-')
val from =
fromStr.toIntOrNull() ?: fail("No valid content range ${request.headers[Range]}")
assertEquals("", endStr)
val from = request.getByteRangeFrom()
assertEquals(skipBytes, from)
if (requestNum++ == 1) respond(content.copyOfRange(from, content.size), PartialContent)
else respond(content, OK)
@ -146,6 +147,50 @@ internal class HttpManagerTest {
}
}
@Test
fun testResumeDownloadWhenMirrorFailOver() = runSuspend {
val failBytes = READ_BUFFER
val content = Random.nextBytes(failBytes * 2)
val readChannel = object : TestByteReadChannel() {
var wasRead = 0
override val availableForRead: Int = 4096
override suspend fun readAvailable(dst: ChunkBuffer): Int {
// We allow three reads. Only the first two give us the first half of content.
// While the third seems to be required, it isn't filling the buffer
// before we throw the exception, so it isn't considered.
if (wasRead == 3) throw SocketTimeoutException("boom!")
dst.writeFully(content, wasRead * 4096, 4096)
wasRead++
return 4096
}
}
val mockEngine = MockEngine.config {
reuseHandlers = false
addHandler {
respond(readChannel, OK)
}
addHandler { request ->
val from = request.getByteRangeFrom()
assertEquals(failBytes, from)
respond(content.copyOfRange(from, content.size), PartialContent)
}
}
val httpManager = HttpManager(userAgent, null, httpClientEngineFactory = mockEngine)
var chunk = 0
httpManager.get(downloadRequest) { bytes, _ ->
// we expect two chunks: 0 and 1
// the first is the first half of content and the second is the second half
val offset = chunk * READ_BUFFER
val expectedBytes = content.copyOfRange(offset, offset + READ_BUFFER)
assertContentEquals(expectedBytes, bytes)
chunk++
}
assertEquals(2, chunk)
}
@Test
fun testMirrorFallback() = runSuspend {
val mockEngine = MockEngine { respondError(InternalServerError) }
@ -315,4 +360,12 @@ internal class HttpManagerTest {
assertEquals(2, numEngines)
}
private fun HttpRequestData.getByteRangeFrom(): Int {
val (fromStr, endStr) = (headers[Range] ?: fail("No Range header"))
.replace("bytes=", "")
.split('-')
assertEquals("", endStr)
return fromStr.toIntOrNull() ?: fail("No valid content range ${headers[Range]}")
}
}