test_random.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import random
  2. import unittest
  3. from math import exp, sqrt
  4. class Test_Distributions(unittest.TestCase):
  5. def test_seeding(self):
  6. # Python's seeding is different so we can't hard code
  7. random.seed(1)
  8. a = random.uniform(1,5)
  9. random.seed(1)
  10. b = random.uniform(1,5)
  11. self.assertEqual(a, b)
  12. def _excluded_test_avg_std(self):
  13. # Use integration to test distribution average and standard deviation.
  14. # Only works for distributions which do not consume variates in pairs
  15. #g = random.Random()
  16. N = 5000
  17. x = [i/float(N) for i in xrange(1,N)]
  18. for variate, args, mu, sigmasqrd in [
  19. (random.uniform, (1.0,10.0), (10.0+1.0)/2, (10.0-1.0)**2/12),
  20. (random.gauss, (-5.0,2.0), -5.0, 2.0**2),
  21. (random.normalvariate, (2.0,0.8), 2.0, 0.8**2),
  22. (random.lognormvariate, (-1.0,0.5), exp(-1.0 + 0.5**2)/2.0, (exp(0.5**2) - 1) * exp(-2.0 + 0.5**2)),
  23. (random.expovariate, (0.4,), 1.0/0.4, 1.0/0.4**2),
  24. (random.triangular, (0.0, 1.0, 1.0/3.0), 4.0/9.0, 7.0/9.0/18.0) #,
  25. # (g.expovariate, (1.5,), 1/1.5, 1/1.5**2),
  26. # (g.paretovariate, (5.0,), 5.0/(5.0-1),
  27. # 5.0/((5.0-1)**2*(5.0-2))),
  28. # (g.weibullvariate, (1.0, 3.0), gamma(1+1/3.0),
  29. # gamma(1+2/3.0)-gamma(1+1/3.0)**2)
  30. ]:
  31. #g.random = x[:].pop
  32. y = []
  33. for i in xrange(len(x)):
  34. try:
  35. y.append(variate(*args))
  36. except IndexError:
  37. pass
  38. s1 = s2 = 0
  39. for e in y:
  40. s1 += e
  41. s2 += (e - mu) ** 2
  42. N = len(y)
  43. # Reduced precision from two to zero decimal places
  44. self.assertAlmostEqual(s1/N, mu, 0)
  45. self.assertAlmostEqual(s2/(N-1), sigmasqrd, 0)
  46. def test_sample_frequency(self):
  47. N = 1000
  48. population = [-2, -1, 0, 1, 2]
  49. hist = {}
  50. for i in xrange(N):
  51. sampled = random.sample(population, 2)
  52. key = ','.join(str(x) for x in sampled)
  53. hist[key] = hist.get(key, 0) + 1
  54. # There are m * (m-1) ways to pick an ordered pair. The
  55. # observed number of occurrences of a pair follows a
  56. # Binomial(N, 1/(m*(m-1))) distribution.
  57. m = len(population)
  58. p = 1.0 / (m*(m-1))
  59. mean = N*p
  60. stddev = sqrt(N*p*(1-p))
  61. low = mean - 4*stddev
  62. high = mean + 4*stddev
  63. for a in population:
  64. for b in population:
  65. if a != b:
  66. key = '%s,%s' % (a, b)
  67. observed = hist.get(key, 0)
  68. self.assertLess(low, observed, 'Sample %s' % key)
  69. self.assertGreater(high, observed, 'Sample %s' % key)
  70. def test_sample_tuple(self):
  71. population = (1, 2, 3, 4)
  72. sampled = random.sample(population, 3)
  73. self.assertEqual(len(sampled), 3)
  74. for x in sampled:
  75. self.assertIn(x, population)
  76. def test_sample_set(self):
  77. population = set(range(20))
  78. sampled = random.sample(population, 10)
  79. self.assertEqual(len(sampled), 10)
  80. for x in sampled:
  81. self.assertIn(x, population)
  82. def test_sample_dict(self):
  83. population = {"one": 1, "two": 2, "three": 3}
  84. sampled = random.sample(population, 2)
  85. self.assertEqual(len(sampled), 2)
  86. for x in sampled:
  87. self.assertIn(x, population.keys())
  88. def test_sample_empty(self):
  89. sampled = random.sample([], 0)
  90. self.assertEqual(sampled, [])
  91. def test_sample_all(self):
  92. population = "ABCDEF"
  93. sampled = random.sample(population, len(population))
  94. self.assertEqual(set(sampled), set(population))
  95. def test_sample_one_too_many(self):
  96. self.assertRaises(ValueError, random.sample, range(4), 5)
  97. if __name__ == '__main__':
  98. unittest.main()