Skip to content

Conversation

@JakobEliasWagner
Copy link
Collaborator

@JakobEliasWagner JakobEliasWagner commented Jul 24, 2024

Feature: Heterogeneous Normalized Attention

Description

This pull request introduces the implementation of the Heterogeneous Normalized Attention mechanism as described in the paper Hao et al., 2023.

The heterogeneous normalized attention block calculates the attention scores in these steps:

  1. normalize the query and key sequence first

$$\tilde{q}_i = Softmax(q_i)$$

$$\tilde{k}_i = Softmax(k_i)$$

  1. calculate the attention score without softmax

$$z_t = \sum_i \frac{\tilde{q}_t \tilde{k}_i}{\sum_j \tilde{q}_t \tilde{k}_j}v_i$$

This implementation is linear with respect to the sequence length.

We added a masking mechanism to the vanilla implementation suggested by Hao et al.

Which issue does this PR tackle?

  • Heterogeneous normalized attention is not implemented.

How does it solve the problem?

  • Implements HeterogeneousNormalizedAttention, a linear attention implementation.
  • Implements masking for HeterogeneousNormalizedAttention.

How are the changes tested?

  • Added 6 unit tests covering: initialization, shape projection, gradient flow, zero inputs, masked forwards, and correctness by masking a known tensor.

Checklist for Contributors

  • Scope: This PR tackles exactly one problem.
  • Conventions: The branch follows the feature/title-slug convention.
  • Conventions: The PR title follows the Bugfix: Title convention.
  • Coding style: The code passes all pre-commit hooks.
  • Documentation: All changes are well-documented.
  • Tests: New features are tested and all tests pass successfully.
  • Changelog: Updated CHANGELOG.md for new features or breaking changes.
  • Review: A suitable reviewer has been assigned.

Checklist for Reviewers:

  • The PR solves the issue it claims to solve and only this one.
  • Changes are tested sufficiently and all tests pass.
  • Documentation is complete and well-written.
  • Changelog has been updated, if necessary.

@JakobEliasWagner JakobEliasWagner marked this pull request as draft July 24, 2024 13:54
@JakobEliasWagner JakobEliasWagner self-assigned this Jul 24, 2024
@JakobEliasWagner JakobEliasWagner added the enhancement New feature or request label Jul 24, 2024
Copy link
Contributor

@samuelburbulla samuelburbulla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, you're the expert

Comment on lines +21 to +23
$$\tilde{q}_i = Softmax(\frac{\exp(q_{i,j})}{\sum_j\exp(q_{i,j})}$$,
$$\tilde{k}_i = Softmax(\frac{\exp(k_{i,j})}{\sum_j\exp(k_{i,j})}$$, and then calculating the attention without
softmax using $$z_t=\sum_i \frac{\tilde{q}_t \cdot \tilde{k}_i}{\sum_j \tilde{q}_t \cdot \tilde{k}_j}\cdot v_i$$.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Math does not render well in docs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants