c++floating-pointintegercomparison

Is it safe to assume 32-bit floats can be directly compared against each other if value fits the mantissa?


In a leetcode problem about finding if an integer number is sum of perfect squares, using floats instead of ints resulted in more speedup ("perfect-squares" problem). Is it safe to assume that if integer values are guaranteed to be less than 10000 and greater than or equal to 0, we can use floats instead?

Sample comparison:

if(n == i*i + j*j * 2)
    result3++;

if(n == i*i + k*k)
    result2++;

both int and float passed all tests (n,i,j,k, all float or all int) but I'm still not sure if there's any difference between CPUs (not sure if leetcode uses exact same always) or compiler or something else(like time?).

Link to problem and code: https://leetcode.com/problems/perfect-squares/description/

Code:

#include<iostream>
#include<math.h>
class Solution {
public:

static constexpr int simd =8;
using FAST_TYPE = short;
using MASK_TYPE = short;
    const int numSquares(const int n) const noexcept {
        if(n==2 || n==8)
            return 2;
        if(n==3 || n==6 || n==11)
            return 3;
        
        if((int)std::sqrt(n)*(int)std::sqrt(n) == n)
            return 1;

        FAST_TYPE found2 = 0;
        FAST_TYPE found3 = 0;
        FAST_TYPE found32 = 0;
        FAST_TYPE found33 = 0;
        FAST_TYPE found34 = 0;


        alignas(64)
        FAST_TYPE zeroSimd[simd];
        alignas(64)
        FAST_TYPE oneSimd[simd];
        alignas(64)
        FAST_TYPE found3Simd[simd];
        alignas(64)
        FAST_TYPE found3Simd2[simd];
        alignas(64)
        FAST_TYPE found3Simd3[simd];
        alignas(64)
        FAST_TYPE found3Simd4[simd];                
        alignas(64)
        FAST_TYPE mSimd[simd];
        alignas(64)
        FAST_TYPE kSimd[simd];
        alignas(64)
        FAST_TYPE k0Simd[simd];
        alignas(64)
        FAST_TYPE nSimd[simd];
        alignas(64)
        FAST_TYPE twoSimd[simd];
        alignas(64)
        FAST_TYPE threeSimd[simd];   
        alignas(64)
        FAST_TYPE iSimd[simd];     
        alignas(64)
        FAST_TYPE jSimd[simd];
        alignas(64)
        FAST_TYPE ijSimd[simd];           
        alignas(64)
        FAST_TYPE j2Simd[simd];        
        alignas(64)
        FAST_TYPE i2Simd[simd];      
        alignas(64)
        MASK_TYPE mask1Simd[simd];                    
        alignas(64)
        MASK_TYPE mask2Simd[simd];                    
        alignas(64)
        MASK_TYPE mask3Simd[simd];                    
        alignas(64)
        MASK_TYPE mask4Simd[simd];         
        alignas(64)
        FAST_TYPE sum1Simd[simd];                                              
        alignas(64)
        FAST_TYPE sum2Simd[simd];                                                      
        alignas(64)
        FAST_TYPE sum3Simd[simd];       
        alignas(64)
        FAST_TYPE mulSimd[simd];                                                         
        for(int i=0;i<simd;i++)
        {
            zeroSimd[i]=0;
            oneSimd[i]=1;
            found3Simd[i]=0;
            found3Simd2[i]=0;
            found3Simd3[i]=0;
            found3Simd4[i]=0;
            mSimd[i]=i;
            nSimd[i]=n;
            twoSimd[i]=2;
            threeSimd[i]=2;
            
        }
        for(int i=1+std::sqrt(n);i>=1;i--)
        {
            const FAST_TYPE i2 = i*i;
            const FAST_TYPE i22 = 2*i*i;            
            const FAST_TYPE i23 = 3*i*i;            
            #pragma GCC ivdep
            for(int m=0;m<simd;m++)
                iSimd[m]=i2;
            #pragma GCC ivdep
            for(int m=0;m<simd;m++)
                i2Simd[m]=i22;                
            found2 += (i22 == n);            
            found3+=(i23 == n);   
            for(int j=i-1;j>=1;j--)
            {
                const FAST_TYPE j2 = j*j;
                const FAST_TYPE j22 = 2*j*j;
                const FAST_TYPE j23 = 3*j*j;
                #pragma GCC ivdep
                for(int m=0;m<simd;m++)
                    jSimd[m]=j2;
                #pragma GCC ivdep
                for(int m=0;m<simd;m++)
                    j2Simd[m]=j22;       
                #pragma GCC ivdep
                for(int m=0;m<simd;m++)
                    ijSimd[m]=i2+j2;                                      
                found2+=(i2 + j2 == n);
                found3+=(i2 + j22 == n)+(i22 + j2 == n)+(j23 == n);        
                const int k32 = j-1 - ((j-1)%simd);  
                #pragma GCC unroll 2
                for(int k0=1;k0<=k32;k0+=simd) 
                {
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)
                        k0Simd[m]=k0;
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                 
                        kSimd[m] = k0Simd[m]+mSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                 
                        kSimd[m] = kSimd[m]*kSimd[m];


                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        sum1Simd[m]=ijSimd[m] + kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                             
                        mask1Simd[m]=sum1Simd[m] == nSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                  
                        found3Simd[m]=mask1Simd[m]?oneSimd[m]:found3Simd[m];


                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        sum2Simd[m]=i2Simd[m] + kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                             
                        mask2Simd[m]=(sum2Simd[m]==nSimd[m]);
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                  
                        found3Simd2[m]=mask2Simd[m]?oneSimd[m]:found3Simd2[m];

                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        sum3Simd[m]=j2Simd[m] + kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                             
                        mask3Simd[m]=(sum3Simd[m]==nSimd[m]);                        
                     #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                      
                        found3Simd3[m]=mask3Simd[m]?oneSimd[m]:found3Simd3[m];     


                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++) 
                        mulSimd[m]=threeSimd[m]*kSimd[m];
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)       
                        mask4Simd[m]=(mulSimd[m]==nSimd[m]);                    
                    #pragma GCC ivdep
                    for(int m=0;m<simd;m++)                                                                                                                                                                                                    
                        found3Simd4[m]=mask4Simd[m]?oneSimd[m]:found3Simd4[m];
                    
                }

                
                for(int k=k32;k<=j-1;k++)   
                {                   
                    const FAST_TYPE k2 = k*k;
                    found3+=(i2 + j2 + k2 ==n);
                    found32+=(i22 + k2 ==n);
                    found33+=(j22 + k2 ==n);
                    found34+=(3*k2 ==n);
                }
            }      
        }

        for(int i=0;i<simd;i++)
        {
            found3+=found3Simd[i];
            found32+=found3Simd2[i];
            found33+=found3Simd3[i];
            found34+=found3Simd4[i];
        }
        found3 += found32 + found33 + found34;
        if(found2)
            return 2;

        if(found3)
            return 3;

        return 4;
    }

};

int main()
{
    Solution s;
    for(int i=10;i<20;i++)
    {
        std::cout<<i<<" is equal to sum of "<<s.numSquares(i)<< " perfect squares"<<std::endl; 
    }
}

output:

10 is equal to sum of 2 perfect squares
11 is equal to sum of 3 perfect squares
12 is equal to sum of 3 perfect squares
13 is equal to sum of 2 perfect squares
14 is equal to sum of 3 perfect squares
15 is equal to sum of 4 perfect squares
16 is equal to sum of 1 perfect squares
17 is equal to sum of 2 perfect squares
18 is equal to sum of 2 perfect squares
19 is equal to sum of 3 perfect squares

Solution

  • Is it safe to assume 32-bit floats can be directly compared against each other if value fits the mantissa?

    This is not the right question; floating-point numbers can always compared with each other, and the comparison will indicate they are equal if and only iff they are equal.

    The right question is whether using floating-point arithmetic will calculate the results you desire.

    Per IEEE-754 and other floating-point specifications, each floating-point number represents one real number. That representation is exact, with no error. It is floating-point operations, not numbers, that approximate real arithmetic. When a floating-point operation is performed, the floating-point result is the real-number-arithmetic result rounded to the nearest value representable in the floating-point format (using a choice of rounding rules). If the real-number-arithmetic result is representable, it is the result; there will be no rounding.

    The C++ standard guarantees float can represent numbers with enough resolution to distinguish six-significant-digit decimal numerals throughout its range. Therefore, it can represent all integers up to 1,000,000 at least, more than the 10,000 range you requested.

    The operations you perform using integer values, such as 3*j*j, will be exact within this range.

    This operation will not always be exact: (int)std::sqrt(n)*(int)std::sqrt(n).

    When n is not a square, it is of course impossible for sqrt to produce an exact result. However, that is fine for your purpose, as the return value will be truncated, and the computed product will not equal n, so the comparison (int)std::sqrt(n)*(int)std::sqrt(n) == n will evaluate to false, as desired.

    Where you have a problem is that some implementations of sqrt may not return an integer even when n is an integer square.

    If you have a good sqrt implementation, it will return an exact result, and the comparison of the product with n will evaluate to true, as desired. If you have a bad sqrt implementation, and it returns a value slightly less than the proper square root of n, then the comparison will evaluate to false, causing a problem in your program.

    Since you are only interested in the range up to 10,000, you can easily test sqrt on all squares in this range to see if it performs as desired.