# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation:
# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py
# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py

from dataclasses import dataclass

import torch
from torch import nn as nn
from torch.nn import functional as F

import comfy.model_management
import comfy.ops
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer

class JinaClip2Tokenizer(sd1_clip.SDTokenizer):
    def __init__(self, embedding_directory=None, tokenizer_data={}):
        tokenizer = tokenizer_data.get("spiece_model", None)
        # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192
        super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data)

    def state_dict(self):
        return {"spiece_model": self.tokenizer.serialize_model()}

class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer):
    def __init__(self, embedding_directory=None, tokenizer_data={}):
        super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2")

# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json
@dataclass
class XLMRobertaConfig:
    vocab_size: int = 250002
    type_vocab_size: int = 1
    hidden_size: int = 1024
    num_hidden_layers: int = 24
    num_attention_heads: int = 16
    rotary_emb_base: float = 20000.0
    intermediate_size: int = 4096
    hidden_act: str = "gelu"
    hidden_dropout_prob: float = 0.1
    attention_probs_dropout_prob: float = 0.1
    layer_norm_eps: float = 1e-05
    bos_token_id: int = 0
    eos_token_id: int = 2
    pad_token_id: int = 1

class XLMRobertaEmbeddings(nn.Module):
    def __init__(self, config, device=None, dtype=None, ops=None):
        super().__init__()
        embed_dim = config.hidden_size
        self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype)
        self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype)

    def forward(self, input_ids=None, embeddings=None):
        if input_ids is not None and embeddings is None:
            embeddings = self.word_embeddings(input_ids)

        if embeddings is not None:
            token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32)
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
            embeddings = embeddings + token_type_embeddings
        return embeddings

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None

    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
        if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
            self._seq_len_cached = seqlen
            t = torch.arange(seqlen, device=device, dtype=torch.float32)
            freqs = torch.outer(t, self.inv_freq.to(device=t.device))
            emb = torch.cat((freqs, freqs), dim=-1)
            self._cos_cached = emb.cos().to(dtype)
            self._sin_cached = emb.sin().to(dtype)

    def forward(self, q, k):
        batch, seqlen, heads, head_dim = q.shape
        self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype)

        cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim)
        sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim)

        def rotate_half(x):
            size = x.shape[-1] // 2
            x1, x2 = x[..., :size], x[..., size:]
            return torch.cat((-x2, x1), dim=-1)

        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

class MHA(nn.Module):
    def __init__(self, config, device=None, dtype=None, ops=None):
        super().__init__()
        embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = embed_dim // config.num_attention_heads

        self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device)
        self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
        self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype)

    def forward(self, x, mask=None, optimized_attention=None):
        qkv = self.Wqkv(x)
        batch_size, seq_len, _ = qkv.shape
        qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(2)

        q, k = self.rotary_emb(q, k)

        # NHD -> HND
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True)
        return self.out_proj(out)

class MLP(nn.Module):
    def __init__(self, config, device=None, dtype=None, ops=None):
        super().__init__()
        self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
        self.activation = F.gelu
        self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x

class Block(nn.Module):
    def __init__(self, config, device=None, dtype=None, ops=None):
        super().__init__()
        self.mixer = MHA(config, device=device, dtype=dtype, ops=ops)
        self.dropout1 = nn.Dropout(config.hidden_dropout_prob)
        self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
        self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
        self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
        self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)

    def forward(self, hidden_states, mask=None, optimized_attention=None):
        mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention)
        hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states)
        mlp_out = self.mlp(hidden_states)
        hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states)
        return hidden_states

class XLMRobertaEncoder(nn.Module):
    def __init__(self, config, device=None, dtype=None, ops=None):
        super().__init__()
        self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask=None):
        optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True)
        for layer in self.layers:
            hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention)
        return hidden_states

class XLMRobertaModel_(nn.Module):
    def __init__(self, config, device=None, dtype=None, ops=None):
        super().__init__()
        self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops)
        self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
        self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
        self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops)

    def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
        x = self.embeddings(input_ids=input_ids, embeddings=embeds)
        x = self.emb_ln(x)
        x = self.emb_drop(x)

        mask = None
        if attention_mask is not None:
            mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1]))
            mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)

        sequence_output = self.encoder(x, attention_mask=mask)

        # Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py
        pooled_output = None
        if attention_mask is None:
            pooled_output = sequence_output.mean(dim=1)
        else:
            attention_mask = attention_mask.to(sequence_output.dtype)
            pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True)

        # Intermediate output is not yet implemented, use None for placeholder
        return sequence_output, None, pooled_output

class XLMRobertaModel(nn.Module):
    def __init__(self, config_dict, dtype, device, operations):
        super().__init__()
        self.config = XLMRobertaConfig(**config_dict)
        self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations)
        self.num_layers = self.config.num_hidden_layers

    def get_input_embeddings(self):
        return self.model.embeddings.word_embeddings

    def set_input_embeddings(self, embeddings):
        self.model.embeddings.word_embeddings = embeddings

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

class JinaClip2TextModel(sd1_clip.SDClipModel):
    def __init__(self, device="cpu", dtype=None, model_options={}):
        super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)

class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel):
    def __init__(self, device="cpu", dtype=None, model_options={}):
        super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options)
