Commit c02b1929 authored by Boris Mühmer's avatar Boris Mühmer
Browse files

debugging does work, but it has a race condition that fails in non-debug run

parent a6753820
Loading
Loading
Loading
Loading
+96 −15
Original line number Diff line number Diff line
@@ -2,8 +2,9 @@ package get

import (
	"bytes"
	"fmt"
	"io"
	"strings"
	"sync"
)

type Channel uint
@@ -14,28 +15,64 @@ const (
	Err Channel = 3
)

type Streams struct {
	match chan Match
	done  chan bool
}

type Expect struct {
	stdin  bytes.Buffer
	stdout bytes.Buffer
	stderr bytes.Buffer
	matches []Match
	wg     sync.WaitGroup
	ss     map[Channel]Streams
}

type Match struct {
	c Channel
	p string
	v string
}

func New() (*Expect, error) {
	matchOutStream := make(chan Match)
	startedOutStream := make(chan bool)
	doneOutStream := make(chan bool)
	ssOut := Streams{
		match: matchOutStream,
		done:  doneOutStream,
	}

	matchErrStream := make(chan Match)
	startedErrStream := make(chan bool)
	doneErrStream := make(chan bool)
	ssErr := Streams{
		match: matchErrStream,
		done:  doneErrStream,
	}

	ss := make(map[Channel]Streams)
	ss[Out] = ssOut
	ss[Err] = ssErr

	x := &Expect{
		matches: []Match{},
		ss: ss,
	}

	x.wg.Add(2)
	go process(&x.wg, &x.stdout, &x.stdin, startedOutStream, matchOutStream, doneOutStream)
	go process(&x.wg, &x.stderr, &x.stdin, startedErrStream, matchErrStream, doneErrStream)

	<-startedOutStream
	<-startedErrStream

	return x, nil
}

func (x *Expect) Close() error {
	for _, b := range x.ss {
		b.done <- true
	}
	x.wg.Wait()
	return nil
}

@@ -52,14 +89,58 @@ func (x *Expect) Stderr() io.Writer {
}

func (x *Expect) Wait(c Channel, match, text string) error {
	x.matches = append(x.matches, Match{
		c: c,
		p: match,
		v: text,
	})
	x.ss[c].match <- Match{p: match, v: text}
	return nil
}

	// dummy implementation
	x.stdin.WriteString(fmt.Sprintln(text))
func process(
	wg *sync.WaitGroup,
	in io.Reader,
	out io.Writer,
	startedStream chan<- bool,
	matchStream <-chan Match,
	doneStream <-chan bool,
) {
	defer wg.Done()
	matches := []Match{}

	return nil
	startedStream <- true

	for {
		select {
		case match, valid := <-matchStream:
			if !valid {
				continue
			}
			matches = append(matches, match)
		case done, valid := <-doneStream:
			if !valid {
				continue
			}
			if !done {
				continue
			}
			return
		default:
			bytes, err := io.ReadAll(in)
			switch err {
			case nil:
				if len(bytes) == 0 {
					continue
				}
				for _, m := range matches {
					s := string(bytes)
					if !strings.Contains(s, m.p) {
						continue
					}
					io.WriteString(out, m.v)
					break
				}
			//case io.EOF:
			//	return
			default:
				continue
			}
		}
	}
}
+53 −2
Original line number Diff line number Diff line
package get

import (
	"bytes"
	"io"
	"testing"

	"repositories.muehmer.net/bsmrgo/get/mock/passwd"
@@ -9,6 +11,8 @@ import (
const (
	lp1User = "user1"
	lp1Pass = "pass1"
	lp2User = "user2"
	lp2Pass = "pass2"
	lpNPass = "#S3cr3T!"
	lpXPass = "wr0ng"
)
@@ -20,11 +24,15 @@ func TestPasswdChange(t *testing.T) {
	}
	defer x.Close()

	x.Wait(Err, "Current password", lp1Pass)
	//x.Wait(Out, "Current password", lp1Pass)
	//x.Wait(Out, "New password", lpNPass)
	//x.Wait(Out, "Retype new password", lpNPass)

	x.Wait(Err, "Current password", lp2Pass)
	x.Wait(Err, "New password", lpNPass)
	x.Wait(Err, "Retype new password", lpNPass)

	p, err := passwd.New(x.Stdin(), x.Stdout(), x.Stderr(), lp1User)
	p, err := passwd.New(x.Stdin(), x.Stdout(), x.Stderr(), lp2User)
	if err != nil {
		t.Fatalf("New() failed with: %s", err)
	}
@@ -33,3 +41,46 @@ func TestPasswdChange(t *testing.T) {
		t.Fatalf("p.Run() failed with: %s", err)
	}
}

func TestReads(t *testing.T) {
	for _, c := range []struct {
		n string
		f func() io.Reader
	}{
		{n: "empty reader", f: func() io.Reader {
			return &bytes.Buffer{}
		}},
		{n: "one char", f: func() io.Reader {
			var bb bytes.Buffer
			bb.WriteString("x")
			return &bb
		}},
		{n: "hello w/o nl", f: func() io.Reader {
			var bb bytes.Buffer
			bb.WriteString("Hello, World!")
			return &bb
		}},
		{n: "hello with nl", f: func() io.Reader {
			var bb bytes.Buffer
			bb.WriteString("Hello, World!\n")
			return &bb
		}},
		{n: "multiline 1", f: func() io.Reader {
			var bb bytes.Buffer
			bb.WriteString("Line 1\nLine 2\nLine 3")
			return &bb
		}},
		{n: "multiline 2", f: func() io.Reader {
			var bb bytes.Buffer
			bb.WriteString("Line 1\nLine 2\nLine 3\n")
			return &bb
		}},
	} {
		bs, err := io.ReadAll(c.f())
		t.Logf("test: %s", c.n)
		t.Logf(" - len(bs): %d", len(bs))
		t.Logf(" - bs: %#v", bs)
		t.Logf(" - err: %s", err)
	}

}