Skip to content

Commit c809dee

Browse files
authored
refactor: clean up code after Plutus versioning (#156)
- Fix bounds checking in constructor data builtin - Remove unused return values from encoder methods - Remove unused parameter from pretty printing - Simplify type inference in encoding calls Signed-off-by: Chris Gianelloni <wolf31o2@blinklabs.io>
1 parent 0c05257 commit c809dee

File tree

8 files changed

+52
-31
lines changed

8 files changed

+52
-31
lines changed

cek/builtins.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,14 +1551,18 @@ func constrData[T syn.Eval](m *Machine[T], b *Builtin[T]) (Value[T], error) {
15511551
dataList = append(dataList, itemData.Inner)
15521552
}
15531553

1554-
tag := arg1.Uint64()
1555-
if tag > math.MaxUint {
1554+
if arg1.BitLen() > 64 {
15561555
return nil, errors.New("constructor tag too large")
15571556
}
1557+
tag64 := arg1.Uint64()
1558+
if tag64 > uint64(math.MaxUint) {
1559+
return nil, errors.New("constructor tag too large")
1560+
}
1561+
tag := uint(tag64)
15581562

15591563
value := &Constant{&syn.Data{
15601564
Inner: &data.Constr{
1561-
Tag: uint(tag),
1565+
Tag: tag,
15621566
Fields: dataList,
15631567
},
15641568
}}

cek/machine.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ type Machine[T syn.Eval] struct {
2020
unbudgetedSteps [10]uint32
2121
}
2222

23-
func NewMachine[T syn.Eval](version [3]uint32, slippage uint32, costs ...CostModel) *Machine[T] {
23+
func NewMachine[T syn.Eval](
24+
version [3]uint32,
25+
slippage uint32,
26+
costs ...CostModel,
27+
) *Machine[T] {
2428
var costModel CostModel
2529
if len(costs) > 0 {
2630
costModel = costs[0]
@@ -41,7 +45,10 @@ func NewMachine[T syn.Eval](version [3]uint32, slippage uint32, costs ...CostMod
4145
}
4246

4347
// NewMachineWithVersionCosts creates a machine with version-appropriate cost models
44-
func NewMachineWithVersionCosts[T syn.Eval](version [3]uint32, slippage uint32) *Machine[T] {
48+
func NewMachineWithVersionCosts[T syn.Eval](
49+
version [3]uint32,
50+
slippage uint32,
51+
) *Machine[T] {
4552
costModel := GetCostModel(version)
4653
return NewMachine[T](version, slippage, costModel)
4754
}

cmd/play/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func main() {
9191
log.Fatalf("parse error: %v\n\n", err)
9292
}
9393

94-
prettyProgram := syn.Pretty[syn.Name](program)
94+
prettyProgram := syn.Pretty(program)
9595

9696
_ = os.WriteFile(filename, []byte(prettyProgram), 0o600)
9797

syn/encode_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func TestEncodeDecodeConstant(t *testing.T) {
1313
term := &Constant{Con: constant}
1414

1515
// Encode the term
16-
encoded, err := Encode[DeBruijn](&Program[DeBruijn]{
16+
encoded, err := Encode(&Program[DeBruijn]{
1717
Version: [3]uint32{1, 0, 0},
1818
Term: term,
1919
})
@@ -63,7 +63,7 @@ func TestEncodeDecodeBuiltin(t *testing.T) {
6363
original := &Builtin{DefaultFunction: tt.fn}
6464

6565
// Encode the builtin term
66-
encoded, err := Encode[DeBruijn](&Program[DeBruijn]{
66+
encoded, err := Encode(&Program[DeBruijn]{
6767
Version: [3]uint32{1, 0, 0},
6868
Term: original,
6969
})
@@ -117,11 +117,11 @@ func TestEncodeDecodeConstantTerm(t *testing.T) {
117117
constant: &Unit{},
118118
},
119119
{
120-
name: "constant_bool_true",
120+
name: "constant_bool_true",
121121
constant: &Bool{Inner: true},
122122
},
123123
{
124-
name: "constant_bool_false",
124+
name: "constant_bool_false",
125125
constant: &Bool{Inner: false},
126126
},
127127
}

syn/flat_decode.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ func Decode[T Binder](bytes []byte) (*Program[T], error) {
3535
return nil, err
3636
}
3737

38-
if major > math.MaxUint32 || minor > math.MaxUint32 || patch > math.MaxUint32 {
38+
if major > math.MaxUint32 || minor > math.MaxUint32 ||
39+
patch > math.MaxUint32 {
3940
return nil, errors.New("version numbers too large")
4041
}
4142

@@ -535,8 +536,19 @@ func (d *decoder) dropBits(numBits uint) {
535536
d.pos += int(allUsedBits / 8)
536537
}
537538

539+
// Ensures the buffer has the required bits passed in by required_bits.
540+
// Throws a NotEnoughBits error if there are less bits remaining in the
541+
// buffer than requiredBits.
542+
// Throws a BitsOverflow error if bits is more than MaxInt64
538543
func (d *decoder) ensureBits(requiredBits uint) error {
539-
if int64(requiredBits) > int64((len(d.buffer)-d.pos)*8)-d.usedBits { //nolint:gosec
544+
if requiredBits > math.MaxInt64 {
545+
return fmt.Errorf("BitsOverflow(%d)", requiredBits)
546+
}
547+
if int64(
548+
requiredBits,
549+
) > int64(
550+
(len(d.buffer)-d.pos)*8,
551+
)-d.usedBits { //nolint:gosec
540552
return fmt.Errorf("NotEnoughBits(%d)", requiredBits)
541553
} else {
542554
return nil

syn/flat_encode.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ func (e *encoder) word(c uint) *encoder {
230230
// number of bits to use is greater than unused bits. Expects that
231231
// number of bits to use is greater than or equal to required bits by the
232232
// value. The param num_bits is i64 to match unused_bits type.
233-
func (e *encoder) bits(numBits byte, val byte) *encoder {
233+
func (e *encoder) bits(numBits byte, val byte) {
234234
if numBits == 1 && val == 0 {
235235
e.zero()
236236
} else if numBits == 1 && val == 1 {
@@ -264,8 +264,6 @@ func (e *encoder) bits(numBits byte, val byte) *encoder {
264264
e.usedBits = used
265265
}
266266
}
267-
268-
return e
269267
}
270268

271269
// A filler amount of end 0's followed by a 1 at the end of a byte.
@@ -306,13 +304,11 @@ func (e *encoder) one() *encoder {
306304
// Write the current byte out to the buffer and begin next byte to write
307305
// out. Add current byte to the buffer and set current byte and used
308306
// bits to 0.
309-
func (e *encoder) nextWord() *encoder {
307+
func (e *encoder) nextWord() {
310308
e.buffer = append(e.buffer, e.currentByte)
311309

312310
e.currentByte = 0
313311
e.usedBits = 0
314-
315-
return e
316312
}
317313

318314
// Encode a string.

syn/pretty.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ func (pp *PrettyPrinter) decreaseIndent() {
6060

6161
// PrettyPrintTerm formats a Term[Name] to a string
6262
func prettyPrintTerm[T Binder](pp *PrettyPrinter, term Term[T]) string {
63-
printTerm[T](pp, term, true)
64-
63+
printTerm[T](pp, term)
6564
return pp.builder.String()
6665
}
6766

@@ -89,7 +88,7 @@ func printProgram[T Binder](pp *PrettyPrinter, prog *Program[T]) {
8988
pp.increaseIndent()
9089
pp.writeIndent()
9190

92-
printTerm[T](pp, prog.Term, false)
91+
printTerm[T](pp, prog.Term)
9392

9493
pp.decreaseIndent()
9594
pp.write("\n")
@@ -100,7 +99,7 @@ func printProgram[T Binder](pp *PrettyPrinter, prog *Program[T]) {
10099
}
101100

102101
// printTerm dispatches to the appropriate term printing method
103-
func printTerm[T Binder](pp *PrettyPrinter, term Term[T], isTopLevel bool) {
102+
func printTerm[T Binder](pp *PrettyPrinter, term Term[T]) {
104103
switch t := term.(type) {
105104
case *Var[T]:
106105
pp.write(t.Name.TextName())
@@ -113,7 +112,7 @@ func printTerm[T Binder](pp *PrettyPrinter, term Term[T], isTopLevel bool) {
113112
pp.increaseIndent()
114113
pp.writeIndent()
115114

116-
printTerm[T](pp, t.Body, false)
115+
printTerm[T](pp, t.Body)
117116

118117
pp.decreaseIndent()
119118
pp.write("\n")
@@ -127,7 +126,7 @@ func printTerm[T Binder](pp *PrettyPrinter, term Term[T], isTopLevel bool) {
127126
pp.increaseIndent()
128127
pp.writeIndent()
129128

130-
printTerm[T](pp, t.Term, false)
129+
printTerm[T](pp, t.Term)
131130

132131
pp.decreaseIndent()
133132
pp.write("\n")
@@ -141,7 +140,7 @@ func printTerm[T Binder](pp *PrettyPrinter, term Term[T], isTopLevel bool) {
141140
pp.increaseIndent()
142141
pp.writeIndent()
143142

144-
printTerm[T](pp, t.Term, false)
143+
printTerm[T](pp, t.Term)
145144

146145
pp.decreaseIndent()
147146
pp.write("\n")
@@ -155,12 +154,12 @@ func printTerm[T Binder](pp *PrettyPrinter, term Term[T], isTopLevel bool) {
155154
pp.increaseIndent()
156155
pp.writeIndent()
157156

158-
printTerm[T](pp, t.Function, false)
157+
printTerm[T](pp, t.Function)
159158

160159
pp.write("\n")
161160
pp.writeIndent()
162161

163-
printTerm[T](pp, t.Argument, false)
162+
printTerm[T](pp, t.Argument)
164163

165164
pp.decreaseIndent()
166165
pp.write("\n")
@@ -183,7 +182,7 @@ func printTerm[T Binder](pp *PrettyPrinter, term Term[T], isTopLevel bool) {
183182
for _, field := range t.Fields {
184183
pp.writeIndent()
185184

186-
printTerm[T](pp, field, false)
185+
printTerm[T](pp, field)
187186

188187
pp.write("\n")
189188
}
@@ -199,7 +198,7 @@ func printTerm[T Binder](pp *PrettyPrinter, term Term[T], isTopLevel bool) {
199198
case *Case[T]:
200199
pp.write("(case ")
201200

202-
printTerm[T](pp, t.Constr, false)
201+
printTerm[T](pp, t.Constr)
203202

204203
if len(t.Branches) > 0 {
205204
pp.write("\n")
@@ -208,7 +207,7 @@ func printTerm[T Binder](pp *PrettyPrinter, term Term[T], isTopLevel bool) {
208207
for _, branch := range t.Branches {
209208
pp.writeIndent()
210209

211-
printTerm[T](pp, branch, false)
210+
printTerm[T](pp, branch)
212211

213212
pp.write("\n")
214213
}

tests/conformance_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,10 @@ func TestConformance(t *testing.T) {
221221

222222
// Evaluate program
223223

224-
machine := cek.NewMachine[syn.DeBruijn](dProgram.Version, 200)
224+
machine := cek.NewMachine[syn.DeBruijn](
225+
dProgram.Version,
226+
200,
227+
)
225228

226229
result, err := machine.Run(dProgram.Term)
227230
if err != nil {

0 commit comments

Comments
 (0)