diff --git a/internal/flags/safe_relative_path.go b/internal/flags/safe_relative_path.go index 4a5f78df..dfd3b9ea 100644 --- a/internal/flags/safe_relative_path.go +++ b/internal/flags/safe_relative_path.go @@ -41,7 +41,10 @@ func (p *SafeRelativePath) Set(str string) error { return fmt.Errorf("invalid relative path '%s': %w", cleanP, err) } // NB: required, as a secure join of "./" will result in "." - cleanP = fmt.Sprintf("./%s", strings.TrimPrefix(cleanP, ".")) + if cleanP == "." { + cleanP = "" + } + cleanP = fmt.Sprintf("./%s", cleanP) *p = SafeRelativePath(cleanP) return nil } diff --git a/internal/flags/safe_relative_path_test.go b/internal/flags/safe_relative_path_test.go index 325bd658..5c74ca28 100644 --- a/internal/flags/safe_relative_path_test.go +++ b/internal/flags/safe_relative_path_test.go @@ -37,6 +37,10 @@ func TestRelativePath_Set(t *testing.T) { {"traversing absolute path", "/foo/../bar", "./bar", false}, {"traversing overflowing absolute path", "/foo/../../../bar", "./bar", false}, {"empty", "", "./", false}, + {"relative empty path", "./", "./", false}, + {"double relative empty path", "././", "./", false}, + {"dot path", ".foo", "./.foo", false}, + {"relative dot path", "./.foo", "./.foo", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {