diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp index 5fb8639cec53913731a467c5d858552654fb307f..e84d183b5fbb3b98a3445416219da4607173ac90 100644 --- a/source/Lib/CommonLib/TrQuant.cpp +++ b/source/Lib/CommonLib/TrQuant.cpp @@ -281,21 +281,54 @@ void TrQuant::invRdpcmNxN(TransformUnit& tu, const ComponentID &compID, PelBuf & // Logical transform // ------------------------------------------------------------------------------------------------ +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP +void TrQuant::getTrTypes(const TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer) +#else void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer ) +#endif { +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP + const bool isExplicitMTS = (CU::isIntra(*tu.cu) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter(*tu.cu)) && isLuma(compID); + const bool isImplicitMTS = CU::isIntra(*tu.cu) && tu.cs->sps->getUseImplicitMTS() && isLuma(compID); + const bool isISP = CU::isIntra(*tu.cu) && tu.cu->ispMode && isLuma(compID); + const bool isSBT = CU::isInter(*tu.cu) && tu.cu->sbtInfo && isLuma(compID); +#else bool mtsActivated = CU::isIntra( *tu.cu ) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter( *tu.cu ); - bool mtsImplicit = CU::isIntra( *tu.cu ) && tu.cs->sps->getUseImplicitMTS() && compID == COMPONENT_Y; +#endif trTypeHor = DCT2; trTypeVer = DCT2; +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP + if (isImplicitMTS || isISP) + { + int width = tu.blocks[compID].width; + int height = tu.blocks[compID].height; + bool widthDstOk = width >= 4 && width <= 16; + bool heightDstOk = height >= 4 && height <= 16; + + if (widthDstOk) + trTypeHor = DST7; + if (heightDstOk) + trTypeVer = DST7; + return; + } +#endif + +#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP if (tu.cu->ispMode && isLuma(compID)) { TU::getTransformTypeISP(tu, compID, trTypeHor, trTypeVer); return; -} + } +#endif + +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP + if (isSBT) +#else if( tu.cu->sbtInfo && compID == COMPONENT_Y ) +#endif { uint8_t sbtIdx = tu.cu->getSbtIdx(); uint8_t sbtPos = tu.cu->getSbtPos(); @@ -329,6 +362,19 @@ void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTy return; } +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP + if (isExplicitMTS) + { + if (tu.mtsIdx > 1) + { + int indHor = (tu.mtsIdx - 2) & 1; + int indVer = (tu.mtsIdx - 2) >> 1; + + trTypeHor = indHor ? DCT8 : DST7; + trTypeVer = indVer ? DCT8 : DST7; + } + } +#else if ( mtsActivated ) { if( compID == COMPONENT_Y ) @@ -357,8 +403,11 @@ void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTy else if ( width == height && widthDstOk ) trTypeHor = trTypeVer = DST7; } +#endif } + + void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPelBuf &resi, CoeffBuf &dstCoeff, const int width, const int height ) { const unsigned maxLog2TrDynamicRange = tu.cs->sps->getMaxLog2TrDynamicRange( toChannelType( compID ) ); diff --git a/source/Lib/CommonLib/TrQuant.h b/source/Lib/CommonLib/TrQuant.h index 85964c1c8efa973db6ab9f8a842ba46918dca4b9..61735cec78e4383cf3317f9dc152702182b64ea0 100644 --- a/source/Lib/CommonLib/TrQuant.h +++ b/source/Lib/CommonLib/TrQuant.h @@ -77,9 +77,11 @@ public: const bool bEnc = false, const bool useTransformSkipFast = false ); - +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP + void getTrTypes(const TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer); +#else void getTrTypes( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer ); - +#endif protected: diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index 18ae5a7bc484844858970e08507ad05e2434274e..acd80e6141f3d32df6cabfc77540bf368ca2715b 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -50,6 +50,8 @@ #include <assert.h> #include <cassert> +#define JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP 1 // JVET-N0866: unified transform derivation for ISP and implicit MTS (combining JVET-N0172, JVET-N0375, JVET-N0419 and JVET-N0420) + #define JVET_N0340_TRI_MERGE_CAND 1 #define JVET_N0324_REGULAR_MRG_FLAG 1 diff --git a/source/Lib/CommonLib/UnitTools.cpp b/source/Lib/CommonLib/UnitTools.cpp index 377e76a8dc2f7e6e968cbd5cde509518685790f8..321523089cb56ea7c0f392f0872df4bdb3965204 100644 --- a/source/Lib/CommonLib/UnitTools.cpp +++ b/source/Lib/CommonLib/UnitTools.cpp @@ -5658,6 +5658,7 @@ bool TU::getPrevTuCbfAtDepth( const TransformUnit ¤tTu, const ComponentID return ( prevTU != nullptr ) ? TU::getCbfAtDepth( *prevTU, compID, trDepth ) : false; } +#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP void TU::getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, int &typeH, int &typeV ) { typeH = DCT2, typeV = DCT2; @@ -5690,6 +5691,7 @@ void TU::getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, typeV = tuArea.height <= 2 || tuArea.height >= 32 ? DCT2 : typeV; } +#endif // other tools diff --git a/source/Lib/CommonLib/UnitTools.h b/source/Lib/CommonLib/UnitTools.h index 45b749c7a23e166b8b403d9ddefa8fbf7a44db9c..89815e85b31aa4f45619e336477942f3f29efdf3 100644 --- a/source/Lib/CommonLib/UnitTools.h +++ b/source/Lib/CommonLib/UnitTools.h @@ -220,7 +220,10 @@ namespace TU #endif TransformUnit* getPrevTU ( const TransformUnit &tu, const ComponentID compID ); bool getPrevTuCbfAtDepth( const TransformUnit &tu, const ComponentID compID, const int trDepth ); +#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP void getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, int &typeH, int &typeV ); +#endif + } uint32_t getCtuAddr (const Position& pos, const PreCalcValues &pcv);