zstd: Refactor prefetching for the decoding loop

Following facebook/zstd#2545, I noticed that one field in `seq_t` is
optional, and only used in combination with prefetching. (This may have
contributed to static analyzer failure to detect correct
initialization).

I then wondered if it would be possible to rewrite the code so that this
optional part is handled directly by the prefetching code rather than
delegated as an option into `ZSTD_decodeSequence()`.

This resulted into this refactoring exercise where the prefetching
responsibility is better isolated into its own function and
`ZSTD_decodeSequence()` is streamlined to contain strictly Sequence
decoding operations.  Incidently, due to better code locality, it
reduces the need to send information around, leading to simplified
interface, and smaller state structures.

Port of facebook/zstd@f5434663ea

Reported-by: Coverity (CID 1462271)
Reviewed-by: Damian Szuberski <szuberskidamian@gmail.com>
Reviewed-by: Brian Behlendorf <behlendorf1@llnl.gov>
Reviewed-by: Tino Reichardt <milky-zfs@mcmilk.de>
Ported-by: Richard Yao <richard.yao@alumni.stonybrook.edu>
Closes #14212
This commit is contained in:
Yann Collet 2021-03-18 11:28:22 -07:00 committed by Brian Behlendorf
parent 466cf54ecf
commit 776441152e

View File

@ -554,7 +554,6 @@ typedef struct {
size_t litLength; size_t litLength;
size_t matchLength; size_t matchLength;
size_t offset; size_t offset;
const BYTE* match;
} seq_t; } seq_t;
typedef struct { typedef struct {
@ -568,9 +567,6 @@ typedef struct {
ZSTD_fseState stateOffb; ZSTD_fseState stateOffb;
ZSTD_fseState stateML; ZSTD_fseState stateML;
size_t prevOffset[ZSTD_REP_NUM]; size_t prevOffset[ZSTD_REP_NUM];
const BYTE* prefixStart;
const BYTE* dictEnd;
size_t pos;
} seqState_t; } seqState_t;
/*! ZSTD_overlapCopy8() : /*! ZSTD_overlapCopy8() :
@ -832,10 +828,9 @@ ZSTD_updateFseStateWithDInfo(ZSTD_fseState* DStatePtr, BIT_DStream_t* bitD, ZSTD
: 0) : 0)
typedef enum { ZSTD_lo_isRegularOffset, ZSTD_lo_isLongOffset=1 } ZSTD_longOffset_e; typedef enum { ZSTD_lo_isRegularOffset, ZSTD_lo_isLongOffset=1 } ZSTD_longOffset_e;
typedef enum { ZSTD_p_noPrefetch=0, ZSTD_p_prefetch=1 } ZSTD_prefetch_e;
FORCE_INLINE_TEMPLATE seq_t FORCE_INLINE_TEMPLATE seq_t
ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets, const ZSTD_prefetch_e prefetch) ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets)
{ {
seq_t seq; seq_t seq;
ZSTD_seqSymbol const llDInfo = seqState->stateLL.table[seqState->stateLL.state]; ZSTD_seqSymbol const llDInfo = seqState->stateLL.table[seqState->stateLL.state];
@ -910,14 +905,6 @@ ZSTD_decodeSequence(seqState_t* seqState, const ZSTD_longOffset_e longOffsets, c
DEBUGLOG(6, "seq: litL=%u, matchL=%u, offset=%u", DEBUGLOG(6, "seq: litL=%u, matchL=%u, offset=%u",
(U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset); (U32)seq.litLength, (U32)seq.matchLength, (U32)seq.offset);
if (prefetch == ZSTD_p_prefetch) {
size_t const pos = seqState->pos + seq.litLength;
const BYTE* const matchBase = (seq.offset > pos) ? seqState->dictEnd : seqState->prefixStart;
seq.match = matchBase + pos - seq.offset; /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted.
* No consequence though : no memory access will occur, offset is only used for prefetching */
seqState->pos = pos + seq.matchLength;
}
/* ANS state update /* ANS state update
* gcc-9.0.0 does 2.5% worse with ZSTD_updateFseStateWithDInfo(). * gcc-9.0.0 does 2.5% worse with ZSTD_updateFseStateWithDInfo().
* clang-9.2.0 does 7% worse with ZSTD_updateFseState(). * clang-9.2.0 does 7% worse with ZSTD_updateFseState().
@ -1072,7 +1059,7 @@ ZSTD_decompressSequences_body( ZSTD_DCtx* dctx,
__asm__(".p2align 4"); __asm__(".p2align 4");
#endif #endif
for ( ; ; ) { for ( ; ; ) {
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, ZSTD_p_noPrefetch); seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, prefixStart, vBase, dictEnd); size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequence, &litPtr, litEnd, prefixStart, vBase, dictEnd);
#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE)
assert(!ZSTD_isError(oneSeqSize)); assert(!ZSTD_isError(oneSeqSize));
@ -1124,6 +1111,24 @@ ZSTD_decompressSequences_default(ZSTD_DCtx* dctx,
#endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */ #endif /* ZSTD_FORCE_DECOMPRESS_SEQUENCES_LONG */
#ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT #ifndef ZSTD_FORCE_DECOMPRESS_SEQUENCES_SHORT
FORCE_INLINE_TEMPLATE size_t
ZSTD_prefetchMatch(size_t prefixPos, seq_t const sequence,
const BYTE* const prefixStart, const BYTE* const dictEnd)
{
prefixPos += sequence.litLength;
{ const BYTE* const matchBase = (sequence.offset > prefixPos) ? dictEnd : prefixStart;
const BYTE* const match = matchBase + prefixPos - sequence.offset; /* note : this operation can overflow when seq.offset is really too large, which can only happen when input is corrupted.
* No consequence though : no memory access will occur, offset is only used for prefetching */
PREFETCH_L1(match); PREFETCH_L1(match + sequence.matchLength - 1); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */
}
return prefixPos + sequence.matchLength;
}
/* This decoding function employs prefetching
* to reduce latency impact of cache misses.
* It's generally employed when block contains a significant portion of long-distance matches
* or when coupled with a "cold" dictionary */
FORCE_INLINE_TEMPLATE size_t FORCE_INLINE_TEMPLATE size_t
ZSTD_decompressSequencesLong_body( ZSTD_decompressSequencesLong_body(
ZSTD_DCtx* dctx, ZSTD_DCtx* dctx,
@ -1153,11 +1158,10 @@ ZSTD_decompressSequencesLong_body(
int const seqAdvance = MIN(nbSeq, ADVANCED_SEQS); int const seqAdvance = MIN(nbSeq, ADVANCED_SEQS);
seqState_t seqState; seqState_t seqState;
int seqNb; int seqNb;
size_t prefixPos = (size_t)(op-prefixStart); /* track position relative to prefixStart */
dctx->fseEntropy = 1; dctx->fseEntropy = 1;
{ int i; for (i=0; i<ZSTD_REP_NUM; i++) seqState.prevOffset[i] = dctx->entropy.rep[i]; } { int i; for (i=0; i<ZSTD_REP_NUM; i++) seqState.prevOffset[i] = dctx->entropy.rep[i]; }
seqState.prefixStart = prefixStart;
seqState.pos = (size_t)(op-prefixStart);
seqState.dictEnd = dictEnd;
assert(dst != NULL); assert(dst != NULL);
assert(iend >= ip); assert(iend >= ip);
RETURN_ERROR_IF( RETURN_ERROR_IF(
@ -1169,21 +1173,23 @@ ZSTD_decompressSequencesLong_body(
/* prepare in advance */ /* prepare in advance */
for (seqNb=0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && (seqNb<seqAdvance); seqNb++) { for (seqNb=0; (BIT_reloadDStream(&seqState.DStream) <= BIT_DStream_completed) && (seqNb<seqAdvance); seqNb++) {
sequences[seqNb] = ZSTD_decodeSequence(&seqState, isLongOffset, ZSTD_p_prefetch); seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
PREFETCH_L1(sequences[seqNb].match); PREFETCH_L1(sequences[seqNb].match + sequences[seqNb].matchLength - 1); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */ prefixPos = ZSTD_prefetchMatch(prefixPos, sequence, prefixStart, dictEnd);
sequences[seqNb] = sequence;
} }
RETURN_ERROR_IF(seqNb<seqAdvance, corruption_detected, ""); RETURN_ERROR_IF(seqNb<seqAdvance, corruption_detected, "");
/* decode and decompress */ /* decode and decompress */
for ( ; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && (seqNb<nbSeq) ; seqNb++) { for ( ; (BIT_reloadDStream(&(seqState.DStream)) <= BIT_DStream_completed) && (seqNb<nbSeq) ; seqNb++) {
seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset, ZSTD_p_prefetch); seq_t const sequence = ZSTD_decodeSequence(&seqState, isLongOffset);
size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequences[(seqNb-ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litEnd, prefixStart, dictStart, dictEnd); size_t const oneSeqSize = ZSTD_execSequence(op, oend, sequences[(seqNb-ADVANCED_SEQS) & STORED_SEQS_MASK], &litPtr, litEnd, prefixStart, dictStart, dictEnd);
#if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE) #if defined(FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION) && defined(FUZZING_ASSERT_VALID_SEQUENCE)
assert(!ZSTD_isError(oneSeqSize)); assert(!ZSTD_isError(oneSeqSize));
if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb-ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart); if (frame) ZSTD_assertValidSequence(dctx, op, oend, sequences[(seqNb-ADVANCED_SEQS) & STORED_SEQS_MASK], prefixStart, dictStart);
#endif #endif
if (ZSTD_isError(oneSeqSize)) return oneSeqSize; if (ZSTD_isError(oneSeqSize)) return oneSeqSize;
PREFETCH_L1(sequence.match); PREFETCH_L1(sequence.match + sequence.matchLength - 1); /* note : it's safe to invoke PREFETCH() on any memory address, including invalid ones */
prefixPos = ZSTD_prefetchMatch(prefixPos, sequence, prefixStart, dictEnd);
sequences[seqNb & STORED_SEQS_MASK] = sequence; sequences[seqNb & STORED_SEQS_MASK] = sequence;
op += oneSeqSize; op += oneSeqSize;
} }