package studio.goodegg.capsule.downloader

import io.ktor.client.HttpClient
import io.ktor.client.plugins.cache.HttpCache
import io.ktor.client.plugins.cache.storage.CacheStorage
import io.ktor.client.plugins.timeout
import io.ktor.client.request.prepareGet
import io.ktor.client.statement.bodyAsChannel
import io.ktor.http.contentLength
import io.ktor.util.collections.ConcurrentMap
import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.readRemaining
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.io.buffered
import kotlinx.io.files.Path
import kotlinx.io.files.SystemFileSystem
import kotlinx.io.readByteArray
import me.tatarka.inject.annotations.Inject
import kotlin.coroutines.cancellation.CancellationException
import kotlin.math.ceil
import kotlin.time.Duration.Companion.minutes

interface Downloader {
    fun download(
        downloadTag: String,
        directoryPath: String,
        fileName: String,
        url: String,
    ): Flow<DownloadProgress>

    fun cancel(downloadTag: String)
}

@Inject
class DownloaderImpl(httpClient: HttpClient) : Downloader {

    private val inProgressMap = ConcurrentMap<String, Boolean>()

    private val downloadsClient = httpClient.config {
        install(HttpCache) {
            this.publicStorage(CacheStorage.Disabled)
            this.privateStorage(CacheStorage.Disabled)
        }
    }

    override fun download(
        downloadTag: String,
        directoryPath: String,
        fileName: String,
        url: String,
    ): Flow<DownloadProgress> = channelFlow {
        SystemFileSystem.createDirectories(Path(directoryPath))

        val filePath = Path(directoryPath + fileName)
        val fileSink = SystemFileSystem.sink(filePath, append = true) // Append mode

        var attempt = 0
        val maxRetries = 3
        var totalBytesRead =
            SystemFileSystem.metadataOrNull(filePath)?.size ?: 0L // Resume from last byte

        send(DownloadProgress(1, 100, filePath.toString())) // Initial dummy progress

        while (attempt < maxRetries) {
            try {
                downloadsClient
                    .prepareGet(url) {
                        timeout {
                            connectTimeoutMillis = 5.minutes.inWholeMilliseconds
                            requestTimeoutMillis = 5.minutes.inWholeMilliseconds
                            socketTimeoutMillis = 5.minutes.inWholeMilliseconds
                        }
                        if (totalBytesRead > 0) {
                            headers.append("Range", "bytes=$totalBytesRead-") // Resume support
                        }
                    }
                    .execute { httpResponse ->
                        val contentLength = httpResponse.contentLength() ?: -1L
                        val channel: ByteReadChannel = httpResponse.bodyAsChannel()

                        inProgressMap[downloadTag] = true

                        fileSink.buffered().use { sink ->
                            try {
                                while (!channel.isClosedForRead) {
                                    if (inProgressMap[downloadTag] == false) {
                                        throw CancellationException("Download $downloadTag cancelled")
                                    }

                                    val packet = channel.readRemaining(4096)
                                    val bytes = packet.readByteArray()
                                    if (bytes.isNotEmpty()) {
                                        sink.write(bytes)
                                        totalBytesRead += bytes.size

                                        trySend(
                                            DownloadProgress(
                                                totalBytesRead,
                                                contentLength,
                                                filePath.toString(),
                                            ),
                                        )
                                    }
                                }

                                send(
                                    DownloadProgress(
                                        totalBytesRead,
                                        contentLength,
                                        filePath.toString(),
                                    ),
                                )
                            } catch (ex: Throwable) {
                                throw ex
                            } finally {
                                channel.cancel(null)
                                inProgressMap.remove(downloadTag)
                            }
                        }
                    }

                return@channelFlow
            } catch (ex: Exception) {
                attempt++
                if (attempt >= maxRetries) {
                    SystemFileSystem.delete(filePath, mustExist = false) // Final cleanup
                    throw ex
                } else {
                    delay(2000)
                }
            } finally {
                inProgressMap.remove(downloadTag)
            }
        }
    }

    override fun cancel(downloadTag: String) {
        if (inProgressMap[downloadTag] == true) {
            inProgressMap[downloadTag] = false
        }
    }
}

data class DownloadProgress(
    val bytesRead: Long,
    val totalBytes: Long,
    val filePath: String,
) {
    val percentage = ceil((bytesRead.toFloat() / totalBytes.toFloat()) * 100).toInt()
}
