Please see Python code below. The algorithm is mine. I got help coding it from Gemini and ChatGPT.
from collections import defaultdict
from bisect import bisect_left
from typing import List, Dict, Set, Tuple, DefaultDict
import math
import random
class SplitSparseTable:
def __init__(self, arr: List[int]):
self.n = len(arr)
if self.n == 0:
return
# 1. Precompute Logarithms for O(1) lookup
self.logs = [0] * (self.n + 1)
for i in range(2, self.n + 1):
self.logs[i] = self.logs[i // 2] + 1
self.K = self.logs[self.n]
# 2. Initialize TWO separate tables
# Using simple integers reduces memory overhead significantly compared to tuples
self.min_st = [[0] * self.n for _ in range(self.K + 1)]
self.max_st = [[0] * self.n for _ in range(self.K + 1)]
# 3. Base Case (Level 0)
for i in range(self.n):
self.min_st[0][i] = arr[i]
self.max_st[0][i] = arr[i]
# 4. Build Tables
for i in range(1, self.K + 1):
current_len = 1 << i
half_len = 1 << (i - 1)
limit = self.n - current_len + 1
# We iterate once, but update both tables independently
for j in range(limit):
# Update Min Table
self.min_st[i][j] = min(
self.min_st[i - 1][j],
self.min_st[i - 1][j + half_len]
)
# Update Max Table
self.max_st[i][j] = max(
self.max_st[i - 1][j],
self.max_st[i - 1][j + half_len]
)
def query(self, L: int, R: int) -> Tuple[int, int]:
"""
Returns (min, max) in O(1) time.
"""
if L > R or L < 0 or R >= self.n:
raise ValueError(f"Invalid range [{L}, {R}]")
# Get the precomputed log for the length
length = R - L + 1
k = self.logs[length]
range_len = 1 << k
# Query Min Table
min_val = min(
self.min_st[k][L],
self.min_st[k][R - range_len + 1]
)
# Query Max Table
max_val = max(
self.max_st[k][L],
self.max_st[k][R - range_len + 1]
)
return min_val, max_val
def get_perm(n: int) -> list[int]:
s = list(range(1, n + 1))
random.shuffle(s)
return s
def brute_force(A: list[int]) -> int:
result = 0
n = len(A)
for i in range(n):
for j in range(i, n):
subarray = A[i:j+1]
if len(subarray) < 2:
continue
sorted_subarray = sorted(subarray)
diff = sorted_subarray[1] - sorted_subarray[0]
is_arithmetic = True
for k in range(2, len(sorted_subarray)):
if sorted_subarray[k] - sorted_subarray[k-1] != diff:
is_arithmetic = False
break
if is_arithmetic:
result += 1
return result
def get_all_divisors_up_to_n(n: int) -> List[List[int]]:
"""
Pre-calculates all divisors for every number from 1 up to n using a sieve-like method.
Returns a list where index i contains a list of all divisors of i.
"""
divisors = [[] for _ in range(n + 1)]
for d in range(1, n + 1):
for multiple in range(d, n + 1, d):
divisors[multiple].append(d)
return divisors
def map_values_to_indices(A: List[int]) -> List[int]:
"""
Creates a map from the value of an element in array A to its index.
Args:
A: The input array of integers.
Returns:
A list whose value is the index in the original list
when using the original value as index
"""
result = [None] * len(A)
for i, val in enumerate(A):
result[val - 1] = i
return result
def right_sweep(A: List[int], D: int, l: int, r: int, val_to_idx: List[int], min_max_table) -> DefaultDict[int, List[int]]:
"""
Performs a sweep from the middle index M to the right boundary r to find
the rightmost R that, when paired with the middle index M, forms an AP.
Args:
A: The input array.
D: The arithmetic difference (stride).
l: The left boundary of the current D&C segment.
r: The right boundary of the current D&C segment.
val_to_idx: Map of array values to their indices.
min_max_table: Sparse table for min and max lookup.
Returns:
A map where key L is the start index of a crossing AP (L=M), and
the value is a list of valid right boundary indices R.
"""
M = (l + r) // 2
res: DefaultDict[int, List[int]] = defaultdict(list)
L = M
R = M
min_val = A[M]
max_val = A[M]
curr = M + 1
while curr <= r:
# If curr is already inside the current AP window [L, R] due to previous expansions, skip it
if curr <= R:
curr += 1
continue
# Step 1: Tentatively expand R to the current index
R = curr
min_cand, max_cand = min_max_table.query(L, R)
while min_cand < min_val or max_cand > max_val:
if min_cand < min_val:
# Check all values required to bridge the gap from new min to old min
for needed_val in range(min_cand + D, min_val, D):
idx = val_to_idx[needed_val - 1]
if idx < l or idx > r:
return res
L = min(L, idx)
R = max(R, idx)
min_val = min_cand
elif max_cand > max_val:
# Check all values required to bridge the gap from old max to new max
for needed_val in range(max_val + D, max_cand, D):
idx = val_to_idx[needed_val - 1]
if idx < l or idx > r:
return res
L = min(L, idx)
R = max(R, idx)
max_val = max_cand
min_cand, max_cand = min_max_table.query(L, R)
if not res[L] or res[L][-1] != R:
res[L].append(R)
curr += 1
return res
def left_sweep(A: List[int], D: int, l: int, r: int,
val_to_idx: List[int],
min_max_table,
rs_map: DefaultDict[int, List[int]]) -> int:
"""
Left sweep rewritten to use the same expansion model as right_sweep.
No missing-set. Uses target_L/target_R logic to expand the AP.
"""
M = (l + r) // 2
total_count = 0
valid_L_keys: Set[int] = set()
if M + 1 > r:
return 0
# Initialize sweep window at [M+1, M+1]
L = M + 1
R = M + 1
min_val = A[R]
max_val = A[R]
curr = M
while curr >= l:
# Step 1: Add new element on the left
L = curr
min_cand, max_cand = min_max_table.query(L, R)
while min_cand < min_val or max_cand > max_val:
if min_cand < min_val:
# Need to bridge values from new <-> old min
for needed_val in range(min_cand + D, min_val, D):
idx = val_to_idx[needed_val - 1]
if idx < l or idx > r:
return total_count
L = min(L, idx)
R = max(R, idx)
min_val = min_cand
elif max_cand > max_val:
# Need to bridge values from old <-> new max
for needed_val in range(max_val + D, max_cand, D):
idx = val_to_idx[needed_val - 1]
if idx < l or idx > r:
return total_count
L = min(L, idx)
R = max(R, idx)
max_val = max_cand
min_cand, max_cand = min_max_table.query(L, R)
# Record L as a valid left boundary
if L in rs_map:
valid_L_keys.add(L)
total_count += 1
# Count extensions via rs_map exactly as before
for L_key in valid_L_keys:
R_list = rs_map[L_key]
start_index = bisect_left(R_list, R + 1)
total_count += len(R_list) - start_index
curr = L - 1
return total_count
def count_arithmetic_progressions_iterative(A: List[int], D: int, val_to_idx: List[int], min_max_table, l_start: int, r_end: int) -> int:
"""
Implements the core divide-and-conquer logic iteratively using a stack.
Args:
A: The input array.
D: The arithmetic difference (stride).
val_to_idx: Map of array values to their indices.
l_start: The starting index of the segment.
r_end: The ending index of the segment.
Returns:
The total count of APs within the segment [l_start, r_end].
"""
if l_start >= r_end:
return 0
total_count = 0
stack: List[Tuple[int, int]] = [(l_start, r_end)]
while stack:
l, r = stack.pop()
if l >= r:
continue
M = (l + r) // 2
# 1. Calculate the map of valid right extensions for APs starting at M (right half)
rs_map = right_sweep(A, D, l, r, val_to_idx, min_max_table)
# 2. Calculate the count of APs that cross the boundary M/M+1 (combination step)
cross_count = left_sweep(A, D, l, r, val_to_idx, min_max_table, rs_map)
total_count += cross_count
# 3. Add subproblems to the stack
stack.append((l, M))
stack.append((M + 1, r))
return total_count
def f(A: List[int]) -> int:
n = len(A)
# Pre-calculate value-to-index map once
val_to_idx = map_values_to_indices(A)
total_count = 0
divisor_table = get_all_divisors_up_to_n(n)
min_max_table = SplitSparseTable(A)
ranges: DefaultDict[int, int] = defaultdict(int)
for i in range(1, n):
diff = abs(A[i] - A[i-1])
current_divisors = divisor_table[diff]
# Close any ranges whose stride no longer divides the current diff
for stride in list(ranges.keys()):
if stride not in current_divisors:
start = ranges[stride]
total_count += count_arithmetic_progressions_iterative(A, stride, val_to_idx, min_max_table, start, i - 1)
del ranges[stride]
# Open new ranges for each divisor D
for D in current_divisors:
if not D in ranges:
ranges[D] = i - 1
# Close all still-open ranges
for D in list(ranges.keys()):
start = ranges[D]
total_count += count_arithmetic_progressions_iterative(A, D, val_to_idx, min_max_table, start, n - 1)
return total_count
num_tests = 100
n = 20
for _ in range(num_tests):
A = get_perm(n)
brute = brute_force(A)
ff = f(A)
if brute != ff:
print(brute, ff, A)
break
print("Done.")
nbe? What's the time limit?D = 2seems to be missing the subsequence[1, 3].