11#encoding:utf-8
22#Created by Liang Sun on May, 6, 2013
33class SegmentTree (object ):
4- def _init (self , start , end ):
5- self .data [(start , end )] = 0
6- if start < end :
7- mid = start + (end - start ) / 2
8- self ._init (start , mid )
9- self ._init (mid + 1 , end )
10-
114 def __init__ (self , start , end ):
125 self .start = start
136 self .end = end
14- self .data = {}
7+ self .max_value = {}
8+ self .sum_value = {}
9+ self .len_value = {}
1510 self ._init (start , end )
1611
1712 def add (self , start , end , weight = 1 ):
@@ -20,9 +15,31 @@ def add(self, start, end, weight=1):
2015 self ._add (start , end , weight , self .start , self .end )
2116 return True
2217
18+ def query_max (self , start , end ):
19+ return self ._query_max (start , end , self .start , self .end )
20+
21+ def query_sum (self , start , end ):
22+ return self ._query_sum (start , end , self .start , self .end )
23+
24+ def query_len (self , start , end ):
25+ return self ._query_len (start , end , self .start , self .end )
26+
27+ """"""
28+ def _init (self , start , end ):
29+ self .max_value [(start , end )] = 0
30+ self .sum_value [(start , end )] = 0
31+ self .len_value [(start , end )] = 0
32+ if start < end :
33+ mid = start + (end - start ) / 2
34+ self ._init (start , mid )
35+ self ._init (mid + 1 , end )
36+
2337 def _add (self , start , end , weight , in_start , in_end ):
38+ key = (in_start , in_end )
2439 if in_start == in_end :
25- self .data [(in_start , in_end )] += weight
40+ self .max_value [key ] += weight
41+ self .sum_value [key ] += weight
42+ self .len_value [key ] = 1 if self .sum_value [key ] > 0 else 0
2643 return
2744
2845 mid = in_start + (in_end - in_start ) / 2
@@ -33,22 +50,49 @@ def _add(self, start, end, weight, in_start, in_end):
3350 else :
3451 self ._add (start , mid , weight , in_start , mid )
3552 self ._add (mid + 1 , end , weight , mid + 1 , in_end )
36- self .data [(in_start , in_end )] = max (self .data [(in_start , mid )], self .data [(mid + 1 , in_end )])
53+ self .max_value [key ] = max (self .max_value [(in_start , mid )], self .max_value [(mid + 1 , in_end )])
54+ self .sum_value [key ] = self .sum_value [(in_start , mid )] + self .sum_value [(mid + 1 , in_end )]
55+ self .len_value [key ] = self .len_value [(in_start , mid )] + self .len_value [(mid + 1 , in_end )]
3756
38- def query (self , start , end ):
39- return self ._query (start , end , self .start , self .end )
40-
41- def _query (self , start , end , in_start , in_end ):
57+ def _query_max (self , start , end , in_start , in_end ):
4258 if start == in_start and end == in_end :
43- ans = self .data [(start , end )]
59+ ans = self .max_value [(start , end )]
4460 else :
4561 mid = in_start + (in_end - in_start ) / 2
4662 if mid >= end :
47- ans = self ._query (start , end , in_start , mid )
63+ ans = self ._query_max (start , end , in_start , mid )
4864 elif mid + 1 <= start :
49- ans = self ._query (start , end , mid + 1 , in_end )
65+ ans = self ._query_max (start , end , mid + 1 , in_end )
5066 else :
51- ans = max (self ._query (start , mid , in_start , mid ),
52- self ._query (mid + 1 , end , mid + 1 , in_end ))
67+ ans = max (self ._query_max (start , mid , in_start , mid ),
68+ self ._query_max (mid + 1 , end , mid + 1 , in_end ))
5369 #print start, end, in_start, in_end, ans
5470 return ans
71+
72+ def _query_sum (self , start , end , in_start , in_end ):
73+ if start == in_start and end == in_end :
74+ ans = self .sum_value [(start , end )]
75+ else :
76+ mid = in_start + (in_end - in_start ) / 2
77+ if mid >= end :
78+ ans = self ._query_sum (start , end , in_start , mid )
79+ elif mid + 1 <= start :
80+ ans = self ._query_sum (start , end , mid + 1 , in_end )
81+ else :
82+ ans = self ._query_sum (start , mid , in_start , mid ) + self ._query_sum (mid + 1 , end , mid + 1 , in_end )
83+ return ans
84+
85+ def _query_len (self , start , end , in_start , in_end ):
86+ if start == in_start and end == in_end :
87+ ans = self .len_value [(start , end )]
88+ else :
89+ mid = in_start + (in_end - in_start ) / 2
90+ if mid >= end :
91+ ans = self ._query_len (start , end , in_start , mid )
92+ elif mid + 1 <= start :
93+ ans = self ._query_len (start , end , mid + 1 , in_end )
94+ else :
95+ ans = self ._query_len (start , mid , in_start , mid ) + self ._query_len (mid + 1 , end , mid + 1 , in_end )
96+
97+ print start , end , in_start , in_end , ans
98+ return ans
0 commit comments