/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_TSL_FRAMEWORK_FIXEDPOINT_MATMATPRODUCTAVX2_H_
#define XLA_TSL_FRAMEWORK_FIXEDPOINT_MATMATPRODUCTAVX2_H_

namespace Eigen {
namespace internal {

// AVX2 optimized implementation of Mat-Mat product.
// LHS is encoded using signed 16-bit integers.
// RHS is encoded using signed 16-bit integers.
#ifdef EIGEN_USE_OPTIMIZED_INT16_INT16_MAT_MAT_PRODUCT

// Define quantized traits
template <bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QInt16, QInt16, _ConjLhs, _ConjRhs> {
 public:
  typedef QInt16 LhsScalar;
  typedef QInt16 RhsScalar;
  typedef QInt32 ResScalar;

  typedef typename packet_traits<LhsScalar>::type LhsPacket;
  typedef LhsPacket LhsPacket4Packing;

  enum {
    // Define register blocking scheme.
    nr = 16,
    mr = 16,
    kr = 4,
    // Ignore progress tracking per loop iteration.
    LhsProgress = -1,
    RhsProgress = -1
  };
};

// Specialized blocking for quantized implementations.
// Used by TensorContractionThreadPool, inputs must have dimensions that are
// multiples of 32.
template <typename Index, int ShardingType>
class TensorContractionBlocking<QInt16, QInt16, QInt16, Index, ShardingType> {
 public:
  TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1)
      : kc_(((k + 15) / 16) * 16),
        mc_(((m + 15) / 16) * 16),
        nc_(((n + 15) / 16) * 16) {
    eigen_assert(mc_ % 16 == 0);
    eigen_assert(kc_ % 16 == 0);
    if (!k || !m || !n) {
      return;
    }

    if (ShardingType == ShardByCol) {
      eigen_assert(nc_ % 16 == 0);
      nc_ = (((nc_ / num_threads) + 15) / 16) * 16;
    } else {
      eigen_assert(nc_ % 16 == 0);
      mc_ = (((mc_ / num_threads) + 15) / 16) * 16;
    }
  }

  EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
  EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
  EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }

 private:
  Index kc_;
  Index mc_;
  Index nc_;
};

// Specialized blocking for quantized implementations.
// Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to
// multiples of 32.
template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
class gemm_blocking_space<ColMajor, QInt16, QInt16, MaxRows, MaxCols, MaxDepth,
                          KcFactor, false>
    : public level3_blocking<QInt16, QInt16> {
  DenseIndex m_sizeA;
  DenseIndex m_sizeB;

 public:
  gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth,
                      DenseIndex /*num_threads*/, bool /*l3_blocking*/) {
    this->m_mc = ((rows + 15) / 16) * 16;
    this->m_nc = ((cols + 15) / 16) * 16;
    this->m_kc = ((depth + 15) / 16) * 16;
    m_sizeA = this->m_mc * this->m_kc;
    m_sizeB = this->m_kc * this->m_nc;
  }
  void allocateA() {
    if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt16>(m_sizeA);
  }
  void allocateB() {
    if (this->m_blockB == 0) this->m_blockB = aligned_new<QInt16>(m_sizeB);
  }
  void allocateAll() {
    allocateA();
    allocateB();
  }
  ~gemm_blocking_space() {
    aligned_delete(this->m_blockA, m_sizeA);
    aligned_delete(this->m_blockB, m_sizeB);
  }
};

// Below are the fully optimized versions that are correct only for sizes that
// are multiple of 16.  It is about a 10% performance benefit to keep these
// implementations separate.

// Arrange a block of the left input matrix in contiguous memory.
//
// Given column major input (A0 beside A1 in memory):
// A0 B0 C0 D0 E0 F0 G0 H0 ...
// A1 B1 C1 D1 E1 F1 G1 H1 ...
// A2 B2 C2 D2 E2 F2 G2 H2 ...
// A3 B3 C3 D3 E3 F3 G3 H3 ...
// A4 B4 C4 D4 E4 F4 G4 H4 ...
// A5 B5 C5 D5 E5 F5 G5 H5 ...
// A6 B6 C6 D6 E6 F6 G6 H6 ...
// A7 B7 C7 D7 E7 F7 G7 H7 ...
// A8 ...
// ...
//
// Packing with m = 8 yields row major output (A0 beside B0 in memory):
// A0 B0
// A1 B1
// A2 B2
// A3 B3
// A4 B4
// A5 B5
// A6 B6
// A7 B7
// ...
//
// The purpose is to collect m rows of size k.  Two elements of the same
// row are arranged contiguously because madd performs an adjacent addition
// in the kernel.

template <typename Index, typename DataMapper, int Pack1, int Pack2,
          bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2, QInt16, ColMajor,
                     Conjugate, PanelMode> {
  EIGEN_DONT_INLINE void operator()(QInt16* blockA, const DataMapper& lhs,
                                    Index depth, Index rows, Index stride = 0,
                                    Index offset = 0);
};

template <typename Index, typename DataMapper, int Pack1, int Pack2,
          bool Conjugate, bool PanelMode>
EIGEN_DONT_INLINE void
gemm_pack_lhs<QInt16, Index, DataMapper, Pack1, Pack2, QInt16, ColMajor,
              Conjugate, PanelMode>::operator()(QInt16* blockA,
                                                const DataMapper& lhs,
                                                Index depth, Index rows,
                                                Index stride, Index offset) {
  eigen_assert(stride == 0);
  eigen_assert(offset == 0);

  typedef typename packet_traits<QInt16>::type Packet;

  // Use alternate function for weird sizes
  if (rows % 16 != 0 || depth % 16 != 0) {
    eigen_assert(false &&
                 "only depths and rows that are a multiple of 16 are currently "
                 "supported");
    // gemm_pack_lhs_any<QInt16, Index, DataMapper, Pack1, Pack2, ColMajor,
    // Conjugate, PanelMode> lhs_pack;
    // return lhs_pack(blockA, lhs, depth, rows, stride, offset);
  }

  // Get vector pointer
  __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA);

  // Pack rows in sets of 16
  for (Index m = 0; m < rows; m += 16) {
    // Pack depth in sets of 4
    for (Index k = 0; k < depth; k += 4) {
      // Load vectors
      __m256i L_A = lhs.template loadPacket<Packet>(m, k);
      __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1);
      __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2);
      __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3);

      // Rearrange the inputs as required by the kernel
      __m256i L_AB0_AB7 = _mm256_unpacklo_epi16(L_A, L_B);
      __m256i L_AB8_AB15 = _mm256_unpackhi_epi16(L_A, L_B);
      __m256i L_CD0_CD7 = _mm256_unpacklo_epi16(L_C, L_D);
      __m256i L_CD8_CD15 = _mm256_unpackhi_epi16(L_C, L_D);

      __m256i L_AD0 = _mm256_permute2x128_si256(L_AB0_AB7, L_AB8_AB15, 0x20);
      _mm256_store_si256(blockA_256++, L_AD0);
      __m256i L_AD8 = _mm256_permute2x128_si256(L_CD0_CD7, L_CD8_CD15, 0x20);
      _mm256_store_si256(blockA_256++, L_AD8);
      __m256i L_AD16 = _mm256_permute2x128_si256(L_AB0_AB7, L_AB8_AB15, 0x31);
      _mm256_store_si256(blockA_256++, L_AD16);
      __m256i L_AD24 = _mm256_permute2x128_si256(L_CD0_CD7, L_CD8_CD15, 0x31);
      _mm256_store_si256(blockA_256++, L_AD24);
    }
  }
}

// Arrange a block of the right input matrix in contiguous memory.
//
// Given column major input (A0 beside A1 in memory):
// A0 B0 C0 D0 E0 F0 G0 H0 ...
// A1 B1 C1 D1 E1 F1 G1 H1 ...
// A2 B2 C2 D2 E2 F2 G2 H2 ...
// A3 B3 C3 D3 E3 F3 G3 H3 ...
// A4 B4 C4 D4 E4 F4 G4 H4 ...
// A5 B5 C5 D5 E5 F5 G5 H5 ...
// A6 B6 C6 D6 E6 F6 G6 H6 ...
// A7 B7 C7 D7 E7 F7 G7 H7 ...
// A8 ...
// ...
// Packing yields row major output (A0 beside A1 in memory):
// A0 A1 A2 A3 A4 A5 A6 A7
// B0 B1 B2 B3 B4 B5 B6 B7
// ...
//
// At least two elements of the same col are arranged contiguously because
// maddubs and madd both perform an adjacent addition in the kernel.  We can
// save work by leaving 4 adjacent elements because kr = 4.
// The purpose is to collect n cols of size k.  Two elements of the same
// col are arranged contiguously because madd performs an adjacent addition
// in the kernel.
template <typename Index, typename DataMapper, int nr, bool Conjugate,
          bool PanelMode>
struct gemm_pack_rhs<QInt16, Index, DataMapper, nr, ColMajor, Conjugate,
                     PanelMode> {
  EIGEN_DONT_INLINE void operator()(QInt16* blockB, const DataMapper& rhs,
                                    Index depth, Index cols, Index stride = 0,
                                    Index offset = 0);
};

template <typename Index, typename DataMapper, int nr, bool Conjugate,
          bool PanelMode>
EIGEN_DONT_INLINE void
gemm_pack_rhs<QInt16, Index, DataMapper, nr, ColMajor, Conjugate,
              PanelMode>::operator()(QInt16* blockB, const DataMapper& rhs,
                                     Index depth, Index cols, Index stride,
                                     Index offset) {
  eigen_assert(stride == 0);
  eigen_assert(offset == 0);

  typedef typename packet_traits<QInt16>::type Packet;

  // Use alternate function for weird sizes
  if (cols % 16 != 0 || depth % 16 != 0) {
    eigen_assert(false &&
                 "only depths and cols that are a multiple of 16 are currently "
                 "supported");
    // gemm_pack_rhs_any<QInt16, Index, DataMapper, nr, ColMajor, Conjugate,
    // PanelMode> rhs_pack;
    // return rhs_pack(blockB, rhs, depth, cols, stride, offset);
  }

  // Get vector pointer
  __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB);

  // Perform a step of the packing for 4 columns
  __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_4, R_AD_8, R_AD_12;
#define PACK_STEP                                            \
  R_AB_L = _mm256_unpacklo_epi64(R_A, R_B);                  \
  R_CD_L = _mm256_unpacklo_epi64(R_C, R_D);                  \
  R_AB_H = _mm256_unpackhi_epi64(R_A, R_B);                  \
  R_CD_H = _mm256_unpackhi_epi64(R_C, R_D);                  \
  R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20);  \
  R_AD_8 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31);  \
  R_AD_4 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20);  \
  R_AD_12 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \
  _mm256_store_si256(blockB_256, R_AD_0);                    \
  _mm256_store_si256(blockB_256 + 4, R_AD_4);                \
  _mm256_store_si256(blockB_256 + 8, R_AD_8);                \
  _mm256_store_si256(blockB_256 + 12, R_AD_12);              \
  blockB_256++;

  // Pack cols in sets of 16
  for (Index n = 0; n < cols; n += 16) {
    // Pack depth in sets of 16
    for (Index k = 0; k < depth; k += 16) {
      __m256i R_A = rhs.template loadPacket<Packet>(k, n);
      __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1);
      __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2);
      __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 4);
      R_B = rhs.template loadPacket<Packet>(k, n + 5);
      R_C = rhs.template loadPacket<Packet>(k, n + 6);
      R_D = rhs.template loadPacket<Packet>(k, n + 7);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 8);
      R_B = rhs.template loadPacket<Packet>(k, n + 9);
      R_C = rhs.template loadPacket<Packet>(k, n + 10);
      R_D = rhs.template loadPacket<Packet>(k, n + 11);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 12);
      R_B = rhs.template loadPacket<Packet>(k, n + 13);
      R_C = rhs.template loadPacket<Packet>(k, n + 14);
      R_D = rhs.template loadPacket<Packet>(k, n + 15);
      PACK_STEP;

      blockB_256 += 12;
    }
  }
#undef PACK_STEP
}

// Perform the actual multiplication on packed inputs
template <typename Index, typename DataMapper, int mr, int nr,
          bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel<QInt16, QInt16, Index, DataMapper, mr, nr, ConjugateLhs,
                   ConjugateRhs> {
  typedef typename DataMapper::LinearMapper LinearMapper;

  EIGEN_DONT_INLINE
  void operator()(const DataMapper& res, const QInt16* blockA,
                  const QInt16* blockB, Index rows, Index depth, Index cols,
                  QInt32 alpha, Index strideA = -1, Index strideB = -1,
                  Index offsetA = 0, Index offsetB = 0);
};

template <typename Index, typename DataMapper, int mr, int nr,
          bool ConjugateLhs, bool ConjugateRhs>
EIGEN_DONT_INLINE void
gebp_kernel<QInt16, QInt16, Index, DataMapper, mr, nr, ConjugateLhs,
            ConjugateRhs>::operator()(const DataMapper& res,
                                      const QInt16* blockA,
                                      const QInt16* blockB, Index rows,
                                      Index depth, Index cols, QInt32 alpha,
                                      Index strideA, Index strideB,
                                      Index offsetA, Index offsetB) {
  EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
  EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
  eigen_assert(alpha.value == 1);
  eigen_assert(strideA == -1);
  eigen_assert(strideB == -1);
  eigen_assert(offsetA == 0);
  eigen_assert(offsetB == 0);
  eigen_assert(rows > 0);
  eigen_assert(cols > 0);
  eigen_assert(depth > 0);
  eigen_assert(blockA);
  eigen_assert(blockB);

  // Use alternate function for weird sizes
  if (rows % 16 != 0 || cols % 16 != 0 || depth % 16 != 0) {
    eigen_assert(
        false &&
        "only depths, cols and rows that are a multiple of 16 are currently "
        "supported");
    // gebp_kernel_any<QInt16, QInt16, Index, DataMapper, mr, nr, ConjugateLhs,
    // ConjugateRhs> gebp;
    // return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA,
    // strideB, offsetA, offsetB);
  }

  // Create result block
  QInt32* blockO = aligned_new<QInt32>(16 * 16);
  memset(blockO, 0, 16 * 16 * sizeof(QInt32));

  // Get vectorized pointers
  __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO);
  const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA);
  const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB);

  // Loop over blocks of 16 columns
  for (Index n = 0; n < cols; n += 16) {
    // Reset index into blockA
    Index indexL = 0;
    // Loop over blocks of 16 rows
    for (Index m = 0; m < rows; m += 16) {
      // Reset index into blockB
      Index indexR = n / 16 * depth;
      // Loop over blocks of 4 on depth
      for (Index k = 0; k < depth; k += 4) {
        // Load inputs
        __m256i L_AD0 = blockA_256[indexL++];
        __m256i L_AD8 = blockA_256[indexL++];
        __m256i L_EH0 = blockA_256[indexL++];
        __m256i L_EH8 = blockA_256[indexL++];

        __m256i R_AH0 = blockB_256[indexR++];
        __m256i R_AH4 = blockB_256[indexR++];
        __m256i R_AH8 = blockB_256[indexR++];
        __m256i R_AH12 = blockB_256[indexR++];

        // Declare variables used in COMPUTE_STEP
        __m256i P_32_A, P_32_B, P_32;

#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET)                         \
  P_32_A = _mm256_madd_epi16(R_INPUT_A, L_AD0);                            \
  P_32_B = _mm256_madd_epi16(R_INPUT_B, L_AD8);                            \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                 \
  _mm256_store_si256(                                                      \
      blockO_256 + 2 * OFFSET,                                             \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 2 * OFFSET), P_32)); \
                                                                           \
  P_32_A = _mm256_madd_epi16(R_INPUT_A, L_EH0);                            \
  P_32_B = _mm256_madd_epi16(R_INPUT_B, L_EH8);                            \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                 \
  _mm256_store_si256(                                                      \
      blockO_256 + 2 * OFFSET + 1,                                         \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 2 * OFFSET + 1), P_32));

        // Permute and shuffle to copy a single value across the entire vector
        // Then compute the multiplication
        // Replicate lower 128-bits of R_AH0 across both lanes
        __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00);
        // Copy first two elements of R_AH0 across entire vector
        __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        // Copy second two elements of R_AH0 across entire vector
        __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);

        COMPUTE_STEP(R_AD0, R_EH0, 0);
        __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 1);

        // Replicate upper 128-bits of R_AH0 across both lanes
        R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11);
        __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 2);
        __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 3);

        R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 4);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 5);
        R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 6);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 7);

        R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 8);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 9);
        R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 10);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 11);

        R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 12);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 13);
        R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 14);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 15);

#undef COMPUTE_STEP
      }

      // Transfer the results to the result matrix
      Index i = 0;
      for (Index j = n; j < n + 16; j++) {
        LinearMapper r0 = res.getLinearMapper(m, j);
        LinearMapper r1 = res.getLinearMapper(m + 8, j);
        typedef typename packet_traits<QInt32>::type Packet;
        r0.template storePacket<Packet>(
            0, _mm256_add_epi32(blockO_256[i++],
                                r0.template loadPacket<Packet>(0)));
        r1.template storePacket<Packet>(
            0, _mm256_add_epi32(blockO_256[i++],
                                r1.template loadPacket<Packet>(0)));
      }

      // Zero the result block so it can be reused
      memset(blockO, 0, 16 * 16 * sizeof(QInt32));
    }
  }
  aligned_delete(blockO, 16 * 16);
}

#endif

// AVX2 optimized implementation of Mat-Mat product.
// LHS is encoded using signed 8-bit integers.
// RHS is encoded using unsigned 8-bit integers.
#ifdef EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT

// Define quantized traits
template <bool _ConjLhs, bool _ConjRhs>
class gebp_traits<QInt8, QUInt8, _ConjLhs, _ConjRhs> {
 public:
  typedef QInt8 LhsScalar;
  typedef QUInt8 RhsScalar;
  typedef QInt32 ResScalar;

  typedef typename packet_traits<LhsScalar>::type LhsPacket;
  typedef LhsPacket LhsPacket4Packing;

  enum {
    // Define register blocking scheme.
    nr = 32,
    mr = 32,
    kr = 8,
    // Ignore progress tracking per loop iteration.
    LhsProgress = -1,
    RhsProgress = -1
  };
};

// Specialized blocking for quantized implementations.
// Used by TensorContractionThreadPool, inputs must have dimensions that are
// multiples of 32.
template <typename ResScalar, typename Index, typename LeftTensor,
          typename left_nocontract_t, typename left_contract_t,
          bool left_inner_dim_contiguous, bool left_inner_dim_reordered,
          int LeftAlignment, typename RightTensor, typename right_nocontract_t,
          typename right_contract_t, bool right_inner_dim_contiguous,
          bool right_inner_dim_reordered, int RightAlignment, int ShardingType>
class TensorContractionBlocking<
    ResScalar,
    TensorContractionInputMapper<
        QInt8, Index, Lhs, LeftTensor, left_nocontract_t, left_contract_t, 32,
        left_inner_dim_contiguous, left_inner_dim_reordered, LeftAlignment>,
    TensorContractionInputMapper<QUInt8, Index, Rhs, RightTensor,
                                 right_nocontract_t, right_contract_t, 32,
                                 right_inner_dim_contiguous,
                                 right_inner_dim_reordered, RightAlignment>,
    Index, ShardingType> {
 public:
  typedef QInt8 LhsScalar;
  typedef QUInt8 RhsScalar;

  TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1)
      : kc_(k), mc_(m), nc_(n) {
    eigen_assert(m % 32 == 0);
    eigen_assert(k % 32 == 0);
    if (!k || !m || !n) {
      return;
    }

    if (ShardingType == ShardByCol) {
      eigen_assert(n % 32 == 0);
      nc_ = (((n / num_threads) + 31) / 32) * 32;
    } else {
      eigen_assert(n % 32 == 0 || n == 1);
      // Special case to avoid breaking the unimplemented matrix-vector case
      if (n == 1) {
        nc_ = 32;
      }
      mc_ = (((m / num_threads) + 31) / 32) * 32;
    }
  }

  EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
  EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
  EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }

 private:
  Index kc_;
  Index mc_;
  Index nc_;
};

// Specialized blocking for quantized implementations.
// Used by TensorContraction and GeneralMatrixMatrix, inputs are padded to
// multiples of 32.
template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
class gemm_blocking_space<ColMajor, QInt8, QInt8, MaxRows, MaxCols, MaxDepth,
                          KcFactor, false>
    : public level3_blocking<QInt8, QInt8> {
  DenseIndex m_sizeA;
  DenseIndex m_sizeB;

 public:
  gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth,
                      DenseIndex /*num_threads*/, bool /*l3_blocking*/) {
    this->m_mc = ((rows + 31) / 32) * 32;
    this->m_nc = ((cols + 31) / 32) * 32;
    this->m_kc = ((depth + 31) / 32) * 32;
    m_sizeA = this->m_mc * this->m_kc;
    m_sizeB = this->m_kc * this->m_nc;
  }
  void allocateA() {
    if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA);
  }
  void allocateB() {
    if (this->m_blockB == 0) this->m_blockB = aligned_new<QInt8>(m_sizeB);
  }
  void allocateAll() {
    allocateA();
    allocateB();
  }
  ~gemm_blocking_space() {
    aligned_delete(this->m_blockA, m_sizeA);
    aligned_delete(this->m_blockB, m_sizeB);
  }
};

template <int MaxRows, int MaxCols, int MaxDepth, int KcFactor>
class gemm_blocking_space<ColMajor, QInt8, QUInt8, MaxRows, MaxCols, MaxDepth,
                          KcFactor, false>
    : public level3_blocking<QInt8, QUInt8> {
  DenseIndex m_sizeA;
  DenseIndex m_sizeB;

 public:
  gemm_blocking_space(DenseIndex rows, DenseIndex cols, DenseIndex depth,
                      DenseIndex /*num_threads*/, bool /*l3_blocking*/) {
    this->m_mc = ((rows + 31) / 32) * 32;
    this->m_nc = ((cols + 31) / 32) * 32;
    this->m_kc = ((depth + 31) / 32) * 32;
    m_sizeA = this->m_mc * this->m_kc;
    m_sizeB = this->m_kc * this->m_nc;
  }
  void allocateA() {
    if (this->m_blockA == 0) this->m_blockA = aligned_new<QInt8>(m_sizeA);
  }
  void allocateB() {
    if (this->m_blockB == 0) this->m_blockB = aligned_new<QUInt8>(m_sizeB);
  }
  void allocateAll() {
    allocateA();
    allocateB();
  }
  ~gemm_blocking_space() {
    aligned_delete(this->m_blockA, m_sizeA);
    aligned_delete(this->m_blockB, m_sizeB);
  }
};

// Alternate templates for any input sizes
template <typename Scalar, typename Index, typename DataMapper, int Pack1,
          int Pack2, int StorageOrder, bool Conjugate = false,
          bool PanelMode = false>
struct gemm_pack_lhs_any;
template <typename Index, typename DataMapper, int Pack1, int Pack2,
          bool Conjugate, bool PanelMode>
struct gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor,
                         Conjugate, PanelMode> {
  EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs,
                                    Index depth, Index rows, Index stride = 0,
                                    Index offset = 0);
};

template <typename Scalar, typename Index, typename DataMapper, int nr,
          int StorageOrder, bool Conjugate = false, bool PanelMode = false>
struct gemm_pack_rhs_any;
template <typename Index, typename DataMapper, int nr, bool Conjugate,
          bool PanelMode>
struct gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
                         PanelMode> {
  EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs,
                                    Index depth, Index cols, Index stride = 0,
                                    Index offset = 0);
};

template <typename LhsScalar, typename RhsScalar, typename Index,
          typename DataMapper, int mr, int nr, bool ConjugateLhs = false,
          bool ConjugateRhs = false>
struct gebp_kernel_any;
template <typename Index, typename DataMapper, int mr, int nr,
          bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
                       ConjugateRhs> {
  typedef typename DataMapper::LinearMapper LinearMapper;

  EIGEN_DONT_INLINE
  void operator()(const DataMapper& res, const QInt8* blockA,
                  const QUInt8* blockB, Index rows, Index depth, Index cols,
                  QInt32 alpha, Index strideA = -1, Index strideB = -1,
                  Index offsetA = 0, Index offsetB = 0);
};

// Alternate implementations for any input sizes
template <typename Index, typename DataMapper, int Pack1, int Pack2,
          bool Conjugate, bool PanelMode>
EIGEN_DONT_INLINE void
gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor, Conjugate,
                  PanelMode>::operator()(QInt8* blockA, const DataMapper& lhs,
                                         Index depth, Index rows, Index stride,
                                         Index offset) {
  eigen_assert(stride == 0);
  eigen_assert(offset == 0);

  typedef typename packet_traits<QInt8>::type Packet;

  // Get vector pointer
  __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA);

  // Get even multiples of the dimensions
  Index rows_32 = (rows / 32) * 32;
  Index depth_8 = (depth / 8) * 8;

  // Get padding for when depth is not a multiple of 32
  int padding = 0;
  if (depth % 32 != 0) {
    int depth_32 = (depth / 32) * 32;
    int extra_depth = depth - depth_32;
    int extra_depth_8 = ((extra_depth + 7) / 8) * 8;
    padding = 32 - extra_depth_8;
  }

  // Pack rows in sets of 32
  for (Index m = 0; m < rows_32; m += 32) {
    // Pack depth in sets of 8
    for (Index k = 0; k < depth_8; k += 8) {
      // Load vectors
      __m256i L_A = lhs.template loadPacket<Packet>(m, k);
      __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1);

      // Interleave 8-bit elements
      __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
      __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);

      __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2);
      __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3);
      __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
      __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);

      // Interleave 16-bit elements
      __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
      __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);

      // Use permute before we store to cross 128-bit lanes
      __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
      _mm256_store_si256(blockA_256++, L_AD0);

      // Complete packing for 32 x 8 block
      __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
      __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
      _mm256_store_si256(blockA_256++, L_AD8);
      _mm256_store_si256(blockA_256++, L_AD16);
      __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
      _mm256_store_si256(blockA_256++, L_AD24);
      __m256i L_E = lhs.template loadPacket<Packet>(m, k + 4);
      __m256i L_F = lhs.template loadPacket<Packet>(m, k + 5);
      __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
      __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
      __m256i L_G = lhs.template loadPacket<Packet>(m, k + 6);
      __m256i L_H = lhs.template loadPacket<Packet>(m, k + 7);
      __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
      __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
      __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
      _mm256_store_si256(blockA_256++, L_EH0);
      __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
      __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
      _mm256_store_si256(blockA_256++, L_EH8);
      _mm256_store_si256(blockA_256++, L_EH16);
      __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
      _mm256_store_si256(blockA_256++, L_EH24);
    }

    // Finish the k dimension, padding with zeros
    if (depth_8 < depth) {
      __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H;
      switch (depth - depth_8) {
        case 1:
          L_A = lhs.template loadPacket<Packet>(m, depth_8);
          L_B = _mm256_setzero_si256();
          L_C = _mm256_setzero_si256();
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          break;
        case 2:
          L_A = lhs.template loadPacket<Packet>(m, depth_8);
          L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
          L_C = _mm256_setzero_si256();
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          break;
        case 3:
          L_A = lhs.template loadPacket<Packet>(m, depth_8);
          L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
          L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          break;
        case 4:
          L_A = lhs.template loadPacket<Packet>(m, depth_8);
          L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
          L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
          L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          break;
        case 5:
          L_A = lhs.template loadPacket<Packet>(m, depth_8);
          L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
          L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
          L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
          L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4);
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          break;
        case 6:
          L_A = lhs.template loadPacket<Packet>(m, depth_8);
          L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
          L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
          L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
          L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4);
          L_F = lhs.template loadPacket<Packet>(m, depth_8 + 5);
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          break;
        case 7:
          L_A = lhs.template loadPacket<Packet>(m, depth_8);
          L_B = lhs.template loadPacket<Packet>(m, depth_8 + 1);
          L_C = lhs.template loadPacket<Packet>(m, depth_8 + 2);
          L_D = lhs.template loadPacket<Packet>(m, depth_8 + 3);
          L_E = lhs.template loadPacket<Packet>(m, depth_8 + 4);
          L_F = lhs.template loadPacket<Packet>(m, depth_8 + 5);
          L_G = lhs.template loadPacket<Packet>(m, depth_8 + 6);
          L_H = _mm256_setzero_si256();
          break;
      }

      // Interleave 8-bit elements
      __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
      __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);

      __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
      __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);

      // Interleave 16-bit elements
      __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
      __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);

      // Use permute before we store to cross 128-bit lanes
      __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
      _mm256_store_si256(blockA_256++, L_AD0);

      // Complete packing
      __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
      __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
      _mm256_store_si256(blockA_256++, L_AD8);
      _mm256_store_si256(blockA_256++, L_AD16);
      __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
      _mm256_store_si256(blockA_256++, L_AD24);
      __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
      __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
      __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
      __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
      __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
      _mm256_store_si256(blockA_256++, L_EH0);
      __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
      __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
      _mm256_store_si256(blockA_256++, L_EH8);
      _mm256_store_si256(blockA_256++, L_EH16);
      __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
      _mm256_store_si256(blockA_256++, L_EH24);
    }
    blockA_256 += padding;
  }

  // Finish the m dimension, padding with zeros
  if (rows_32 < rows) {
    // Pack depth in sets of 8
    for (Index k = 0; k < depth_8; k += 8) {
      // Load vectors
      __m256i L_A = _mm256_setzero_si256();
      __m256i L_B = _mm256_setzero_si256();
      __m256i L_C = _mm256_setzero_si256();
      __m256i L_D = _mm256_setzero_si256();
      __m256i L_E = _mm256_setzero_si256();
      __m256i L_F = _mm256_setzero_si256();
      __m256i L_G = _mm256_setzero_si256();
      __m256i L_H = _mm256_setzero_si256();
      for (Index m = 0; m < rows - rows_32; m++) {
        QInt8* ptr = (QInt8*)&L_A;
        ptr[m] = lhs(rows_32 + m, k);
        ptr = (QInt8*)&L_B;
        ptr[m] = lhs(rows_32 + m, k + 1);
        ptr = (QInt8*)&L_C;
        ptr[m] = lhs(rows_32 + m, k + 2);
        ptr = (QInt8*)&L_D;
        ptr[m] = lhs(rows_32 + m, k + 3);
        ptr = (QInt8*)&L_E;
        ptr[m] = lhs(rows_32 + m, k + 4);
        ptr = (QInt8*)&L_F;
        ptr[m] = lhs(rows_32 + m, k + 5);
        ptr = (QInt8*)&L_G;
        ptr[m] = lhs(rows_32 + m, k + 6);
        ptr = (QInt8*)&L_H;
        ptr[m] = lhs(rows_32 + m, k + 7);
      }

      // Interleave 8-bit elements
      __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
      __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
      __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
      __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);

      // Interleave 16-bit elements
      __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
      __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);

      // Use permute before we store to cross 128-bit lanes
      __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
      _mm256_store_si256(blockA_256++, L_AD0);

      // Complete packing for 32 x 8 block
      __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
      __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
      _mm256_store_si256(blockA_256++, L_AD8);
      _mm256_store_si256(blockA_256++, L_AD16);
      __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
      _mm256_store_si256(blockA_256++, L_AD24);
      __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
      __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
      __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
      __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
      __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
      _mm256_store_si256(blockA_256++, L_EH0);
      __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
      __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
      _mm256_store_si256(blockA_256++, L_EH8);
      _mm256_store_si256(blockA_256++, L_EH16);
      __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
      _mm256_store_si256(blockA_256++, L_EH24);
    }

    // Finish the k dimension, padding with zeros
    if (depth_8 < depth) {
      __m256i L_A, L_B, L_C, L_D, L_E, L_F, L_G, L_H;
      QInt8* ptr;
      switch (depth - depth_8) {
        case 1:
          L_A = _mm256_setzero_si256();
          L_B = _mm256_setzero_si256();
          L_C = _mm256_setzero_si256();
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          for (Index m = 0; m < rows - rows_32; m++) {
            QInt8* ptr = (QInt8*)&L_A;
            ptr[m] = lhs(rows_32 + m, depth_8);
          }
          break;
        case 2:
          L_A = _mm256_setzero_si256();
          L_B = _mm256_setzero_si256();
          L_C = _mm256_setzero_si256();
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          for (Index m = 0; m < rows - rows_32; m++) {
            ptr = (QInt8*)&L_A;
            ptr[m] = lhs(rows_32 + m, depth_8);
            ptr = (QInt8*)&L_B;
            ptr[m] = lhs(rows_32 + m, depth_8 + 1);
          }
          break;
        case 3:
          L_A = _mm256_setzero_si256();
          L_B = _mm256_setzero_si256();
          L_C = _mm256_setzero_si256();
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          for (Index m = 0; m < rows - rows_32; m++) {
            ptr = (QInt8*)&L_A;
            ptr[m] = lhs(rows_32 + m, depth_8);
            ptr = (QInt8*)&L_B;
            ptr[m] = lhs(rows_32 + m, depth_8 + 1);
            ptr = (QInt8*)&L_C;
            ptr[m] = lhs(rows_32 + m, depth_8 + 2);
          }
          break;
        case 4:
          L_A = _mm256_setzero_si256();
          L_B = _mm256_setzero_si256();
          L_C = _mm256_setzero_si256();
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          for (Index m = 0; m < rows - rows_32; m++) {
            ptr = (QInt8*)&L_A;
            ptr[m] = lhs(rows_32 + m, depth_8);
            ptr = (QInt8*)&L_B;
            ptr[m] = lhs(rows_32 + m, depth_8 + 1);
            ptr = (QInt8*)&L_C;
            ptr[m] = lhs(rows_32 + m, depth_8 + 2);
            ptr = (QInt8*)&L_D;
            ptr[m] = lhs(rows_32 + m, depth_8 + 3);
          }
          break;
        case 5:
          L_A = _mm256_setzero_si256();
          L_B = _mm256_setzero_si256();
          L_C = _mm256_setzero_si256();
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          for (Index m = 0; m < rows - rows_32; m++) {
            ptr = (QInt8*)&L_A;
            ptr[m] = lhs(rows_32 + m, depth_8);
            ptr = (QInt8*)&L_B;
            ptr[m] = lhs(rows_32 + m, depth_8 + 1);
            ptr = (QInt8*)&L_C;
            ptr[m] = lhs(rows_32 + m, depth_8 + 2);
            ptr = (QInt8*)&L_D;
            ptr[m] = lhs(rows_32 + m, depth_8 + 3);
            ptr = (QInt8*)&L_E;
            ptr[m] = lhs(rows_32 + m, depth_8 + 4);
          }
          break;
        case 6:
          L_A = _mm256_setzero_si256();
          L_B = _mm256_setzero_si256();
          L_C = _mm256_setzero_si256();
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          for (Index m = 0; m < rows - rows_32; m++) {
            ptr = (QInt8*)&L_A;
            ptr[m] = lhs(rows_32 + m, depth_8);
            ptr = (QInt8*)&L_B;
            ptr[m] = lhs(rows_32 + m, depth_8 + 1);
            ptr = (QInt8*)&L_C;
            ptr[m] = lhs(rows_32 + m, depth_8 + 2);
            ptr = (QInt8*)&L_D;
            ptr[m] = lhs(rows_32 + m, depth_8 + 3);
            ptr = (QInt8*)&L_E;
            ptr[m] = lhs(rows_32 + m, depth_8 + 4);
            ptr = (QInt8*)&L_F;
            ptr[m] = lhs(rows_32 + m, depth_8 + 5);
          }
          break;
        case 7:
          L_A = _mm256_setzero_si256();
          L_B = _mm256_setzero_si256();
          L_C = _mm256_setzero_si256();
          L_D = _mm256_setzero_si256();
          L_E = _mm256_setzero_si256();
          L_F = _mm256_setzero_si256();
          L_G = _mm256_setzero_si256();
          L_H = _mm256_setzero_si256();
          for (Index m = 0; m < rows - rows_32; m++) {
            ptr = (QInt8*)&L_A;
            ptr[m] = lhs(rows_32 + m, depth_8);
            ptr = (QInt8*)&L_B;
            ptr[m] = lhs(rows_32 + m, depth_8 + 1);
            ptr = (QInt8*)&L_C;
            ptr[m] = lhs(rows_32 + m, depth_8 + 2);
            ptr = (QInt8*)&L_D;
            ptr[m] = lhs(rows_32 + m, depth_8 + 3);
            ptr = (QInt8*)&L_E;
            ptr[m] = lhs(rows_32 + m, depth_8 + 4);
            ptr = (QInt8*)&L_F;
            ptr[m] = lhs(rows_32 + m, depth_8 + 5);
            ptr = (QInt8*)&L_G;
            ptr[m] = lhs(rows_32 + m, depth_8 + 6);
          }
          break;
      }

      // Interleave 8-bit elements
      __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
      __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);
      __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
      __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);

      // Interleave 16-bit elements
      __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
      __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);

      // Use permute before we store to cross 128-bit lanes
      __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
      _mm256_store_si256(blockA_256++, L_AD0);

      // Complete packing
      __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
      __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
      _mm256_store_si256(blockA_256++, L_AD8);
      _mm256_store_si256(blockA_256++, L_AD16);
      __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
      _mm256_store_si256(blockA_256++, L_AD24);
      __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
      __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
      __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
      __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
      __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
      _mm256_store_si256(blockA_256++, L_EH0);
      __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
      __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
      _mm256_store_si256(blockA_256++, L_EH8);
      _mm256_store_si256(blockA_256++, L_EH16);
      __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
      _mm256_store_si256(blockA_256++, L_EH24);
    }
  }
}

template <typename Index, typename DataMapper, int nr, bool Conjugate,
          bool PanelMode>
EIGEN_DONT_INLINE void
gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
                  PanelMode>::operator()(QUInt8* blockB, const DataMapper& rhs,
                                         Index depth, Index cols, Index stride,
                                         Index offset) {
  eigen_assert(stride == 0);
  eigen_assert(offset == 0);

  typedef typename packet_traits<QUInt8>::type Packet;

  // Get vector pointer
  __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB);

  // Get even multiples of the dimensions
  Index cols_32 = (cols / 32) * 32;
  Index depth_32 = (depth / 32) * 32;

  // Perform a step of the packing for 4 columns
  __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24;
#define PACK_STEP                                            \
  R_AB_L = _mm256_unpacklo_epi64(R_A, R_B);                  \
  R_CD_L = _mm256_unpacklo_epi64(R_C, R_D);                  \
  R_AB_H = _mm256_unpackhi_epi64(R_A, R_B);                  \
  R_CD_H = _mm256_unpackhi_epi64(R_C, R_D);                  \
  R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20);  \
  R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \
  R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20);  \
  R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \
  _mm256_store_si256(blockB_256, R_AD_0);                    \
  _mm256_store_si256(blockB_256 + 8, R_AD_8);                \
  _mm256_store_si256(blockB_256 + 16, R_AD_16);              \
  _mm256_store_si256(blockB_256 + 24, R_AD_24);              \
  blockB_256++;

  // Pack cols in sets of 32
  for (Index n = 0; n < cols_32; n += 32) {
    // Pack depth in sets of 32
    for (Index k = 0; k < depth_32; k += 32) {
      __m256i R_A = rhs.template loadPacket<Packet>(k, n);
      __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1);
      __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2);
      __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 4);
      R_B = rhs.template loadPacket<Packet>(k, n + 5);
      R_C = rhs.template loadPacket<Packet>(k, n + 6);
      R_D = rhs.template loadPacket<Packet>(k, n + 7);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 8);
      R_B = rhs.template loadPacket<Packet>(k, n + 9);
      R_C = rhs.template loadPacket<Packet>(k, n + 10);
      R_D = rhs.template loadPacket<Packet>(k, n + 11);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 12);
      R_B = rhs.template loadPacket<Packet>(k, n + 13);
      R_C = rhs.template loadPacket<Packet>(k, n + 14);
      R_D = rhs.template loadPacket<Packet>(k, n + 15);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 16);
      R_B = rhs.template loadPacket<Packet>(k, n + 17);
      R_C = rhs.template loadPacket<Packet>(k, n + 18);
      R_D = rhs.template loadPacket<Packet>(k, n + 19);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 20);
      R_B = rhs.template loadPacket<Packet>(k, n + 21);
      R_C = rhs.template loadPacket<Packet>(k, n + 22);
      R_D = rhs.template loadPacket<Packet>(k, n + 23);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 24);
      R_B = rhs.template loadPacket<Packet>(k, n + 25);
      R_C = rhs.template loadPacket<Packet>(k, n + 26);
      R_D = rhs.template loadPacket<Packet>(k, n + 27);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 28);
      R_B = rhs.template loadPacket<Packet>(k, n + 29);
      R_C = rhs.template loadPacket<Packet>(k, n + 30);
      R_D = rhs.template loadPacket<Packet>(k, n + 31);
      PACK_STEP;

      blockB_256 += 24;
    }

    if (depth_32 < depth) {
      QUInt8* ptr;
      __m256i R_A = _mm256_setzero_si256();
      __m256i R_B = _mm256_setzero_si256();
      __m256i R_C = _mm256_setzero_si256();
      __m256i R_D = _mm256_setzero_si256();
      for (Index k = depth_32; k < depth; k++) {
        ptr = (QUInt8*)&R_A;
        ptr[k - depth_32] = rhs(k, n);
        ptr = (QUInt8*)&R_B;
        ptr[k - depth_32] = rhs(k, n + 1);
        ptr = (QUInt8*)&R_C;
        ptr[k - depth_32] = rhs(k, n + 2);
        ptr = (QUInt8*)&R_D;
        ptr[k - depth_32] = rhs(k, n + 3);
      }
      PACK_STEP;

      R_A = _mm256_setzero_si256();
      R_B = _mm256_setzero_si256();
      R_C = _mm256_setzero_si256();
      R_D = _mm256_setzero_si256();
      for (Index k = depth_32; k < depth; k++) {
        ptr = (QUInt8*)&R_A;
        ptr[k - depth_32] = rhs(k, n + 4);
        ptr = (QUInt8*)&R_B;
        ptr[k - depth_32] = rhs(k, n + 5);
        ptr = (QUInt8*)&R_C;
        ptr[k - depth_32] = rhs(k, n + 6);
        ptr = (QUInt8*)&R_D;
        ptr[k - depth_32] = rhs(k, n + 7);
      }
      PACK_STEP;

      R_A = _mm256_setzero_si256();
      R_B = _mm256_setzero_si256();
      R_C = _mm256_setzero_si256();
      R_D = _mm256_setzero_si256();
      for (Index k = depth_32; k < depth; k++) {
        ptr = (QUInt8*)&R_A;
        ptr[k - depth_32] = rhs(k, n + 8);
        ptr = (QUInt8*)&R_B;
        ptr[k - depth_32] = rhs(k, n + 9);
        ptr = (QUInt8*)&R_C;
        ptr[k - depth_32] = rhs(k, n + 10);
        ptr = (QUInt8*)&R_D;
        ptr[k - depth_32] = rhs(k, n + 11);
      }
      PACK_STEP;

      R_A = _mm256_setzero_si256();
      R_B = _mm256_setzero_si256();
      R_C = _mm256_setzero_si256();
      R_D = _mm256_setzero_si256();
      for (Index k = depth_32; k < depth; k++) {
        ptr = (QUInt8*)&R_A;
        ptr[k - depth_32] = rhs(k, n + 12);
        ptr = (QUInt8*)&R_B;
        ptr[k - depth_32] = rhs(k, n + 13);
        ptr = (QUInt8*)&R_C;
        ptr[k - depth_32] = rhs(k, n + 14);
        ptr = (QUInt8*)&R_D;
        ptr[k - depth_32] = rhs(k, n + 15);
      }
      PACK_STEP;

      R_A = _mm256_setzero_si256();
      R_B = _mm256_setzero_si256();
      R_C = _mm256_setzero_si256();
      R_D = _mm256_setzero_si256();
      for (Index k = depth_32; k < depth; k++) {
        ptr = (QUInt8*)&R_A;
        ptr[k - depth_32] = rhs(k, n + 16);
        ptr = (QUInt8*)&R_B;
        ptr[k - depth_32] = rhs(k, n + 17);
        ptr = (QUInt8*)&R_C;
        ptr[k - depth_32] = rhs(k, n + 18);
        ptr = (QUInt8*)&R_D;
        ptr[k - depth_32] = rhs(k, n + 19);
      }
      PACK_STEP;

      R_A = _mm256_setzero_si256();
      R_B = _mm256_setzero_si256();
      R_C = _mm256_setzero_si256();
      R_D = _mm256_setzero_si256();
      for (Index k = depth_32; k < depth; k++) {
        ptr = (QUInt8*)&R_A;
        ptr[k - depth_32] = rhs(k, n + 20);
        ptr = (QUInt8*)&R_B;
        ptr[k - depth_32] = rhs(k, n + 21);
        ptr = (QUInt8*)&R_C;
        ptr[k - depth_32] = rhs(k, n + 22);
        ptr = (QUInt8*)&R_D;
        ptr[k - depth_32] = rhs(k, n + 23);
      }
      PACK_STEP;

      R_A = _mm256_setzero_si256();
      R_B = _mm256_setzero_si256();
      R_C = _mm256_setzero_si256();
      R_D = _mm256_setzero_si256();
      for (Index k = depth_32; k < depth; k++) {
        ptr = (QUInt8*)&R_A;
        ptr[k - depth_32] = rhs(k, n + 24);
        ptr = (QUInt8*)&R_B;
        ptr[k - depth_32] = rhs(k, n + 25);
        ptr = (QUInt8*)&R_C;
        ptr[k - depth_32] = rhs(k, n + 26);
        ptr = (QUInt8*)&R_D;
        ptr[k - depth_32] = rhs(k, n + 27);
      }
      PACK_STEP;

      R_A = _mm256_setzero_si256();
      R_B = _mm256_setzero_si256();
      R_C = _mm256_setzero_si256();
      R_D = _mm256_setzero_si256();
      for (Index k = depth_32; k < depth; k++) {
        ptr = (QUInt8*)&R_A;
        ptr[k - depth_32] = rhs(k, n + 28);
        ptr = (QUInt8*)&R_B;
        ptr[k - depth_32] = rhs(k, n + 29);
        ptr = (QUInt8*)&R_C;
        ptr[k - depth_32] = rhs(k, n + 30);
        ptr = (QUInt8*)&R_D;
        ptr[k - depth_32] = rhs(k, n + 31);
      }
      PACK_STEP;
      blockB_256 += 24;
    }
  }

  // Finish packing cols
  if (cols_32 < cols) {
    // Pack depth in sets of 32
    for (Index k = 0; k < depth_32; k += 32) {
      __m256i R_A, R_B, R_C, R_D;
      Index n;
      for (n = cols_32; n < cols; n += 4) {
        switch (cols - n) {
          case 1:
            R_A = rhs.template loadPacket<Packet>(k, n);
            R_B = _mm256_setzero_si256();
            R_C = _mm256_setzero_si256();
            R_D = _mm256_setzero_si256();
            PACK_STEP;
            break;
          case 2:
            R_A = rhs.template loadPacket<Packet>(k, n);
            R_B = rhs.template loadPacket<Packet>(k, n + 1);
            R_C = _mm256_setzero_si256();
            R_D = _mm256_setzero_si256();
            PACK_STEP;
            break;
          case 3:
            R_A = rhs.template loadPacket<Packet>(k, n);
            R_B = rhs.template loadPacket<Packet>(k, n + 1);
            R_C = rhs.template loadPacket<Packet>(k, n + 2);
            R_D = _mm256_setzero_si256();
            PACK_STEP;
            break;
          default:
            R_A = rhs.template loadPacket<Packet>(k, n);
            R_B = rhs.template loadPacket<Packet>(k, n + 1);
            R_C = rhs.template loadPacket<Packet>(k, n + 2);
            R_D = rhs.template loadPacket<Packet>(k, n + 3);
            PACK_STEP;
            break;
        }
      }

      // Increment the block pointer.
      // We must pad if cols is not a multiple of 32.
      blockB_256 += 32 - (n - cols_32) / 4;
    }

    if (depth_32 < depth) {
      for (Index n = cols_32; n < cols; n += 4) {
        QUInt8* ptr;
        __m256i R_A = _mm256_setzero_si256();
        __m256i R_B = _mm256_setzero_si256();
        __m256i R_C = _mm256_setzero_si256();
        __m256i R_D = _mm256_setzero_si256();
        switch (cols - n) {
          case 1:
            for (Index k = depth_32; k < depth; k++) {
              ptr = (QUInt8*)&R_A;
              ptr[k - depth_32] = rhs(k, n);
            }
            PACK_STEP;
            break;
          case 2:
            for (Index k = depth_32; k < depth; k++) {
              ptr = (QUInt8*)&R_A;
              ptr[k - depth_32] = rhs(k, n);
              ptr = (QUInt8*)&R_B;
              ptr[k - depth_32] = rhs(k, n + 1);
            }
            PACK_STEP;
            break;
          case 3:
            for (Index k = depth_32; k < depth; k++) {
              ptr = (QUInt8*)&R_A;
              ptr[k - depth_32] = rhs(k, n);
              ptr = (QUInt8*)&R_B;
              ptr[k - depth_32] = rhs(k, n + 1);
              ptr = (QUInt8*)&R_C;
              ptr[k - depth_32] = rhs(k, n + 2);
            }
            PACK_STEP;
            break;
          default:
            for (Index k = depth_32; k < depth; k++) {
              ptr = (QUInt8*)&R_A;
              ptr[k - depth_32] = rhs(k, n);
              ptr = (QUInt8*)&R_B;
              ptr[k - depth_32] = rhs(k, n + 1);
              ptr = (QUInt8*)&R_C;
              ptr[k - depth_32] = rhs(k, n + 2);
              ptr = (QUInt8*)&R_D;
              ptr[k - depth_32] = rhs(k, n + 3);
            }
            PACK_STEP;
            break;
        }
      }
    }
  }
#undef PACK_STEP
}

template <typename Index, typename DataMapper, int mr, int nr,
          bool ConjugateLhs, bool ConjugateRhs>
EIGEN_DONT_INLINE void
gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
                ConjugateRhs>::operator()(const DataMapper& res,
                                          const QInt8* blockA,
                                          const QUInt8* blockB, Index rows,
                                          Index depth, Index cols, QInt32 alpha,
                                          Index strideA, Index strideB,
                                          Index offsetA, Index offsetB) {
  EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
  EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
  eigen_assert(alpha.value == 1);
  eigen_assert(strideA == -1);
  eigen_assert(strideB == -1);
  eigen_assert(offsetA == 0);
  eigen_assert(offsetB == 0);
  eigen_assert(rows > 0);
  eigen_assert(cols > 0);
  eigen_assert(depth > 0);
  eigen_assert(blockA);
  eigen_assert(blockB);

  Index rows_32 = ((rows + 31) / 32) * 32;
  Index cols_32 = ((cols + 31) / 32) * 32;
  Index depth_32 = ((depth + 31) / 32) * 32;

  // Create result block
  ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0);
  memset(blockO, 0, 32 * 32 * sizeof(QInt32));

  // Get vectorized pointers
  __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO);
  const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA);
  const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB);

  // Loop over blocks of 32 columns
  for (Index n = 0; n < cols_32; n += 32) {
    // Reset index into blockA
    Index indexL = 0;
    // Loop over blocks of 32 rows
    for (Index m = 0; m < rows_32; m += 32) {
      // Reset index into blockB
      Index indexR = n / 32 * depth_32;
      // Loop over blocks of 8 on depth
      for (Index k = 0; k < depth_32; k += 8) {
        // Load inputs
        __m256i L_AD0 = blockA_256[indexL++];
        __m256i L_AD8 = blockA_256[indexL++];
        __m256i L_AD16 = blockA_256[indexL++];
        __m256i L_AD24 = blockA_256[indexL++];
        __m256i L_EH0 = blockA_256[indexL++];
        __m256i L_EH8 = blockA_256[indexL++];
        __m256i L_EH16 = blockA_256[indexL++];
        __m256i L_EH24 = blockA_256[indexL++];
        __m256i R_AH0 = blockB_256[indexR++];
        __m256i R_AH4 = blockB_256[indexR++];
        __m256i R_AH8 = blockB_256[indexR++];
        __m256i R_AH12 = blockB_256[indexR++];
        __m256i R_AH16 = blockB_256[indexR++];
        __m256i R_AH20 = blockB_256[indexR++];
        __m256i R_AH24 = blockB_256[indexR++];
        __m256i R_AH28 = blockB_256[indexR++];

        // This constant is used with madd to convert 16 bit to 32 bit
        const __m256i ONE = _mm256_set1_epi32(0x00010001);

        // Declare variables used in COMPUTE_STEP
        __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32;

#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET)                             \
  P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0);                             \
  P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
  P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0);                             \
  P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
  _mm256_store_si256(                                                          \
      blockO_256 + 4 * OFFSET,                                                 \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32));     \
                                                                               \
  P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8);                             \
  P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
  P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8);                             \
  P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
  _mm256_store_si256(                                                          \
      blockO_256 + 4 * OFFSET + 1,                                             \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \
                                                                               \
  P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16);                            \
  P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
  P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16);                            \
  P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
  _mm256_store_si256(                                                          \
      blockO_256 + 4 * OFFSET + 2,                                             \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \
                                                                               \
  P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24);                            \
  P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
  P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24);                            \
  P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
  _mm256_store_si256(                                                          \
      blockO_256 + 4 * OFFSET + 3,                                             \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32));

        // Permute and shuffle to copy a single value across the entire vector
        // Then compute the multiplication
        __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00);
        __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 0);
        __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 1);
        R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11);
        __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 2);
        __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 3);

        R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 4);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 5);
        R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 6);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 7);

        R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 8);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 9);
        R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 10);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 11);

        R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 12);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 13);
        R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 14);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 15);

        R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 16);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 17);
        R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 18);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 19);

        R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 20);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 21);
        R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 22);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 23);

        R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 24);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 25);
        R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 26);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 27);

        R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 28);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 29);
        R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 30);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 31);

#undef COMPUTE_STEP
      }

      // Transfer the results to the result matrix.
      if (m + 32 <= rows && n + 32 <= cols) {
        Index i = 0;
        for (Index j = n; j < n + 32; j++) {
          LinearMapper r0 = res.getLinearMapper(m, j);
          LinearMapper r1 = res.getLinearMapper(m + 8, j);
          LinearMapper r2 = res.getLinearMapper(m + 16, j);
          LinearMapper r3 = res.getLinearMapper(m + 24, j);
          typedef typename packet_traits<QInt32>::type Packet;
          r0.template storePacket<Packet>(
              0, _mm256_add_epi32(blockO_256[i++],
                                  r0.template loadPacket<Packet>(0)));
          r1.template storePacket<Packet>(
              0, _mm256_add_epi32(blockO_256[i++],
                                  r1.template loadPacket<Packet>(0)));
          r2.template storePacket<Packet>(
              0, _mm256_add_epi32(blockO_256[i++],
                                  r2.template loadPacket<Packet>(0)));
          r3.template storePacket<Packet>(
              0, _mm256_add_epi32(blockO_256[i++],
                                  r3.template loadPacket<Packet>(0)));
        }
      } else {
        for (Index j = n; j < cols; j++) {
          for (Index i = m; i < rows; i++) {
            res(i, j) = blockO[(j - n) * 32 + (i - m)];
          }
        }
      }

      // Zero the result block so it can be reused
      memset(blockO, 0, 32 * 32 * sizeof(QInt32));
    }
  }
}

// Below are the fully optimized versions that are correct only for sizes that
// are multiple of 32.  It is about a 10% performance benefit to keep these
// implementations separate.

// Arrange a block of the left input matrix in contiguous memory.
//
// Given column major input (A0 beside A1 in memory):
// A0 B0 C0 D0 E0 F0 G0 H0 ...
// A1 B1 C1 D1 E1 F1 G1 H1 ...
// A2 B2 C2 D2 E2 F2 G2 H2 ...
// A3 B3 C3 D3 E3 F3 G3 H3 ...
// A4 B4 C4 D4 E4 F4 G4 H4 ...
// A5 B5 C5 D5 E5 F5 G5 H5 ...
// A6 B6 C6 D6 E6 F6 G6 H6 ...
// A7 B7 C7 D7 E7 F7 G7 H7 ...
// A8 ...
// ...
//
// Packing yields output (A0 beside B0 in memory):
// A0 B0 C0 D0
// A1 B1 C1 D1
// A2 B2 C2 D2
// A3 B3 C3 D3
// A4 B4 C4 D4
// A5 B5 C5 D5
// A6 B6 C6 D6
// A7 B7 C7 D7
// ...
// A31 B31 C31 D31
// E0 F0 G0 H0
// E1 F1 G1 H1
// E2 F2 G2 H2
// E3 F3 G3 H3
// E4 F4 G4 H4
// E5 F5 G5 H5
// E6 F6 G6 H6
// E7 F7 G7 H7
// ...
//
// Four elements of the same row are arranged contiguously because maddubs and
// madd both perform an adjacent addition in the kernel.
template <typename Index, typename DataMapper, int Pack1, int Pack2,
          bool Conjugate, bool PanelMode>
struct gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, QInt8, ColMajor,
                     Conjugate, PanelMode> {
  EIGEN_DONT_INLINE void operator()(QInt8* blockA, const DataMapper& lhs,
                                    Index depth, Index rows, Index stride = 0,
                                    Index offset = 0);
};

template <typename Index, typename DataMapper, int Pack1, int Pack2,
          bool Conjugate, bool PanelMode>
EIGEN_DONT_INLINE void
gemm_pack_lhs<QInt8, Index, DataMapper, Pack1, Pack2, QInt8, ColMajor,
              Conjugate, PanelMode>::operator()(QInt8* blockA,
                                                const DataMapper& lhs,
                                                Index depth, Index rows,
                                                Index stride, Index offset) {
  eigen_assert(stride == 0);
  eigen_assert(offset == 0);

  typedef typename packet_traits<QInt8>::type Packet;

  // Use alternate function for weird sizes
  if (rows % 32 != 0 || depth % 32 != 0) {
    gemm_pack_lhs_any<QInt8, Index, DataMapper, Pack1, Pack2, ColMajor,
                      Conjugate, PanelMode>
        lhs_pack;
    return lhs_pack(blockA, lhs, depth, rows, stride, offset);
  }

  // Get vector pointer
  __m256i* blockA_256 = reinterpret_cast<__m256i*>(blockA);

  // Pack rows in sets of 32
  for (Index m = 0; m < rows; m += 32) {
    // Pack depth in sets of 8
    for (Index k = 0; k < depth; k += 8) {
      // Load vectors
      __m256i L_A = lhs.template loadPacket<Packet>(m, k);
      __m256i L_B = lhs.template loadPacket<Packet>(m, k + 1);

      // Interleave 8-bit elements
      __m256i L_AB0_AB16 = _mm256_unpacklo_epi8(L_A, L_B);
      __m256i L_AB8_AB24 = _mm256_unpackhi_epi8(L_A, L_B);

      __m256i L_C = lhs.template loadPacket<Packet>(m, k + 2);
      __m256i L_D = lhs.template loadPacket<Packet>(m, k + 3);
      __m256i L_CD0_CD16 = _mm256_unpacklo_epi8(L_C, L_D);
      __m256i L_CD8_CD24 = _mm256_unpackhi_epi8(L_C, L_D);

      // Interleave 16-bit elements
      __m256i L_AD0_AD16 = _mm256_unpacklo_epi16(L_AB0_AB16, L_CD0_CD16);
      __m256i L_AD4_AD20 = _mm256_unpackhi_epi16(L_AB0_AB16, L_CD0_CD16);

      // Use permute before we store to cross 128-bit lanes
      __m256i L_AD0 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x20);
      _mm256_store_si256(blockA_256++, L_AD0);

      // Complete packing for 32 x 8 block
      __m256i L_AD16 = _mm256_permute2x128_si256(L_AD0_AD16, L_AD4_AD20, 0x31);
      __m256i L_AD8_AD24 = _mm256_unpacklo_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD12_AD28 = _mm256_unpackhi_epi16(L_AB8_AB24, L_CD8_CD24);
      __m256i L_AD8 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x20);
      _mm256_store_si256(blockA_256++, L_AD8);
      _mm256_store_si256(blockA_256++, L_AD16);
      __m256i L_AD24 = _mm256_permute2x128_si256(L_AD8_AD24, L_AD12_AD28, 0x31);
      _mm256_store_si256(blockA_256++, L_AD24);
      __m256i L_E = lhs.template loadPacket<Packet>(m, k + 4);
      __m256i L_F = lhs.template loadPacket<Packet>(m, k + 5);
      __m256i L_EF0_EF16 = _mm256_unpacklo_epi8(L_E, L_F);
      __m256i L_EF8_EF24 = _mm256_unpackhi_epi8(L_E, L_F);
      __m256i L_G = lhs.template loadPacket<Packet>(m, k + 6);
      __m256i L_H = lhs.template loadPacket<Packet>(m, k + 7);
      __m256i L_GH0_GH16 = _mm256_unpacklo_epi8(L_G, L_H);
      __m256i L_GH8_GH24 = _mm256_unpackhi_epi8(L_G, L_H);
      __m256i L_EH0_EH16 = _mm256_unpacklo_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH4_EH20 = _mm256_unpackhi_epi16(L_EF0_EF16, L_GH0_GH16);
      __m256i L_EH0 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x20);
      _mm256_store_si256(blockA_256++, L_EH0);
      __m256i L_EH16 = _mm256_permute2x128_si256(L_EH0_EH16, L_EH4_EH20, 0x31);
      __m256i L_EH8_EH24 = _mm256_unpacklo_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH12_EH28 = _mm256_unpackhi_epi16(L_EF8_EF24, L_GH8_GH24);
      __m256i L_EH8 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x20);
      _mm256_store_si256(blockA_256++, L_EH8);
      _mm256_store_si256(blockA_256++, L_EH16);
      __m256i L_EH24 = _mm256_permute2x128_si256(L_EH8_EH24, L_EH12_EH28, 0x31);
      _mm256_store_si256(blockA_256++, L_EH24);
    }
  }
}

// Arrange a block of the right input matrix in contiguous memory.
//
// Given column major input (A0 beside A1 in memory):
// A0 B0 C0 D0 E0 F0 G0 H0 ...
// A1 B1 C1 D1 E1 F1 G1 H1 ...
// A2 B2 C2 D2 E2 F2 G2 H2 ...
// A3 B3 C3 D3 E3 F3 G3 H3 ...
// A4 B4 C4 D4 E4 F4 G4 H4 ...
// A5 B5 C5 D5 E5 F5 G5 H5 ...
// A6 B6 C6 D6 E6 F6 G6 H6 ...
// A7 B7 C7 D7 E7 F7 G7 H7 ...
// A8 ...
// ...
//
// Packing yields row major output (A0 beside A1 in memory):
// A0 A1 A2 A3 A4 A5 A6 A7
// B0 B1 B2 B3 B4 B5 B6 B7
// ...
//
// At least four elements of the same col are arranged contiguously because
// maddubs and madd both perform an adjacent addition in the kernel.  We can
// save work by leaving 8 adjacent elements because kr = 8.
template <typename Index, typename DataMapper, int nr, bool Conjugate,
          bool PanelMode>
struct gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
                     PanelMode> {
  EIGEN_DONT_INLINE void operator()(QUInt8* blockB, const DataMapper& rhs,
                                    Index depth, Index cols, Index stride = 0,
                                    Index offset = 0);
};

template <typename Index, typename DataMapper, int nr, bool Conjugate,
          bool PanelMode>
EIGEN_DONT_INLINE void
gemm_pack_rhs<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
              PanelMode>::operator()(QUInt8* blockB, const DataMapper& rhs,
                                     Index depth, Index cols, Index stride,
                                     Index offset) {
  eigen_assert(stride == 0);
  eigen_assert(offset == 0);

  typedef typename packet_traits<QUInt8>::type Packet;

  // Use alternate function for weird sizes
  if (cols % 32 != 0 || depth % 32 != 0) {
    gemm_pack_rhs_any<QUInt8, Index, DataMapper, nr, ColMajor, Conjugate,
                      PanelMode>
        rhs_pack;
    return rhs_pack(blockB, rhs, depth, cols, stride, offset);
  }

  // Get vector pointer
  __m256i* blockB_256 = reinterpret_cast<__m256i*>(blockB);

  // Perform a step of the packing for 4 columns
  __m256i R_AB_L, R_AB_H, R_CD_L, R_CD_H, R_AD_0, R_AD_8, R_AD_16, R_AD_24;
#define PACK_STEP                                            \
  R_AB_L = _mm256_unpacklo_epi64(R_A, R_B);                  \
  R_CD_L = _mm256_unpacklo_epi64(R_C, R_D);                  \
  R_AB_H = _mm256_unpackhi_epi64(R_A, R_B);                  \
  R_CD_H = _mm256_unpackhi_epi64(R_C, R_D);                  \
  R_AD_0 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x20);  \
  R_AD_16 = _mm256_permute2x128_si256(R_AB_L, R_CD_L, 0x31); \
  R_AD_8 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x20);  \
  R_AD_24 = _mm256_permute2x128_si256(R_AB_H, R_CD_H, 0x31); \
  _mm256_store_si256(blockB_256, R_AD_0);                    \
  _mm256_store_si256(blockB_256 + 8, R_AD_8);                \
  _mm256_store_si256(blockB_256 + 16, R_AD_16);              \
  _mm256_store_si256(blockB_256 + 24, R_AD_24);              \
  blockB_256++;

  // Pack cols in sets of 32
  for (Index n = 0; n < cols; n += 32) {
    // Pack depth in sets of 32
    for (Index k = 0; k < depth; k += 32) {
      __m256i R_A = rhs.template loadPacket<Packet>(k, n);
      __m256i R_B = rhs.template loadPacket<Packet>(k, n + 1);
      __m256i R_C = rhs.template loadPacket<Packet>(k, n + 2);
      __m256i R_D = rhs.template loadPacket<Packet>(k, n + 3);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 4);
      R_B = rhs.template loadPacket<Packet>(k, n + 5);
      R_C = rhs.template loadPacket<Packet>(k, n + 6);
      R_D = rhs.template loadPacket<Packet>(k, n + 7);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 8);
      R_B = rhs.template loadPacket<Packet>(k, n + 9);
      R_C = rhs.template loadPacket<Packet>(k, n + 10);
      R_D = rhs.template loadPacket<Packet>(k, n + 11);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 12);
      R_B = rhs.template loadPacket<Packet>(k, n + 13);
      R_C = rhs.template loadPacket<Packet>(k, n + 14);
      R_D = rhs.template loadPacket<Packet>(k, n + 15);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 16);
      R_B = rhs.template loadPacket<Packet>(k, n + 17);
      R_C = rhs.template loadPacket<Packet>(k, n + 18);
      R_D = rhs.template loadPacket<Packet>(k, n + 19);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 20);
      R_B = rhs.template loadPacket<Packet>(k, n + 21);
      R_C = rhs.template loadPacket<Packet>(k, n + 22);
      R_D = rhs.template loadPacket<Packet>(k, n + 23);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 24);
      R_B = rhs.template loadPacket<Packet>(k, n + 25);
      R_C = rhs.template loadPacket<Packet>(k, n + 26);
      R_D = rhs.template loadPacket<Packet>(k, n + 27);
      PACK_STEP;

      R_A = rhs.template loadPacket<Packet>(k, n + 28);
      R_B = rhs.template loadPacket<Packet>(k, n + 29);
      R_C = rhs.template loadPacket<Packet>(k, n + 30);
      R_D = rhs.template loadPacket<Packet>(k, n + 31);
      PACK_STEP;

      blockB_256 += 24;
    }
  }
#undef PACK_STEP
}

// Perform the actual multiplication on packed inputs
template <typename Index, typename DataMapper, int mr, int nr,
          bool ConjugateLhs, bool ConjugateRhs>
struct gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
                   ConjugateRhs> {
  typedef typename DataMapper::LinearMapper LinearMapper;

  EIGEN_DONT_INLINE
  void operator()(const DataMapper& res, const QInt8* blockA,
                  const QUInt8* blockB, Index rows, Index depth, Index cols,
                  QInt32 alpha, Index strideA = -1, Index strideB = -1,
                  Index offsetA = 0, Index offsetB = 0);
};

template <typename Index, typename DataMapper, int mr, int nr,
          bool ConjugateLhs, bool ConjugateRhs>
EIGEN_DONT_INLINE void
gebp_kernel<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
            ConjugateRhs>::operator()(const DataMapper& res,
                                      const QInt8* blockA, const QUInt8* blockB,
                                      Index rows, Index depth, Index cols,
                                      QInt32 alpha, Index strideA,
                                      Index strideB, Index offsetA,
                                      Index offsetB) {
  EIGEN_STATIC_ASSERT(!ConjugateLhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
  EIGEN_STATIC_ASSERT(!ConjugateRhs, YOU_MADE_A_PROGRAMMING_MISTAKE);
  eigen_assert(alpha.value == 1);
  eigen_assert(strideA == -1);
  eigen_assert(strideB == -1);
  eigen_assert(offsetA == 0);
  eigen_assert(offsetB == 0);
  eigen_assert(rows > 0);
  eigen_assert(cols > 0);
  eigen_assert(depth > 0);
  eigen_assert(blockA);
  eigen_assert(blockB);

  // Use alternate function for weird sizes
  if (rows % 32 != 0 || cols % 32 != 0 || depth % 32 != 0) {
    gebp_kernel_any<QInt8, QUInt8, Index, DataMapper, mr, nr, ConjugateLhs,
                    ConjugateRhs>
        gebp;
    return gebp(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB,
                offsetA, offsetB);
  }

  // Create result block
  QInt32* blockO = aligned_new<QInt32>(32 * 32);
  // Allocating the result block is about 5-10% faster than declaring stack
  // space.  It is unclear why this is the case.
  // ei_declare_aligned_stack_constructed_variable(QInt32, blockO, 32 * 32, 0);
  memset(blockO, 0, 32 * 32 * sizeof(QInt32));

  // Get vectorized pointers
  __m256i* blockO_256 = reinterpret_cast<__m256i*>(blockO);
  const __m256i* blockA_256 = reinterpret_cast<const __m256i*>(blockA);
  const __m256i* blockB_256 = reinterpret_cast<const __m256i*>(blockB);

  // Loop over blocks of 32 columns
  for (Index n = 0; n < cols; n += 32) {
    // Reset index into blockA
    Index indexL = 0;
    // Loop over blocks of 32 rows
    for (Index m = 0; m < rows; m += 32) {
      // Reset index into blockB
      Index indexR = n / 32 * depth;
      // Loop over blocks of 8 on depth
      for (Index k = 0; k < depth; k += 8) {
        // Load inputs
        __m256i L_AD0 = blockA_256[indexL++];
        __m256i L_AD8 = blockA_256[indexL++];
        __m256i L_AD16 = blockA_256[indexL++];
        __m256i L_AD24 = blockA_256[indexL++];
        __m256i L_EH0 = blockA_256[indexL++];
        __m256i L_EH8 = blockA_256[indexL++];
        __m256i L_EH16 = blockA_256[indexL++];
        __m256i L_EH24 = blockA_256[indexL++];
        __m256i R_AH0 = blockB_256[indexR++];
        __m256i R_AH4 = blockB_256[indexR++];
        __m256i R_AH8 = blockB_256[indexR++];
        __m256i R_AH12 = blockB_256[indexR++];
        __m256i R_AH16 = blockB_256[indexR++];
        __m256i R_AH20 = blockB_256[indexR++];
        __m256i R_AH24 = blockB_256[indexR++];
        __m256i R_AH28 = blockB_256[indexR++];

        // This constant is used with madd to convert 16 bit to 32 bit
        const __m256i ONE = _mm256_set1_epi32(0x00010001);

        // Declare variables used in COMPUTE_STEP
        __m256i P_16_A, P_16_B, P_32_A, P_32_B, P_32;

#define COMPUTE_STEP(R_INPUT_A, R_INPUT_B, OFFSET)                             \
  P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD0);                             \
  P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
  P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH0);                             \
  P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
  _mm256_store_si256(                                                          \
      blockO_256 + 4 * OFFSET,                                                 \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET), P_32));     \
                                                                               \
  P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD8);                             \
  P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
  P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH8);                             \
  P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
  _mm256_store_si256(                                                          \
      blockO_256 + 4 * OFFSET + 1,                                             \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 1), P_32)); \
                                                                               \
  P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD16);                            \
  P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
  P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH16);                            \
  P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
  _mm256_store_si256(                                                          \
      blockO_256 + 4 * OFFSET + 2,                                             \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 2), P_32)); \
                                                                               \
  P_16_A = _mm256_maddubs_epi16(R_INPUT_A, L_AD24);                            \
  P_32_A = _mm256_madd_epi16(P_16_A, ONE);                                     \
  P_16_B = _mm256_maddubs_epi16(R_INPUT_B, L_EH24);                            \
  P_32_B = _mm256_madd_epi16(P_16_B, ONE);                                     \
  P_32 = _mm256_add_epi32(P_32_A, P_32_B);                                     \
  _mm256_store_si256(                                                          \
      blockO_256 + 4 * OFFSET + 3,                                             \
      _mm256_add_epi32(_mm256_load_si256(blockO_256 + 4 * OFFSET + 3), P_32));

        // Permute and shuffle to copy a single value across the entire vector
        // Then compute the multiplication
        __m256i R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x00);
        __m256i R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        __m256i R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 0);
        __m256i R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        __m256i R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 1);
        R_AH0_ = _mm256_permute2x128_si256(R_AH0, R_AH0, 0x11);
        __m256i R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        __m256i R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 2);
        __m256i R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        __m256i R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 3);

        R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 4);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 5);
        R_AH0_ = _mm256_permute2x128_si256(R_AH4, R_AH4, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 6);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 7);

        R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 8);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 9);
        R_AH0_ = _mm256_permute2x128_si256(R_AH8, R_AH8, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 10);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 11);

        R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 12);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 13);
        R_AH0_ = _mm256_permute2x128_si256(R_AH12, R_AH12, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 14);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 15);

        R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 16);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 17);
        R_AH0_ = _mm256_permute2x128_si256(R_AH16, R_AH16, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 18);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 19);

        R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 20);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 21);
        R_AH0_ = _mm256_permute2x128_si256(R_AH20, R_AH20, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 22);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 23);

        R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 24);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 25);
        R_AH0_ = _mm256_permute2x128_si256(R_AH24, R_AH24, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 26);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 27);

        R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x00);
        R_AD0 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH0 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD0, R_EH0, 28);
        R_AD1 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH1 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD1, R_EH1, 29);
        R_AH0_ = _mm256_permute2x128_si256(R_AH28, R_AH28, 0x11);
        R_AD2 = _mm256_shuffle_epi32(R_AH0_, 0x00);
        R_EH2 = _mm256_shuffle_epi32(R_AH0_, 0x55);
        COMPUTE_STEP(R_AD2, R_EH2, 30);
        R_AD3 = _mm256_shuffle_epi32(R_AH0_, 0xAA);
        R_EH3 = _mm256_shuffle_epi32(R_AH0_, 0xFF);
        COMPUTE_STEP(R_AD3, R_EH3, 31);

#undef COMPUTE_STEP
      }

      // Transfer the results to the result matrix
      Index i = 0;
      for (Index j = n; j < n + 32; j++) {
        LinearMapper r0 = res.getLinearMapper(m, j);
        LinearMapper r1 = res.getLinearMapper(m + 8, j);
        LinearMapper r2 = res.getLinearMapper(m + 16, j);
        LinearMapper r3 = res.getLinearMapper(m + 24, j);
        typedef typename packet_traits<QInt32>::type Packet;
        r0.template storePacket<Packet>(
            0, _mm256_add_epi32(blockO_256[i++],
                                r0.template loadPacket<Packet>(0)));
        r1.template storePacket<Packet>(
            0, _mm256_add_epi32(blockO_256[i++],
                                r1.template loadPacket<Packet>(0)));
        r2.template storePacket<Packet>(
            0, _mm256_add_epi32(blockO_256[i++],
                                r2.template loadPacket<Packet>(0)));
        r3.template storePacket<Packet>(
            0, _mm256_add_epi32(blockO_256[i++],
                                r3.template loadPacket<Packet>(0)));
      }

      // Zero the result block so it can be reused
      memset(blockO, 0, 32 * 32 * sizeof(QInt32));
    }
  }
  aligned_delete(blockO, 32 * 32);
}

#endif  // EIGEN_USE_OPTIMIZED_INT8_UINT8_MAT_MAT_PRODUCT

}  // namespace internal
}  // namespace Eigen

#endif  // XLA_TSL_FRAMEWORK_FIXEDPOINT_MATMATPRODUCTAVX2_H_
