import { BlockId } from '@shapeci/types'
import { getShapeClassName, LIGHT_THEME } from '@shapeci/ui'
import { sleep } from '@shapeci/utils'
import { createRoot, Root } from 'react-dom/client'
import { ThemeProvider } from 'styled-components'

import {
    GHOST_CONTAINER_ID,
    THREAD_GAP,
    THREAD_TRANSITION_DELAY_MS,
    THREAD_WIDTH,
} from './constants'
import { ThreadGhost } from './Thread'
import { AnyThread, AnyThreadWithHeight, AnyThreadWithPosition, NewThread } from './types'

export const getContainerElementFromBlockElement = (
    element: HTMLElement | null | undefined,
    isDocumentInPreviewMode: boolean
) => (isDocumentInPreviewMode ? element : element?.parentElement?.parentElement)

export const getThreadOffsetTop = (element: HTMLElement, isDocumentInPreviewMode: boolean) => {
    const editorContainer = document.getElementById(getShapeClassName('editor'))
    const container = getContainerElementFromBlockElement(element, isDocumentInPreviewMode)
    if (!container) return 0

    const { top } = container.getBoundingClientRect()
    const editorTop = editorContainer?.getBoundingClientRect().top || 0
    return top - editorTop
}

export function createNewThread(referencedBlock: BlockId): NewThread {
    return {
        id: 'NEW_THREAD',
        isNew: true,
        referencedBlock,
        meta: {
            dateCreated: new Date(),
        },
    }
}

export const isNewThread = (t: AnyThread): t is NewThread => !!(t as NewThread).isNew

function withBackoff<T>(
    fn: () => Promise<T>,
    maxRetries = 5,
    delay = THREAD_TRANSITION_DELAY_MS
): Promise<T> {
    return fn().catch((err) => {
        if (maxRetries === 0) {
            throw err
        }

        return sleep(delay).then(() => withBackoff(fn, maxRetries - 1, delay * 2))
    })
}

interface ThreadPositionManagerOptions {
    isDocumentInPreviewMode: boolean
    _debug?: boolean // makes the ghost threads visible for easier testing
}

export class ThreadPositionsManager {
    root: Root

    containerIteration = 0

    // cache of height by content of thread and open state
    threadContentToHeight: Record<string, number>

    isDocumentInPreviewMode: boolean

    constructor(options: ThreadPositionManagerOptions) {
        const container = document.createElement('div')
        Object.assign(container.style, {
            position: 'absolute',
            top: '0',
            left: '0',
            width: THREAD_WIDTH,
            height: '100%',
            overflow: 'hidden',
            pointerEvents: 'none',
            display: 'inline-block',
            visibility: 'hidden',
            zIndex: '-1',
        })

        // eslint-disable-next-line no-underscore-dangle
        if (options?._debug) {
            Object.assign(container.style, {
                visibility: 'visible',
                zIndex: '1000',
            })
        }

        document.body.appendChild(container)

        this.root = createRoot(container)
        this.threadContentToHeight = {}
        this.isDocumentInPreviewMode = options.isDocumentInPreviewMode
    }

    // eslint-disable-next-line class-methods-use-this
    private get container() {
        return document.getElementById(GHOST_CONTAINER_ID)
    }

    private get threadContainerId() {
        return `shape__ghost-thread-container-${this.containerIteration}`
    }

    private get threadContainer() {
        return document.getElementById(this.threadContainerId)
    }

    private static getThreadKey(thread: AnyThread, isOpen: boolean) {
        if (isNewThread(thread)) {
            return 'new'
        }

        const openFlag = isOpen ? 'open' : 'closed'

        return thread.comments.map((comment) => comment.content).join('|') + openFlag
    }

    /**
     * Renders un-memoized ghost threads in a hidden container to measure their height
     *
     * @param {AnyThread[]} threads - threads to measure
     * @param {openThreadId} openThreadId - id of the thread that is open
     * @param {editingCommentId} editingCommentId - id of the comment that is being edited
     */
    private renderThreads(
        threads: AnyThread[],
        openThreadId: string | null,
        editingCommentId: string | null
    ) {
        const threadElements = threads
            // avoid rendering threads if we know the height
            .filter((thread) => {
                const key = ThreadPositionsManager.getThreadKey(thread, thread.id === openThreadId)

                if (this.threadContentToHeight[key]) {
                    return false
                }

                return true
            })
            // get ghost thread elements
            .map((thread) => (
                <ThreadGhost
                    key={thread.id}
                    thread={thread}
                    isOpen={thread.id === openThreadId}
                    editingCommentId={editingCommentId}
                />
            ))

        // iterate the container instance
        ++this.containerIteration

        // render the ghost threads
        this.root.render(
            <div id={this.threadContainerId} style={{ width: THREAD_WIDTH }}>
                <ThemeProvider theme={LIGHT_THEME}>{threadElements}</ThemeProvider>
            </div>
        )
    }

    /**
     * Measures the height of a ghost thread in a hidden container
     *
     * @param {AnyThread} thread - thread to measure
     * @param {boolean} isOpen - if the thread is open or not
     */
    private async getThreadHeight(thread: AnyThread, isOpen: boolean) {
        const key = ThreadPositionsManager.getThreadKey(thread, isOpen)

        if (key in this.threadContentToHeight) {
            // if we already have the height, return it
            return this.threadContentToHeight[key]
        }

        const maybeThread = this.threadContainer?.querySelector(
            `[data-ghost-thread-id="${thread.id}"]`
        )

        if (!maybeThread) {
            // generally this happens when the thread is not rendered due to a race condition
            // between the thread being rendered and the thread being measured
            throw new Error(`Could not find thread with id ${thread.id}`)
        }

        const height = maybeThread.scrollHeight

        // cache the height
        this.threadContentToHeight[key] = height

        return height
    }

    /**
     * gets pixel heights of ghost threads as an array
     *
     * @param {AnyThread[]} threads - threads to measure
     * @param {string | null} openThreadId - id of the thread that is open
     * @returns {Promise<number[]>} - array of heights
     */
    private getThreadHeights = async (
        threads: AnyThread[],
        openThreadId: string | null
    ): Promise<number[]> => {
        const heights = await Promise.all(
            threads
                .map((thread) => () => this.getThreadHeight(thread, thread.id === openThreadId))
                .map((fn) => withBackoff<number>(fn))
        )

        return heights
    }

    /**
     * gets threads with their pixel heights as the `height` property
     *
     * @param {AnyThread[]} threads - threads to measure
     * @param {string | null} openThreadId - id of the thread that is open
     * @returns {Promise<AnyThreadWithHeight[]>} - array of heights
     */
    private getThreadsWithHeights = async (
        threads: AnyThread[],
        openThreadId: string | null
    ): Promise<AnyThreadWithHeight[]> => {
        const heights = await this.getThreadHeights(threads, openThreadId)
        return threads.map((thread, index) => ({
            ...thread,
            height: heights[index],
        }))
    }

    /**
     * Layout engine logic to positions threads relative to eachother and
     * the block they are referencing while avoiding overlaps
     *
     * @param {AnyThread[]} threads - threads to position
     * @param {string | null} openThreadId - id of the thread that is open
     * @returns {Promise<number[]>} - array of heights
     */
    private async getThreadsWithPositions(
        threads: AnyThread[],
        openThreadId: string | null
    ): Promise<AnyThreadWithPosition[]> {
        const threadToPosition: Record<string, number> = {}

        // the lowest point on the page that a thread has been positioned so far
        let lastScan = 0

        const threadsWitHeights: AnyThreadWithHeight[] = await this.getThreadsWithHeights(
            threads,
            openThreadId
        )

        threadsWitHeights.forEach((thread) => {
            const referenced = document.getElementById(thread.referencedBlock)

            if (!referenced) {
                console.error(
                    `Could not find referenced block ${thread.referencedBlock} for thread ${thread.id}`
                )
                return
            }

            // the top of the referenced block (ideal position for the thread)
            const blockPosition = getThreadOffsetTop(referenced, this.isDocumentInPreviewMode)

            if (blockPosition < lastScan) {
                // subroutine for collision detection, so threads don't overlap

                threadToPosition[thread.id] = lastScan

                lastScan = threadToPosition[thread.id] + thread.height + THREAD_GAP

                return
            }

            lastScan = blockPosition + thread.height + THREAD_GAP
            threadToPosition[thread.id] = blockPosition
        })

        const threadsWithPositions = threads.map((thread) => ({
            ...thread,
            position: threadToPosition[thread.id],
        }))

        /**
         * TODO: Opening a specific thread without unnecessary re-positioning
         *
         * Currently all threads are shifted up to make room for the current open thread by
         * the same amount `delta`. Some threads above the open thread might not need to be
         * shift at all.
         *
         * The solution is to rewrite this function to order and traverse the threads
         * in two lists bi-directionally. One list for threads above the open thread and
         * one for threads below the open thread.
         *
         * Ask @sirajchokshi for more details if you're interested in working on this.
         */
        let delta = 0

        if (openThreadId) {
            const openThread = threads.find((thread) => thread.id === openThreadId)

            if (!openThread) {
                // if no thread is open there is nothing to center
                return threadsWithPositions
            }

            const referenced = document.getElementById(openThread.referencedBlock)

            if (!referenced) {
                console.error(
                    `Could not find referenced block ${openThread.referencedBlock} for open thread ${openThread.id}`
                )
                return threadsWithPositions
            }

            const blockPosition = getThreadOffsetTop(referenced, this.isDocumentInPreviewMode)

            // the open thread is guaranteed it's ideal position
            const currentPosition = threadToPosition[openThreadId]
            delta = blockPosition - currentPosition
        }

        if (delta) {
            // shift all threads up by the same amount to make room for the open thread
            return threadsWithPositions.map((thread) => ({
                ...thread,
                position: thread.position + delta,
            }))
        }

        return threadsWithPositions
    }

    /**
     * Deletes all threads rendered in the container
     */
    private reset() {
        this.root.render(<></>)
    }

    /**
     * Remove second React root instance and delete all related DOM nodes
     * used in cleanup for any component that implements ThreadPositionsManager
     */
    public async cleanup() {
        // 1 second timeout to allow for all parent unmounting to complete
        await sleep(1000)

        this.root.unmount()
        this.container?.remove()
    }

    /**
     * Renders threads as ghost threads in a hidden container to measure their height
     * and then returns them in the correct position
     *
     * @param {AnyThread} threads - threads to render
     * @param {string | null} openThreadId - id of the thread that is open
     * @returns {Promise<AnyThreadWithPosition[]>}
     */
    public getPositions = async (
        threads: AnyThread[],
        openThreadId: string | null,
        editingCommentId: string | null
    ): Promise<AnyThreadWithPosition[]> => {
        this.renderThreads(threads, openThreadId, editingCommentId)

        await sleep(THREAD_TRANSITION_DELAY_MS)

        const threadsWithPositions = await this.getThreadsWithPositions(threads, openThreadId)

        setTimeout(() => {
            // throw this into the event loop so we can return the positions first
            this.reset()
        }, 0)

        return threadsWithPositions
    }
}
