mirror of
https://github.com/espressif/esp-matter.git
synced 2026-04-27 19:13:13 +00:00
339 lines
13 KiB
Python
339 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright 2022 Espressif Systems (Shanghai) PTE LTD
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Contains utilitiy functions for validating argument, certs/keys conversion, etc.
|
|
"""
|
|
|
|
import sys
|
|
import enum
|
|
import logging
|
|
import subprocess
|
|
from bitarray import bitarray
|
|
from bitarray.util import ba2int
|
|
import cryptography.hazmat.backends
|
|
import cryptography.x509
|
|
|
|
|
|
ROTATING_DEVICE_ID_UNIQUE_ID_LEN_BITS = 128
|
|
SERIAL_NUMBER_LEN = 16
|
|
|
|
|
|
# Lengths for manual pairing codes and qrcode
|
|
SHORT_MANUALCODE_LEN = 11
|
|
LONG_MANUALCODE_LEN = 21
|
|
QRCODE_LEN = 22
|
|
|
|
|
|
INVALID_PASSCODES = [00000000, 11111111, 22222222, 33333333, 44444444, 55555555,
|
|
66666666, 77777777, 88888888, 99999999, 12345678, 87654321]
|
|
|
|
|
|
class CalendarTypes(enum.Enum):
|
|
Buddhist = 0
|
|
Chinese = 1
|
|
Coptic = 2
|
|
Ethiopian = 3
|
|
Gregorian = 4
|
|
Hebrew = 5
|
|
Indian = 6
|
|
Islamic = 7
|
|
Japanese = 8
|
|
Korean = 9
|
|
Persian = 10
|
|
Taiwanese = 11
|
|
|
|
|
|
def vid_pid_str(vid, pid):
|
|
return '_'.join([hex(vid)[2:], hex(pid)[2:]])
|
|
|
|
|
|
def disc_pin_str(discriminator, passcode):
|
|
return '_'.join([hex(discriminator)[2:], hex(passcode)[2:]])
|
|
|
|
|
|
# Checks if the input string is a valid hex string
|
|
def ishex(s):
|
|
try:
|
|
n = int(s, 16)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
# Validate the input string length against the min and max length
|
|
def check_str_range(s, min_len, max_len, name):
|
|
if s and ((len(s) < min_len) or (len(s) > max_len)):
|
|
logging.error('%s must be between %d and %d characters', name, min_len, max_len)
|
|
sys.exit(1)
|
|
|
|
|
|
# Validate the input integer range
|
|
def check_int_range(value, min_value, max_value, name):
|
|
if value and ((value < min_value) or (value > max_value)):
|
|
logging.error('%s is out of range, should be in range [%d, %d]', name, min_value, max_value)
|
|
sys.exit(1)
|
|
|
|
|
|
# Validates discriminator and passcode
|
|
def validate_commissionable_data(args):
|
|
check_int_range(args.discriminator, 0x0000, 0x0FFF, 'Discriminator')
|
|
if args.passcode is not None:
|
|
if ((args.passcode < 0x0000001 and args.passcode > 0x5F5E0FE) or (args.passcode in INVALID_PASSCODES)):
|
|
logging.error('Invalid passcode' + str(args.passcode))
|
|
sys.exit(1)
|
|
|
|
|
|
# Validate the device instance information
|
|
def validate_device_instance_info(args):
|
|
check_int_range(args.product_id, 0x0000, 0xFFFF, 'Product id')
|
|
check_int_range(args.vendor_id, 0x0000, 0xFFFF, 'Vendor id')
|
|
check_int_range(args.hw_ver, 0x0000, 0xFFFF, 'Hardware version')
|
|
check_str_range(args.serial_num, 1, SERIAL_NUMBER_LEN, 'Serial number')
|
|
check_str_range(args.vendor_name, 1, 32, 'Vendor name')
|
|
check_str_range(args.product_name, 1, 32, 'Product name')
|
|
check_str_range(args.hw_ver_str, 1, 64, 'Hardware version string')
|
|
check_str_range(args.mfg_date, 8, 16, 'Manufacturing date')
|
|
check_str_range(args.rd_id_uid, 32, 32, 'Rotating device Unique id')
|
|
|
|
|
|
# Validate the device information: calendar types and fixed labels
|
|
def validate_device_info(args):
|
|
# Validate the input calendar types
|
|
if args.calendar_types is not None:
|
|
if not (set(args.calendar_types) <= set(CalendarTypes.__members__)):
|
|
invalid_types = set(args.calendar_types).union(set(CalendarTypes.__members__)) - set(CalendarTypes.__members__)
|
|
logging.error('Unknown calendar type/s: %s', invalid_types)
|
|
logging.error('Supported calendar types: %s', ', '.join(CalendarTypes.__members__))
|
|
sys.exit(1)
|
|
|
|
if args.fixed_labels is not None:
|
|
for fl in args.fixed_labels:
|
|
_l = fl.split('/')
|
|
if len(_l) != 3:
|
|
logging.error('Invalid fixed label: %s', fl)
|
|
sys.exit(1)
|
|
|
|
if not (ishex(_l[0]) and (len(_l[1]) > 0 and len(_l[1]) < 16) and (len(_l[2]) > 0 and len(_l[2]) < 16)):
|
|
logging.error('Invalid fixed label: %s', fl)
|
|
sys.exit(1)
|
|
|
|
|
|
# Validates the attestation related arguments
|
|
def validate_attestation_info(args):
|
|
# DAC key and DAC cert both should be present or none
|
|
if (args.dac_key is not None) != (args.dac_cert is not None):
|
|
logging.error("dac_key and dac_cert should be both present or none")
|
|
sys.exit(1)
|
|
else:
|
|
# Make sure PAI certificate is present if DAC is present
|
|
if (args.dac_key is not None) and (args.pai is False):
|
|
logging.error('Please provide PAI certificate along with DAC certificate and DAC key')
|
|
sys.exit(1)
|
|
|
|
# Validate the input certificate type, if DAC is not present
|
|
if args.dac_key is None and args.dac_cert is None:
|
|
if args.paa:
|
|
logging.info('Input Root certificate type PAA')
|
|
elif args.pai:
|
|
logging.info('Input Root certificate type PAI')
|
|
else:
|
|
logging.info('Do not include the device attestation certificates and keys in partition binaries')
|
|
|
|
# Check if Key and certificate are present
|
|
if (args.paa or args.pai) and (args.key is None or args.cert is None):
|
|
logging.error('CA key and certificate are required to generate DAC key and certificate')
|
|
sys.exit(1)
|
|
|
|
|
|
# Validates few basic cluster related arguments: product-label and product-url
|
|
def validate_basic_cluster_info(args):
|
|
check_str_range(args.product_label, 1, 64, 'Product Label')
|
|
check_str_range(args.product_url, 1, 256, 'Product URL')
|
|
|
|
|
|
# Validates the input arguments, this calls the above functions
|
|
def validate_args(args):
|
|
# csv and mcsv both should present or none
|
|
if (args.csv is not None) != (args.mcsv is not None):
|
|
logging.error("csv and mcsv should be both present or none")
|
|
sys.exit(1)
|
|
else:
|
|
# Read the number of lines in mcsv file
|
|
if args.mcsv is not None:
|
|
with open(args.mcsv, 'r') as f:
|
|
lines = sum(1 for line in f)
|
|
|
|
# Subtract 1 for the header line
|
|
args.count = lines - 1
|
|
|
|
validate_commissionable_data(args)
|
|
validate_device_instance_info(args)
|
|
validate_device_info(args)
|
|
validate_attestation_info(args)
|
|
validate_basic_cluster_info(args)
|
|
|
|
# If discriminator/passcode/DAC/serial_number/rotating_device_id is present
|
|
# then we are restricting the number of partitions to 1
|
|
if (args.discriminator is not None
|
|
or args.passcode is not None
|
|
or args.dac_key is not None
|
|
or args.serial_num is not None
|
|
or args.rd_id_uid is not None):
|
|
if args.count > 1:
|
|
logging.error('Number of partitions should be 1 when discriminator or passcode or DAC or serial number or rotating device id is present')
|
|
sys.exit(1)
|
|
|
|
logging.info('Number of manufacturing NVS images to generate: {}'.format(args.count))
|
|
|
|
|
|
# Supported Calendar types is stored as a bit array in one uint32_t.
|
|
def calendar_types_to_uint32(calendar_types):
|
|
# In validate_device_info() we have already verified that the calendar types are valid
|
|
result = bitarray(32, endian='little')
|
|
result.setall(0)
|
|
for calendar_type in calendar_types:
|
|
result[CalendarTypes[calendar_type].value] = 1
|
|
return ba2int(result)
|
|
|
|
|
|
# get_fixed_label_dict() converts the list of strings to per endpoint dictionaries.
|
|
# example input : ['0/orientation/up', '1/orientation/down', '2/orientation/down']
|
|
# example outout : {'0': [{'orientation': 'up'}], '1': [{'orientation': 'down'}], '2': [{'orientation': 'down'}]}
|
|
def get_fixed_label_dict(fixed_labels):
|
|
fl_dict = {}
|
|
for fl in fixed_labels:
|
|
_l = fl.split('/')
|
|
|
|
if len(_l) != 3:
|
|
logging.error('Invalid fixed label: %s', fl)
|
|
sys.exit(1)
|
|
|
|
if not (ishex(_l[0]) and (len(_l[1]) > 0 and len(_l[1]) < 16) and (len(_l[2]) > 0 and len(_l[2]) < 16)):
|
|
logging.error('Invalid fixed label: %s', fl)
|
|
sys.exit(1)
|
|
|
|
if _l[0] not in fl_dict.keys():
|
|
fl_dict[_l[0]] = list()
|
|
|
|
fl_dict[_l[0]].append({_l[1]: _l[2]})
|
|
|
|
return fl_dict
|
|
|
|
|
|
# Convert the certificate in PEM format to DER format
|
|
def convert_x509_cert_from_pem_to_der(pem_file, out_der_file):
|
|
with open(pem_file, 'rb') as f:
|
|
pem_data = f.read()
|
|
|
|
pem_cert = cryptography.x509.load_pem_x509_certificate(pem_data, cryptography.hazmat.backends.default_backend())
|
|
der_cert = pem_cert.public_bytes(cryptography.hazmat.primitives.serialization.Encoding.DER)
|
|
|
|
with open(out_der_file, 'wb') as f:
|
|
f.write(der_cert)
|
|
|
|
|
|
# Generate the Public and Private key pair binaries
|
|
def generate_keypair_bin(pem_file, out_privkey_bin, out_pubkey_bin):
|
|
with open(pem_file, 'rb') as f:
|
|
pem_data = f.read()
|
|
|
|
key_pem = cryptography.hazmat.primitives.serialization.load_pem_private_key(pem_data, None)
|
|
private_number_val = key_pem.private_numbers().private_value
|
|
public_number_x = key_pem.public_key().public_numbers().x
|
|
public_number_y = key_pem.public_key().public_numbers().y
|
|
public_key_first_byte = 0x04
|
|
|
|
with open(out_privkey_bin, 'wb') as f:
|
|
f.write(private_number_val.to_bytes(32, byteorder='big'))
|
|
|
|
with open(out_pubkey_bin, 'wb') as f:
|
|
f.write(public_key_first_byte.to_bytes(1, byteorder='big'))
|
|
f.write(public_number_x.to_bytes(32, byteorder='big'))
|
|
f.write(public_number_y.to_bytes(32, byteorder='big'))
|
|
|
|
|
|
def execute_cmd(cmd):
|
|
logging.debug('Executing Command: {}'.format(cmd))
|
|
status = subprocess.run(cmd, capture_output=True)
|
|
|
|
try:
|
|
status.check_returncode()
|
|
except subprocess.CalledProcessError as e:
|
|
if status.stderr:
|
|
logging.error('[stderr]: {}'.format(status.stderr.decode('utf-8').strip()))
|
|
logging.error('Command failed with error: {}'.format(e))
|
|
sys.exit(1)
|
|
|
|
|
|
def get_manualcode_args(vid, pid, flow, discriminator, passcode):
|
|
payload_args = list()
|
|
payload_args.append('--discriminator')
|
|
payload_args.append(str(discriminator))
|
|
payload_args.append('--setup-pin-code')
|
|
payload_args.append(str(passcode))
|
|
payload_args.append('--version')
|
|
payload_args.append('0')
|
|
payload_args.append('--vendor-id')
|
|
payload_args.append(str(vid))
|
|
payload_args.append('--product-id')
|
|
payload_args.append(str(pid))
|
|
payload_args.append('--commissioning-mode')
|
|
payload_args.append(str(flow))
|
|
return payload_args
|
|
|
|
|
|
def get_qrcode_args(vid, pid, flow, discriminator, passcode, disc_mode):
|
|
payload_args = get_manualcode_args(vid, pid, flow, discriminator, passcode)
|
|
payload_args.append('--rendezvous')
|
|
payload_args.append(str(1 << disc_mode))
|
|
return payload_args
|
|
|
|
|
|
def get_chip_qrcode(chip_tool, vid, pid, flow, discriminator, passcode, disc_mode):
|
|
payload_args = get_qrcode_args(vid, pid, flow, discriminator, passcode, disc_mode)
|
|
cmd_args = [chip_tool, 'payload', 'generate-qrcode']
|
|
cmd_args.extend(payload_args)
|
|
data = subprocess.check_output(cmd_args)
|
|
|
|
# Command output is as below:
|
|
# \x1b[0;32m[1655386003372] [23483:7823617] CHIP: [TOO] QR Code: MT:Y.K90-WB010E7648G00\x1b[0m
|
|
return data.decode('utf-8').split('QR Code: ')[1][:QRCODE_LEN]
|
|
|
|
|
|
def get_chip_manualcode(chip_tool, vid, pid, flow, discriminator, passcode):
|
|
payload_args = get_manualcode_args(vid, pid, flow, discriminator, passcode)
|
|
cmd_args = [chip_tool, 'payload', 'generate-manualcode']
|
|
cmd_args.extend(payload_args)
|
|
data = subprocess.check_output(cmd_args)
|
|
|
|
# Command output is as below:
|
|
# \x1b[0;32m[1655386909774] [24424:7837894] CHIP: [TOO] Manual Code: 749721123365521327689\x1b[0m\n
|
|
# OR
|
|
# \x1b[0;32m[1655386926028] [24458:7838229] CHIP: [TOO] Manual Code: 34972112338\x1b[0m\n
|
|
# Length of manual code depends on the commissioning flow:
|
|
# For standard commissioning flow it is 11 digits
|
|
# For User-intent and custom commissioning flow it is 21 digits
|
|
manual_code_len = LONG_MANUALCODE_LEN if flow else SHORT_MANUALCODE_LEN
|
|
ret = data.decode('utf-8').split('Manual Code: ')[1][:manual_code_len]
|
|
# For 11-digits manual code, use the format 'XXXX-XXX-XXXX'
|
|
# TODO: change the format of 21-digits maunal code
|
|
if manual_code_len == SHORT_MANUALCODE_LEN:
|
|
ret = ret[:4] + '-' + ret[4:7] + '-' + ret[7:]
|
|
elif manual_code_len == LONG_MANUALCODE_LEN:
|
|
ret = '"' + ret[:4] + '-' + ret[4:7] + '-' + ret[7:11] + '\n' + ret[11:15] + '-' + ret[15:18] + '-' + ret[18:20] + '-' + ret[20:21] + '"'
|
|
return ret
|