Spark Shuffle模块——Suffle Read过程分析

    xiaoxiao2026-04-08  6

    在阅读本文之前,请先阅读Spark Sort Based Shuffle内存分析

    Spark Shuffle Read调用栈如下: 1. org.apache.spark.rdd.ShuffledRDD#compute() 2. org.apache.spark.shuffle.ShuffleManager#getReader() 3. org.apache.spark.shuffle.hash.HashShuffleReader#read() 4. org.apache.spark.storage.ShuffleBlockFetcherIterator#initialize() 5. org.apache.spark.storage.ShuffleBlockFetcherIterator#splitLocalRemoteBlocks() org.apache.spark.storage.ShuffleBlockFetcherIterator#sendRequest() org.apache.spark.storage.ShuffleBlockFetcherIterator#fetchLocalBlocks()

    下面是fetchLocalBlocks()方法执行时涉及到的类和对应方法: 6. org.apache.spark.storage.BlockManager#getBlockData() org.apache.spark.shuffle.hash.ShuffleManager#shuffleBlockResolver() ShuffleManager有两个子类,如果是HashShuffle 则对应的是org.apache.spark.shuffle.hash.HashShuffleManager#shuffleBlockResolver()方法,该方法返回的是org.apache.spark.shuffle.FileShuffleBlockResolver,再调用FileShuffleBlockResolver#getBlockData()方法返回Block数据 ;如果是Sort Shuffle,则对应的是 org.apache.spark.shuffle.hash.SortShuffleManager#shuffleBlockResolver(),该方法返回的是org.apache.spark.shuffle.IndexShuffleBlockResolver,然后再调用IndexShuffleBlockResolver#getBlockData()返回Block数据。

    下面是org.apache.spark.storage.ShuffleBlockFetcherIterator#sendRequest()方法执行时涉及到的类和对应方法 7.

    org.apache.spark.network.shuffle.ShuffleClient#fetchBlocks org.apache.spark.network.shuffle.ShuffleClient有两个子类,分别是ExternalShuffleClient及BlockTransferService ,其中org.apache.spark.network.shuffle.BlockTransferService又有两个子类,分别是NettyBlockTransferService和NioBlockTransferService,对应两种不同远程获取Block数据方式,Spark 1.5.2中已经将NioBlockTransferService方式设置为deprecated,在后续版本中将被移除

    下面按上述调用栈对各方法进行说明,这里只讲脉络,细节后面再讨论

    ShuffledRDD#compute()代码

    Task执行时,调用ShuffledRDD的compute方法,其代码如下:

    //org.apache.spark.rdd.ShuffledRDD#compute() override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] //通过org.apache.spark.shuffle.ShuffleManager#getReader()方法 //无论是Sort Shuffle 还是 Hash Shuffle,使用的都是 //org.apache.spark.shuffle.hash.HashShuffleReader SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) .read() .asInstanceOf[Iterator[(K, C)]] }

    可以看到,其核心逻辑是通过调用ShuffleManager#getReader()方法得到HashShuffleReader对象,然后调用HashShuffleReader#read()方法完成前一Stage中ShuffleMapTask生成的Shuffle 数据的读取。需要说明的是,无论是Hash Shuffle还是Sort Shuffle,使用的都是HashShuffleReader。

    HashShuffleReader#read()

    跳到HashShuffleReader#read()方法当中,其源码如下:

    /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { //创建ShuffleBlockFetcherIterator对象,在其构造函数中会调用initialize()方法 //该方法中会执行splitLocalRemoteBlocks(),确定数据的读取策略 //远程数据调用sendRequest()方法读取 //本地数据调用fetchLocalBlocks()方法读取 val blockFetcherItr = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, blockManager, mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) // Wrap the streams for compression based on configuration val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => blockManager.wrapForCompression(blockId, inputStream) } val ser = Serializer.getSerializer(dep.serializer) val serializerInstance = ser.newInstance() // Create a key/value iterator for each stream val recordIter = wrappedStreams.flatMap { wrappedStream => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } // Update the context task metrics for each record read. val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map(record => { readMetrics.incRecordsRead(1) record }), context.taskMetrics().updateShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // 读取Map端已经聚合的数据 val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { //读取Reducer端聚合的数据 val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // 对输出结果进行排序 dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) sorter.iterator case None => aggregatedIter } }

    ShuffleBlockFetcherIterator#splitLocalRemoteBlocks()

    splitLocalRemoteBlocks()方法确定数据的读取策略,localBlocks变量记录在本地机器的BlockID,remoteBlocks变量则用于记录所有在远程机器上的BlockID。远程数据块被分割成最大为maxSizeInFlight大小的FetchRequests

    val remoteRequests = new ArrayBuffer[FetchRequest]

    splitLocalRemoteBlocks()方法具有源码如下:

    private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. //maxBytesInFlight为每次请求的最大数据量,默认值为48M //通过SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)进行设置 val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] // Tracks total number of blocks (including zero sized blocks) var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size //要获取的数据在本地 if (address.executorId == blockManager.blockManagerId.executorId) { // Filter out zero-sized blocks //记录数据在本地的BlockID localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) numBlocksToFetch += localBlocks.size } else { //数据不在本地时 val iterator = blockInfos.iterator var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() // Skip empty blocks if (size > 0) { curBlocks += ((blockId, size)) //记录数据在远程机器上的BlockID remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } if (curRequestSize >= targetRequestSize) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) curBlocks = new ArrayBuffer[(BlockId, Long)] logDebug(s"Creating fetch request of $curRequestSize at $address") curRequestSize = 0 } } // Add in the final request if (curBlocks.nonEmpty) { remoteRequests += new FetchRequest(address, curBlocks) } } } logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") remoteRequests }

    ShuffleBlockFetcherIterator#fetchLocalBlocks()

    fetchLocalBlocks()方法进行本地Block的读取,调用的是BlockManager的getBlockData方法,其源代码如下:

    private[this] def fetchLocalBlocks() { val iter = localBlocks.iterator while (iter.hasNext) { val blockId = iter.next() try { //调用BlockManager的getBlockData方法 val buf = blockManager.getBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) return } } }

    跳转到BlockManager的getBlockData方法,可以看到其源代码如下:

    override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { //先调用的是ShuffleManager的shuffleBlockResolver方法,得到ShuffleBlockResolver //然后再调用其getBlockData方法 shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get new NioManagedBuffer(buffer) } else { throw new BlockNotFoundException(blockId.toString) } } }

    org.apache.spark.shuffle.hash.ShuffleManager#shuffleBlockResolver()方法获取相应的ShuffleBlockResolver,如果是Hash Shuffle,则 是org.apache.spark.shuffle.FileShuffleBlockResolver,如果是Sort Shuffle则org.apache.spark.shuffle.IndexShuffleBlockResolver。然后调用对应ShuffleBlockResolver的getBlockData方法,返回对应的FileSegment。 FileShuffleBlockResolver#getBlockData方法源码如下:

    override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { //对应Hash Shuffle中的Shuffle Consolidate Files机制生成的文件 if (consolidateShuffleFiles) { // Search all file groups associated with this shuffle. val shuffleState = shuffleStates(blockId.shuffleId) val iter = shuffleState.allFileGroups.iterator while (iter.hasNext) { val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) if (segmentOpt.isDefined) { val segment = segmentOpt.get return new FileSegmentManagedBuffer( transportConf, segment.file, segment.offset, segment.length) } } throw new IllegalStateException("Failed to find shuffle block: " + blockId) } else { //普通的Hash Shuffle机制生成的文件 val file = blockManager.diskBlockManager.getFile(blockId) new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } }

    IndexShuffleBlockResolver#getBlockData方法源码如下:

    override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index //使用shuffleId和mapId,获取对应索引文件 val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) val in = new DataInputStream(new FileInputStream(indexFile)) try { //定位到本次Block对应的数据位置 ByteStreams.skipFully(in, blockId.reduceId * 8) //数据起始位置 val offset = in.readLong() //数据结束位置 val nextOffset = in.readLong() //返回FileSegment new FileSegmentManagedBuffer( transportConf, getDataFile(blockId.shuffleId, blockId.mapId), offset, nextOffset - offset) } finally { in.close() } }

    ShuffleBlockFetcherIterator#sendRequest()

    sendRequest()方法用于从远程机器上获取数据

    private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) bytesInFlight += req.size // so we can look up the size of each blockID val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val blockIds = req.blocks.map(_._1.toString) val address = req.address //使用ShuffleClient的fetchBlocks方法获取数据 //有两种ShuffleClient,分别是ExternalShuffleClient和BlockTransferService //默认为BlockTransferService shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, new BlockFetchingListener { override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie, // i.e. cleanup() has not been called yet. if (!isZombie) { // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) results.put(new FailureFetchResult(BlockId(blockId), address, e)) } } ) }

    通过上面的代码可以看到,代码使用的是shuffleClient.fetchBlocks进行远程Block数据的获取,org.apache.spark.network.shuffle.ShuffleClient有两个子类,分别是ExternalShuffleClient和BlockTransferService,而org.apache.spark.network.shuffle.BlockTransferService又有两个子类,分别是NettyBlockTransferService和NioBlockTransferService,shuffleClient 对象在 org.apache.spark.storage.BlockManager定义,其源码如下:

    // org.apache.spark.storage.BlockManager中定义的shuffleClient private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { //使用ExternalShuffleClient获取远程Block数据 val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), securityManager.isSaslEncryptionEnabled()) } else { //使用NettyBlockTransferService或NioBlockTransferService获取远程Block数据 blockTransferService }

    代码中的blockTransferService在SparkEnv中被初始化,具体如下:

    //org.apache.spark.SparkEnv中初始化blockTransferService val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => new NettyBlockTransferService(conf, securityManager, numUsableCores) case "nio" => logWarning("NIO-based block transfer service is deprecated, " + "and will be removed in Spark 1.6.0.") new NioBlockTransferService(conf, securityManager) }
    最新回复(0)