FSE decoder

Now that the FSE distribution table is available, the FSE table itself must be built. As you saw in the example, those 7 coefficients will be used to build a 32 states table.

✅ In the decoders::fse module, add a FseTable type containing the states (from a State type) of a FSE decoder.

The symbols decoded by the FSE decoder are not necessarily bytes, it is suggested that you use u16 as the symbol type to ensure you will be able to decode all sequences. Tables may contain up to \(2^{20}\) states, but you are allowed to set a limit. No common Zstandard file uses more than \(2^{14}\) states, so using a u16 for the baseline looks reasonable.

Building the FSE decoder tables

A FSE decoder table is built from a distribution of symbols, i.e. the number of states that will decode to each symbol. The states corresponding to a symbol will be roughly evenly spaced in the table, making it possible to branch to a new state to emit any other symbol afterwards after reading only a small number of bits.

For every symbol, the distribution list contains:

  • \(-1\) if the symbol is not frequent at all. In this case, it will be placed at the end of the states table, and will read all AL (accuracy log) bits to get the following state, making all further states reachable.
  • \(0\) if the symbol does not appear in the output. No state will be built for this symbol.
  • \(n > 0\): the symbol will be emitted by \(n\) states spread through the table. Enough bits will be read by every state so that by carefully choosing the selected state for the current symbol it is possible to reach any other state in the table.

✅ Build a FseTable::from_distribution() method building the FSE states table given the accuracy log and the distribution list.

One possible signature is:

impl FseTable {
    pub fn from_distribution(accuracy_log: u8, distribution: &[i16]) -> Self {
        todo!()
    }
}

The steps to build the table are described below, and are also described in the RFC.

Symbols with "less than 1" probability

Some symbols have a "less than 1" probability, represented with the -1 distribution value. They will be emitted by one state only. Since you must be able to jump to any state after emitting such a symbol, its baseline will be 0 and the number of bits to read will be the accuracy log.

✅ After building an empty table, fill it starting from the end with states emitting those symbols with a -1 distribution value (the lowest symbol with a -1 distribution value will be stored last in the table).

The rest of the table must now be filled with states emitting symbols with a positive distribution value, in order.

Spreading the entries for a given symbol with a positive distribution value

The states corresponding to a symbol will be spread throughout the remaining table entries. The idea behind spreading them, rather agglomerating them near each other, is that for a given symbol you will need less bits to reach any other state.

Aside: why are the entries for a symbol spread in a FSE table?

Let us assume that your output uses four symbols A, B, C, and D, and that A density is \(\frac 12\), B density is \(\frac 14\), and C and D both have density \(\frac 18\). The chosen accuracy log is 3, so there exist 8 states.

The table may then be built as-is:

StateOutput symbolBase line (BL)Number of bits to read
0A01
1B02
2A21
3C03
4A41
5B42
6A61
7D03

Let us assume we want to emit a symbol then jump to a given state, there always exists a state which has this property. An encoder would start from the end of the stream of symbols to encode, and go backward to find an appropriate state. For example, let us encode "AACBAABD" and build the stream (which will be parsed by a backward bit parser when decoding) in the process. Underscores (_) will be used to separate bytes once they have been filled entirely.

  • Only state 7 emits D. We want the previous state to jump there and stop.
  • State 5 emits B and jumps to 7 when reading 11. Stream: 11.
  • State 4 emits A and jumps to 5 when reading 1. Stream: 111.
  • State 4 emits A and jumps to 4 when reading 0. Stream: 0111.
  • State 5 emits B and jumps to 4 when reading 00. Stream: 000111.
  • State 3 emits C and jumps to 4 when reading 100. Stream: 00000111_1.
  • State 2 emits A and jumps to 3 when reading 1. Stream: 00000111_11.
  • State 2 emits A and jumps to 2 when reading 0. Stream: 00000111_011.
  • The initial value is taken by reading a number of bits equal to the accuracy log (3): 010 will designate state 2. Stream is: 00000111_010011.
  • A 1 is added to mark the beginning is the stream, then zeroes are added as necessary to complete the last byte (here only one is needed). Stream is 00000111_01010011, or [0b00000111, 0b01010011], or [0x07, 0x53].

The advantage of spreading appears clearly in this example: by spreading the most frequent symbol A accross the state, every time a A appears in the output only one bit will be needed to describe the following symbol.

The Zstandard RFC specifies how the symbols will be spread:

  • Start with filling state 0
  • If state \(s\) was the last state filled and \(R\) is the table size, compute \[s' = (s + \frac R2 + \frac R8 + 3) \land (R-1)\].
  • If state \(s'\) is already filled by a symbol with a "less than 1" probability, apply this operation again, otherwise you have your new state.

Given that \(R\) is a power of 2, \(x \land (R-1)\) is equivalent to x % R in Rust. Also, the accuracy log is defined as being at least 5, which means that \(R \geq 32\). This means that by applying repeatidly the operation to get from one state \(s\) to the next one \(s'\) is a generating function which will reach all states once before starting from 0 again.

💡 Using functions such as std::iter::successors() and combinators such as .filter(), you can easily build an iterator whose .next() result is the next state to use.

✅ Determine the list of states that will be used for the current symbol (there will be as many states as the distribution value). Sort the list in natural order.

If the table size is \(R\) and the distribution value for the current symbol is \(d\), starting with a baseline of \(\frac {kR}d\) for \(k \in [0, d[\) and reading \(\log_2{\frac Rd}\) bits to compute an offset would allow you to reach any other state. However, that works only when \(d\) is itself a power of 2.

If \(d\) is not a power of 2, some of the states will need to read one more bit in order to ensure the coverage of the whole state table starting from the current symbol.

✅ Given the distribution value \(d\):

  • Compute \(p\) as being the smallest power of two greater than or equal to \(d\). Note that \(p\) may be equal to \(d\).
  • Compute the number of bits \(b = log_2\frac Rp\) you will need to read to be able to globally reach all states if had \(p\) states available for your symbol instead of \(d\).
  • Compute \(e = p - d\), the number of states to which you will need to add one extra bit provided that, in reality, you only have \(d\) states available for your symbol. \(e\) may be 0 if \(p=d\).

Once you have done that, you will be able to:

  • reach \(2^{b+1}\) states from each of \(e\) of your states
  • reach \(2^b\) states from each of \(d-e\) of your states

which allows you to reach a total of \(e \times 2^{b+1} + (d-e) \times 2^b = R\) states. The whole table is covered provided the baselines are chosen correctly.

The Zstandard mandates that:

  • the first \(e\) states for the current symbol will use a Number of bits to read of \(b+1\);
  • the remaining \(d-e\) states for the current symbol will use a Number of bits to read of \(b\);
  • the base line starts at 0 for the first state with \b\ bits;
  • after setting the base line of a state, it is incremented by 2 to the power of the number of bits corresponding to this state, then the next state is processed (wrapping around the list if needed) until we have completed all states for the current symbol.

A very detailed example is given in the RFC itself. You see in the table describing the 5 states allocated for a given symbol:

  • The first 3 of those states require 5 bits (4 + 1 extra) and the other 2 require 4 bits.
  • The baseline:
    • starts at 0 in the first state requiring 4 bits, then is incremented by \(2^4=16\)
    • is set to 16 in the next state, then incremented by \(2^4=16\), and \(16+16=32\)
    • is set to 32 in the next state (wrapped around), then incremented by \(2^5=32\), and \(32+32=64\)
    • is set to 64 in the next state, then incremented by \(2^5=32\), and \(64+32=96\)
    • is set to 96 in the next state, then incremented by \(2^5=32\), and \(96+32=128\)
  • All states have their baseline filled, and the baseline has reached 128 which is the table size anyway.

✅ Follow a similar process and fill the states assigned to your symbol.

✅ Do this for all symbols with a distribution value greater than 0.

You can return the table, the process is complete.

Building a FSE table from a bitstream

✅ Add a FseTable::parse() constructor which takes a mutable reference to a ForwardBitParser and returns either the newly built table or an error.

✅ Add tests.

You may, for example, check that the table built from [0x30, 0x6f, 0x9b, 0x03] decodes as the table shown in Nigel Tao's example.

The FSE decoder

✅ In the fse module, add a FseDecoder struct which encapsulates everything a FSE decoder needs.

Your decoder will need to hold the following information:

  • the FSE table used to decode
  • the current baseline
  • the next symbol size
  • the symbol to emit when requested, as an Option<u16> (you may encounter more than 256 symbols while decoding sequences, but never more than 65535)

✅ In the decoders module, add the BitDecoder trait with the following signature:

/// A (possibly) stateful bit-level decoder
pub trait BitDecoder<T = u8, E> {
    /// Initialize the state.
    ///
    /// # Panics
    ///
    /// This method may panic if the decoder is already initialized.
    fn initialize(&mut self, bitstream: &mut BackwardBitParser) -> Result<(), E>;

    /// Return the next expected input size in bits
    ///
    /// # Panics
    ///
    /// This method may panic if no bits are expected right now
    fn expected_bits(&self) -> usize;

    /// Retrieve a decoded symbol
    ///
    /// # Panics
    ///
    /// This method may panic if the state has not been updated
    /// since the last state retrieval.
    fn symbol(&mut self) -> T;

    /// Update the state from a bitstream by reading the right
    /// number of bits, silently completing with zeroes if needed.
    /// Return `true` if zeroes have been added.
    ///
    /// # Panics
    ///
    /// This method may panic if the symbol has not been retrieved since
    /// the last update.
    fn update_bits(&mut self, bitstream: &mut BackwardBitParser) -> Result<bool, E>;

    /// Reset the table at its state before `initialize` is called. It allows
    /// reusing the same decoder.
    fn reset(&mut self);
}

This trait will be implemented by several bit decoders, starting by FseDecoder. You might note that some methods may panic: this is expected, as this would represent a logic error (for example initializing a decoder twice without resetting it), and not something that can be due to the input of our program. While it is not advised to panic if the user feeds our program bad data, we are allowed to do so when our program has an internal bug.

The expected use of a type implementing this trait is:

  1. Call initialize().
  2. Extract the decoded symbol from the decoder using symbol() and add it to the output stream.
  3. If you have previously executed step 4 and it returned Ok(true), meaning that the input stream had to be completed with zeroes, stop there, you have reached the end of the decoding process.
  4. Update the decoder state with update_bits() (this will attempt to consume expected_bits()).
  5. Go back to 2.

✅ Implement the BitDecoder trait for FseDecoder.

✅ Write some tests, using some hand-crafted pathological cases, and Nigel Tao's examples.