medium primitives

Vision-Language Fusion

Why this matters

Multimodal models — CLIP, Flamingo, BLIP, GPT-4o — combine representations from at least two modalities (vision, language, audio). The simplest, most pervasive trick they share is early fusion: project each modality into a shared space, mix them with a non-linearity, and let downstream layers do the heavy lifting.

Before joint Transformers became standard for VLMs, simple fusion modules like this were the ENTIRE multimodal head — used in VQA (visual question answering), image captioning, retrieval. They’re still the building block at the boundary where modalities first meet.

The recipe

Two 1-D inputs (e.g., a CLIP image embedding and a text encoder output):

image_features ∈ R^{D_v}
text_features  ∈ R^{D_t}

The fusion module:

  1. Project each modality to a shared hidden dim: v = Dense_v(image_features)(hidden,), t = Dense_t(text_features)(hidden,).
  2. Mix with element-wise add and a non-linearity: fused = tanh(v + t).
  3. Reproject with one more Dense(hidden) so the mixed representation has a chance to recombine its features.

Element-wise add is the simplest fusion. More sophisticated methods (concatenation, gating, bilinear pooling, cross-attention) extend this; they’re all extensions of “project to shared space, combine.”

Why tanh, not relu?

tanh outputs to (-1, 1) — it lets a “negative agreement” between modalities propagate (v says +0.5, t says -0.5, sum is 0; v says -0.5, t says -0.5, sum is -1, fused has a strong negative signal). relu would zero out anything negative — fine, but loses sign information at the fusion point. Most early-fusion modules in the literature use tanh.

Worked walk-through

image_features shape (6,), text_features shape (4,), hidden = 4:

  1. v = Dense(hidden=4)(image_features)(4,).
  2. t = Dense(hidden=4)(text_features)(4,).
  3. fused = tanh(v + t)(4,).
  4. out = Dense(hidden=4)(fused)(4,).

Note that D_v and D_t can differ; the projections handle the width mismatch. After fusion, downstream layers see a uniform hidden-dim vector.

Common pitfalls

  • Sharing one Dense for both modalities: each needs its OWN projection. The point of early fusion is to learn modality- specific maps into the shared space.
  • No non-linearity at fusion: out = Dense(v + t) is just a linear map of (v, t). Without tanh (or some non-linear activation), the fusion can be folded into a single linear layer at training time — you’ve lost capacity.
  • Forgetting the final Dense(hidden): the fused vector goes BACK through one more learned projection. This is the “reproject after mixing” step that lets the network shape the fused representation.

Problem

Implement vision_language_fusion_forward(seed, image_features, text_features, hidden):

  1. VLFusion(nn.Module) with hidden field.
  2. Inside @nn.compact:
    • v = nn.Dense(hidden)(image_features),
    • t = nn.Dense(hidden)(text_features),
    • fused = jnp.tanh(v + t),
    • out = nn.Dense(hidden)(fused).
  3. Return out flattened.

Inputs:

  • seed: int.
  • image_features: 1-D (D_v,).
  • text_features: 1-D (D_t,).
  • hidden: int.

Output: 1-D, length hidden.

Hints

flax multimodal fusion vlm

Sign in to attempt this problem and view the solution.