TLE on USACO "Subset Equality"

For Silver Problem 2 on the 2022 USACO Open (http://www.usaco.org/index.php?page=viewproblem2&cpid=1231), my solution times out even though the number of operations it does seem to fit within the time limit.

I was testing my solution, and I found that this piece of code is what causes it to time out:

for (int i=0; i<18; i++){
 for (int j=i+1; j<18; j++){
  if (check(i, j)) eq[(1<<i)+(1<<j)]=1;
  else eq[(1<<i)+(1<<j)]=0;
 }
}

where I defined the function check as :

 bool check(int i, int j){
  if (posS[i].size()!=posT[i].size() || posS[j].size()!=posT[j].size()) return false;
  string s1="", t1="";
  int pi=0, pj=0;
  for (int k=0; k<posS[i].size()+posT[j].size(); k++){
    if (pi<posS[i].size() && pj<posS[j].size()){
      if (posS[i][pi]<posS[j][pj]) s1=s1+let[i], pi++;
      else s1=s1+let[j], pj++;
    }
    else if (pi>=posS[i].size()) s1=s1+let[j], pj++;
    else s1=s1+let[i], pi++;
  }
  pi=0, pj=0;
  for (int k=0; k<posT[i].size()+posT[j].size(); k++){
    if (pi<posT[i].size() && pj<posT[j].size()){
      if (posT[i][pi]<posT[j][pj]) t1=t1+let[i], pi++;
      else t1=t1+let[j], pj++;
    }
    else if (pi>=posT[i].size()) t1=t1+let[j], pj++;
    else t1=t1+let[i], pi++;
  }
  return(s1==t1);
}

Basically, at the beginning I iterated through both strings s and t once, and kept track of the positions of each letter (with let[0]=β€˜a’, let[1]=β€˜b’, …, let[17]=β€˜r’). This takes O(|s|+|t|).

Then, I iterated through all pairs of letters: the check function uses two-pointers to reconstruct s, t when restricted to that pair of letters. But since there are 153 pairs, doing the check function on each pair should take a total of at most O(153*(|s|+|t|)) since there are that many pairs (and actually it should take a lot less time than that due to amortization). Could someone please help me figure out why the solution is timing out?

(The full code is below if anyone needs it)

#include <iostream>
#include <bits/stdc++.h>
#include <fstream>
#include <string>
#include <cstdio>
#include <cstring>
#include <algorithm>  
#include <math.h>
#include <numeric>
using namespace std;
using ll=long long;

int eq[1<<18];
vector<int> posS[18], posT[18];
char let[18]={'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r'};

bool check(int i, int j){
  if (posS[i].size()!=posT[i].size() || posS[j].size()!=posT[j].size()) return false;
  string s1="", t1="";
  int pi=0, pj=0;
  for (int k=0; k<posS[i].size()+posT[j].size(); k++){
    if (pi<posS[i].size() && pj<posS[j].size()){
      if (posS[i][pi]<posS[j][pj]) s1=s1+let[i], pi++;
      else s1=s1+let[j], pj++;
    }
    else if (pi>=posS[i].size()) s1=s1+let[j], pj++;
    else s1=s1+let[i], pi++;
  }
  pi=0, pj=0;
  for (int k=0; k<posT[i].size()+posT[j].size(); k++){
    if (pi<posT[i].size() && pj<posT[j].size()){
      if (posT[i][pi]<posT[j][pj]) t1=t1+let[i], pi++;
      else t1=t1+let[j], pj++;
    }
    else if (pi>=posT[i].size()) t1=t1+let[j], pj++;
    else t1=t1+let[i], pi++;
  }
  // cout<<s1<<" "<<t1<<"\n";
  return(s1==t1);
}

int main() {
  ios_base::sync_with_stdio(false); 
  cin.tie(0);
  // ifstream cin("file.in");
  string s, t;
  int Q;
  cin>>s>>t>>Q;
  eq[0]=0;
  
  for (int i=0; i<s.size(); i++) posS[(int)(s[i]-'a')].push_back(i);
  for (int i=0; i<t.size(); i++) posT[(int)(t[i]-'a')].push_back(i);
  for (int i=0; i<18; i++){
    if (posS[i].size()==posT[i].size()) eq[1<<i]=1;
    else eq[1<<i]=0;
  }
  for (int i=0; i<18; i++){
    for (int j=i+1; j<18; j++){
      if (check(i, j)) eq[(1<<i)+(1<<j)]=1;
      else eq[(1<<i)+(1<<j)]=0;
    }
  }
  vector<int> subs[19];
  for (int mask=0; mask<(1<<18); mask++){
    vector<int> v;
    for (int i=0; i<18; i++){
      if (mask>>i&1) v.push_back(i);
    }
    subs[v.size()].push_back(mask);
  }
  for (int i=3; i<=18; i++){
    for (int mask: subs[i]){
      vector<int> v;
      for (int j=0; j<18; j++){
        if (mask>>j&1) v.push_back(j);
      }
      int works=1;
      for (int i: v){
        if (eq[mask-(1<<i)]==0) works=0;
      }
      eq[mask]=works;
    }
  }
  for (int i=0; i<Q; i++){
    string s0;
    cin>>s0;
    int mask=0;
    for (char c: s0) mask+=(1<<(int)(c-'a'));
    if (eq[mask]==1) cout<<"Y";
    else cout<<"N";
  }
}

Repeated string concatenation is \mathcal{O}(N^2).

Here’s another helpful link:

1 Like

Ohhhhh that makes sense, before I always thought that string concatenation was O(1) per operation. I replaced the string concatenation with comparing two char arrays and it works now. Thank you!