消息关闭
    暂无新消息!
最近在学习用spark 的GraphX框架实现近邻传播聚类算法(AP)的并行化,但是代码写好后,迭代次数设置>30,运行中就会报错java.lang.stackoverflowerror,我之前查了查,有可能是迭代次数过多导致lineage过长,但是checkpoint并没有效果,该报错还是报错。在迭代次数设置不多就可以跑成功。本人是spark1.6.0环境,local模式。求有相关经验的人看看,不胜感激!

算法主体代码:
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.graphx.{Graph, TripletFields, VertexId}

/**
  * Created by LCJ on 2017.3.19.
  */
class AP (
           val graphInput: Graph[VertexData, EdgeData],
           val lambda: Double, val iterations: Int, val threshold: Int,
           val a: Double, val r: Double
         )
  extends Serializable {

  private var graph = this.graphInput
  private val lam = this.lambda
  private val maxIterNum = this.iterations
  private val thresholdNum = this.threshold
  private val avaiInitial = this.a
  private val respInitial = this.r


  private def checkOutputPath(path: String): Unit = {
    val fs = FileSystem.get(new Configuration())
    if (fs.exists(new Path(path))) {
      fs.delete(new Path(path), true)
    }
  }


  private def balance(valuePrevious: Double, valueNow: Double): Double = {
    lam * valuePrevious + (1 - lam) * valueNow
  }


  private def getExemplars(g: Graph[VertexData, EdgeData]): Set[VertexId] = {
    g.aggregateMessages[(VertexId, Double)](
      sendMsg = s => s.sendToSrc((s.dstId, s.attr.avai + s.attr.resp)),
      mergeMsg = (a, b) => if (a._2 > b._2) a else b,
      TripletFields.EdgeOnly
    ).map(v => v._2._1).collect().toSet
  }


  def run(): Unit = {

    var prevG: Graph[VertexData, EdgeData] = null

    var centers = Set[VertexId]()

    var countForThreshold = 0   // 聚类中心不改变(算法收敛)迭代次数计数
    var flag = true
    var iterCount = 0   // 总迭代次数计数

    for (_ <- 1 to maxIterNum if flag) {
      // 必须先用prevG保留住对原来图的引用, 并在新图产生后, 快速将旧图彻底释放掉.
      // 否则, 十几轮迭代后, 会有内存泄漏问题, 很快耗光作业缓存空间
      prevG = graph

      // 更新r
      val updating_r = graph.aggregateMessages[Seq[Double]](
        sendMsg = s => s.sendToSrc(Seq(s.attr.similarity + s.attr.avai)),
        mergeMsg = (a, b) => a ++ b,
        TripletFields.EdgeOnly
      )

      val updated_r = Graph(updating_r, graph.edges)
        .mapTriplets(t => {
          val filtered = t.srcAttr.filter(_ != (t.attr.similarity + t.attr.avai))
          val pool =
            if (filtered.size < t.srcAttr.size - 1) filtered :+ (t.attr.similarity + t.attr.avai)
            else filtered
          val maxValue = if (pool.isEmpty) 0.0 else pool.max
          EdgeData(t.attr.similarity, t.attr.avai, balance(t.attr.resp, t.attr.similarity - maxValue))
        }, TripletFields.Src)

      graph = Graph.fromEdges(updated_r.edges, VertexData(avaiInitial, respInitial))

      // 更新a
      val updating_a = graph.aggregateMessages[Double](
        sendMsg = s => {
          if (s.srcId != s.dstId) s.sendToDst(math.max(s.attr.resp, 0.0))
          else s.sendToDst(s.attr.resp)
        },
        mergeMsg = (a, b) => a + b,
        TripletFields.EdgeOnly
      )

      val updated_a = Graph(updating_a, graph.edges)
        .mapTriplets(t => {
          if (t.srcId != t.dstId) {
            val a = balance(
              t.attr.avai,
              math.min(0.0, t.dstAttr - math.max(t.attr.resp, 0.0))
            )
            EdgeData(t.attr.similarity, a, t.attr.resp)
          }
          else {
            val a = balance(
              t.attr.avai,
              t.dstAttr - t.attr.resp
            )
            EdgeData(t.attr.similarity, a, t.attr.resp)
          }
        }, TripletFields.Dst)

      graph = Graph.fromEdges(updated_a.edges, VertexData(avaiInitial, respInitial)).persist()

      iterCount += 1

      // 每次更新r和a后判断聚类中心有无变化
      val centersTmp = getExemplars(graph)
      if (centers == centersTmp) {
        countForThreshold += 1
        if (countForThreshold == thresholdNum) {
          flag = false
          println("Break!")
        }
      }
      else {
        centers = centersTmp
        countForThreshold = 0
      }

      if (iterCount % 6 == 0) {
        graph.cache()
        graph.checkpoint()
        println(graph.numVertices)
      }

      prevG.unpersistVertices()
      prevG.edges.unpersist()

    }

    println("算法总迭代次数: " + iterCount)
    println("聚类中心不改变次数: " + countForThreshold)
    println("Exemplars: " + centers)

    // 确定每个点到聚类中心的分配情况
    val clusterInfo = graph.aggregateMessages[(VertexId, Double, Double)](
      sendMsg = s => s.sendToSrc((s.dstId, s.attr.similarity, s.attr.avai + s.attr.resp)),
      mergeMsg = (a, b) => if (a._3 > b._3) a else b,
      TripletFields.EdgeOnly
    ).persist()

    // 将点的分配存入文本
    checkOutputPath("member-exemplar")
    clusterInfo.mapValues(s => s._1).saveAsTextFile("member-exemplar")

    // 计算总误差平方和
    val WSSSE = clusterInfo.map(e => if (e._1 == e._2._1) 0.0 else math.pow(e._2._2, 2)).sum()
    println("WSSSE: " + WSSSE)

    clusterInfo.unpersist()
    graph.unpersist()

  }

}
报错位置在我红色的地方那个collect action操作。谢谢各位大神!

5个回答

︿ 1
If you gave me a case can be reproduced locally, I will take a look. Are you running Spark 1.6.3? 
︿ 1
检查下graph, 使用spark on yarn去执行, 让yarn帮助资源调配
然后有问题再进行调试