Skip to content

vllm.model_executor.models.qwen3_guard

Inference-only Qwen3 Guard model compatible with HuggingFace weights.

logger module-attribute

logger = init_logger(__name__)

Qwen3ForGuardModel

Bases: Module, SupportsPP

Source code in vllm/model_executor/models/qwen3_guard.py
@default_pooling_type("ALL")
class Qwen3ForGuardModel(nn.Module, SupportsPP):

    if envs.VLLM_USE_V1:
        is_pooling_model = True

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.lora_config = lora_config

        self.quant_config = quant_config
        self.model = Qwen3Model(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))

        self.risk_level_category_pre = nn.Linear(config.hidden_size,
                                                 config.guard_inner_size,
                                                 bias=False)
        self.risk_level_category_layernorm = RMSNorm(config.guard_inner_size,
                                                     eps=config.rms_norm_eps)
        self.risk_level_head = nn.Linear(config.guard_inner_size,
                                         config.num_risk_level,
                                         bias=False)
        self.category_head = nn.Linear(config.guard_inner_size,
                                       config.num_category,
                                       bias=False)

        self.query_risk_level_category_pre = nn.Linear(config.hidden_size,
                                                       config.guard_inner_size,
                                                       bias=False)
        self.query_risk_level_category_layernorm = RMSNorm(
            config.guard_inner_size, eps=config.rms_norm_eps)
        self.query_risk_level_head = nn.Linear(config.guard_inner_size,
                                               config.num_query_risk_level,
                                               bias=False)
        self.query_category_head = nn.Linear(config.guard_inner_size,
                                             config.num_query_category,
                                             bias=False)

        if get_pp_group().is_last_rank:
            if config.tie_word_embeddings:
                self.lm_head = self.model.embed_tokens
            else:
                self.lm_head = ParallelLMHead(config.vocab_size,
                                              config.hidden_size,
                                              quant_config=quant_config,
                                              prefix=maybe_prefix(
                                                  prefix, "lm_head"))
        else:
            self.lm_head = PPMissingLayer()

        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

        self.pooler = DispatchPooler({
            "encode":
            Pooler.for_encode(
                PoolerConfig(
                    pooling_type="ALL",
                    normalize=False,
                    dimensions=None,
                    enable_chunked_processing=True,
                    activation=False,
                    softmax=False,
                )),
        })

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)

        hidden_states = hidden_states[:, None, :]

        risk_level_category_x = self.risk_level_category_pre(hidden_states)
        risk_level_category_x = self.risk_level_category_layernorm(
            risk_level_category_x)
        risk_level_logits = self.risk_level_head(risk_level_category_x)
        category_logits = self.category_head(risk_level_category_x)

        query_risk_level_category_x = self.query_risk_level_category_pre(
            hidden_states)
        query_risk_level_category_x = self.query_risk_level_category_layernorm(
            query_risk_level_category_x)
        query_risk_level_logits = self.query_risk_level_head(
            query_risk_level_category_x)
        query_category_logits = self.query_category_head(
            query_risk_level_category_x)

        return torch.cat([
            risk_level_logits, category_logits, query_risk_level_logits,
            query_category_logits, hidden_states
        ],
                         dim=-1)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights)

category_head instance-attribute

category_head = Linear(
    guard_inner_size, num_category, bias=False
)

config instance-attribute

config = config

is_pooling_model class-attribute instance-attribute

is_pooling_model = True

lm_head instance-attribute

lm_head = embed_tokens

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

lora_config instance-attribute

lora_config = lora_config

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = Qwen3Model(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv_proj": ["q_proj", "k_proj", "v_proj"],
    "gate_up_proj": ["gate_proj", "up_proj"],
}

pooler instance-attribute

pooler = DispatchPooler(
    {
        "encode": for_encode(
            PoolerConfig(
                pooling_type="ALL",
                normalize=False,
                dimensions=None,
                enable_chunked_processing=True,
                activation=False,
                softmax=False,
            )
        )
    }
)

quant_config instance-attribute

quant_config = quant_config

query_category_head instance-attribute

query_category_head = Linear(
    guard_inner_size, num_query_category, bias=False
)

query_risk_level_category_layernorm instance-attribute

query_risk_level_category_layernorm = RMSNorm(
    guard_inner_size, eps=rms_norm_eps
)

query_risk_level_category_pre instance-attribute

query_risk_level_category_pre = Linear(
    hidden_size, guard_inner_size, bias=False
)

query_risk_level_head instance-attribute

query_risk_level_head = Linear(
    guard_inner_size, num_query_risk_level, bias=False
)

risk_level_category_layernorm instance-attribute

risk_level_category_layernorm = RMSNorm(
    guard_inner_size, eps=rms_norm_eps
)

risk_level_category_pre instance-attribute

risk_level_category_pre = Linear(
    hidden_size, guard_inner_size, bias=False
)

risk_level_head instance-attribute

risk_level_head = Linear(
    guard_inner_size, num_risk_level, bias=False
)

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/qwen3_guard.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    lora_config = vllm_config.lora_config

    self.config = config
    self.lora_config = lora_config

    self.quant_config = quant_config
    self.model = Qwen3Model(vllm_config=vllm_config,
                            prefix=maybe_prefix(prefix, "model"))

    self.risk_level_category_pre = nn.Linear(config.hidden_size,
                                             config.guard_inner_size,
                                             bias=False)
    self.risk_level_category_layernorm = RMSNorm(config.guard_inner_size,
                                                 eps=config.rms_norm_eps)
    self.risk_level_head = nn.Linear(config.guard_inner_size,
                                     config.num_risk_level,
                                     bias=False)
    self.category_head = nn.Linear(config.guard_inner_size,
                                   config.num_category,
                                   bias=False)

    self.query_risk_level_category_pre = nn.Linear(config.hidden_size,
                                                   config.guard_inner_size,
                                                   bias=False)
    self.query_risk_level_category_layernorm = RMSNorm(
        config.guard_inner_size, eps=config.rms_norm_eps)
    self.query_risk_level_head = nn.Linear(config.guard_inner_size,
                                           config.num_query_risk_level,
                                           bias=False)
    self.query_category_head = nn.Linear(config.guard_inner_size,
                                         config.num_query_category,
                                         bias=False)

    if get_pp_group().is_last_rank:
        if config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(config.vocab_size,
                                          config.hidden_size,
                                          quant_config=quant_config,
                                          prefix=maybe_prefix(
                                              prefix, "lm_head"))
    else:
        self.lm_head = PPMissingLayer()

    self.logits_processor = LogitsProcessor(config.vocab_size)

    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors)

    self.pooler = DispatchPooler({
        "encode":
        Pooler.for_encode(
            PoolerConfig(
                pooling_type="ALL",
                normalize=False,
                dimensions=None,
                enable_chunked_processing=True,
                activation=False,
                softmax=False,
            )),
    })

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/qwen3_guard.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    hidden_states = self.model(input_ids, positions, intermediate_tensors,
                               inputs_embeds)

    hidden_states = hidden_states[:, None, :]

    risk_level_category_x = self.risk_level_category_pre(hidden_states)
    risk_level_category_x = self.risk_level_category_layernorm(
        risk_level_category_x)
    risk_level_logits = self.risk_level_head(risk_level_category_x)
    category_logits = self.category_head(risk_level_category_x)

    query_risk_level_category_x = self.query_risk_level_category_pre(
        hidden_states)
    query_risk_level_category_x = self.query_risk_level_category_layernorm(
        query_risk_level_category_x)
    query_risk_level_logits = self.query_risk_level_head(
        query_risk_level_category_x)
    query_category_logits = self.query_category_head(
        query_risk_level_category_x)

    return torch.cat([
        risk_level_logits, category_logits, query_risk_level_logits,
        query_category_logits, hidden_states
    ],
                     dim=-1)

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/qwen3_guard.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/qwen3_guard.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(
        self,
        skip_prefixes=(["lm_head."]
                       if self.config.tie_word_embeddings else None),
    )
    return loader.load_weights(weights)