/* eslint-disable @typescript-eslint/no-use-before-define */
import { firstValueFrom, lastValueFrom, Observable, Observer, Subject, Subscriber, timer } from 'rxjs';
import { filter, shareReplay, startWith, switchMap, takeUntil, throwIfEmpty } from 'rxjs/operators';
import { v4 as guid } from 'uuid';
import { Log } from '@capital-access/common/logging';
import { Cancellation, fromAsync } from './utils/async-rx.utils';

export interface OidcOrLambdaHeader {
  host: string;
  Authorization: string;
}
export interface WebsocketAuthProvider {
  (host: string, payload: string): Observable<OidcOrLambdaHeader>;
}

interface WebSocketEventTemplate<TType extends string, TPayload> {
  id?: string;
  type: TType;
  payload?: TPayload;
}

type AcknowledgedWebSocketEvent = WebSocketEventTemplate<'connection_ack', { connectionTimeoutMs: number }>;
type ErrorWebSocketEvent = WebSocketEventTemplate<
  'error' | 'connection_error',
  { errors: { errorType?: string; message: string }[] }
>;

export type WebSocketEvent =
  | AcknowledgedWebSocketEvent
  | ErrorWebSocketEvent
  | WebSocketEventTemplate<'start' | 'data', unknown>
  | WebSocketEventTemplate<'connection_init' | 'ka' | 'start_ack' | 'complete' | 'stop', never>;

export const CONNECT_PAYLOAD = '{}';

/**
 * AppSync uses specific graphql-ws protocol for websocket.
 * Thankfully it's well described here
 * DOCS: https://docs.aws.amazon.com/appsync/latest/devguide/real-time-websocket-client.html
 */
export class AppsyncWebsocket {
  private _socket$: Observable<RxWebSocket>;

  constructor(private host: string, private realtimeHost: string, private getAuthHeader: WebsocketAuthProvider) {
    this._socket$ = fromAsync<RxWebSocket>(async (subscriber, cancellation) => {
      const socket = await this._open(cancellation);
      cancellation.subscribe(() => socket.close());

      await this._connectHandshake(socket, cancellation);

      subscriber.next(socket);
    }, Cancellation.timeout(15_000)).pipe(shareReplay({ bufferSize: 1, refCount: true }));
  }

  public subscribe<TData = unknown, TVariables = unknown>(query: string, variables?: TVariables): Observable<TData> {
    return fromAsync<TData>(async (subscriber, cancellation) => {
      const subscriptionId = guid();
      const socketCancellation = new Cancellation();
      const socket = await this._requestSocket(socketCancellation);

      const subscriptionObserver = new AppsyncSubscriptionObserver(subscriptionId, socket, subscriber);
      const subscription = socket.messages$.subscribe(subscriptionObserver);

      cancellation.subscribe(async () => {
        try {
          await subscriptionObserver.cancelAsync();
        } finally {
          subscription.unsubscribe();
          socketCancellation.cancel();
        }
      });

      const subscriptionPayload = JSON.stringify({ query, variables });

      const authHeader = await this._getAuthHeader(subscriptionPayload, cancellation);
      socket.send({
        id: subscriptionId,
        type: 'start',
        payload: {
          extensions: {
            authorization: authHeader
          },
          data: subscriptionPayload
        }
      });
    });
  }
  private async _open(cancellation: Cancellation) {
    const authHeader = await this._getAuthHeader(CONNECT_PAYLOAD, cancellation);

    return await RxWebSocket.openAsync(
      `wss://${this.realtimeHost}?header=${btoa(JSON.stringify(authHeader))}&payload=${btoa(CONNECT_PAYLOAD)}`,
      cancellation
    );
  }

  private async _connectHandshake(ws: RxWebSocket, cancellation: Cancellation) {
    ws.send({ type: 'connection_init' });

    const connectionResponse = await ws.nextAsync(
      (msg): msg is AcknowledgedWebSocketEvent | ErrorWebSocketEvent =>
        msg.type === 'connection_ack' || msg.type === 'connection_error',
      cancellation
    );

    if (connectionResponse.type === 'connection_ack') {
      this._trackConnectionTimeout(ws, connectionResponse.payload?.connectionTimeoutMs || 300_000, cancellation);
      return;
    }

    throw connectionResponse.payload?.errors || new Error('connection error');
  }

  private _trackConnectionTimeout(ws: RxWebSocket, timeoutMs: number, cancellation: Cancellation) {
    ws.messages$
      .pipe(
        filter(msg => msg.type === 'ka'),
        startWith({ type: 'ka' }), // trigger initial timeout countdown
        switchMap(_ => timer(timeoutMs)),
        takeUntil(cancellation.observable)
      )
      .subscribe({
        // eslint-disable-next-line @typescript-eslint/no-empty-function
        error: () => {},
        next: () => ws.messages$.error(new Error('connection timeout'))
      });
  }

  private _requestSocket(cancellation: Cancellation): Promise<RxWebSocket> {
    return new Promise((resolve, reject) => {
      const subscription = this._socket$.subscribe({
        next: resolve,
        error: reject,
        complete: () => reject(new Error('failed to get socket'))
      });
      // unsubscribe on cancellation (user unsubscribed) to decrease ref count and close the connection if 0
      cancellation.subscribe(() => subscription.unsubscribe());
    });
  }

  private _getAuthHeader(payload: string, cancellation: Cancellation): Promise<unknown> {
    return firstValueFrom(this.getAuthHeader(this.host, payload).pipe(takeUntil(cancellation.observable)));
  }
}

class AppsyncSubscriptionObserver<TData> implements Observer<WebSocketEvent> {
  private _started = false;
  private _completed = false;

  constructor(private subscriptionId: string, private socket: RxWebSocket, private subscriber: Subscriber<TData>) {}

  next(msg: WebSocketEvent) {
    if (msg.id !== this.subscriptionId) return;

    switch (msg.type) {
      case 'start_ack':
        this._started = true;
        return;
      case 'data':
        try {
          this.subscriber.next((msg.payload! as { data: TData }).data);
        } catch (err) {
          Log.error({
            errorMessage: 'failed to retrieve websocket message payload. Message is skipped',
            receivedWsMessage: msg
          });
        }
        return;
      case 'error':
        this.error(msg.payload?.errors || []);
        return;
      case 'complete':
        this._completed = true;
        this.subscriber.complete();
        return;
      default:
        return;
    }
  }

  error(err: unknown) {
    this._completed = true;
    this.subscriber.error(err);
  }

  complete() {
    if (!this._completed) this.error(new Error('connection closed before subscription is completed'));
  }

  async cancelAsync(): Promise<void> {
    try {
      if (!this._completed) {
        if (!this._started) {
          await this.socket.nextAsync(
            msg => msg.id === this.subscriptionId && msg.type === 'start_ack',
            Cancellation.timeout(1_000)
          );
        }
        this.socket.send({ type: 'stop', id: this.subscriptionId });
      }
    } catch (err) {
      if (!this.subscriber.closed) this.subscriber.error(err);
    }
  }
}

class RxWebSocket {
  public messages$ = new Subject<WebSocketEvent>();

  private _ws: WebSocket;

  private constructor(url: string) {
    this._ws = new WebSocket(url, 'graphql-ws');
  }

  private open(cancellation: Cancellation) {
    Log.debug('Appsync Websocket connection - opening');
    const openedSubject = new Subject<void>();
    this._ws.onopen = () => {
      openedSubject.next();
      openedSubject.complete();
    };
    this._ws.onmessage = msg => this.messages$.next(JSON.parse(msg.data));
    this._ws.onclose = () => this.messages$.complete();
    this._ws.onerror = err => this.messages$.error(err);

    return lastValueFrom(
      openedSubject.pipe(
        takeUntil(cancellation.observable),
        throwIfEmpty(() => new Error('Websocket open timeout'))
      )
    );
  }

  public nextAsync<T extends WebSocketEvent>(
    filterFn: (msg: WebSocketEvent) => msg is T,
    cancellation: Cancellation
  ): Promise<T>;
  public nextAsync(filterFn: (msg: WebSocketEvent) => boolean, cancellation: Cancellation): Promise<WebSocketEvent>;
  public nextAsync(filterFn: (msg: WebSocketEvent) => boolean, cancellation: Cancellation): Promise<WebSocketEvent> {
    return firstValueFrom(
      this.messages$.pipe(
        filter(filterFn),
        takeUntil(cancellation.observable),
        throwIfEmpty(() =>
          cancellation.cancelled
            ? new Error('Received cancellation request while waiting for message from the server')
            : new Error('Connection closed while waiting for message from server')
        )
      )
    );
  }

  public send(e: WebSocketEvent) {
    this._ws.send(JSON.stringify(e));
  }

  public close() {
    Log.debug('Appsync Websocket connection - closing');
    this._ws.close();
  }

  public static async openAsync(url: string, cancellation: Cancellation) {
    const socket = new RxWebSocket(url);
    await socket.open(cancellation);
    return socket;
  }
}
