[download] Fix corrupted download after mirror fail-over bug
When we had already read bytes from one mirror, but needed to fail-over to another mirror, we would re-read the same bytes again corrupting our download. Fixes #2412 This commit resumes the download where the last mirror left off.
This commit is contained in:
parent
f6d1637d92
commit
c8514adb94
|
@ -49,8 +49,9 @@ public open class HttpManager @JvmOverloads constructor(
|
|||
private val httpClientEngineFactory: HttpClientEngineFactory<*> = getHttpClientEngineFactory(),
|
||||
) {
|
||||
|
||||
private companion object {
|
||||
internal companion object {
|
||||
val log = KotlinLogging.logger {}
|
||||
const val READ_BUFFER = 8 * 1024
|
||||
}
|
||||
|
||||
private var httpClient = getNewHttpClient(proxyConfig)
|
||||
|
@ -122,18 +123,23 @@ public open class HttpManager @JvmOverloads constructor(
|
|||
request: DownloadRequest,
|
||||
skipFirstBytes: Long? = null,
|
||||
receiver: BytesReceiver,
|
||||
): Unit = mirrorChooser.mirrorRequest(request) { mirror, url ->
|
||||
getHttpStatement(request, mirror, url, skipFirstBytes).execute { response ->
|
||||
val contentLength = response.contentLength()
|
||||
if (skipFirstBytes != null && response.status != PartialContent) {
|
||||
throw NoResumeException()
|
||||
}
|
||||
val channel: ByteReadChannel = response.body()
|
||||
val limit = 8L * 1024L
|
||||
while (!channel.isClosedForRead) {
|
||||
val packet = channel.readRemaining(limit)
|
||||
while (!packet.isEmpty) {
|
||||
receiver.receive(packet.readBytes(), contentLength)
|
||||
) {
|
||||
// remember what we've read already, so we can pass it to the next mirror if needed
|
||||
var skipBytes = skipFirstBytes ?: 0L
|
||||
mirrorChooser.mirrorRequest(request) { mirror, url ->
|
||||
getHttpStatement(request, mirror, url, skipBytes).execute { response ->
|
||||
val contentLength = response.contentLength()
|
||||
if (skipBytes > 0L && response.status != PartialContent) {
|
||||
throw NoResumeException()
|
||||
}
|
||||
val channel: ByteReadChannel = response.body()
|
||||
while (!channel.isClosedForRead) {
|
||||
val packet = channel.readRemaining(READ_BUFFER.toLong())
|
||||
while (!packet.isEmpty) {
|
||||
val readBytes = packet.readBytes()
|
||||
skipBytes += readBytes.size
|
||||
receiver.receive(readBytes, contentLength)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -143,7 +149,7 @@ public open class HttpManager @JvmOverloads constructor(
|
|||
request: DownloadRequest,
|
||||
mirror: Mirror,
|
||||
url: Url,
|
||||
skipFirstBytes: Long? = null,
|
||||
skipFirstBytes: Long,
|
||||
): HttpStatement {
|
||||
resetProxyIfNeeded(request.proxy, mirror)
|
||||
log.info { "GET $url" }
|
||||
|
@ -154,7 +160,7 @@ public open class HttpManager @JvmOverloads constructor(
|
|||
// increase connect timeout if using Tor mirror
|
||||
if (mirror.isOnion()) timeout { connectTimeoutMillis = 20_000 }
|
||||
// add range header if set
|
||||
if (skipFirstBytes != null) header(Range, "bytes=$skipFirstBytes-")
|
||||
if (skipFirstBytes > 0) header(Range, "bytes=$skipFirstBytes-")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -167,7 +173,7 @@ public open class HttpManager @JvmOverloads constructor(
|
|||
): ByteReadChannel {
|
||||
// TODO check if closed
|
||||
return mirrorChooser.mirrorRequest(request) { mirror, url ->
|
||||
getHttpStatement(request, mirror, url, skipFirstBytes).body()
|
||||
getHttpStatement(request, mirror, url, skipFirstBytes ?: 0L).body()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,16 @@ import io.ktor.client.engine.HttpClientEngine
|
|||
import io.ktor.client.engine.HttpClientEngineFactory
|
||||
import io.ktor.client.engine.mock.MockEngine
|
||||
import io.ktor.client.engine.mock.MockEngineConfig
|
||||
import io.ktor.utils.io.ByteReadChannel
|
||||
import io.ktor.utils.io.LookAheadSession
|
||||
import io.ktor.utils.io.LookAheadSuspendSession
|
||||
import io.ktor.utils.io.ReadSession
|
||||
import io.ktor.utils.io.SuspendableReadSession
|
||||
import io.ktor.utils.io.bits.Memory
|
||||
import io.ktor.utils.io.core.ByteReadPacket
|
||||
import io.ktor.utils.io.core.internal.ChunkBuffer
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import java.nio.ByteBuffer
|
||||
import kotlin.random.Random
|
||||
|
||||
fun getRandomString(length: Int = Random.nextInt(4, 16)): String {
|
||||
|
@ -38,3 +47,59 @@ internal fun getIndexFile(
|
|||
override fun serialize(): String = error("Not yet implemented")
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("OVERRIDE_DEPRECATION", "OverridingDeprecatedMember", "DEPRECATION")
|
||||
internal abstract class TestByteReadChannel : ByteReadChannel {
|
||||
override val closedCause: Throwable? get() = error("Not yet implemented")
|
||||
override val isClosedForRead: Boolean get() = error("Not yet implemented")
|
||||
override val isClosedForWrite: Boolean get() = error("Not yet implemented")
|
||||
override val totalBytesRead: Long get() = error("Not yet implemented")
|
||||
|
||||
override suspend fun awaitContent() = error("Not yet implemented")
|
||||
override fun cancel(cause: Throwable?): Boolean = error("Not yet implemented")
|
||||
override suspend fun discard(max: Long): Long = error("Not yet implemented")
|
||||
override fun <R> lookAhead(visitor: LookAheadSession.() -> R): R = error("Not yet implemented")
|
||||
override suspend fun <R> lookAheadSuspend(visitor: suspend LookAheadSuspendSession.() -> R): R =
|
||||
error("Not yet implemented")
|
||||
|
||||
override suspend fun peekTo(
|
||||
destination: Memory,
|
||||
destinationOffset: Long,
|
||||
offset: Long,
|
||||
min: Long,
|
||||
max: Long,
|
||||
): Long = error("Not yet implemented")
|
||||
|
||||
override suspend fun read(min: Int, consumer: (ByteBuffer) -> Unit) =
|
||||
error("Not yet implemented")
|
||||
|
||||
override suspend fun readAvailable(dst: ByteBuffer): Int = error("Not yet implemented")
|
||||
override suspend fun readAvailable(dst: ByteArray, offset: Int, length: Int): Int =
|
||||
error("Not yet implemented")
|
||||
|
||||
override fun readAvailable(min: Int, block: (ByteBuffer) -> Unit): Int =
|
||||
error("Not yet implemented")
|
||||
|
||||
override suspend fun readBoolean(): Boolean = error("Not yet implemented")
|
||||
override suspend fun readByte(): Byte = error("Not yet implemented")
|
||||
override suspend fun readDouble(): Double = error("Not yet implemented")
|
||||
override suspend fun readFloat(): Float = error("Not yet implemented")
|
||||
override suspend fun readFully(dst: ChunkBuffer, n: Int) = error("Not yet implemented")
|
||||
override suspend fun readFully(dst: ByteBuffer): Int = error("Not yet implemented")
|
||||
override suspend fun readFully(dst: ByteArray, offset: Int, length: Int) =
|
||||
error("Not yet implemented")
|
||||
|
||||
override suspend fun readInt(): Int = error("Not yet implemented")
|
||||
override suspend fun readLong(): Long = error("Not yet implemented")
|
||||
override suspend fun readPacket(size: Int): ByteReadPacket = error("Not yet implemented")
|
||||
override suspend fun readRemaining(limit: Long): ByteReadPacket = error("Not yet implemented")
|
||||
override fun readSession(consumer: ReadSession.() -> Unit) = error("Not yet implemented")
|
||||
override suspend fun readShort(): Short = error("Not yet implemented")
|
||||
override suspend fun readSuspendableSession(
|
||||
consumer: suspend SuspendableReadSession.() -> Unit,
|
||||
) = error("Not yet implemented")
|
||||
|
||||
override suspend fun readUTF8Line(limit: Int): String? = error("Not yet implemented")
|
||||
override suspend fun <A : Appendable> readUTF8LineTo(out: A, limit: Int): Boolean =
|
||||
error("Not yet implemented")
|
||||
}
|
||||
|
|
|
@ -3,15 +3,18 @@ package org.fdroid.download
|
|||
import io.ktor.client.engine.HttpClientEngine
|
||||
import io.ktor.client.engine.HttpClientEngineFactory
|
||||
import io.ktor.client.engine.ProxyBuilder
|
||||
import io.ktor.client.engine.config
|
||||
import io.ktor.client.engine.mock.MockEngine
|
||||
import io.ktor.client.engine.mock.MockEngineConfig
|
||||
import io.ktor.client.engine.mock.respond
|
||||
import io.ktor.client.engine.mock.respondError
|
||||
import io.ktor.client.engine.mock.respondOk
|
||||
import io.ktor.client.engine.mock.respondRedirect
|
||||
import io.ktor.client.network.sockets.SocketTimeoutException
|
||||
import io.ktor.client.plugins.ClientRequestException
|
||||
import io.ktor.client.plugins.RedirectResponseException
|
||||
import io.ktor.client.plugins.ServerResponseException
|
||||
import io.ktor.client.request.HttpRequestData
|
||||
import io.ktor.http.HttpHeaders.Authorization
|
||||
import io.ktor.http.HttpHeaders.ETag
|
||||
import io.ktor.http.HttpHeaders.Range
|
||||
|
@ -24,6 +27,10 @@ import io.ktor.http.HttpStatusCode.Companion.PartialContent
|
|||
import io.ktor.http.HttpStatusCode.Companion.TemporaryRedirect
|
||||
import io.ktor.http.Url
|
||||
import io.ktor.http.headersOf
|
||||
import io.ktor.utils.io.core.internal.ChunkBuffer
|
||||
import io.ktor.utils.io.core.writeFully
|
||||
import org.fdroid.TestByteReadChannel
|
||||
import org.fdroid.download.HttpManager.Companion.READ_BUFFER
|
||||
import org.fdroid.get
|
||||
import org.fdroid.getRandomString
|
||||
import org.fdroid.runSuspend
|
||||
|
@ -122,13 +129,7 @@ internal class HttpManagerTest {
|
|||
|
||||
var requestNum = 1
|
||||
val mockEngine = MockEngine { request ->
|
||||
assertNotNull(request.headers[Range])
|
||||
val (fromStr, endStr) = request.headers[Range]!!
|
||||
.replace("bytes=", "")
|
||||
.split('-')
|
||||
val from =
|
||||
fromStr.toIntOrNull() ?: fail("No valid content range ${request.headers[Range]}")
|
||||
assertEquals("", endStr)
|
||||
val from = request.getByteRangeFrom()
|
||||
assertEquals(skipBytes, from)
|
||||
if (requestNum++ == 1) respond(content.copyOfRange(from, content.size), PartialContent)
|
||||
else respond(content, OK)
|
||||
|
@ -146,6 +147,50 @@ internal class HttpManagerTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testResumeDownloadWhenMirrorFailOver() = runSuspend {
|
||||
val failBytes = READ_BUFFER
|
||||
val content = Random.nextBytes(failBytes * 2)
|
||||
|
||||
val readChannel = object : TestByteReadChannel() {
|
||||
var wasRead = 0
|
||||
override val availableForRead: Int = 4096
|
||||
override suspend fun readAvailable(dst: ChunkBuffer): Int {
|
||||
// We allow three reads. Only the first two give us the first half of content.
|
||||
// While the third seems to be required, it isn't filling the buffer
|
||||
// before we throw the exception, so it isn't considered.
|
||||
if (wasRead == 3) throw SocketTimeoutException("boom!")
|
||||
dst.writeFully(content, wasRead * 4096, 4096)
|
||||
wasRead++
|
||||
return 4096
|
||||
}
|
||||
}
|
||||
|
||||
val mockEngine = MockEngine.config {
|
||||
reuseHandlers = false
|
||||
addHandler {
|
||||
respond(readChannel, OK)
|
||||
}
|
||||
addHandler { request ->
|
||||
val from = request.getByteRangeFrom()
|
||||
assertEquals(failBytes, from)
|
||||
respond(content.copyOfRange(from, content.size), PartialContent)
|
||||
}
|
||||
}
|
||||
val httpManager = HttpManager(userAgent, null, httpClientEngineFactory = mockEngine)
|
||||
|
||||
var chunk = 0
|
||||
httpManager.get(downloadRequest) { bytes, _ ->
|
||||
// we expect two chunks: 0 and 1
|
||||
// the first is the first half of content and the second is the second half
|
||||
val offset = chunk * READ_BUFFER
|
||||
val expectedBytes = content.copyOfRange(offset, offset + READ_BUFFER)
|
||||
assertContentEquals(expectedBytes, bytes)
|
||||
chunk++
|
||||
}
|
||||
assertEquals(2, chunk)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMirrorFallback() = runSuspend {
|
||||
val mockEngine = MockEngine { respondError(InternalServerError) }
|
||||
|
@ -315,4 +360,12 @@ internal class HttpManagerTest {
|
|||
assertEquals(2, numEngines)
|
||||
}
|
||||
|
||||
private fun HttpRequestData.getByteRangeFrom(): Int {
|
||||
val (fromStr, endStr) = (headers[Range] ?: fail("No Range header"))
|
||||
.replace("bytes=", "")
|
||||
.split('-')
|
||||
assertEquals("", endStr)
|
||||
return fromStr.toIntOrNull() ?: fail("No valid content range ${headers[Range]}")
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue