Fix false positive in ContainsAll function
As the `ContainsAll` is working with a match counter, it could return a false positive when the `haystack` slice contains duplicate elements. This can be checked with the included testing scenario, with `haystack = [1, 1]` and `needles = [1, 2]`. Iterating over the haystack to check for items to be present in needles will increase the match counter to 2, even if `2` is not present in the first slice.
This commit is contained in:
		
							parent
							
								
									f4e6874ff0
								
							
						
					
					
						commit
						ebd4367dda
					
				
					 2 changed files with 13 additions and 9 deletions
				
			
		
							
								
								
									
										12
									
								
								util/util.go
									
										
									
									
									
								
							
							
						
						
									
										12
									
								
								util/util.go
									
										
									
									
									
								
							|  | @ -6,7 +6,6 @@ import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"golang.org/x/time/rate" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
|  | @ -17,6 +16,8 @@ import ( | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"golang.org/x/time/rate" | ||||||
|  | 
 | ||||||
| 	"github.com/gabriel-vasile/mimetype" | 	"github.com/gabriel-vasile/mimetype" | ||||||
| 	"golang.org/x/term" | 	"golang.org/x/term" | ||||||
| ) | ) | ||||||
|  | @ -67,15 +68,12 @@ func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool { | ||||||
| 
 | 
 | ||||||
| // ContainsAll returns true if all needles are contained in haystack | // ContainsAll returns true if all needles are contained in haystack | ||||||
| func ContainsAll[T comparable](haystack []T, needles []T) bool { | func ContainsAll[T comparable](haystack []T, needles []T) bool { | ||||||
| 	matches := 0 |  | ||||||
| 	for _, s := range haystack { |  | ||||||
| 	for _, needle := range needles { | 	for _, needle := range needles { | ||||||
| 			if s == needle { | 		if !Contains(haystack, needle) { | ||||||
| 				matches++ | 			return false | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	} | 	return true | ||||||
| 	return matches == len(needles) |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SplitNoEmpty splits a string using strings.Split, but filters out empty strings | // SplitNoEmpty splits a string using strings.Split, but filters out empty strings | ||||||
|  |  | ||||||
|  | @ -2,7 +2,6 @@ package util | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"golang.org/x/time/rate" |  | ||||||
| 	"io" | 	"io" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
| 	"os" | 	"os" | ||||||
|  | @ -11,6 +10,8 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"golang.org/x/time/rate" | ||||||
|  | 
 | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -49,6 +50,11 @@ func TestContains(t *testing.T) { | ||||||
| 	require.False(t, Contains(s, 3)) | 	require.False(t, Contains(s, 3)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestContainsAll(t *testing.T) { | ||||||
|  | 	require.True(t, ContainsAll([]int{1, 2, 3}, []int{2, 3})) | ||||||
|  | 	require.False(t, ContainsAll([]int{1, 1}, []int{1, 2})) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestContainsIP(t *testing.T) { | func TestContainsIP(t *testing.T) { | ||||||
| 	require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.1.1.1"))) | 	require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.1.1.1"))) | ||||||
| 	require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fd12:1234:5678::9876"))) | 	require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fd12:1234:5678::9876"))) | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue