6

I'm working on a problem where I need to count, for each possible common difference D, the number of contiguous subarrays whose elements can be rearranged to form an arithmetic progression with common difference D.

Problem Description

Given an array arr[1..n] which is a permutation of 1..n, for each D from 1 to n, count the number of contiguous subarrays where the set of elements can be rearranged into an arithmetic progression with common difference D.

Example:

text

arr = [5, 1, 3, 2, 4]  (permutation of 1..5)

Expected results:

  • D = 1: 5 subarrays ([3,2], [1,3,2], [3,2,4], [1,3,2,4], [5, 1, 3, 2, 4])

  • D = 2: 3 subarrays ([5,1,3], [1,3], [2,4])

  • D = 3: 0 subarrays

  • D = 4: 1 subarray ([5,1])

  • D = 5: 0 subarrays

I've tried an O(n²) solution that checks all contiguous subarrays, but it's too slow for large n .

8
  • 2
    You could analyse the prime factor decomposition of differences between subsequent array elements. If the difference between two elements is 12, for example, then they MUST be part of some subsequence with stride 12 and they MAY be part of subsequences with strides 2, 3, 4, and 6. Basically an indexing exercise. Commented Nov 21 at 7:01
  • How large can n be? What's the time limit? Commented Nov 21 at 8:27
  • n ≤ 10⁵, and the time limit is 2 seconds. The array is a permutation of 1..n, so all values are distinct and consecutive integers. Commented Nov 21 at 13:42
  • The expected result for D = 2 seems to be missing the subsequence [1, 3]. Commented Nov 22 at 18:36
  • Yes, you are right. Commented Nov 23 at 12:31

2 Answers 2

2

TL;DR an efficient algorithm based on factor decomposition of differences that relies on an O(n²) component for counting subsequences, just barely avoiding the timeout for worst-case input (full PoC available as C# fiddle)

Each pair of consecutive values in the sequence has a certain absolute difference d (non-zero because all values are distinct). This means that the pair MUST be part of a contiguous subsequence with stride d and it MAY be part of contiguous subsequences whose strides are equal to the other factors of d. It also means that candidate subsequences with other strides cannot continue across the current pair and must necessarily terminate at the first value of the pair.

Ergo you can sweep across the input sequence looking at each pair of consecutive values in order to collect candidate subsequences as follows:

  • for each factor f of the absolute difference d between ai and ai+i:
    • if there is no open candidate sequence with stride f, create one with the tuple (ai, i)
    • append the tuple (ai+1, i+1) to the open subsequence with stride f
  • for each open candidate sequence whose stride is not a factor of d:
    • count the number of contiguous subsequences that form progressions
    • throw the candidate sequence away

Here is the example sequence with the difference factors indicated below the gaps between the values, in separate rows per factor. This shows how the candidate sequences come into being:

5 1 3 2 4
 1 1 1 1
 2 2   2
 4

The tricky bit is the counting of the contiguous subsequences contained in a collected candidate sequence. This sounds a lot like the overall problem description again, but at this point we are in a position to mine the raw number ore efficiently.

Here is the first of the two candidate sequences for factor/stride 2. It contains tuples (value, index), but values and indices are shown in separate rows for clarity:

5 1 3
0 1 2

If the tuples of the candidate sequence are not sorted by value already (i.e. if they are not collected with the aid of a structure that orders them by value), sort them.

1 3 5
1 2 0

In this form it becomes possible to isolate unbroken runs of values and then mine those for compliant subsequences. The example sequence does not generate candidate sequences with separate runs but some other sequence might generate a stride 2 candidate sequence like this, with a gap between 3 and 7:

1 3 7 9
1 2 0 3

Candidate sequence processing

Remember your current position as base position. Step through the tuples, tracking index minima and maxima. If the value chain is unbroken and the index extrema indicate a compact range of the same size then you have found a compliant subsequence; add 1 to your count and step forward. If the value chain is broken, remember the current position for later (as the potential start of a new unbroken subsequence) and repeat the process starting at base position + 1. When you are done with this unbroken run of values, continue with the next one whose beginning you stored earlier.

This caterpillar movement is square in nature, and a test implementation clocked in at 1.6 seconds for the worst-case input (ordered sequence 1 .. 50000 with N * (N - 1) / 2 = 1,249,975,000 subsequences for stride 1).

That is roughly 1 nanosecond per subsequence, meaning there is little speed-up potential left in this approach. Or any other approach that is based on finding/counting all compliant subsequences individually, for that matter, because this necessarily leads to the big bad O(n²).

However, all is not lost. Perhaps we have reduced the original, complex problem to a smaller one. The crucial - and so far O(n²) - operation is this:

  • given a sequence of integers, count all the pairings where the difference in position plus 1 is equal to the difference in between minimal and maximal value between them (indicating that they represent a contiguous subrange)

Note that 'count' here means 'perform some sub-square algorithm that effectively computes the count without counting each pairing individually'.

As illustration, here's the candidate sequence for stride 1 that is generated by the example sequence, sorted by value and separated into unbroken value runs (of which stride 1 only ever has exactly one that covers the whole input).

1 2 3 4 5 <- value components of the tuples
1 3 2 4 0 <- index components of the tuples, input for above subproblem

I do not have a solution at the moment, but this reduced problem may be easier to solve than the original one.

I have looked at the displacement of inversions as a promising avenue of research, but I quickly shelved that (consider 4 3 2 1 0, which loses no subsequences at all). Still in the running is looking at the absolute values of gaps in the index sequence (other than ±1). Each of those gaps causes a certain number of sequences to be lost and the reach of its effect grows with its size.

Last but not least, it may be possible to curb the worst quadratic excesses by finding unbroken runs - regardless of whether descending or ascending - and treating them wholesale, like a single hop that contributes run_length * (run_length - 1) / 2 counts instead of one. This would make the worst-case input of the original algorithm the easiest input of all,, and so this little trick may just be enough to beat the time-out. ;-) Caveat: an index sequence like 0 3 1 4 2 5 ... does not contain any runs and so the trick cannot work in this case.

Sign up to request clarification or add additional context in comments.

3 Comments

This code runs 500,000 in 3 seconds on Ideone. I assume it would be faster on a local setup. ideone.com/dFED4v
That's amazing! It'll take me a while to work through it, though. ATM I see what it does but not what that means. The result for the random permutation certainly looks plausible; IIRC the expected value is vector length + 1 or something like that. The brute force reference looks suspect, however. Perhaps you could simply use my code and generate a couple of reference results; I haven't got around to making a decent set of fiendish test vectors yet. For the next couple of days I'm off to help the elves, though. ;-)
Brute force looks OK to me unless I’m missing something. It seems to me that it takes every subarray, makes a copy of it, sorts it, and then confirms that every difference between adjacent elements is the same.
0

Please see Python code below. The algorithm is mine. I got help coding it from Gemini and ChatGPT.

from collections import defaultdict
from bisect import bisect_left
from typing import List, Dict, Set, Tuple, DefaultDict
import math
import random


class SplitSparseTable:
    def __init__(self, arr: List[int]):
        self.n = len(arr)
        if self.n == 0:
            return

        # 1. Precompute Logarithms for O(1) lookup
        self.logs = [0] * (self.n + 1)
        for i in range(2, self.n + 1):
            self.logs[i] = self.logs[i // 2] + 1
        
        self.K = self.logs[self.n]

        # 2. Initialize TWO separate tables
        # Using simple integers reduces memory overhead significantly compared to tuples
        self.min_st = [[0] * self.n for _ in range(self.K + 1)]
        self.max_st = [[0] * self.n for _ in range(self.K + 1)]

        # 3. Base Case (Level 0)
        for i in range(self.n):
            self.min_st[0][i] = arr[i]
            self.max_st[0][i] = arr[i]

        # 4. Build Tables
        for i in range(1, self.K + 1):
            current_len = 1 << i
            half_len = 1 << (i - 1)
            
            limit = self.n - current_len + 1
            
            # We iterate once, but update both tables independently
            for j in range(limit):
                # Update Min Table
                self.min_st[i][j] = min(
                    self.min_st[i - 1][j], 
                    self.min_st[i - 1][j + half_len]
                )
                
                # Update Max Table
                self.max_st[i][j] = max(
                    self.max_st[i - 1][j], 
                    self.max_st[i - 1][j + half_len]
                )

    def query(self, L: int, R: int) -> Tuple[int, int]:
        """
        Returns (min, max) in O(1) time.
        """
        if L > R or L < 0 or R >= self.n:
            raise ValueError(f"Invalid range [{L}, {R}]")

        # Get the precomputed log for the length
        length = R - L + 1
        k = self.logs[length]
        
        range_len = 1 << k
        
        # Query Min Table
        min_val = min(
            self.min_st[k][L], 
            self.min_st[k][R - range_len + 1]
        )
        
        # Query Max Table
        max_val = max(
            self.max_st[k][L], 
            self.max_st[k][R - range_len + 1]
        )

        return min_val, max_val


def get_perm(n: int) -> list[int]:
    s = list(range(1, n + 1))
    random.shuffle(s)
    return s


def brute_force(A: list[int]) -> int:
    result = 0
    n = len(A)
    for i in range(n):
        for j in range(i, n):
            subarray = A[i:j+1]
            
            if len(subarray) < 2:
                continue

            sorted_subarray = sorted(subarray)

            diff = sorted_subarray[1] - sorted_subarray[0]
            is_arithmetic = True
            
            for k in range(2, len(sorted_subarray)):
                if sorted_subarray[k] - sorted_subarray[k-1] != diff:
                    is_arithmetic = False
                    break

            if is_arithmetic:
                result += 1

    return result


def get_all_divisors_up_to_n(n: int) -> List[List[int]]:
    """
    Pre-calculates all divisors for every number from 1 up to n using a sieve-like method.
    Returns a list where index i contains a list of all divisors of i.
    """
    divisors = [[] for _ in range(n + 1)]
    for d in range(1, n + 1):
        for multiple in range(d, n + 1, d):
            divisors[multiple].append(d)
    return divisors


def map_values_to_indices(A: List[int]) -> List[int]:
    """
    Creates a map from the value of an element in array A to its index.
    
    Args:
        A: The input array of integers.

    Returns:
        A list whose value is the index in the original list
        when using the original value as index
    """
    result = [None] * len(A)

    for i, val in enumerate(A):
      result[val - 1] = i

    return result


def right_sweep(A: List[int], D: int, l: int, r: int, val_to_idx: List[int],  min_max_table) -> DefaultDict[int, List[int]]:
    """
    Performs a sweep from the middle index M to the right boundary r to find 
    the rightmost R that, when paired with the middle index M, forms an AP.

    Args:
        A: The input array.
        D: The arithmetic difference (stride).
        l: The left boundary of the current D&C segment.
        r: The right boundary of the current D&C segment.
        val_to_idx: Map of array values to their indices.
        min_max_table: Sparse table for min and max lookup.

    Returns:
        A map where key L is the start index of a crossing AP (L=M), and 
        the value is a list of valid right boundary indices R.
    """
    M = (l + r) // 2
    res: DefaultDict[int, List[int]] = defaultdict(list)
    
    L = M
    R = M
    min_val = A[M]
    max_val = A[M]
    
    curr = M + 1
    while curr <= r:
        # If curr is already inside the current AP window [L, R] due to previous expansions, skip it
        if curr <= R:
            curr += 1
            continue

        # Step 1: Tentatively expand R to the current index
        R = curr
        min_cand, max_cand = min_max_table.query(L, R)

        while min_cand < min_val or max_cand > max_val:
            if min_cand < min_val:
                # Check all values required to bridge the gap from new min to old min
                for needed_val in range(min_cand + D, min_val, D):
                    idx = val_to_idx[needed_val - 1]
                    if idx < l or idx > r:
                        return res
                    L = min(L, idx)
                    R = max(R, idx)
                min_val = min_cand
            elif max_cand > max_val:
                # Check all values required to bridge the gap from old max to new max
                for needed_val in range(max_val + D, max_cand, D):
                    idx = val_to_idx[needed_val - 1]
                    if idx < l or idx > r:
                        return res
                    L = min(L, idx)
                    R = max(R, idx)
                max_val = max_cand
            min_cand, max_cand = min_max_table.query(L, R)

        if not res[L] or res[L][-1] != R:
            res[L].append(R)
        
        curr += 1
        
    return res


def left_sweep(A: List[int], D: int, l: int, r: int,
               val_to_idx: List[int],
               min_max_table,
               rs_map: DefaultDict[int, List[int]]) -> int:
    """
    Left sweep rewritten to use the same expansion model as right_sweep.
    No missing-set. Uses target_L/target_R logic to expand the AP.
    """
    M = (l + r) // 2
    total_count = 0
    valid_L_keys: Set[int] = set()

    if M + 1 > r:
        return 0

    # Initialize sweep window at [M+1, M+1]
    L = M + 1
    R = M + 1
    min_val = A[R]
    max_val = A[R]

    curr = M
    while curr >= l:
        # Step 1: Add new element on the left
        L = curr
        min_cand, max_cand = min_max_table.query(L, R)

        while min_cand < min_val or max_cand > max_val:
            if min_cand < min_val:
                # Need to bridge values from new <-> old min
                for needed_val in range(min_cand + D, min_val, D):
                    idx = val_to_idx[needed_val - 1]
                    if idx < l or idx > r:
                        return total_count
                    L = min(L, idx)
                    R = max(R, idx)
                min_val = min_cand
            elif max_cand > max_val:
                # Need to bridge values from old <-> new max
                for needed_val in range(max_val + D, max_cand, D):
                    idx = val_to_idx[needed_val - 1]
                    if idx < l or idx > r:
                        return total_count
                    L = min(L, idx)
                    R = max(R, idx)
                max_val = max_cand
            min_cand, max_cand = min_max_table.query(L, R)

        # Record L as a valid left boundary
        if L in rs_map:
            valid_L_keys.add(L)

        total_count += 1

        # Count extensions via rs_map exactly as before
        for L_key in valid_L_keys:
            R_list = rs_map[L_key]
            start_index = bisect_left(R_list, R + 1)
            total_count += len(R_list) - start_index

        curr = L - 1

    return total_count

def count_arithmetic_progressions_iterative(A: List[int], D: int, val_to_idx: List[int], min_max_table, l_start: int, r_end: int) -> int:
    """
    Implements the core divide-and-conquer logic iteratively using a stack.

    Args:
        A: The input array.
        D: The arithmetic difference (stride).
        val_to_idx: Map of array values to their indices.
        l_start: The starting index of the segment.
        r_end: The ending index of the segment.

    Returns:
        The total count of APs within the segment [l_start, r_end].
    """
    if l_start >= r_end:
        return 0
        
    total_count = 0
    stack: List[Tuple[int, int]] = [(l_start, r_end)]
    
    while stack:
        l, r = stack.pop()
        
        if l >= r:
            continue

        M = (l + r) // 2

        # 1. Calculate the map of valid right extensions for APs starting at M (right half)
        rs_map = right_sweep(A, D, l, r, val_to_idx, min_max_table)
        
        # 2. Calculate the count of APs that cross the boundary M/M+1 (combination step)
        cross_count = left_sweep(A, D, l, r, val_to_idx, min_max_table, rs_map)
        
        total_count += cross_count
        
        # 3. Add subproblems to the stack
        stack.append((l, M))
        stack.append((M + 1, r))
        
    return total_count


def f(A: List[int]) -> int:
    n = len(A)

    # Pre-calculate value-to-index map once
    val_to_idx = map_values_to_indices(A)
 
    total_count = 0

    divisor_table = get_all_divisors_up_to_n(n)
    min_max_table = SplitSparseTable(A)

    ranges: DefaultDict[int, int] = defaultdict(int)

    for i in range(1, n):
        diff = abs(A[i] - A[i-1])
        current_divisors = divisor_table[diff]

        # Close any ranges whose stride no longer divides the current diff
        for stride in list(ranges.keys()):
            if stride not in current_divisors:
                start = ranges[stride]
                total_count += count_arithmetic_progressions_iterative(A, stride, val_to_idx, min_max_table, start, i - 1)
                del ranges[stride]

        # Open new ranges for each divisor D
        for D in current_divisors:
            if not D in ranges:
                ranges[D] = i - 1

    # Close all still-open ranges
    for D in list(ranges.keys()):
        start = ranges[D]
        total_count += count_arithmetic_progressions_iterative(A, D, val_to_idx, min_max_table, start, n - 1)

    return total_count


num_tests = 100
n = 20

for _ in range(num_tests):
  A = get_perm(n)
  brute = brute_force(A)
  ff = f(A)
  if brute != ff:
    print(brute, ff, A)
    break

print("Done.")

4 Comments

This looks really sophisticated, but trying to understand it makes my head swim. ;-) What are the timings for random and worst-case inputs of length 5, 50, 500, 5000 and 50000?
Timing for 500,000 was 41-42 seconds for the two times I ran it. Nowhere near desirable but is smaller complexity than O(n^2). (50,000 was around 3 seconds, but I think you meant 10^5, the question's limit.)
What is a "worst-case" input? An ordered sequence does not appear to be worst-case because there is only one sequence of GCD 1 for adjacent differences. I get about 15 seconds for 500,000 (10^5).
There should be a way to avoid some, maybe much of the iteration in the range expansion steps by direct lookup.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.