whereblocks.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from __future__ import print_function
  2. import itertools
  3. import numpy as np
  4. import numexpr as ne
  5. import bcolz
  6. import time
  7. import cProfile
  8. import inspect
  9. print("numexpr version:", ne.__version__)
  10. bcolz.defaults.cparams['shuffle'] = bcolz.SHUFFLE
  11. #bcolz.defaults.cparams['shuffle'] = bcolz.BITSHUFFLE
  12. bcolz.defaults.cparams['cname'] = 'blosclz'
  13. #bcolz.defaults.cparams['cname'] = 'lz4'
  14. bcolz.defaults.cparams['clevel'] = 5
  15. #bcolz.defaults.vm = "dask"
  16. #bcolz.defaults.vm = "python"
  17. bcolz.defaults.vm = "numexpr"
  18. N = 1e8
  19. LMAX = 1e3
  20. npa = np.arange(N)
  21. npb = np.arange(N)
  22. ct = bcolz.ctable([npa, npb], names=["a", "b"])
  23. def do_cprofile(func):
  24. def profiled_func(*args, **kwargs):
  25. profile = cProfile.Profile()
  26. try:
  27. profile.enable()
  28. result = func(*args, **kwargs)
  29. profile.disable()
  30. return result
  31. finally:
  32. profile.print_stats(sort='cumulative')
  33. return profiled_func
  34. def timefunc(f):
  35. def f_timer(*args, **kwargs):
  36. start = time.time()
  37. result = f(*args, **kwargs)
  38. end = time.time()
  39. print(f.__name__, 'took', round(end - start, 3), 'sec')
  40. return result
  41. return f_timer
  42. @timefunc
  43. def where_numpy():
  44. return sum(npa[i] for i in np.where((npa > 5) & (npb < LMAX))[0])
  45. @timefunc
  46. def where_numexpr():
  47. return sum(npa[i] for i in np.where(
  48. ne.evaluate('(npa > 5) & (npb < LMAX)'))[0])
  49. @timefunc
  50. #@do_cprofile
  51. def bcolz_where():
  52. return sum(r.a for r in ct.where("(a > 5) & (b < LMAX)"))
  53. @timefunc
  54. #@do_cprofile
  55. def bcolz_where_numpy():
  56. return sum(r.a for r in ct.where("(npa > 5) & (npb < LMAX)"))
  57. @timefunc
  58. #@do_cprofile
  59. def bcolz_where_numexpr():
  60. return sum(r.a for r in ct.where(ne.evaluate("(npa > 5) & (npb < LMAX)")))
  61. @timefunc
  62. #@do_cprofile
  63. def whereblocks():
  64. sum = 0.
  65. for r in ct.whereblocks("(a > 5) & (b < LMAX)", blen=None):
  66. sum += r['a'].sum()
  67. return sum
  68. @timefunc
  69. #@do_cprofile
  70. def fetchwhere_bcolz():
  71. return ct.fetchwhere("(a > 5) & (b < LMAX)", out_flavor='bcolz')['a'].sum()
  72. @timefunc
  73. #@do_cprofile
  74. def fetchwhere_numpy():
  75. return ct.fetchwhere("(a > 5) & (b < LMAX)", out_flavor='numpy')['a'].sum()
  76. @timefunc
  77. #@do_cprofile
  78. def fetchwhere_dask():
  79. result = ct.fetchwhere("(a > 5) & (b < LMAX)", vm="dask")['a'].sum()
  80. return result
  81. print(repr(ct))
  82. a0 = where_numpy()
  83. print("a0:", a0)
  84. a1 = where_numexpr()
  85. assert a0 == a1
  86. a1 = bcolz_where()
  87. assert a0 == a1
  88. a1 = bcolz_where_numpy()
  89. assert a0 == a1
  90. a1 = bcolz_where_numexpr()
  91. assert a0 == a1
  92. a1 = whereblocks()
  93. assert a0 == a1
  94. a1 = fetchwhere_bcolz()
  95. assert a0 == a1
  96. a1 = fetchwhere_numpy()
  97. assert a0 == a1
  98. a1 = fetchwhere_dask()
  99. assert a0 == a1