import { queryClient } from "@app/QueryClientWithHeaders";
import { GraphQLAiRepository } from "@app/repositories/GraphQLRepositories/ai";
import { RestAiRepository } from "@app/repositories/RestRepositories/aiRepository";
// We are using the same fn that handles optimistic updates, as it takes care of
// canceling outgoing refetchings, getting previous query values and setting the cache
import { tanstackUnitaryOptimisticUpdate as updateCache } from "@app/shared/utils/optimisticUpdates";
import { AiThreadMessageRole } from "@generated/client/graphql";
import { skipToken, useMutation, useQuery } from "@tanstack/react-query";
import { useAtom } from "jotai";
import { AssistantStream } from "openai/lib/AssistantStream";
import { type TextDelta } from "openai/resources/beta/threads/messages";
import {
    RunStep,
    ToolCall,
} from "openai/resources/beta/threads/runs/steps.mjs";
import { useMemo } from "react";
import { lastRunThreadIdAtom, RunStatus, runStatusAtom } from "./ctx";
import {
    filterAnnotation,
    makeAnnotation,
    makeThread,
    prependMessageToThreadQuery,
    updateFirstMessageOfThreadQuery,
} from "./services";
import { Thread } from "./types";

const graphqlAiRepository = new GraphQLAiRepository();
const restAiRepository = new RestAiRepository();

export const askAiKeys = {
    all: ["askAi"] as const,
    createThread: () => [...askAiKeys.all, "createThread"] as const,
    threads: () => [...askAiKeys.all, "threads"] as const,
    thread: (id?: string) => [...askAiKeys.all, "thread", id] as const,
    annotationDocument: (openAiFileId?: string) =>
        [...askAiKeys.all, "annotationDocument", openAiFileId] as const,
    annotationsMetadata: (openAiFileIds?: string[]) =>
        [...askAiKeys.all, "annotationsMetadata", openAiFileIds] as const,
};

export function useThreads({ skip }: { skip?: boolean } = {}) {
    const { data, ...query } = useQuery({
        queryKey: askAiKeys.threads(),
        queryFn: skip ? skipToken : graphqlAiRepository.getUserThreads,
    });

    const threads = useMemo(
        () => (data?.aiThreads ?? []).map(makeThread),
        [data],
    );

    return { threads, ...query };
}

export function useThread(id?: string) {
    const { data, ...query } = useQuery({
        queryKey: askAiKeys.thread(id),
        queryFn: id ? () => graphqlAiRepository.getThread(id) : skipToken,
    });

    const thread = useMemo(() => {
        if (!data) return undefined;
        return makeThread(data.aiThread);
    }, [data]);

    return { thread, ...query };
}

export function useCreateThread() {
    const { mutateAsync, ...mutation } = useMutation({
        mutationKey: askAiKeys.createThread(),
        mutationFn: graphqlAiRepository.createThread,
        onSettled: () => {
            queryClient.invalidateQueries({ queryKey: askAiKeys.threads() });
        },
    });

    return { createThreadAsync: mutateAsync, ...mutation };
}

export function useAsyncCreateEmptyThread() {
    const { mutateAsync, ...mutation } = useMutation({
        mutationFn: async () => {
            const { createAiThread } = await graphqlAiRepository.createThread();

            // Cancel any outgoing refetches for all queries that depend on the threads
            await queryClient.cancelQueries({
                queryKey: askAiKeys.threads(),
            });

            const [threads, thread] = await Promise.all([
                graphqlAiRepository.getUserThreads(),
                graphqlAiRepository.getThread(createAiThread.id),
            ]);

            queryClient.setQueryData(askAiKeys.threads(), threads);
            queryClient.setQueryData(
                askAiKeys.thread(createAiThread.id),
                thread,
            );

            return createAiThread;
        },
    });

    return { createThreadAsync: mutateAsync, ...mutation };
}

export function useSendPrompt() {
    const [runStatus, setRunStatus] = useAtom(runStatusAtom);
    const [lastRunThreadId, setLastRunThreadId] = useAtom(lastRunThreadIdAtom);

    const { mutate: sendPrompt, ...mutation } = useMutation({
        mutationFn: ({
            prompt,
            thread,
        }: {
            prompt: string;
            thread: Thread;
        }) => {
            return restAiRepository.sendMessage(prompt, thread.openAiThreadId);
        },
        onMutate: async ({ prompt, thread }) => {
            setRunStatus(RunStatus.Streaming);
            setLastRunThreadId(thread.id);
            await updateCache(
                { role: AiThreadMessageRole.User, text: prompt },
                () => askAiKeys.thread(thread.id),
                prependMessageToThreadQuery,
            );
        },
        onSuccess: async (stream, { thread }) => {
            const handleTextDelta = async ({
                value,
                annotations,
            }: TextDelta) => {
                if (value == null) return;
                const formattedAnnotations =
                    annotations
                        ?.map(makeAnnotation)
                        ?.filter(filterAnnotation) ?? [];

                await updateCache(
                    { text: value, annotations: formattedAnnotations },
                    () => askAiKeys.thread(thread.id),
                    updateFirstMessageOfThreadQuery,
                );
                if (value.length > 0) setRunStatus(RunStatus.Writing);
            };

            const handleTextCreated = () => {
                setRunStatus(RunStatus.Writing);
                updateCache(
                    { role: AiThreadMessageRole.Assistant, text: "" },
                    () => askAiKeys.thread(thread.id),
                    prependMessageToThreadQuery,
                );
            };

            const handleMessageDone = () => {
                setRunStatus(undefined);
            };

            const handleRunCompleted = () => {
                setRunStatus(undefined);
                queryClient.invalidateQueries({
                    queryKey: askAiKeys.thread(thread.id),
                });
            };

            const handleToolCallCreated = async (toolCall: ToolCall) => {
                if (toolCall.type === "file_search") {
                    setRunStatus(RunStatus.SearchingFiles);
                }
            };

            const handleRunStepCreated = async (runStep: RunStep) => {
                /**
                 * Small hack to display to the user that we are searching in his files.
                 * To be more precise, we should wait for the toolCallCreated event that includes the tool call type.
                 * In the case of file search, this event only appears a few ms before the answer starts streaming,
                 * so we would not display anything if we waited.
                 */
                if (runStep.type === "tool_calls") {
                    setRunStatus(RunStatus.SearchingFiles);
                }
            };
            const handleRunFailed = () => {
                setRunStatus(RunStatus.Failed);
            };

            /**
             * TODO: I am executing some event handlers that are actually async;
             * I am counting on the fact that the async functions will always take
             * the same amount of time, but what if they don't? We should come up
             * with a way to queue the event handlers here, so that the new ones
             * await the completion of the previous ones.
             */
            const handleReadableStream = (stream: AssistantStream) => {
                setRunStatus(RunStatus.Streaming);
                // Create an empty message from the assistant
                stream.on("textCreated", handleTextCreated);

                // Populate the message with the text as it comes
                stream.on("textDelta", handleTextDelta);

                // When Cleo uses a tool call, display an appropriate loading message
                stream.on("toolCallCreated", handleToolCallCreated);
                stream.on("runStepCreated", handleRunStepCreated);

                stream.on("messageDone", handleMessageDone);

                // When the run is completed, refetch the thread
                stream.on("event", ({ event }) => {
                    if (event === "thread.run.completed") handleRunCompleted();
                    if (event === "thread.run.failed") handleRunFailed();
                });
                stream.on("abort", handleRunFailed);
                stream.on("error", handleRunFailed);
            };

            handleReadableStream(stream);
        },
    });

    return { sendPrompt, runStatus, lastRunThreadId, ...mutation };
}
