Files
esp-matter/tools/mfg_tool/utils.py
T

333 lines
12 KiB
Python

#!/usr/bin/env python3
# Copyright 2022 Espressif Systems (Shanghai) PTE LTD
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains utilitiy functions for validating argument, certs/keys conversion, etc.
"""
import sys
import enum
import logging
import subprocess
from bitarray import bitarray
from bitarray.util import ba2int
import cryptography.hazmat.backends
import cryptography.x509
ROTATING_DEVICE_ID_UNIQUE_ID_LEN_BITS = 128
SERIAL_NUMBER_LEN = 16
# Lengths for manual pairing codes and qrcode
SHORT_MANUALCODE_LEN = 11
LONG_MANUALCODE_LEN = 21
QRCODE_LEN = 22
INVALID_PASSCODES = [00000000, 11111111, 22222222, 33333333, 44444444, 55555555,
66666666, 77777777, 88888888, 99999999, 12345678, 87654321]
class CalendarTypes(enum.Enum):
Buddhist = 0
Chinese = 1
Coptic = 2
Ethiopian = 3
Gregorian = 4
Hebrew = 5
Indian = 6
Islamic = 7
Japanese = 8
Korean = 9
Persian = 10
Taiwanese = 11
def vid_pid_str(vid, pid):
return '_'.join([hex(vid)[2:], hex(pid)[2:]])
def disc_pin_str(discriminator, passcode):
return '_'.join([hex(discriminator)[2:], hex(passcode)[2:]])
# Checks if the input string is a valid hex string
def ishex(s):
try:
n = int(s, 16)
return True
except ValueError:
return False
# Validate the input string length against the min and max length
def check_str_range(s, min_len, max_len, name):
if s and ((len(s) < min_len) or (len(s) > max_len)):
logging.error('%s must be between %d and %d characters', name, min_len, max_len)
sys.exit(1)
# Validate the input integer range
def check_int_range(value, min_value, max_value, name):
if value and ((value < min_value) or (value > max_value)):
logging.error('%s is out of range, should be in range [%d, %d]', name, min_value, max_value)
sys.exit(1)
# Validates discriminator and passcode
def validate_commissionable_data(args):
check_int_range(args.discriminator, 0x0000, 0x0FFF, 'Discriminator')
if args.passcode is not None:
if ((args.passcode < 0x0000001 and args.passcode > 0x5F5E0FE) or (args.passcode in INVALID_PASSCODES)):
logging.error('Invalid passcode' + str(args.passcode))
sys.exit(1)
# Validate the device instance information
def validate_device_instance_info(args):
check_int_range(args.product_id, 0x0000, 0xFFFF, 'Product id')
check_int_range(args.vendor_id, 0x0000, 0xFFFF, 'Vendor id')
check_int_range(args.hw_ver, 0x0000, 0xFFFF, 'Hardware version')
check_str_range(args.serial_num, 1, SERIAL_NUMBER_LEN, 'Serial number')
check_str_range(args.vendor_name, 1, 32, 'Vendor name')
check_str_range(args.product_name, 1, 32, 'Product name')
check_str_range(args.hw_ver_str, 1, 64, 'Hardware version string')
check_str_range(args.mfg_date, 8, 16, 'Manufacturing date')
check_str_range(args.rd_id_uid, 32, 32, 'Rotating device Unique id')
# Validate the device information: calendar types and fixed labels
def validate_device_info(args):
# Validate the input calendar types
if args.calendar_types is not None:
if not (set(args.calendar_types) <= set(CalendarTypes.__members__)):
invalid_types = set(args.calendar_types).union(set(CalendarTypes.__members__)) - set(CalendarTypes.__members__)
logging.error('Unknown calendar type/s: %s', invalid_types)
logging.error('Supported calendar types: %s', ', '.join(CalendarTypes.__members__))
sys.exit(1)
if args.fixed_labels is not None:
for fl in args.fixed_labels:
_l = fl.split('/')
if len(_l) != 3:
logging.error('Invalid fixed label: %s', fl)
sys.exit(1)
if not (ishex(_l[0]) and (len(_l[1]) > 0 and len(_l[1]) < 16) and (len(_l[2]) > 0 and len(_l[2]) < 16)):
logging.error('Invalid fixed label: %s', fl)
sys.exit(1)
# Validates the attestation related arguments
def validate_attestation_info(args):
# DAC key and DAC cert both should be present or none
if (args.dac_key is not None) != (args.dac_cert is not None):
logging.error("dac_key and dac_cert should be both present or none")
sys.exit(1)
else:
# Make sure PAI certificate is present if DAC is present
if (args.dac_key is not None) and (args.pai is False):
logging.error('Please provide PAI certificate along with DAC certificate and DAC key')
sys.exit(1)
# Validate the input certificate type, if DAC is not present
if args.dac_key is None and args.dac_cert is None:
if args.paa:
logging.info('Input Root certificate type PAA')
elif args.pai:
logging.info('Input Root certificate type PAI')
else:
logging.error('Either PAA or PAI certificate is required')
sys.exit(1)
# Check if Key and certificate are present
if args.key is None or args.cert is None:
logging.error('PAA key and certificate are required')
sys.exit(1)
# Validates few basic cluster related arguments: product-label and product-url
def validate_basic_cluster_info(args):
check_str_range(args.product_label, 1, 64, 'Product Label')
check_str_range(args.product_url, 1, 256, 'Product URL')
# Validates the input arguments, this calls the above functions
def validate_args(args):
# csv and mcsv both should present or none
if (args.csv is not None) != (args.mcsv is not None):
logging.error("csv and mcsv should be both present or none")
sys.exit(1)
else:
# Read the number of lines in mcsv file
if args.mcsv is not None:
with open(args.mcsv, 'r') as f:
lines = sum(1 for line in f)
# Subtract 1 for the header line
args.count = lines - 1
validate_commissionable_data(args)
validate_device_instance_info(args)
validate_device_info(args)
validate_attestation_info(args)
validate_basic_cluster_info(args)
# If discriminator/passcode/DAC/serial_number/rotating_device_id is present
# then we are restricting the number of partitions to 1
if (args.discriminator is not None
or args.passcode is not None
or args.dac_key is not None
or args.serial_num is not None
or args.rd_id_uid is not None):
if args.count > 1:
logging.error('Number of partitions should be 1 when discriminator or passcode or DAC or serial number or rotating device id is present')
sys.exit(1)
logging.info('Number of manufacturing NVS images to generate: {}'.format(args.count))
# Supported Calendar types is stored as a bit array in one uint32_t.
def calendar_types_to_uint32(calendar_types):
# In validate_device_info() we have already verified that the calendar types are valid
result = bitarray(32, endian='little')
result.setall(0)
for calendar_type in calendar_types:
result[CalendarTypes[calendar_type].value] = 1
return ba2int(result)
# get_fixed_label_dict() converts the list of strings to per endpoint dictionaries.
# example input : ['0/orientation/up', '1/orientation/down', '2/orientation/down']
# example outout : {'0': [{'orientation': 'up'}], '1': [{'orientation': 'down'}], '2': [{'orientation': 'down'}]}
def get_fixed_label_dict(fixed_labels):
fl_dict = {}
for fl in fixed_labels:
_l = fl.split('/')
if len(_l) != 3:
logging.error('Invalid fixed label: %s', fl)
sys.exit(1)
if not (ishex(_l[0]) and (len(_l[1]) > 0 and len(_l[1]) < 16) and (len(_l[2]) > 0 and len(_l[2]) < 16)):
logging.error('Invalid fixed label: %s', fl)
sys.exit(1)
if _l[0] not in fl_dict.keys():
fl_dict[_l[0]] = list()
fl_dict[_l[0]].append({_l[1]: _l[2]})
return fl_dict
# Convert the certificate in PEM format to DER format
def convert_x509_cert_from_pem_to_der(pem_file, out_der_file):
with open(pem_file, 'rb') as f:
pem_data = f.read()
pem_cert = cryptography.x509.load_pem_x509_certificate(pem_data, cryptography.hazmat.backends.default_backend())
der_cert = pem_cert.public_bytes(cryptography.hazmat.primitives.serialization.Encoding.DER)
with open(out_der_file, 'wb') as f:
f.write(der_cert)
# Generate the Public and Private key pair binaries
def generate_keypair_bin(pem_file, out_privkey_bin, out_pubkey_bin):
with open(pem_file, 'rb') as f:
pem_data = f.read()
key_pem = cryptography.hazmat.primitives.serialization.load_pem_private_key(pem_data, None)
private_number_val = key_pem.private_numbers().private_value
public_number_x = key_pem.public_key().public_numbers().x
public_number_y = key_pem.public_key().public_numbers().y
public_key_first_byte = 0x04
with open(out_privkey_bin, 'wb') as f:
f.write(private_number_val.to_bytes(32, byteorder='big'))
with open(out_pubkey_bin, 'wb') as f:
f.write(public_key_first_byte.to_bytes(1, byteorder='big'))
f.write(public_number_x.to_bytes(32, byteorder='big'))
f.write(public_number_y.to_bytes(32, byteorder='big'))
def execute_cmd(cmd):
logging.debug('Executing Command: {}'.format(cmd))
status = subprocess.run(cmd, capture_output=True)
try:
status.check_returncode()
except subprocess.CalledProcessError as e:
if status.stderr:
logging.error('[stderr]: {}'.format(status.stderr.decode('utf-8').strip()))
logging.error('Command failed with error: {}'.format(e))
sys.exit(1)
def get_manualcode_args(vid, pid, flow, discriminator, passcode):
payload_args = list()
payload_args.append('--discriminator')
payload_args.append(str(discriminator))
payload_args.append('--setup-pin-code')
payload_args.append(str(passcode))
payload_args.append('--version')
payload_args.append('0')
payload_args.append('--vendor-id')
payload_args.append(str(vid))
payload_args.append('--product-id')
payload_args.append(str(pid))
payload_args.append('--commissioning-mode')
payload_args.append(str(flow))
return payload_args
def get_qrcode_args(vid, pid, flow, discriminator, passcode, disc_mode):
payload_args = get_manualcode_args(vid, pid, flow, discriminator, passcode)
payload_args.append('--rendezvous')
payload_args.append(str(1 << disc_mode))
return payload_args
def get_chip_qrcode(chip_tool, vid, pid, flow, discriminator, passcode, disc_mode):
payload_args = get_qrcode_args(vid, pid, flow, discriminator, passcode, disc_mode)
cmd_args = [chip_tool, 'payload', 'generate-qrcode']
cmd_args.extend(payload_args)
data = subprocess.check_output(cmd_args)
# Command output is as below:
# \x1b[0;32m[1655386003372] [23483:7823617] CHIP: [TOO] QR Code: MT:Y.K90-WB010E7648G00\x1b[0m
return data.decode('utf-8').split('QR Code: ')[1][:QRCODE_LEN]
def get_chip_manualcode(chip_tool, vid, pid, flow, discriminator, passcode):
payload_args = get_manualcode_args(vid, pid, flow, discriminator, passcode)
cmd_args = [chip_tool, 'payload', 'generate-manualcode']
cmd_args.extend(payload_args)
data = subprocess.check_output(cmd_args)
# Command output is as below:
# \x1b[0;32m[1655386909774] [24424:7837894] CHIP: [TOO] Manual Code: 749721123365521327689\x1b[0m\n
# OR
# \x1b[0;32m[1655386926028] [24458:7838229] CHIP: [TOO] Manual Code: 34972112338\x1b[0m\n
# Length of manual code depends on the commissioning flow:
# For standard commissioning flow it is 11 digits
# For User-intent and custom commissioning flow it is 21 digits
manual_code_len = LONG_MANUALCODE_LEN if flow else SHORT_MANUALCODE_LEN
return data.decode('utf-8').split('Manual Code: ')[1][:manual_code_len]