whereblocks.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from __future__ import print_function
  2. import itertools
  3. import time
  4. import cProfile
  5. import inspect
  6. import numpy as np
  7. import numexpr as ne
  8. import bcolz
  9. import numba
  10. import pyorcy
  11. from cython_condition import condition as condition_cython_
  12. print("numexpr version:", ne.__version__)
  13. bcolz.defaults.cparams['shuffle'] = bcolz.SHUFFLE
  14. #bcolz.defaults.cparams['shuffle'] = bcolz.BITSHUFFLE
  15. bcolz.defaults.cparams['cname'] = 'blosclz'
  16. #bcolz.defaults.cparams['cname'] = 'lz4'
  17. bcolz.defaults.cparams['clevel'] = 5
  18. #bcolz.defaults.vm = "dask"
  19. #bcolz.defaults.vm = "python"
  20. bcolz.defaults.vm = "numexpr"
  21. N = 1e8
  22. LMAX = 1e3
  23. npa = np.arange(N, dtype=np.float64)
  24. npb = np.arange(N, dtype=np.float64)
  25. ct = bcolz.ctable([npa, npb], names=["a", "b"])
  26. def do_cprofile(func):
  27. def profiled_func(*args, **kwargs):
  28. profile = cProfile.Profile()
  29. try:
  30. profile.enable()
  31. result = func(*args, **kwargs)
  32. profile.disable()
  33. return result
  34. finally:
  35. profile.print_stats(sort='cumulative')
  36. return profiled_func
  37. def timefunc(f):
  38. def f_timer(*args, **kwargs):
  39. start = time.time()
  40. result = f(*args, **kwargs)
  41. end = time.time()
  42. print(f.__name__, 'took', round(end - start, 3), 'sec')
  43. return result
  44. return f_timer
  45. @timefunc
  46. def where_numpy():
  47. return sum(npa[i] for i in np.where((npa > 5) & (npb < LMAX))[0])
  48. @timefunc
  49. def where_numexpr():
  50. return sum(npa[i] for i in np.where(
  51. ne.evaluate('(npa > 5) & (npb < LMAX)'))[0])
  52. @timefunc
  53. #@do_cprofile
  54. def bcolz_where():
  55. return sum(r.a for r in ct.where("(a > 5) & (b < LMAX)"))
  56. @timefunc
  57. #@do_cprofile
  58. def bcolz_where_numpy():
  59. return sum(r.a for r in ct.where("(npa > 5) & (npb < LMAX)"))
  60. @timefunc
  61. #@do_cprofile
  62. def bcolz_where_numexpr():
  63. return sum(r.a for r in ct.where(ne.evaluate("(npa > 5) & (npb < LMAX)")))
  64. @timefunc
  65. #@do_cprofile
  66. def whereblocks():
  67. sum = 0.
  68. for r in ct.whereblocks("(a > 5) & (b < LMAX)", blen=None):
  69. sum += r['a'].sum()
  70. return sum
  71. @timefunc
  72. #@do_cprofile
  73. def fetchwhere_bcolz():
  74. return ct.fetchwhere("(a > 5) & (b < LMAX)", out_flavor='bcolz')['a'].sum()
  75. @timefunc
  76. #@do_cprofile
  77. def fetchwhere_numpy():
  78. return ct.fetchwhere("(a > 5) & (b < LMAX)", out_flavor='numpy')['a'].sum()
  79. @timefunc
  80. #@do_cprofile
  81. def fetchwhere_dask():
  82. result = ct.fetchwhere("(a > 5) & (b < LMAX)", vm="dask")['a'].sum()
  83. return result
  84. #@numba.jit
  85. def condition(a, b):
  86. return (a > 5) & (b < LMAX)
  87. def condition_cython(a, b):
  88. pyorcy.USE_CYTHON = True
  89. #pyorcy.VERBOSE = True
  90. return condition_cython_(a, b)
  91. @timefunc
  92. #@do_cprofile
  93. def fetchwhere_func():
  94. #result = ct.fetchwhere(condition)['a'].sum()
  95. result = ct.fetchwhere(lambda a,b: (a > 5) & (b < LMAX))['a'].sum()
  96. return result
  97. @timefunc
  98. #@do_cprofile
  99. def fetchwhere_func_cython():
  100. result = ct.fetchwhere(condition_cython)['a'].sum()
  101. return result
  102. print(repr(ct))
  103. a0 = where_numpy()
  104. print("a0:", a0)
  105. # a1 = where_numexpr()
  106. # assert a0 == a1
  107. # a1 = bcolz_where()
  108. # assert a0 == a1
  109. # a1 = bcolz_where_numpy()
  110. # assert a0 == a1
  111. # a1 = bcolz_where_numexpr()
  112. # assert a0 == a1
  113. # a1 = whereblocks()
  114. # assert a0 == a1
  115. a1 = fetchwhere_bcolz()
  116. assert a0 == a1
  117. # a1 = fetchwhere_numpy()
  118. # assert a0 == a1
  119. a1 = fetchwhere_dask()
  120. assert a0 == a1
  121. a1 = fetchwhere_func()
  122. assert a0 == a1
  123. a1 = fetchwhere_func_cython()
  124. assert a0 == a1