Haskellでラマヌジャン数

見た目がきれいなバージョンがあります。
https://qiita.com/zeal1483/private/69689dd8a68903932170

第4回『Haskellによる関数プログラミングの思考法』読書会に参加してきました。
https://sampou.connpass.com/event/66242/
そこで、ラマヌジャン数をhaskellで求める方法として以下が上がりました。
http://takeichi.ipl-lab.org/~onoue/pub/jssst01.pdf

10 -- Finding Ramanujan numbers
11
12 ramanujan :: [((Int,Int), (Int,Int))]
13 ramanujan = [(x,y)|(x,y)<-zip s (tail s),
14     sumcubes x == sumcubes y]
15   where s = fsort sumcubes 1
16
17 sumcubes :: (Int,Int) -> Int
18 sumcubes (a,b) = a^3 + b^3
20
21 fsort :: ((Int,Int) -> Int) -> Int
22 -> [(Int,Int)]
23 fsort r k
24 = (k,k) : fmerge r [(k,b)|b<-[k+1..]]
25 (fsort r (k+1))
26
27 {-
28 Two argument lists should be infinite,
29 so there is no definition for null list.
30 -}
31
32 fmerge :: (a->Int) -> [a] -> [a] -> [a]
33 fmerge r (x:xs) (y:ys)
34   | r x <= r y = x : fmerge r xs (y:ys)
35   | otherwise = y : fmerge r (x:xs) ys

しかし、下の様なナイーブな方法よりは、早いですが
1000個のラマヌジャン数を求めるなどすると、遅いことがわかります。

rmNum n = [(e,a,b,c,d)|e<-[28..2*n^3]
                      ,a<-[1..n]
                      ,b<-[a+1..n]
                      ,e == a^3 + b^3
                      ,c<-[a+1..n]
                      ,d<-[c..n]
                      ,e == c^3 + d^3]

遅い理由は、下記のフィボナッチ数列の求め方が遅い理由と同じように思えます。

fib n = fibs !! n
fibs = 0 : 1 : zipWith (+) fibs (tail fibs)

Haskellの神話
http://d.hatena.ne.jp/kazu-yamamoto/20100624/1277348961
フィボナッチ数列の遅い理由は、上記リンクによると
計算(サンク?、算術木?)が消せないこととあります。

(この遅い理由が、計算が消せないとなっていることについては
自分には、なぜそうなのかが、明瞭には分からないです。
空間計算量や、時間計算量ならば、直感的に分かります。

例えば、メモリの専有領域の問題なら、限界を超えると
スワップが発生し、ディスクIOが保存と取込で2重に発生するから
速いメモリとのやり取りから、遅いディスクとのやり取りが加わり
そのスピード差から、圧倒的に速度が違うことは、明らかです。

また、時間計算量の話は、計算に物理的に時間がかかるので
これも、明らか。
最初のナイーブな書き方で遅い理由だと思われる。

計算が消せないと、なぜ遅くなるのかのメカニズムが見えない。)

これとは別に問題点として
a^2+b^3=c^3+d^3=e^3+f^3=kとなるような
要件を満たす(a,b),(c,d),(e,f)とkが存在することがあります。
例えば、k=87539319,(167,436),(228,423),(255,414)
連続する2つの組で捉える方法では、これを捉えられません。
(これは、groupByやgroupを使用することで、簡単に対応できます。)

読書会の時に、山下さんが、言われた
x+y=kをk=[1..]として進めていく方法の方が
実装してみると明らかに速く
1000個のラマヌジャン数を求めても
苦にはならないくらい十分にスピードがあります。

ここで、実装に入る前に、実装方針の説明をしてみます。

単に、x+y=kをk=[1..]で進めていくと、大小関係がわからない。
(読書会中に、議論になっていたとおもう)

そこで、区間を(1,2]、(2,4]、(4,8]、(8,16]、(16,32]...と区切っていく
その内の(n,2n]と(2n,4n]を考える。
そうすると、下図のように大小関係がはっきりする。

図の(n,2n]上のラマヌジャン数を考えると
点線x^3+y^3=n〜x^3+y^3=2nの範囲を考えれば良い

A:∫(x,y)dk, x+y=k、ただし積分範囲( n,2n]
B:∫(x,y)dk, x+y=k、ただし積分範囲(2n,4n]
f(x,y)=x^3+y^3
とおく

{(x,y)|(x,y)∈A∪B} <= MAX( {f(x,y)|(x,y)∈A} )
となる(x,y)を取れば、図の②を上限とする領域がとれる。

下限の④については、Aの1つ下の階層O(オー)を考えればよい。

O:∫(x,y)dk, x+y=k、ただし積分範囲(n/2,n] (nが偶数の場合)
O:x+y=0、x^3+y^3=0として、下限値0をとる (nが奇数、すなわち1の場合)

これに対して、下記のように取れば、下限を押さえられる
{(x,y)|(x,y)∈A∪B} > MAX( {f(x,y)|(x,y)∈O} )


以上から、上図の斜線部を1単位として

n = [2^k| k<-[1..]]

を展開して、足し合わせることによって
ラマヌジャン数の探索範囲を敷き詰めることができる。

実装は、以下の通り。
makeSearchAreaで、上図の斜線部を求めている。

import Data.List

-----------------------------------------
-- x + y = n 上の正数値列挙
-- > line 8
-- > [(1,7),(2,6),(3,5),(4,4)]
-----------------------------------------
line n = [(a,n-a)|a<-[1..n-1],a<=n-a]

------------------------------------------------------------------
-- 高さ(a^3 + b^3)付与
-- > estLine 8
-- > [(344,1,7),(224,2,6),(152,3,5),(128,4,4)]
------------------------------------------------------------------
estLine n = map cubeNumNum (line n)
  where cubeNumNum (a,b) = (sumCubes (a,b), a, b)

------------------------------------------------------------------
-- 2n〜4nの区間の等高線を束ねて面にする
-- (これを等高線の層と言う事にする)
-- > sekibun (2, 4)
-- > [(9,1,2),(16,2,2),(28,1,3)]
-- s: 2^n (2の指数乗)
-- e: 2*s
------------------------------------------------------------------
sekibun (s, e) = sort (concatMap estLine [s+1..e])

------------------------------------------------------------------
-- 探索領域をつくる
-- > makeSearchArea 1
-- > [(9,1,2),(16,2,2),(28,1,3)]
-- > makeSearchArea 2
-- > [(35,2,3),(54,3,3),(65,1,4),(72,2,4),(91,3,4),(126,1,5),(128,4,4),(133,2,5),(152,3,5),(189,4,5),(217,1,6),(224,2,6),(243,3,6),(250,5,5),(280,4,6),(341,5,6),(344,1,7)]
------------------------------------------------------------------
makeSearchArea n = filter (\(x,y,z) -> min < x && x <= max) concat2Part
  where -- 上層・中層・下層
        fstPart = sekibun (n,   2*n)
        mdlPart = sekibun (2*n, 4*n)
        lstPart = sekibun (4*n, 8*n)
        -- 上界・下界
        min = fst3 (last fstPart)
        max = fst3 (last mdlPart)
        -- 連続した上位2つの等高線の層を束ね、整列する
        concat2Part = sort (mdlPart ++ lstPart)

------------------------------------------------------------------
-- ラマニュジャン数(等高線の層の指定)
-- > rmNumRange 2
-- > []
-- > rmNumRange 4
-- > [[(1729,1,12),(1729,9,10)]]
-- > rmNumRange 8
-- > [[(4104,2,16),(4104,9,15)],[(13832,2,24),(13832,18,20)],[(20683,10,27),(20683,19,24)]]
------------------------------------------------------------------
rmNumRange n = getRmNums (makeSearchArea n)

------------------------------------------------------------------
-- searchArea上のラマヌジャン数取得
------------------------------------------------------------------
getRmNums searchArea = filter isDuplicate (groupBy eq searchArea)
  where -- 等高線上の点か ?
        eq p q = fst3 p == fst3 q
        -- 複数要素があるか ?
        isDuplicate xs = length xs > 1

------------------------------------------------------------------
-- ラマニュジャン数
-- > take 1 rmNums
-- > [[(1729,1,12),(1729,9,10)]]
-- > take 2 rmNums
-- > [[(1729,1,12),(1729,9,10)],[(4104,2,16),(4104,9,15)]]
-- > take 3 rmNums
-- > [[(1729,1,12),(1729,9,10)],[(4104,2,16),(4104,9,15)],[(13832,2,24),(13832,18,20)]]
------------------------------------------------------------------
rmNums = concatMap rmNumRange ninoshisu



------------------------------------------------------------------
-- 補助関数
-- 2の指数の系列を生成
-- > ninoshisu  8
-- > [2,4,8,16,32,64,128,256...]
------------------------------------------------------------------
ninoshisu = map (\m -> 2^m) [1..]

------------------------------------------------------------------
-- 補助関数
------------------------------------------------------------------
fst3 (x,y,z) = x
sumCubes (a,b) = a^3 + b^3

上層・中層・下層の3つを、それぞれで求めているので
無駄に見えるかもしれない。

  where -- 上層・中層・下層
        fstPart = sekibun (n,   2*n)
        mdlPart = sekibun (2*n, 4*n)
        lstPart = sekibun (4*n, 8*n)

iterateを使えば、上層を、それぞれで求めることになるが
体感的には、速度は上がらないようです。

import Data.List

-----------------------------------------
-- x + y = n 上の正数値列挙
-- > line 8
-- > [(1,7),(2,6),(3,5),(4,4)]
-----------------------------------------
line n = [(a,n-a)|a<-[1..n-1],a<=n-a]

------------------------------------------------------------------
-- 高さ(a^3 + b^3)付与
-- > estLine 8
-- > [(344,1,7),(224,2,6),(152,3,5),(128,4,4)]
------------------------------------------------------------------
estLine n = map cubeNumNum (line n)
  where cubeNumNum (a,b) = (sumCubes (a,b), a, b)

------------------------------------------------------------------
-- 等高線を束ねて、面にする(これを等高線の層と言う事にする)
-- > sekibun (2, 4)
-- > [(9,1,2),(16,2,2),(28,1,3)]
------------------------------------------------------------------
sekibun (s, e) = sort (concatMap estLine [s+1..e])

------------------------------------------------------------------
-- 次のラマヌジャン数取得
-- > get6th $ next (1, 0, 0, [], [], [])
-- > []
-- > get6th $ next $ next (1, 0, 0, [], [], [])
-- > []
-- > get6th $ next $ next $ next (1, 0, 0, [], [], [])
-- > [[(1729,1,12),(1729,9,10)]]
-- > get6th $ next $ next $ next $ next (1, 0, 0, [], [], [])
-- > [[(4104,2,16),(4104,9,15)],[(13832,2,24),(13832,18,20)],[(20683,10,27),(20683,19,24)]]
------------------------------------------------------------------
next (n, min, max, fstPnl, mdlPnl, _) = (2*n, max, nextMax, mdlPnl, lstPnl, rmNums)
  where lstPnl = sekibun (4*n, 8*n)
        -- 次の上界
        nextMax = fst3 (last lstPnl)
        -- x^3 + y^3の上下の、上界の等高線を考慮して検索エリア作成
        searchArea = makeSearchArea (min, max) mdlPnl lstPnl
        -- searchArea上のラマヌジャン数
        rmNums = getRmNums searchArea

------------------------------------------------------------------
-- 補助関数
-- 上位2層を結合して、区間で絞り込む
------------------------------------------------------------------
makeSearchArea (min, max) mdlPnl lstPnl = filter p (sort (mdlPnl ++ lstPnl))
  where p (x,y,z) = min < x && x <= max

------------------------------------------------------------------
-- 補助関数
-- searchArea上のラマヌジャン数取得
------------------------------------------------------------------
getRmNums searchArea = filter isDuplicate (groupBy eq searchArea)
  where -- 等高線上の点か ?
        eq p q = fst3 p == fst3 q
        -- 複数要素があるか ?
        isDuplicate xs = length xs > 1

------------------------------------------------------------------
-- ラマニュジャン数
-- > take 1 rmNums
-- > [[(1729,1,12),(1729,9,10)]]
-- > take 2 rmNums
-- > [[(1729,1,12),(1729,9,10)],[(4104,2,16),(4104,9,15)]]
-- > take 3 rmNums
-- > [[(1729,1,12),(1729,9,10)],[(4104,2,16),(4104,9,15)],[(13832,2,24),(13832,18,20)]]
------------------------------------------------------------------
rmNums = concatMap get6th $ drop 3 $ iterate next (1, 0, 0, [], [], [])
get6th (n, min, max, fstPnl, mdlPnl, rmNums) = rmNums

------------------------------------------------------------------
-- 補助関数
------------------------------------------------------------------
fst3 (x,y,z) = x
sumCubes (a,b) = a^3 + b^3