Skip to content

About text mask #32

@ZeyuLing

Description

@ZeyuLing

Thanks a lot for your paper and code!
In your implementation, you didn't set attention mask for text sequence both in textTransformer layers and LinearTemporalCrossAttention layers, why it didn't cause any influence? Below is the related code.

def encode_text(self, text, device):
with torch.no_grad():
text = clip.tokenize(text, truncate=True).to(device)
x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, latent_dim]
x = x + self.clip.positional_embedding.type(self.clip.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip.transformer(x)
x = self.clip.ln_final(x).type(self.clip.dtype)
# T, B, D
x = self.text_pre_proj(x)
xf_out = self.textTransEncoder(x)
xf_out = self.text_ln(xf_out)
xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])])
# B, T, D
xf_out = xf_out.permute(1, 0, 2)
return xf_proj, xf_out

class LinearTemporalCrossAttention(nn.Module):\

  def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim):
      super().__init__()
      self.num_head = num_head
      self.norm = nn.LayerNorm(latent_dim)
      self.text_norm = nn.LayerNorm(text_latent_dim)
      self.query = nn.Linear(latent_dim, latent_dim)
      self.key = nn.Linear(text_latent_dim, latent_dim)
      self.value = nn.Linear(text_latent_dim, latent_dim)
      self.dropout = nn.Dropout(dropout)
      self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)

  def forward(self, x, xf, emb):
      """
      x: B, T, D
      xf: B, N, L
      """
      B, T, D = x.shape
      N = xf.shape[1]
      H = self.num_head
      # B, T, D
      query = self.query(self.norm(x))
      # B, N, D
      key = self.key(self.text_norm(xf))
      query = F.softmax(query.view(B, T, H, -1), dim=-1)
      key = F.softmax(key.view(B, N, H, -1), dim=1)
      # B, N, H, HD
      value = self.value(self.text_norm(xf)).view(B, N, H, -1)
      # B, H, HD, HD
      attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
      y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
      y = x + self.proj_out(y, emb)
      return y

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions