import { createSlice, PayloadAction } from '@reduxjs/toolkit'
import { ClientAgentReference } from './schemas'

// Defines a unique instance of a citation_id and its highlight count
// (a single citation_id can have multiple unique sets of highlights)
export type CitationIdHighlightCount = { [citation_id: string]: number }

// Records
export type AgentReferencesRecord = Partial<{ [citation_id: string]: ClientAgentReference }>
export type AgentTurnUniqueCitationAndHighlight = Partial<{ [turn_id: string]: CitationIdHighlightCount }>

export type AgentReferencesState = {
  references: AgentReferencesRecord
  turnUniqueCitationAndHighlightCount: AgentTurnUniqueCitationAndHighlight
}

const initialState: AgentReferencesState = {
  references: {},
  turnUniqueCitationAndHighlightCount: {},
}

/**
 * Agent References Slice
 * Holds agent references representing communication between the user and the agent.
 */
export const agentReferencesSlice = createSlice({
  name: 'agentReferencesState',
  initialState,
  reducers: {
    // ============== Slice Actions ============== >
    nullifyData: () => initialState,

    // Upsert a list of agent references and populate turn unique citation ids
    upsertReferences: (state, action: PayloadAction<{ references: ClientAgentReference[] }>) => {
      const { references } = action.payload

      // Create the new references state - start by cloning the existing references
      const newReferenceState: AgentReferencesRecord = { ...state.references }

      // Upsert the references
      references.forEach((reference) => {
        newReferenceState[reference.reference.citation_id] = reference
      })

      // Identify the unique turns in the references
      const turnIds = Array.from(new Set(references.map((ref) => ref.reference.turn_id)))

      // For each turn id, process the references into new structured data
      turnIds.forEach((turnId) => {
        // Get the unique Citation Ids for this turn
        const turnCitationIds = new Set<string>()
        const turnReferences = Object.values(newReferenceState).filter((ref) => !!ref && ref.reference.turn_id === turnId)

        // Populate the unique citation ids
        turnReferences.forEach((ref) => {
          if (ref) {
            turnCitationIds.add(ref.reference.citation_id)
          }

          // For each unique citation id, populate the highlight count
          const citationIdHighlightCount: CitationIdHighlightCount = {}
          turnCitationIds.forEach((citationId) => {
            const refHighlightCount = newReferenceState[citationId]?.highlighted_citations.length || 0
            citationIdHighlightCount[citationId] = refHighlightCount
          })

          // Upsert the turn unique citation ids
          const existingTurnUniqueCitationHighlights = state.turnUniqueCitationAndHighlightCount[turnId] || {}
          state.turnUniqueCitationAndHighlightCount[turnId] = { ...existingTurnUniqueCitationHighlights, ...citationIdHighlightCount }
        })
      })

      // Update the state
      state.references = newReferenceState
    },
  },
})

// Actions
export const AgentReferencesActions = agentReferencesSlice.actions

export default agentReferencesSlice.reducer
