from typing import List, Tuple, Union, Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( DeepseekV3MoE, DeepseekV3MLP, DeepseekV3Attention, DeepseekV3RMSNorm, DeepseekV3DecoderLayer, DeepseekV3Model, DeepseekV3ForCausalLM, apply_rotary_pos_emb, apply_rotary_pos_emb_interleave, ) def _extract_kv_cache(cache: Optional[Cache], layer_idx: int): if cache is None: return None, None key_cache = getattr(cache, "key_cache", None) value_cache = getattr(cache, "value_cache", None) if key_cache is not None: assert layer_idx < len(key_cache) assert len(key_cache) == len(value_cache) key_states = key_cache[layer_idx] value_states = value_cache[layer_idx] return key_states, value_states layers = getattr(cache, "layers", None) if layers is not None: layer_cache = layers[layer_idx] key_states = getattr(layer_cache, "keys", None) value_states = getattr(layer_cache, "values", None) return key_states, value_states return None, None class GFusionAttention(DeepseekV3Attention): def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, update_dllm_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.qk_head_dim).transpose(1, 2) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split( kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 ) prefix_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0 kv_seq_len = prefix_len + value_states.shape[-2] cos, sin = position_embeddings if getattr(self.config, "rope_interleave", False): q_pe, k_pe = apply_rotary_pos_emb_interleave(q_pe, k_pe, cos, sin) else: q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.qk_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.qk_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: if update_dllm_cache: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": kwargs.get("cache_position")} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) else: key_prefix, value_prefix = _extract_kv_cache(past_key_value, self.layer_idx) if key_prefix is not None: key_states = torch.cat([key_prefix, key_states], dim=2) value_states = torch.cat([value_prefix, value_states], dim=2) attn_weights = ( torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling ) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class GFusionDecoderLayer(DeepseekV3DecoderLayer): def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__(config, layer_idx) self.hidden_size = config.hidden_size self.self_attn = GFusionAttention(config=config, layer_idx=layer_idx) if layer_idx >= config.first_k_dense_replace: self.mlp = DeepseekV3MoE(config) else: self.mlp = DeepseekV3MLP(config) self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC update_dllm_cache: bool = False, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, position_embeddings=position_embeddings, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, update_dllm_cache=update_dllm_cache, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class GFusionModel(DeepseekV3Model): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GFusionDecoderLayer`] Args: config: DeepseekV3Config """ def __init__(self, config: DeepseekV3Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ GFusionDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, update_dllm_cache: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("Both input_ids and inputs_embeds are specified at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") use_legacy_cache = past_key_values is not None and not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) if use_cache and past_key_values is None: past_key_values = DynamicCache() past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # embed positions hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, update_dllm_cache=update_dllm_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache ) if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class GFusionForDiffusionLM(DeepseekV3ForCausalLM): def __init__(self, config: DeepseekV3Config): super().__init__(config) self.model = GFusionModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, update_dllm_cache: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer >>> model = GFusionForDiffusionLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, update_dllm_cache=update_dllm_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Same label postition tokens shift_logits = logits shift_labels = labels # Flatten the tokens loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @staticmethod def _top_k_logits(logits: torch.Tensor, k: Optional[int]): if k is None or k <= 0: return logits values, _ = torch.topk(logits, k) min_values = values[..., -1, None] return torch.where( logits < min_values, torch.full_like(logits, float("-inf")), logits ) @staticmethod def _top_p_logits(logits, p): if p is None or p >= 1.0: return logits sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_mask = cumulative_probs > p sorted_mask[..., 1:] = sorted_mask[..., :-1].clone() sorted_mask[..., 0] = False mask_indices = torch.scatter( torch.zeros_like(logits, dtype=torch.bool), dim=-1, index=sorted_indices, src=sorted_mask, ) return logits.masked_fill(mask_indices, float("-inf")) @staticmethod def _sample_tokens( logits: torch.Tensor, do_sample: bool, temperature: float, top_k: Optional[int], top_p: Optional[float], ) -> torch.Tensor: if not do_sample: return torch.argmax(logits, dim=-1) vocab_size = logits.shape[-1] flat_logits = logits.reshape(-1, vocab_size) if temperature is not None and temperature > 0 and temperature != 1.0: flat_logits = flat_logits / temperature flat_logits = GFusionForDiffusionLM._top_k_logits(flat_logits, top_k) flat_logits = GFusionForDiffusionLM._top_p_logits(flat_logits, top_p) probs = F.softmax(flat_logits, dim=-1) sampled = torch.multinomial(probs, num_samples=1) return sampled.view(logits.shape[:-1]) @staticmethod def _entropy_bounded_step( block_input_ids: torch.Tensor, block_logits: torch.Tensor, do_sample: bool, temperature: float, top_p: Optional[float], top_k: Optional[int], gamma: float, mask_id: int ): next_block = block_input_ids active_block_mask = block_input_ids[0] == mask_id masked_positions = torch.nonzero(active_block_mask, as_tuple=False).flatten() if active_block_mask.sum() == 0: return next_block if 0 <= mask_id < block_logits.shape[-1]: block_logits = block_logits.clone() block_logits[..., mask_id] = torch.finfo(block_logits.dtype).min candidate_tokens = GFusionForDiffusionLM._sample_tokens( block_logits, do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, ) masked_logits = block_logits[0, masked_positions] masked_log_probs = F.log_softmax(masked_logits, dim=-1) masked_probs = masked_log_probs.exp() masked_entropies = -(masked_probs * masked_log_probs).sum(dim=-1) sort_index = torch.argsort(masked_entropies, dim=0) sorted_positions = masked_positions[sort_index] sorted_entropies = masked_entropies[sort_index] unmasked_cnt = 1 entropy_budget = 0.0 for entropy in sorted_entropies[:-1]: entropy_budget += float(entropy.item()) if entropy_budget <= gamma: unmasked_cnt += 1 else: break transfer_positions = sorted_positions[:unmasked_cnt] next_block[0, transfer_positions] = candidate_tokens[0, transfer_positions] return next_block @torch.no_grad() def generate( self, input_ids: torch.Tensor, block_size: int = 32, denoise_steps: int = 32, max_new_tokens: int = 2048, do_sample: bool = False, temperature: float = 0.0, top_p: Optional[float] = None, top_k: Optional[int] = None, eos_early_stop: bool = True, gamma: float = 0.15, eos_id: int = 2, mask_id: int = 128170, **kwargs, ): """Generate continuation tokens with block diffusion decoding. Args: input_ids: Prompt token ids. Only batch size 1 is supported. block_size: Number of tokens denoised as one diffusion block. denoise_steps: Maximum denoising iterations per block. gamma: Entropy budget controlling how many tokens are revealed each step. mask_id: Token id used for masked positions. eos_id: Token id used for early stopping and final truncation. Returns: Generated token ids only, without the prompt prefix. """ input_ids = input_ids.to(self.device) prompt_length = input_ids.shape[1] num_blocks = (prompt_length + max_new_tokens + block_size - 1) // block_size total_tokens = num_blocks * block_size position_ids = torch.arange(total_tokens, device=self.device).unsqueeze(0) x = torch.full((1, total_tokens), mask_id, device=self.device, dtype=torch.long) x[:, :prompt_length] = input_ids.clone() num_prefill_blocks = prompt_length // block_size past_key_values = None # prefill stage for block_idx in range(num_prefill_blocks): block_start = block_idx * block_size block_end = block_start + block_size cur_x = x[:, block_start:block_end] cur_position_ids = position_ids[:, block_start:block_end] outputs = self.forward( input_ids=cur_x, position_ids=cur_position_ids, use_cache=True, update_dllm_cache=True, past_key_values=past_key_values, return_dict=True, ) past_key_values = outputs.past_key_values # decoding stage for block_idx in range(num_prefill_blocks, num_blocks): block_start = block_idx * block_size block_end = block_start + block_size cur_x = x[:, block_start:block_end] cur_position_ids = position_ids[:, block_start:block_end] for _ in range(denoise_steps): if not (cur_x == mask_id).any(): break outputs = self.forward( input_ids=cur_x, position_ids=cur_position_ids, use_cache=False, update_dllm_cache=False, past_key_values=past_key_values, return_dict=True, ) cur_x = GFusionForDiffusionLM._entropy_bounded_step( block_input_ids=cur_x, block_logits=outputs.logits, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, gamma=gamma, mask_id=mask_id, ) if eos_early_stop: generated = x[0, prompt_length:block_end] eos_offsets = torch.nonzero(generated == eos_id, as_tuple=False).flatten() if eos_offsets.numel() > 0: end = prompt_length + int(eos_offsets[0].item()) + 1 if not (x[0, prompt_length:end] == mask_id).any(): return x[:, prompt_length:end] outputs = self.forward( input_ids=cur_x, position_ids=cur_position_ids, use_cache=True, update_dllm_cache=True, past_key_values=past_key_values, return_dict=True, ) past_key_values = outputs.past_key_values if eos_id is not None and (x[0, prompt_length:block_end] == eos_id).any(): break # parse response generated = x[:, prompt_length : prompt_length + max_new_tokens] if eos_id is not None: eos_positions = (generated[0] == eos_id).nonzero(as_tuple=True)[0] if len(eos_positions) > 0: generated = generated[:, : eos_positions[0].item() + 1] mask_positions = (generated[0] == mask_id).nonzero(as_tuple=True)[0] if len(mask_positions) > 0: generated = generated[:, : mask_positions[0].item()] return generated