fancadescala/tagless/src/main/scala/tf/bug/fancadetagless/PointerChain.scala

53 lines
1.5 KiB
Scala

package tf.bug.fancadetagless
import cats._
import cats.implicits._
import cats.data.State
import higherkindness.droste._
import higherkindness.droste.data._
import higherkindness.droste.scheme
case class PointerChain[T[_], E[_]](vertices: Vector[E[Int]], roots: T[Int])
case class PointerState[F[_]](seen: Map[F[Int], Int], vertices: Vector[F[Int]], lastIndex: Int)
object PointerChain {
def deduplicate[R[_], F[_]](input: R[Fix[F]])(
implicit rt: Traverse[R],
rf: Traverse[F]
): PointerChain[R, F] = {
val addAll: State[PointerState[F], R[Int]] = input.traverse(scheme.cataM(findOrAddNode[F]))
val result: (PointerState[F], R[Int]) = addAll.run(PointerState.empty[F]).value
val graphed: (Vector[F[Int]], R[Int]) = result.leftMap(_.vertices.reverse)
val chain: PointerChain[R, F] = PointerChain(graphed._1, graphed._2)
chain
}
def findOrAddNode[F[_]]: AlgebraM[State[PointerState[F], *], F, Int] = AlgebraM { tree =>
for {
gs <- State.get[PointerState[F]]
seen = gs.seen
result <- seen.get(tree) match {
case Some(index) => State.pure[PointerState[F], Int](index)
case None =>
val index = gs.lastIndex
for {
_ <- State.set(PointerState(
seen + (tree -> index),
tree +: gs.vertices,
index + 1
))
} yield index
}
} yield result
}
}
object PointerState {
def empty[F[_]]: PointerState[F] = PointerState(Map(), Vector(), 0)
}