From 3ef354706ad13ab558e0b2bd34125d598c1d6abf Mon Sep 17 00:00:00 2001
From: hegilmez <hegilmez@qti.qualcomm.com>
Date: Tue, 9 Apr 2019 13:21:51 -0700
Subject: [PATCH 1/2] JVET-N0866: unified transform derivation for ISP and
 implicit MTS

- Implicit transform derivation
- DST-7 is applied horizontally/vertically as long as the number of (luma) samples are less than or equal to 16 in a row/column, otherwise DCT-2 is applied.
---
 source/Lib/CommonLib/TrQuant.cpp   | 10 ++++
 source/Lib/CommonLib/TrQuant.h     |  3 +-
 source/Lib/CommonLib/TypeDef.h     |  2 +
 source/Lib/CommonLib/UnitTools.cpp | 74 ++++++++++++++++++++++++++++++
 source/Lib/CommonLib/UnitTools.h   |  6 +++
 5 files changed, 94 insertions(+), 1 deletion(-)

diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp
index 8c701503e3..b542d274a7 100644
--- a/source/Lib/CommonLib/TrQuant.cpp
+++ b/source/Lib/CommonLib/TrQuant.cpp
@@ -281,6 +281,7 @@ void TrQuant::invRdpcmNxN(TransformUnit& tu, const ComponentID &compID, PelBuf &
 // Logical transform
 // ------------------------------------------------------------------------------------------------
 
+#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
 void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer )
 {
   bool mtsActivated = CU::isIntra( *tu.cu ) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter( *tu.cu );
@@ -358,6 +359,7 @@ void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTy
       trTypeHor = trTypeVer = DST7;
   }
 }
+#endif
 
 void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPelBuf &resi, CoeffBuf &dstCoeff, const int width, const int height )
 {
@@ -371,7 +373,11 @@ void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPel
   int trTypeHor = DCT2;
   int trTypeVer = DCT2;
 
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  TU::getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+#else
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+#endif
 
   const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
@@ -439,7 +445,11 @@ void TrQuant::xIT( const TransformUnit &tu, const ComponentID &compID, const CCo
   int trTypeHor = DCT2;
   int trTypeVer = DCT2;
 
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  TU::getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+#else
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+#endif
 
   const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
diff --git a/source/Lib/CommonLib/TrQuant.h b/source/Lib/CommonLib/TrQuant.h
index 85964c1c8e..e023c8c624 100644
--- a/source/Lib/CommonLib/TrQuant.h
+++ b/source/Lib/CommonLib/TrQuant.h
@@ -78,8 +78,9 @@ public:
                     const bool useTransformSkipFast = false
   );
 
+#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
   void getTrTypes( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer );
-
+#endif
 
 protected:
 
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index 67b78ab04a..d26c385b92 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -50,6 +50,8 @@
 #include <assert.h>
 #include <cassert>
 
+#define JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP             1 // JVET-N0866: unified transform derivation for ISP and implicit MTS (combining JVET-N0172, JVET-N0375, JVET-N0419 and JVET-N0420)
+
 #define JVET_N0103_CGSIZE_HARMONIZATION                   1 // Chroma CG sizes aligned to luma CG sizes
 
 #define JVET_N0146_DMVR_BDOF_CONDITION                    1 // JVET-N146/N0162/N0442/N0153/N0262/N0440/N0086 applicable condition of DMVR and BDOF
diff --git a/source/Lib/CommonLib/UnitTools.cpp b/source/Lib/CommonLib/UnitTools.cpp
index 43bfb9bf62..04b41e437a 100644
--- a/source/Lib/CommonLib/UnitTools.cpp
+++ b/source/Lib/CommonLib/UnitTools.cpp
@@ -5568,6 +5568,7 @@ bool TU::getPrevTuCbfAtDepth( const TransformUnit &currentTu, const ComponentID
   return ( prevTU != nullptr ) ? TU::getCbfAtDepth( *prevTU, compID, trDepth ) : false;
 }
 
+#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
 void TU::getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, int &typeH, int &typeV )
 {
   typeH = DCT2, typeV = DCT2;
@@ -5600,7 +5601,80 @@ void TU::getTransformTypeISP( const TransformUnit &tu, const ComponentID compID,
   typeV = tuArea.height <= 2 || tuArea.height >= 32 ? DCT2 : typeV;
 }
 
+#endif
+
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+void TU::getTrTypes ( const TransformUnit &tu, const ComponentID compID, int &trTypeHor, int &trTypeVer )
+{
+  const bool isExplicitMTS = (CU::isIntra(*tu.cu) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter(*tu.cu)) && isLuma(compID);
+  const bool isImplicitMTS = CU::isIntra(*tu.cu) && tu.cs->sps->getUseImplicitMTS() && isLuma(compID);
+  const bool isISP         = CU::isIntra(*tu.cu) && tu.cu->ispMode && isLuma(compID);
+  const bool isSBT         = CU::isInter(*tu.cu) && tu.cu->sbtInfo && isLuma(compID);
+
+  trTypeHor = DCT2;
+  trTypeVer = DCT2;
+
+  if (isImplicitMTS || isISP)
+  {
+    int  width       = tu.blocks[compID].width;
+    int  height      = tu.blocks[compID].height;
+    bool widthDstOk  = width  >= 4 && width  <= 16;
+    bool heightDstOk = height >= 4 && height <= 16;
+
+    if (widthDstOk)
+      trTypeHor = DST7;
+    if (heightDstOk)
+      trTypeVer = DST7;
+    return;
+  }
+
+  if( isSBT )
+  {
+    uint8_t sbtIdx = tu.cu->getSbtIdx();
+    uint8_t sbtPos = tu.cu->getSbtPos();
+
+    if( sbtIdx == SBT_VER_HALF || sbtIdx == SBT_VER_QUAD )
+    {
+      assert( tu.lwidth() <= MTS_INTER_MAX_CU_SIZE );
+      if( tu.lheight() > MTS_INTER_MAX_CU_SIZE )
+      {
+        trTypeHor = trTypeVer = DCT2;
+      }
+      else
+      {
+        if( sbtPos == SBT_POS0 )  { trTypeHor = DCT8;  trTypeVer = DST7; }
+        else                      { trTypeHor = DST7;  trTypeVer = DST7; }
+      }
+    }
+    else
+    {
+      assert( tu.lheight() <= MTS_INTER_MAX_CU_SIZE );
+      if( tu.lwidth() > MTS_INTER_MAX_CU_SIZE )
+      {
+        trTypeHor = trTypeVer = DCT2;
+      }
+      else
+      {
+        if( sbtPos == SBT_POS0 )  { trTypeHor = DST7;  trTypeVer = DCT8; }
+        else                      { trTypeHor = DST7;  trTypeVer = DST7; }
+      }
+    }
+    return;
+  }
+
+  if ( isExplicitMTS )
+  {
+    if ( tu.mtsIdx > 1 )
+    {
+      int indHor = ( tu.mtsIdx - 2 ) &  1;
+      int indVer = ( tu.mtsIdx - 2 ) >> 1;
 
+      trTypeHor = indHor ? DCT8 : DST7;
+      trTypeVer = indVer ? DCT8 : DST7;
+    }
+  }
+}
+#endif
 // other tools
 
 uint32_t getCtuAddr( const Position& pos, const PreCalcValues& pcv )
diff --git a/source/Lib/CommonLib/UnitTools.h b/source/Lib/CommonLib/UnitTools.h
index 45b749c7a2..df6ddbe661 100644
--- a/source/Lib/CommonLib/UnitTools.h
+++ b/source/Lib/CommonLib/UnitTools.h
@@ -220,7 +220,13 @@ namespace TU
 #endif
   TransformUnit* getPrevTU          ( const TransformUnit &tu, const ComponentID compID );
   bool           getPrevTuCbfAtDepth( const TransformUnit &tu, const ComponentID compID, const int trDepth );
+#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
   void           getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, int &typeH, int &typeV );
+#endif
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  void          getTrTypes         ( const TransformUnit &tu, const ComponentID compID, int &trTypeHor, int &trTypeVer);
+#endif
+
 }
 
 uint32_t getCtuAddr        (const Position& pos, const PreCalcValues &pcv);
-- 
GitLab


From 197ec30de4540d0988abea747b96cac6c302723c Mon Sep 17 00:00:00 2001
From: hegilmez <hegilmez@qti.qualcomm.com>
Date: Wed, 10 Apr 2019 18:38:18 -0700
Subject: [PATCH 2/2] JVET-N0866: unified transform derivation for ISP and
 implicit MTS

- getTrTypes function is moved to TrQuant class
---
 source/Lib/CommonLib/TrQuant.cpp   | 63 +++++++++++++++++++++-----
 source/Lib/CommonLib/TrQuant.h     |  5 ++-
 source/Lib/CommonLib/TypeDef.h     |  6 +++
 source/Lib/CommonLib/UnitTools.cpp | 72 ------------------------------
 source/Lib/CommonLib/UnitTools.h   |  3 --
 5 files changed, 60 insertions(+), 89 deletions(-)

diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp
index b542d274a7..44ddf3e47e 100644
--- a/source/Lib/CommonLib/TrQuant.cpp
+++ b/source/Lib/CommonLib/TrQuant.cpp
@@ -281,22 +281,54 @@ void TrQuant::invRdpcmNxN(TransformUnit& tu, const ComponentID &compID, PelBuf &
 // Logical transform
 // ------------------------------------------------------------------------------------------------
 
-#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+void TrQuant::getTrTypes(const TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer)
+#else
 void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer )
+#endif
 {
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  const bool isExplicitMTS = (CU::isIntra(*tu.cu) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter(*tu.cu)) && isLuma(compID);
+  const bool isImplicitMTS = CU::isIntra(*tu.cu) && tu.cs->sps->getUseImplicitMTS() && isLuma(compID);
+  const bool isISP = CU::isIntra(*tu.cu) && tu.cu->ispMode && isLuma(compID);
+  const bool isSBT = CU::isInter(*tu.cu) && tu.cu->sbtInfo && isLuma(compID);
+#else
   bool mtsActivated = CU::isIntra( *tu.cu ) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter( *tu.cu );
-
   bool mtsImplicit  = CU::isIntra( *tu.cu ) && tu.cs->sps->getUseImplicitMTS() && compID == COMPONENT_Y;
+#endif
 
   trTypeHor = DCT2;
   trTypeVer = DCT2;
 
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  if (isImplicitMTS || isISP)
+  {
+    int  width = tu.blocks[compID].width;
+    int  height = tu.blocks[compID].height;
+    bool widthDstOk = width >= 4 && width <= 16;
+    bool heightDstOk = height >= 4 && height <= 16;
+
+    if (widthDstOk)
+      trTypeHor = DST7;
+    if (heightDstOk)
+      trTypeVer = DST7;
+    return;
+  }
+#endif
+
+#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
   if (tu.cu->ispMode && isLuma(compID))
   {
     TU::getTransformTypeISP(tu, compID, trTypeHor, trTypeVer);
     return;
-}
+  }
+#endif
+
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  if (isSBT)
+#else
   if( tu.cu->sbtInfo && compID == COMPONENT_Y )
+#endif
   {
     uint8_t sbtIdx = tu.cu->getSbtIdx();
     uint8_t sbtPos = tu.cu->getSbtPos();
@@ -330,6 +362,19 @@ void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTy
     return;
   }
 
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  if (isExplicitMTS)
+  {
+    if (tu.mtsIdx > 1)
+    {
+      int indHor = (tu.mtsIdx - 2) & 1;
+      int indVer = (tu.mtsIdx - 2) >> 1;
+
+      trTypeHor = indHor ? DCT8 : DST7;
+      trTypeVer = indVer ? DCT8 : DST7;
+    }
+  }
+#else
   if ( mtsActivated )
   {
     if( compID == COMPONENT_Y )
@@ -358,8 +403,10 @@ void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTy
     else if ( width == height && widthDstOk )
       trTypeHor = trTypeVer = DST7;
   }
-}
 #endif
+}
+
+
 
 void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPelBuf &resi, CoeffBuf &dstCoeff, const int width, const int height )
 {
@@ -373,11 +420,7 @@ void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPel
   int trTypeHor = DCT2;
   int trTypeVer = DCT2;
 
-#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
-  TU::getTrTypes ( tu, compID, trTypeHor, trTypeVer );
-#else
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
-#endif
 
   const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
@@ -445,11 +488,7 @@ void TrQuant::xIT( const TransformUnit &tu, const ComponentID &compID, const CCo
   int trTypeHor = DCT2;
   int trTypeVer = DCT2;
 
-#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
-  TU::getTrTypes ( tu, compID, trTypeHor, trTypeVer );
-#else
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
-#endif
 
   const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
diff --git a/source/Lib/CommonLib/TrQuant.h b/source/Lib/CommonLib/TrQuant.h
index e023c8c624..61735cec78 100644
--- a/source/Lib/CommonLib/TrQuant.h
+++ b/source/Lib/CommonLib/TrQuant.h
@@ -77,8 +77,9 @@ public:
                     const bool bEnc                 = false,
                     const bool useTransformSkipFast = false
   );
-
-#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  void getTrTypes(const TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer);
+#else
   void getTrTypes( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer );
 #endif
 
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index d26c385b92..ac2bba9c8c 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -50,8 +50,14 @@
 #include <assert.h>
 #include <cassert>
 
+
+
+
 #define JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP             1 // JVET-N0866: unified transform derivation for ISP and implicit MTS (combining JVET-N0172, JVET-N0375, JVET-N0419 and JVET-N0420)
 
+
+
+
 #define JVET_N0103_CGSIZE_HARMONIZATION                   1 // Chroma CG sizes aligned to luma CG sizes
 
 #define JVET_N0146_DMVR_BDOF_CONDITION                    1 // JVET-N146/N0162/N0442/N0153/N0262/N0440/N0086 applicable condition of DMVR and BDOF
diff --git a/source/Lib/CommonLib/UnitTools.cpp b/source/Lib/CommonLib/UnitTools.cpp
index 04b41e437a..0ea22153cb 100644
--- a/source/Lib/CommonLib/UnitTools.cpp
+++ b/source/Lib/CommonLib/UnitTools.cpp
@@ -5603,78 +5603,6 @@ void TU::getTransformTypeISP( const TransformUnit &tu, const ComponentID compID,
 
 #endif
 
-#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
-void TU::getTrTypes ( const TransformUnit &tu, const ComponentID compID, int &trTypeHor, int &trTypeVer )
-{
-  const bool isExplicitMTS = (CU::isIntra(*tu.cu) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter(*tu.cu)) && isLuma(compID);
-  const bool isImplicitMTS = CU::isIntra(*tu.cu) && tu.cs->sps->getUseImplicitMTS() && isLuma(compID);
-  const bool isISP         = CU::isIntra(*tu.cu) && tu.cu->ispMode && isLuma(compID);
-  const bool isSBT         = CU::isInter(*tu.cu) && tu.cu->sbtInfo && isLuma(compID);
-
-  trTypeHor = DCT2;
-  trTypeVer = DCT2;
-
-  if (isImplicitMTS || isISP)
-  {
-    int  width       = tu.blocks[compID].width;
-    int  height      = tu.blocks[compID].height;
-    bool widthDstOk  = width  >= 4 && width  <= 16;
-    bool heightDstOk = height >= 4 && height <= 16;
-
-    if (widthDstOk)
-      trTypeHor = DST7;
-    if (heightDstOk)
-      trTypeVer = DST7;
-    return;
-  }
-
-  if( isSBT )
-  {
-    uint8_t sbtIdx = tu.cu->getSbtIdx();
-    uint8_t sbtPos = tu.cu->getSbtPos();
-
-    if( sbtIdx == SBT_VER_HALF || sbtIdx == SBT_VER_QUAD )
-    {
-      assert( tu.lwidth() <= MTS_INTER_MAX_CU_SIZE );
-      if( tu.lheight() > MTS_INTER_MAX_CU_SIZE )
-      {
-        trTypeHor = trTypeVer = DCT2;
-      }
-      else
-      {
-        if( sbtPos == SBT_POS0 )  { trTypeHor = DCT8;  trTypeVer = DST7; }
-        else                      { trTypeHor = DST7;  trTypeVer = DST7; }
-      }
-    }
-    else
-    {
-      assert( tu.lheight() <= MTS_INTER_MAX_CU_SIZE );
-      if( tu.lwidth() > MTS_INTER_MAX_CU_SIZE )
-      {
-        trTypeHor = trTypeVer = DCT2;
-      }
-      else
-      {
-        if( sbtPos == SBT_POS0 )  { trTypeHor = DST7;  trTypeVer = DCT8; }
-        else                      { trTypeHor = DST7;  trTypeVer = DST7; }
-      }
-    }
-    return;
-  }
-
-  if ( isExplicitMTS )
-  {
-    if ( tu.mtsIdx > 1 )
-    {
-      int indHor = ( tu.mtsIdx - 2 ) &  1;
-      int indVer = ( tu.mtsIdx - 2 ) >> 1;
-
-      trTypeHor = indHor ? DCT8 : DST7;
-      trTypeVer = indVer ? DCT8 : DST7;
-    }
-  }
-}
-#endif
 // other tools
 
 uint32_t getCtuAddr( const Position& pos, const PreCalcValues& pcv )
diff --git a/source/Lib/CommonLib/UnitTools.h b/source/Lib/CommonLib/UnitTools.h
index df6ddbe661..89815e85b3 100644
--- a/source/Lib/CommonLib/UnitTools.h
+++ b/source/Lib/CommonLib/UnitTools.h
@@ -223,9 +223,6 @@ namespace TU
 #if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
   void           getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, int &typeH, int &typeV );
 #endif
-#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
-  void          getTrTypes         ( const TransformUnit &tu, const ComponentID compID, int &trTypeHor, int &trTypeVer);
-#endif
 
 }
 
-- 
GitLab