章
目
录
SparkSQL已经成为离线计算的重要工具,其中,Join操作作为数据处理中常用的功能,今天,咱们就来详细分析一下SparkSQL Join的源码,帮助大家更好地掌握这一技术。
一、背景介绍
在SparkSQL的应用过程中,我们常常会用到Join操作来关联不同的数据表。但为了更透彻地理解它的底层运作机制,对其源码进行分析是很有必要的。接下来,就让我们一步步揭开SparkSQL Join源码的神秘面纱。
二、Join策略选择
SparkSQL在org.apache.spark.sql.execution.joins包中实现了多种Join策略。在Join类的doExecute方法里,会依据统计信息和配置情况,选择最合适的Join策略,具体代码如下:
def doExecute(): RDD[InternalRow] = {
  val leftKeys = leftKeysArray
  val rightKeys = rightKeysArray
  if (joinType == JoinType.CROSS) {
    CrossHashJoin.doJoin(left, right, leftKeys, rightKeys, joinType, condition, leftFilters, rightFilters)
  } else {
    if (left.output.size > 0 && right.output.size > 0) {
      leftKeys.length match {
        case 0 =>
          // Cartesian product
          CartesianProduct.doJoin(left, right, joinType, condition, leftFilters, rightFilters)
        case 1 =>
          // Single key, use hash join
          if (joinType == JoinType.INNER || joinType == JoinType.CROSS) {
            HashJoin.doJoin(left, right, leftKeys, rightKeys, joinType, condition, leftFilters, rightFilters)
          } else {
            // For outer joins, use sort merge join to preserve the order
            SortMergeJoin.doJoin(left, right, leftKeys, rightKeys, joinType, condition, leftFilters, rightFilters)
          }
        case _ =>
          // Multiple keys, use sort merge join
          SortMergeJoin.doJoin(left, right, leftKeys, rightKeys, joinType, condition, leftFilters, rightFilters)
      }
    } else {
      // One of the children has no output, return empty
      RDD.empty[InternalRow](sparkContext)
    }
  }
}
这段代码首先会判断joinType是否为CROSS。如果是,就直接使用CrossHashJoin.doJoin方法进行处理。如果不是,接着会检查左右表的输出是否都不为空。然后根据leftKeys的长度来决定使用哪种Join策略:当leftKeys长度为0时,执行笛卡尔积操作;当leftKeys长度为1时,如果是内连接或交叉连接,就使用Hash Join,否则使用Sort Merge Join;当leftKeys长度大于1时,同样使用Sort Merge Join。要是左右表中有一个没有输出,就直接返回空的RDD。
三、Hash Join实现
Hash Join的具体实现主要集中在HashJoin类中,主要分为以下几个步骤:
- 选择构建侧和Probe侧:根据统计信息,挑选较小的表作为构建侧,这样可以减少内存的占用,提高性能。
- 构建Hash表:把构建侧的数据依据Join键构建成Hash表,方便后续查找。
- Probe阶段:在Probe侧的数据中,按照Join键去构建好的Hash表中查找匹配的数据。
- 连接操作:根据不同的Join类型(内连接、外连接等),执行相应的连接操作。
下面是HashJoin类的具体代码:
object HashJoin {
  def doJoin(
      left: RDD[InternalRow],
      right: RDD[InternalRow],
      leftKeys: Array[Expression],
      rightKeys: Array[Expression],
      joinType: JoinType,
      condition: Option[Expression],
      leftFilters: Option[Expression],
      rightFilters: Option[Expression]): RDD[InternalRow] = {
    // 选择构建侧和Probe侧
    val (buildSide, probeSide) = chooseSides(left, right)
    val (buildKeys, probeKeys) = if (buildSide == BuildSide.LEFT) {
      (leftKeys, rightKeys)
    } else {
      (rightKeys, leftKeys)
    }
    // 构建Hash表
    val buildRDD = buildSide match {
      case BuildSide.LEFT =>
        left.mapPartitions(iter => {
          val keyToRows = new mutable.HashMap[Any, mutable.Buffer[InternalRow]]()
          iter.foreach(row => {
            val key = leftKeys.map(_.eval(row)).toArray
            keyToRows.getOrElseUpdate(key, new mutable.ArrayBuffer[InternalRow]()) += row
          })
          iter ++ keyToRows.values.flatten
        })
      case BuildSide.RIGHT =>
        right.mapPartitions(iter => {
          val keyToRows = new mutable.HashMap[Any, mutable.Buffer[InternalRow]]()
          iter.foreach(row => {
            val key = rightKeys.map(_.eval(row)).toArray
            keyToRows.getOrElseUpdate(key, new mutable.ArrayBuffer[InternalRow]()) += row
          })
          iter ++ keyToRows.values.flatten
        })
    }
    // Probe阶段
    val probeRDD = probeSide match {
      case BuildSide.LEFT =>
        right.mapPartitions(iter => {
          val keyToRows = new mutable.HashMap[Any, mutable.Buffer[InternalRow]]()
          iter.foreach(row => {
            val key = rightKeys.map(_.eval(row)).toArray
            keyToRows.getOrElseUpdate(key, new mutable.ArrayBuffer[InternalRow]()) += row
          })
          iter ++ keyToRows.values.flatten
        })
      case BuildSide.RIGHT =>
        left.mapPartitions(iter => {
          val keyToRows = new mutable.HashMap[Any, mutable.Buffer[InternalRow]]()
          iter.foreach(row => {
            val key = leftKeys.map(_.eval(row)).toArray
            keyToRows.getOrElseUpdate(key, new mutable.ArrayBuffer[InternalRow]()) += row
          })
          iter ++ keyToRows.values.flatten
        })
    }
    // 连接操作
    probeRDD.join(buildRDD).mapPartitions(iter => {
      iter.flatMap { case (key, (probeRow, buildRow)) =>
        // 根据Join类型进行连接操作
        joinType match {
          case JoinType.INNER =>
            if (condition.map(_.eval(probeRow, buildRow)).getOrElse(true)) {
              Some(InternalRow.fromSeq(probeRow ++ buildRow))
            } else {
              None
            }
          case JoinType.LEFT =>
            if (condition.map(_.eval(probeRow, buildRow)).getOrElse(true)) {
              Some(InternalRow.fromSeq(probeRow ++ buildRow))
            } else {
              Some(InternalRow.fromSeq(probeRow ++ Seq.fill(buildRow.length)(null)))
            }
          case JoinType.RIGHT =>
            if (condition.map(_.eval(probeRow, buildRow)).getOrElse(true)) {
              Some(InternalRow.fromSeq(Seq.fill(probeRow.length)(null) ++ buildRow))
            } else {
              Some(InternalRow.fromSeq(Seq.fill(probeRow.length)(null) ++ buildRow))
            }
          case JoinType.FULL =>
            if (condition.map(_.eval(probeRow, buildRow)).getOrElse(true)) {
              Some(InternalRow.fromSeq(probeRow ++ buildRow))
            } else {
              Some(InternalRow.fromSeq(probeRow ++ Seq.fill(buildRow.length)(null)))
              Some(InternalRow.fromSeq(Seq.fill(probeRow.length)(null) ++ buildRow))
            }
        }
      }
    })
  }
}
在这段代码里,首先通过chooseSides方法确定构建侧和Probe侧,并相应地确定构建键和探测键。然后分别对构建侧和Probe侧的数据进行处理,构建Hash表和进行探测操作。最后,根据不同的Join类型,对匹配到的数据进行连接操作,并返回连接后的结果。
四、Sort Merge Join实现
Sort Merge Join的实现主要在SortMergeJoin类中,其实现步骤如下:
- 排序:对参与Join操作的两个表,按照Join键进行排序。
- 合并:利用双指针技术,将两个排序后的数据集进行合并。
- 连接操作:依据Join类型,执行相应的连接操作。
下面是SortMergeJoin类的代码:
object SortMergeJoin {
  def doJoin(
      left: RDD[InternalRow],
      right: RDD[InternalRow],
      leftKeys: Array[Expression],
      rightKeys: Array[Expression],
      joinType: JoinType,
      condition: Option[Expression],
      leftFilters: Option[Expression],
      rightFilters: Option[Expression]): RDD[InternalRow] = {
    // 排序
    val sortedLeft = left.sortBy(row => leftKeys.map(_.eval(row)).toArray)
    val sortedRight = right.sortBy(row => rightKeys.map(_.eval(row)).toArray)
    // 合并
    sortedLeft.zip(sortedRight).mapPartitions(iter => {
      val leftIter = iter.map(_._1).iterator
      val rightIter = iter.map(_._2).iterator
      val leftRow = new mutable.ArrayBuffer[InternalRow]()
      val rightRow = new mutable.ArrayBuffer[InternalRow]()
      while (leftIter.hasNext && rightIter.hasNext) {
        val l = leftIter.next()
        val r = rightIter.next()
        val lKey = leftKeys.map(_.eval(l)).toArray
        val rKey = rightKeys.map(_.eval(r)).toArray
        if (lKey < rKey) {
          leftRow += l
        } else if (lKey > rKey) {
          rightRow += r
        } else {
          // Join键相等,进行连接操作
          if (condition.map(_.eval(l, r)).getOrElse(true)) {
            yield JoinedRow(l, r)
          }
          // 处理重复键
          while (leftIter.hasNext && leftKeys.map(_.eval(leftIter.head)).toArray == lKey) {
            leftRow += leftIter.next()
          }
          while (rightIter.hasNext && rightKeys.map(_.eval(rightIter.head)).toArray == rKey) {
            rightRow += rightIter.next()
          }
          // 生成所有可能的组合
          for (l <- leftRow; r <- rightRow) {
            if (condition.map(_.eval(l, r)).getOrElse(true)) {
              yield JoinedRow(l, r)
            }
          }
          leftRow.clear()
          rightRow.clear()
        }
      }
      // 处理剩余的行
      while (leftIter.hasNext) {
        leftRow += leftIter.next()
      }
      while (rightIter.hasNext) {
        rightRow += rightIter.next()
      }
      // 根据Join类型处理剩余的行
      joinType match {
        case JoinType.INNER =>
          // 不需要处理剩余的行
        case JoinType.LEFT =>
          for (l <- leftRow) {
            if (leftFilters.map(_.eval(l)).getOrElse(true)) {
              yield JoinedRow(l, null)
            }
          }
        case JoinType.RIGHT =>
          for (r <- rightRow) {
            if (rightFilters.map(_.eval(r)).getOrElse(true)) {
              yield JoinedRow(null, r)
            }
          }
        case JoinType.FULL =>
          for (l <- leftRow) {
            if (leftFilters.map(_.eval(l)).getOrElse(true)) {
              yield JoinedRow(l, null)
            }
          }
          for (r <- rightRow) {
            if (rightFilters.map(_.eval(r)).getOrElse(true)) {
              yield JoinedRow(null, r)
            }
          }
      }
    })
  }
}
在这个代码中,先对左右两个表进行排序,得到sortedLeft和sortedRight。接着,通过zip操作将两个排序后的数据集合并,并使用双指针技术遍历。当遇到Join键相等的情况时,进行连接操作,并处理可能存在的重复键。遍历结束后,还会根据不同的Join类型,对剩余的行进行相应的处理,最终返回连接结果。
五、总结
通过对SparkSQL中Join的实现方式,包括Broadcast Join、Hash Join(含Shuffle Hash Join)和Sort Merge Join的源码分析,我们详细了解了它们的实现原理、工作流程以及适用场景。这有助于我们更深入地理解SparkSQL中Join操作的内部机制。在实际应用中,根据表的大小、数据分布和内存资源等因素,选择合适的Join策略,能够显著提升SparkSQL查询的性能。希望大家通过这篇文章,对SparkSQL Join有更清晰的认识。





